"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import asyncio
import logging
import time
import queue
import uuid
from collections import defaultdict
from collections.abc import Hashable
from contextlib import suppress

from defence360agent.api import inactivity
from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import (
    MessageSink,
    MessageSource,
    expect,
)
from defence360agent.utils import batched, nice_iterator, recurring_check
from imav.malwarelib.config import VulnerabilityHitStatus
from imav.malwarelib.model import VulnerabilityHit
from imav.malwarelib.vulnerabilities.patcher import (
    PatchResult,
    VulnerabilityPatcher,
)
from imav.malwarelib.vulnerabilities.storage import PatchStorage

logger = logging.getLogger(__name__)


class PatchQueue:
    def __init__(self):
        self._queue = defaultdict(set)

    def put(self, key: Hashable, values: set):
        self._queue[key] |= values

    def get(self) -> tuple[Hashable, set]:
        try:
            return self._queue.popitem()
        except KeyError as exc:
            raise queue.Empty() from exc

    def empty(self) -> bool:
        return not bool(self._queue)


class Patch(MessageSink, MessageSource):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._queue = PatchQueue()
        self._loop = None
        self._sink = None
        self._patcher = None
        self._patch_task = None

    async def create_sink(self, loop):
        pass

    async def create_source(self, loop, sink):
        self._loop = loop
        self._sink = sink
        self._patcher = VulnerabilityPatcher(loop=loop, sink=sink)
        self._patch_task = loop.create_task(self.recurring_patch())

    async def shutdown(self):
        if self._patch_task:
            self._patch_task.cancel()
            with suppress(asyncio.CancelledError):
                await self._patch_task

    @expect(MessageType.VulnerabilityPatchTask)
    async def process_patch_task(self, message: dict):
        source = (
            message.get("cause"),
            message.get("initiator"),
            message.get("manual", False),
        )
        files_to_patch = message.get("filelist", [])
        self._queue.put(source, set(files_to_patch))

    async def _patch_vulnerabilities(self):
        while not self._queue.empty():
            [cause, initiator, manual], files_to_patch = self._queue.get()
            for files_batch in batched(files_to_patch, n=10_000):
                with inactivity.track.task("patch_vulnerabilities"):
                    vulnerable_statuses = [VulnerabilityHitStatus.VULNERABLE]
                    if manual:
                        vulnerable_statuses.append(
                            VulnerabilityHitStatus.REVERTED
                        )
                    hits = VulnerabilityHit.select().where(
                        VulnerabilityHit.orig_file.in_(files_batch),
                        VulnerabilityHit.status.in_(vulnerable_statuses),
                    )
                    (
                        succeeded,
                        failed,
                        not_exist,
                    ) = await PatchStorage.store_all(hits)
                    if failed:
                        for hit in failed:
                            await self._sink.process_message(
                                MessageType.VulnerabilityPatchFailed(
                                    message=(
                                        "Failed to store the original from {}"
                                        " to {}".format(
                                            hit.orig_file, PatchStorage.path
                                        )
                                    ),
                                    timestamp=int(time.time()),
                                )
                            )
                    if not_exist:
                        VulnerabilityHit.delete_hits(
                            [hit.orig_file for hit in not_exist]
                        )
                    user_hits = VulnerabilityHit.group_by_attribute(
                        succeeded,
                        attribute="owner",
                    )
                    for user, hits in user_hits.items():
                        started = time.time()
                        files = [hit.orig_file for hit in hits]
                        # update status to avoid any races
                        VulnerabilityHit.set_status(
                            hits, VulnerabilityHitStatus.PATCH_IN_PROGRESS
                        )
                        result, error, cmd = await self._patcher.start(
                            user, files
                        )
                        await self._sink.process_message(
                            MessageType.VulnerabilityPatch(
                                hits=hits,
                                result=result,
                                cleanup_id=uuid.uuid4().hex,
                                started=started,
                                error=error,
                                cause=cause,
                                initiator=initiator,
                                args=cmd,
                            )
                        )

    @recurring_check(1)
    async def recurring_patch(self):
        if not self._queue.empty():
            await self._patch_vulnerabilities()


class PatchResultProcessor(MessageSink):
    async def create_sink(self, loop):
        pass

    @staticmethod
    def _set_hit_status(
        hits: list[VulnerabilityHit], status: str, patched_at=None
    ):
        VulnerabilityHit.set_status(hits, status, patched_at)
        for hit in hits:
            hit.status = status
            hit.patched_at = patched_at

    @expect(MessageType.VulnerabilityPatch)
    async def process_patch_result(self, message: dict):
        hits: list[VulnerabilityHit] = message["hits"]
        result: PatchResult = message["result"]
        now = time.time()

        processed = [hit for hit in hits if hit in result]
        unprocessed = [hit for hit in hits if hit not in result]
        not_exist = []
        async for hit in nice_iterator(processed, chunk_size=100):
            # in case if procu2.php tries to clean/patch user file in root dirs,
            # it will be marked as non-existent due to 'Permission denied'
            # error which confuses users, consider it as unable to cleanup/patch
            if result[hit].not_exist():
                if hit.orig_file_path.exists():
                    unprocessed.append(hit)
                else:
                    not_exist.append(hit)
        if not_exist:
            VulnerabilityHit.delete_hits([hit.orig_file for hit in not_exist])

        patched, failed = [], []
        for hit in processed:
            # treat as failed unless success is explicitly stated
            if result[hit].is_patched():
                patched.append(hit)
            else:
                failed.append(hit)
        self._set_hit_status(patched, VulnerabilityHitStatus.PATCHED, now)
        if unable_to_path := unprocessed + failed:
            self._set_hit_status(
                unable_to_path, VulnerabilityHitStatus.VULNERABLE
            )
