"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import asyncio
import logging
import time
from concurrent.futures.thread import ThreadPoolExecutor
from functools import wraps
from tempfile import NamedTemporaryFile
from typing import Any, Dict
from uuid import uuid4

from defence360agent.contracts.hook_events import HookEvent
from defence360agent.subsys.panels import hosting_panel
from defence360agent.utils import encode_filename
from imav.contracts.config import Malware
from imav.malwarelib.config import MalwareScanType
from imav.malwarelib.utils.user_list import fill_results_owner

logger = logging.getLogger(__name__)

_executor = ThreadPoolExecutor(max_workers=1)

SCAN_OPTIONS = [
    "intensity_cpu",
    "intensity_io",
    "intensity_ram",
    "detect_elf",
    "use_filters",
]

DIRECT_SCAN_OPTIONS = [
    *SCAN_OPTIONS,
    "follow_symlinks",
    "exclude_patterns",
    "file_patterns",
]

SCAN_HOOK_PARAMS = [
    "file_patterns",
    "follow_symlinks",
    "exclude_patterns",
    "intensity_cpu",
    "intensity_io",
    "intensity_ram",
]


class ScanResult:
    def __init__(self, path, scan_id, scan_type):
        self.scans = []
        self.total_files = 0
        self.error = None
        self.errors = []

        self._begin_time = self._end_time = None
        self._aggregated_results = {}

        self._path = path
        self._scan_id = scan_id
        self._scan_type = scan_type
        self.args = None

    def is_detached(self):
        return self._scan_type in (
            MalwareScanType.BACKGROUND,
            MalwareScanType.ON_DEMAND,
            MalwareScanType.USER,
        )

    def set_start_stop(self, begin_time=None, end_time=None):
        if begin_time:
            self._begin_time = begin_time
        if end_time:
            self._end_time = end_time

    def to_dict_initial(self):
        result = {
            "summary": {
                "scanid": self._scan_id,
                "type": self._scan_type,
                "path": self._path,
                "started": self._begin_time,
                "completed": self._end_time,
                "total_files": self.total_files,
                "error": self.error,
                "errors": self.errors,
            },
            # We need to provide empty (not null) result to 'complete_scan'
            # in case of scan without files
            "results": None if self.is_detached() else {},
        }
        # Do not include args if they are not applicable
        if self.args:
            result["summary"]["args"] = self.args
        return result

    def to_dict(self):
        as_dict = self.to_dict_initial()
        as_dict["results"] = self.scans
        return as_dict

    def _aggregate_result(self):
        self.scans = aggregate_result(list(*self.scans))
        return self

    async def get(self):
        self._aggregate_result()
        await fill_results_owner(self.scans)
        return self


def aggregate_result(scans):
    aggregated_results: Dict[str, Any] = {}

    for record in scans:
        matches = {
            "matches": record["signature"],
            "suspicious": record["suspicious"],
            "extended_suspicious": record.get("extended_suspicious", False),
            "timestamp": record["timestamp"],
        }
        if record.get("ignore"):
            logger.info(
                "File match for %s will be ignored: %s",
                record["file_name"],
                matches,
            )
            continue

        row = aggregated_results.setdefault(
            record["file_name"],
            {
                "hits": [],
                "size": record["size"],
                "hash": record["hash"],
                "ctime": record.get("ctime", 0),
                "modification_time": record.get("modification_time", 0),
            },
        )

        if record.get("curable"):
            matches["curable"] = True

        # The first - non suspicious matches, suspicious - the second
        if record["suspicious"]:
            row["hits"].append(matches)
        else:
            row["hits"].insert(0, matches)

    return aggregated_results


