"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import os
import sqlite3
from abc import ABC
from typing import NamedTuple
from enum import Enum


class HashState(Enum):
    ACTIVE = 0
    BLACKLISTED = 1
    SUPERSEDED = 2


class DefinitionType(Enum):
    MALWARE = 1
    VULNERABILITY = 2
    APPLICATION = 3
    DRYRUN = 4
    MALWARE_RULE = 7
    MALWARE_RULE_DRYRUN = 8
    VULNERABILITY_ECOMMERCE = 9
    VULNERABILITY_PLUGIN = 10


class Table(ABC):
    table_name: str
    fields: tuple
    create_table_query: str

    def __init__(self, conn: sqlite3.Connection, buffer_size: int = 1000):
        self.conn = conn
        self.buffer_size = buffer_size
        self.buffer: list[tuple] = []
        self.create_table()

    def buffered_insert(self, row: tuple):
        """
        Insert with buffer
        """
        assert len(row) == len(self.fields)
        self.buffer.append(row)
        if len(self.buffer) >= self.buffer_size:
            self.flush()

    def flush(self):
        """
        Flush buffer
        """
        fields = ", ".join(self.fields)
        with self.conn:
            self.conn.executemany(
                f"INSERT INTO {self.table_name} ({fields}) VALUES"
                f" ({', '.join(['?'] * len(self.fields))})",
                self.buffer,
            )
        self.buffer = []

    def create_table(self):
        with self.conn:
            self.conn.execute(self.create_table_query)


class VersionMatch(NamedTuple):
    id: int
    path: str
    hash: str


class VersionMatchTable(Table):
    table_name = "versions_matches"
    fields = VersionMatch._fields
    create_table_query = f"""
        CREATE TABLE {table_name} (
        id INTEGER,
        path TEXT,
        hash TEXT,
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        )
    """


class PatchDependencyMatch(NamedTuple):
    path: str
    hash: str
    vuln_id: int
    vuln_type: DefinitionType
    dependencies_met: bool = False


class PatchDependencyTable(Table):
    table_name = "patch_dependencies"
    fields = PatchDependencyMatch._fields
    create_table_query = f"""
        CREATE TABLE {table_name} (
        path TEXT,
        hash TEXT,
        vuln_id INTEGER,
        vuln_type INTEGER,
        dependencies_met BOOLEAN,
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        )
    """


class HashMatch(NamedTuple):
    path: str
    hash: str
    vuln_type: DefinitionType
    vuln_id: int
    state: HashState


class HashMatchTable(Table):
    table_name = "hashes_matches"
    fields = HashMatch._fields
    create_table_query = f"""
        CREATE TABLE {table_name} (
        path TEXT,
        hash TEXT,
        vuln_type INTEGER,
        vuln_id INTEGER,
        state INTEGER,
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
        )
    """


class DB:
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.flush()
        self.close()

    def __init__(self, db_name: str, buffer_size: int = 1000):
        if os.path.exists(db_name):
            os.remove(db_name)

        self.conn = sqlite3.connect(db_name)

        self.versions_matches = VersionMatchTable(
            self.conn, buffer_size=buffer_size
        )
        self.patch_dependencies = PatchDependencyTable(
            self.conn, buffer_size=buffer_size
        )
        self.hashes_matches = HashMatchTable(
            self.conn, buffer_size=buffer_size
        )
        self._tables = (
            self.versions_matches,
            self.patch_dependencies,
            self.hashes_matches,
        )

    def flush(self):
        """
        Flush tables buffers
        """
        [_.flush() for _ in self._tables]

    def close(self):
        self.conn.close()
