# coding=utf-8
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2019 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENSE.TXT

from collections import defaultdict
from lvestats.core.plugin import LveStatsPlugin
from lvestats.plugins.generic.analyzers import LVEUsage, FAULTS


class AggregatedLveUsage(LVEUsage):
    __slots__ = ()

    def __init__(self, lve_version=8):
        LVEUsage.__init__(self, lve_version)
        self.time = 0.0

    def add(self, lveusage):
        self.time += lveusage.time or 0

        if lveusage.time:
            self.cpu_usage += lveusage.cpu_usage * lveusage.time
            self.mem_usage += lveusage.mem_usage * lveusage.time
            self.mep += lveusage.mep * lveusage.time
            self.io_usage += lveusage.io_usage * lveusage.time
            self.iops += lveusage.iops * lveusage.time
            self.memphy += lveusage.memphy * lveusage.time
            self.nproc += lveusage.nproc * lveusage.time

        self.cpu_fault += lveusage.cpu_fault
        self.mem_fault += lveusage.mem_fault
        self.mep_fault += lveusage.mep_fault
        self.io_fault += lveusage.io_fault
        # Commented due to LVES-159. See aggregate function below
        # self.iops_fault += lveusage.iops_fault
        self.memphy_fault += lveusage.memphy_fault
        self.nproc_fault += lveusage.nproc_fault

        self.lmem = max(lveusage.lmem, self.lmem)
        self.lcpu = max(lveusage.lcpu, self.lcpu)
        self.lep = max(lveusage.lep, self.lep)
        self.io = max(lveusage.io, self.io)
        self.lmemphy = max(lveusage.lmemphy, self.lmemphy)
        self.lnproc = max(lveusage.lnproc, self.lnproc)
        self.liops = max(lveusage.liops, self.liops)

        self.has_changed_limits = self.has_changed_limits or lveusage.has_changed_limits
        self.has_changed_nproc = self.has_changed_nproc or lveusage.has_changed_nproc

    def aggregate(self):
        if self.time > 0:
            self.cpu_usage /= self.time
            self.mem_usage /= self.time
            self.mep /= self.time
            self.io_usage /= self.time
            self.iops /= self.time
            self.memphy /= self.time
            self.nproc /= self.time
            # LVES-159: change the way iops_faults are calculated
            # if aIOPS >= lIOPS, then iops_faults = 1
            if self.iops >= self.liops > 0:
                self.iops_fault = 1
            else:
                self.iops_fault = 0


class LveUsageAggregator5S(LveStatsPlugin):
    def __init__(self):
        self.get_data_from = 'lve_usages_5s'
        self.set_data_to = 'lve_usage_5s'
        self.snapshots_enabled = True

    def set_config(self, config):
        """
        :param dict config:
        """
        self.snapshots_enabled = config.get('disable_snapshots', "false").lower() != "true"

    def dict_sum_values(self, dict1, dict2):
        """
        :param dict dict1:
        :param dict dict2:
        """
        result = {}
        items = list(dict1.items()) + list(dict2.items())
        for attr_name, count in items:
            result[attr_name] = result.get(attr_name, 0) + count
        return result

    def aggregate_faults(self, lve_data, lve_usage):
        """
        :type lve_data: dict
        :type lve_usage: dict[int, AggregatedLveUsage]
        """
        if self.snapshots_enabled:
            old_faults = lve_data.get('faults', {})
            for lve_id, usage in lve_usage.items():
                for attr_name in FAULTS:
                    fault = getattr(usage, attr_name)
                    if fault > 0:
                        if lve_id not in old_faults:
                            old_faults[lve_id] = {}
                        old_faults[lve_id][attr_name] = old_faults[lve_id].get(attr_name, 0) + fault
            lve_data['faults'] = old_faults

    def execute(self, lve_data):
        """
        :param dict lve_data:
        """
        lve_usages = lve_data[self.get_data_from]
        aggregated = defaultdict(AggregatedLveUsage)

        for iteration_data in lve_usages:
            for lve_id, lve_usage in iteration_data.items():
                aggregated[lve_id].add(lve_usage)

        for aggregated_lve_usage in aggregated.values():
            aggregated_lve_usage.aggregate()

        result = {}
        for lve_id, usage in aggregated.items():
            if lve_id == 0:
                result[lve_id] = usage
                continue

            if usage.has_interesting_values() \
                    or usage.has_changed_limits \
                    or usage.has_changed_nproc:
                result[lve_id] = usage

        lve_data[self.set_data_to] = result
        self.aggregate_faults(lve_data, result)
        # FIXME(vlebedev): Dirty hack to get raw usage measurement in plugins after this one.
        # lve_data[self.get_data_from] = []


class LveUsageAggregator(LveUsageAggregator5S):
    def __init__(self):  # pylint: disable=super-init-not-called
        self.period = 60
        self.get_data_from = 'lve_usages'
        self.set_data_to = 'lve_usage'

    def set_config(self, config):
        """
        :param dict config:
        """
        self.period = config.get('aggregation_period', self.period)

    def aggregate_faults(self, lve_data, lve_usage):
        pass

    # FIXME(vlebedev): Dirty hack to get raw usage measurement in plugins after this one.
    def execute(self, lve_data):
        super().execute(lve_data)
        lve_data[self.get_data_from] = []