def event_hook(sink):
    def _extract_scan_hook_params(kwargs):
        return {opt: kwargs[opt] for opt in kwargs if opt in SCAN_HOOK_PARAMS}

    def wrap(f):
        @wraps(f)
        async def wrapper(
            path, scan_id=None, scan_type=None, started=None, **kwargs
        ):
            scan_id = scan_id or uuid4().hex
            started = started or time.time()
            scan_params = _extract_scan_hook_params(kwargs)

            if scan_type:
                scan_started_event = HookEvent.MalwareScanningStarted(
                    scan_id=scan_id,
                    scan_type=scan_type,
                    path=path,
                    started=started,
                    scan_params=scan_params,
                )
                await sink.process_message(scan_started_event)

            _started = time.time()
            try:
                scan_result = await f(
                    path, scan_id=scan_id, scan_type=scan_type, **kwargs
                )
            except asyncio.CancelledError:
                raise
            except Exception as e:
                logger.exception("Scan wrapper task failed")
                scan_result = {
                    "summary": {
                        "scanid": scan_id,
                        "type": scan_type,
                        "path": path,
                        "total_files": 0,
                        "started": _started,
                        "completed": time.time(),
                        "error": repr(e),
                    },
                    "results": {},
                }

            return scan_result

        return wrapper

    return wrap


class DirectAiBolit:
    def __init__(self, *_, **__):
        pass

    @staticmethod
    async def _add_db_dir(home_dir, scan_options):
        if Malware.RAPID_SCAN:
            d = hosting_panel.HostingPanel().get_rapid_scan_db_dir(home_dir)
            scan_options["db_dir"] = d

    @staticmethod
    def _extract_scan_options(kwargs):
        return {
            opt: kwargs[opt] for opt in kwargs if opt in DIRECT_SCAN_OPTIONS
        }

    @staticmethod
    def _update_scan_options(kwargs):
        if (
            "exclude_patterns" in kwargs
            and kwargs["exclude_patterns"] is not None
        ):
            kwargs["exclude_patterns"] = ",".join(kwargs["exclude_patterns"])
        if "file_patterns" in kwargs and kwargs["file_patterns"] is not None:
            kwargs["file_patterns"] = ",".join(kwargs["file_patterns"])
        return kwargs

    def __call__(self, f):
        @wraps(f)
        async def wrapper(
            path, scan_id=None, scan_type=None, begin_time=None, **kwargs
        ):
            scan_options = self._update_scan_options(
                self._extract_scan_options(kwargs)
            )
            if scan_type in (MalwareScanType.USER, MalwareScanType.BACKGROUND):
                await self._add_db_dir(path, scan_options)
            scan_result = ScanResult(path, scan_id, scan_type)
            scan_result.set_start_stop(begin_time=begin_time)

            scan_result.scans, scan_result.error = await f(
                None,
                scan_type=scan_type,
                scan_id=scan_id,
                scan_path=path,
                **scan_options,
            )
            scan_result.scans = list(*scan_result.scans)
            return scan_result

        return wrapper


class PrepareFileList:
    def __init__(self, tmpdir):
        self._tmpdir = tmpdir

    async def prepare_file(self, fname, files, **kwargs) -> int:
        total_files = self._write_list_to_file(fname, files)
        return total_files

    @staticmethod
    def _write_list_to_file(fname, files):
        with open(fname, "wb") as f:
            total_files = 0
            for file in files:
                total_files += 1
                f.write(encode_filename(file))
            return total_files

    @staticmethod
    def _extract_scan_options(kwargs):
        return {opt: kwargs[opt] for opt in kwargs if opt in SCAN_OPTIONS}

    def __call__(self, f):
        @wraps(f)
        async def wrapper(
            path, scan_id=None, scan_type=None, begin_time=None, **kwargs
        ):
            scan_options = self._extract_scan_options(kwargs)
            scan_result = ScanResult(path, scan_id, scan_type)
            scan_result.set_start_stop(begin_time=begin_time)

            with NamedTemporaryFile(dir=self._tmpdir) as tf:
                total_files = await self.prepare_file(tf.name, path, **kwargs)
                scan_result.scans, scan_result.error = await f(
                    tf, scan_type=scan_type, scan_id=scan_id, **scan_options
                )
            scan_result.total_files = total_files
            scan_result.scans = list(*scan_result.scans)
            return scan_result

        return wrapper
