"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>
"""
import hashlib
import logging
import time
from contextlib import suppress
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import List, Optional

from defence360agent.contracts.config import Malware as Config
from defence360agent.contracts.config import UserType
from defence360agent.internals.the_sink import TheSink
from defence360agent.utils import antivirus_mode, safe_fileops
from imav.malwarelib.cleanup.types import CleanupRevertInitiator
from imav.malwarelib.config import (
    MalwareHitStatus,
    MalwareScanResourceType,
)
from imav.malwarelib.model import MalwareHit
from imav.malwarelib.scan.crontab import is_crontab
from imav.malwarelib.scan.mds import restore as mds_restore
from imav.malwarelib.subsys.malware import MalwareAction
from imav.malwarelib.utils import hash_path

logger = logging.getLogger(__name__)


@dataclass
class RestoreReport:
    file: str
    scan_id: str
    owner: str
    initiator: str = UserType.ROOT
    cleaned_at: float = -1
    reverted_at: float = -1
    hash_before_revert: str = ""
    hash_after_revert: str = ""
    mtime_before_revert: float = -1
    mtime_after_revert: float = -1
    size_before_revert: float = -1
    size_after_revert: float = -1

    to_dict = asdict


class CleanupStorage:
    """
    Store files before cleanup and restore them by request
    """

    path: Path = Path(Config.CLEANUP_STORAGE)

    @staticmethod
    async def _copy(src: Path, dst: Path, safe_src=False, safe_dst=False):
        await safe_fileops.safe_move(
            str(src),
            str(dst),
            src_unlink=False,
            dst_overwrite=True,
            safe_src=safe_src,
            safe_dst=safe_dst,
        )

    @classmethod
    def storage_name(cls, filename: str) -> str:
        """
        Get file name for cleanup storage
        :return: file name
        """
        return hash_path(filename)

    @classmethod
    def get_hit_store_path(cls, hit):
        return cls.path / cls.storage_name(hit.orig_file)

    @classmethod
    async def store(cls, hit):
        src = hit.orig_file_path
        dst = cls.get_hit_store_path(hit)
        safe_src = is_crontab(src)
        await cls._copy(src, dst, safe_src=safe_src, safe_dst=True)

    @classmethod
    async def store_all(cls, hits):
        if not cls.path.exists():
            cls.path.mkdir(0o700)

        succeeded, not_exist, failed = set(), set(), set()

        for hit in hits:
            try:
                await cls.store(hit)
                succeeded.add(hit)
            except FileNotFoundError:
                not_exist.add(hit)
            except (OSError, safe_fileops.UnsafeFileOperation) as e:
                logger.warning(
                    "Failed to store file before cleanup: %r -- %s",
                    str(hit),
                    e,
                )
                failed.add(hit)
                await MalwareAction.cleanup_failed_store(
                    path=hit.orig_file,
                    file_owner=hit.owner,
                    file_user=hit.user,
                    signature_id=hit.signature_id,
                )

        return succeeded, failed, not_exist

    @classmethod
    async def restore(cls, hit: MalwareHit) -> RestoreReport:
        report = RestoreReport(hit.orig_file, hit.scanid_id, hit.user)

        src = cls.get_hit_store_path(hit)
        dst = hit.orig_file_path
        safe_dst = is_crontab(dst)

        with suppress(FileNotFoundError):
            report.cleaned_at = src.stat().st_mtime

        with suppress(FileNotFoundError):
            st_before = dst.stat()
            report.mtime_before_revert = st_before.st_mtime
            report.size_before_revert = st_before.st_size
            hash_before = hashlib.sha256(dst.read_bytes()).hexdigest()
            report.hash_before_revert = hash_before

        await cls._copy(src, dst, safe_src=True, safe_dst=safe_dst)

        report.reverted_at = time.time()

        with suppress(FileNotFoundError):
            st_after = dst.stat()
            report.mtime_after_revert = st_after.st_mtime
            report.size_after_revert = st_after.st_size
            hash_after = hashlib.sha256(dst.read_bytes()).hexdigest()
            report.hash_after_revert = hash_after

        return report

    @classmethod
    async def restore_all(
        cls, hits: List[MalwareHit], initiator: Optional[str] = None
    ):
        succeeded, failed = set(), set()

        for hit in hits:
            try:
                report = await cls.restore(hit)
                await MalwareAction.cleanup_restored_original(
                    path=hit.orig_file,
                    file_owner=hit.owner,
                    file_user=hit.user,
                    signature_id=hit.signature_id,
                    initiator=initiator or "",
                    report=report,
                )
                succeeded.add(hit)
            except (OSError, safe_fileops.UnsafeFileOperation) as e:
                await MalwareAction.cleanup_failed_restore(
                    path=hit.orig_file,
                    file_owner=hit.owner,
                    file_user=hit.user,
                    signature_id=hit.signature_id,
                )
                logger.warning("Failed to restore file: %r -- %s", str(hit), e)
                failed.add(hit)

        return succeeded, failed

    @classmethod
    async def _clear(cls, path: Path, keep: float) -> bool:
        st = path.stat()
        if st.st_mtime < keep:
            path.unlink()
            return True
        return False

    @classmethod
    async def clear(cls, keep: float) -> int:
        """
        Clear storage
        :param keep: keep files after specified timestamp
        :return:
        """
        cls.path.mkdir(0o700, exist_ok=True)
        cleared = 0
        for path in cls.path.iterdir():
            if await cls._clear(path, keep):
                cleared += 1

        return cleared


async def restore_hits(
    hits: list[MalwareHit],
    sink: TheSink,
    initiator: CleanupRevertInitiator | None = None,
):
    file_hits = [
        hit
        for hit in hits
        if hit.resource_type == MalwareScanResourceType.FILE.value
    ]
    succeeded, failed = await CleanupStorage.restore_all(file_hits, initiator)
    MalwareHit.set_status(succeeded, MalwareHitStatus.FOUND)

    if antivirus_mode.disabled:
        await mds_restore.restore_hits(hits, sink, initiator)

    # FIXME: we cannot include db hits here
    return succeeded, failed
