"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import glob
import os
import re
from typing import Generator, Union

from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import MessageSink, expect
from defence360agent.utils import nice_iterator
from imav.malwarelib.config import MalwareScanType, VulnerabilityHitStatus
from imav.contracts.plugins import ProcessOrder
from imav.malwarelib.model import MalwareScan, VulnerabilityHit


class StoreVulnerabilities(MessageSink):
    PROCESSING_ORDER = ProcessOrder.AFTER_STORE_SCAN

    async def create_sink(self, loop):
        pass

    @staticmethod
    def get_outdated_entries(
        path_obj: Union[str, list],
        scan_type: str | None = None,
    ) -> Generator[str, None, None]:
        """
        Return files that may already not be vulnerable, yet we still
        consider them such.

        For example, a vulnerable file might have been removed manually.
        """
        # NOTE: this logic was taken by analogy with StoreMalwareHits
        # consider optimizing this code
        paths = [path_obj] if isinstance(path_obj, str) else path_obj
        if scan_type == MalwareScanType.REALTIME:
            # to avoid duplicates (DEF-10404)
            yield from iter(paths)
            return
        for target_path in paths:
            for path in glob.iglob(target_path):
                path = os.path.realpath(path)
                if (
                    os.path.isfile(path)
                    and VulnerabilityHit.select()
                    .where(
                        VulnerabilityHit.orig_file
                        == path
                        & VulnerabilityHit.status.in_(
                            [
                                VulnerabilityHitStatus.VULNERABLE,
                                VulnerabilityHitStatus.REVERTED,
                            ]
                        )
                    )
                    .first()
                ):
                    yield path
                else:
                    scanned_dir = re.escape(path) + r"(/.*|\b)"
                    yield from (
                        i.orig_file
                        for i in VulnerabilityHit.select().where(
                            VulnerabilityHit.orig_file.regexp(scanned_dir),
                            VulnerabilityHit.status.in_(
                                [
                                    VulnerabilityHitStatus.VULNERABLE,
                                    VulnerabilityHitStatus.REVERTED,
                                ]
                            ),
                        )
                    )

    def _delete_outdated_entries(self, summary: dict) -> None:
        file_patterns = summary.pop("file_patterns", None)
        exclude_patterns = summary.pop("exclude_patterns", None)
        if (
            summary.get("error") is None
            and file_patterns is None
            and exclude_patterns is None
        ):
            outdated_entries = self.get_outdated_entries(
                summary["path"], scan_type=summary["type"]
            )
            VulnerabilityHit.delete_hits(outdated_entries)

    @expect(MessageType.MalwareScan)
    async def process_hits(self, message):
        if not message["summary"].get("started") or message["results"] is None:
            # Scan is queued/aborted.
            return

        scan = MalwareScan.get(scanid=message["summary"]["scanid"])
        # get('path') indicates that this is the second message,
        # even if they are out of order
        if message["summary"].get("path") is not None:
            # keep the same logic as for malware hits
            self._delete_outdated_entries(message["summary"])
        if results := {
            filename: data
            for filename, data in message["results"].items()
            if VulnerabilityHit.match(data["hits"][0]["matches"])
        }:
            # TODO: handle possible races when we implement patch/revert
            async for filename, data in nice_iterator(results.items()):
                # the latest detection type is relevant
                VulnerabilityHit.create(
                    scanid=scan.scanid,
                    owner=data["owner"],
                    user=data["user"],
                    size=data["size"],
                    hash=data["hash"],
                    orig_file=filename,
                    type=data["hits"][0]["matches"],
                    timestamp=data["hits"][0]["timestamp"],
                    status=VulnerabilityHitStatus.VULNERABLE,
                )
