"""
This program is free software: you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License,
or (at your option) any later version.


This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 
See the GNU General Public License for more details.


You should have received a copy of the GNU General Public License
 along with this program.  If not, see <https://www.gnu.org/licenses/>.

Copyright © 2019 Cloud Linux Software Inc.

This software is also available under ImunifyAV commercial license,
see <https://www.imunify360.com/legal/eula>


Generic Sensor plugin

- Creates listening unix domain socket on config.GenericSensor.SOCKET_PATH
- Expects alert data formatted as

"""

import asyncio
import base64
import json
import os
import time
from logging import getLogger

from imav.contracts.messages import MSGS_WITHOUT_IP
from defence360agent.contracts.config import (
    SimpleRpc,
    GENERIC_SENSOR_SOCKET_PATH,
)
from defence360agent import files
from defence360agent.contracts.messages import MessageType
from defence360agent.contracts.plugins import Sensor
from defence360agent.internals.global_scope import g
from defence360agent.internals.logger import getNetworkLogger
from defence360agent.simple_rpc import RpcServerAV
from defence360agent.utils import Scope
from defence360agent.utils.buffer import LineBuffer

logger, network_logger = getLogger(__name__), getNetworkLogger(__name__)


class Protocol(asyncio.Protocol):
    METHOD2MSGTYPE = {
        "NOOP": MessageType.Noop,
        "MALWARE_SCAN": MessageType.MalwareScan,
        "MALWARE_SCAN_TASK": MessageType.MalwareScanTask,
        "MALWARE_SCAN_COMPLETE": MessageType.MalwareScanComplete,
        "MALWARE_CLEAN_COMPLETE": MessageType.MalwareCleanComplete,
        "MALWARE_RESTORE_COMPLETE": MessageType.MalwareRestoreComplete,
        "MALWARE_CHECK_DETACHED_SCANS": MessageType.CheckDetachedScans,
        "MALWARE_SEND_FILES": MessageType.MalwareSendFiles,
    }

    def __init__(self, loop, sink, *_):
        self._loop = loop
        self._sink = sink
        self._line_buffer = LineBuffer()
        self.transport = None

    def connection_made(self, transport):
        self.transport = transport
        network_logger.debug("Connection made")

    def data_received(self, data):
        msgs = data.decode()
        if not msgs.strip():
            logger.error("Empty message received <%s>", msgs)
            return

        self._line_buffer.append(msgs)
        for msg in self._line_buffer:
            if msg:
                network_logger.debug("data_received: {!r}".format(msg))
                tokens = self._parse_msg(msg)
                if tokens:
                    tokens["timestamp"] = time.time()
                    self._process_msg(tokens)

    def _parse_msg(self, msg):
        try:
            return json.loads(msg)
        except json.JSONDecodeError:
            logger.exception("data_received(%s): unable to decode", repr(msg))

    def _process_msg(self, tokens):
        # map 'method' to appropriate Message type
        try:
            method = tokens["method"]
            msgtype = self.METHOD2MSGTYPE[method]
        except KeyError as e:
            logger.error(
                "data_received(%s): Wrong or missing 'method' [%s]",
                repr(tokens),
                repr(e),
            )
            return

        if method == "MALWARE_SCAN_TASK":
            tokens["filelist"] = [
                os.fsdecode(base64.b64decode(f)) for f in tokens["filelist"]
            ]
        elif method == "SYNCLIST":  # added to avoid using pyrasite
            logger.info("Received test SynclistResponse")
        elif method in [
            "IP_LISTS_UPDATE",
            "BLOCKED_PORT_UPDATE",
            "BLOCKED_PORT_IP_UPDATE",
            "HEALTH",
        ]:
            tokens["transport"] = self.transport
            logger.info(f"Received {method}")
        elif method in [
            "WHITELIST_CACHE_UPDATE",
            "IPSET_UPDATE",
            "UPDATE_RULES",
            "UPDATE_CUSTOM_LISTS",
        ]:
            logger.info(f"Received {method}")
        elif method == "FILES_UPDATE":
            try:
                index = files.Index(tokens["files_type"])
            except files.IntegrityError as error:
                logger.error(
                    "Error during processing %s: %s", method, str(error)
                )
            else:
                self._loop.create_task(
                    self._sink.process_message(
                        msgtype(tokens["files_type"], index)
                    )
                )
            return
        elif not tokens.get("attackers_ip") and method not in MSGS_WITHOUT_IP:
            logger.error(
                "Method type is %s but empty or no <attackers_ip> "
                "in message <%s>",
                tokens.get("method"),
                tokens,
            )
            return
        self._loop.create_task(self._sink.process_message(msgtype(tokens)))

    def connection_lost(self, transport):
        self.transport = None
        network_logger.debug("Disconnected")


class GenericSensor(Sensor):
    SOCKET_PATH = GENERIC_SENSOR_SOCKET_PATH
    PROTOCOL_CLASS = Protocol
    SCOPE = Scope.AV

    async def create_sensor(self, loop, sink):
        if SimpleRpc.SOCKET_ACTIVATION:

            class GenericSensorSocket(RpcServerAV):
                SOCKET_PATH = self.SOCKET_PATH
                PROTOCOL_CLASS = self.PROTOCOL_CLASS

            g.sensor_server = await GenericSensorSocket.create(loop, sink)
            return g.sensor_server
        else:
            # FIXME make sure root can write to
            os.makedirs(os.path.dirname(self.SOCKET_PATH), exist_ok=True)
            if os.path.exists(self.SOCKET_PATH):
                os.unlink(self.SOCKET_PATH)

            g.sensor_server = await loop.create_unix_server(
                lambda: self.PROTOCOL_CLASS(loop, sink), self.SOCKET_PATH
            )
            return g.sensor_server
