# coding=utf-8
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2019 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENSE.TXT

import json
import logging
import math
import os
from collections import defaultdict
from datetime import datetime  # NOQA
from typing import Dict, List, Optional, Tuple, Union  # NOQA

import sqlalchemy  # NOQA

import lvestats.lib.commons.decorators
from clcommon import cpapi
from clcommon.cpapi.pluginlib import getuser
from clcommon.lib import GovernorStatus, MySQLGovernor, MySQLGovException
from lvestats.lib import config
from lvestats.lib.commons import dateutil, func
from lvestats.lib.commons.func import deserialize_lve_id, get_users_for_reseller
from lvestats.lib.commons.logsetup import setup_logging
from lvestats.lib.commons.users_manager import g_users_manager
from lvestats.lib.dbengine import make_db_engine
from lvestats.lib.info.lveinfomain import main_
from lvestats.lib.lve_list import LVEList
from lvestats.lib.lveinfolib import HistoryShowUnion
from lvestats.lib.lveinfolib_gov import HistoryShowDBGov
from lvestats.lib.parsers.cloudlinux_statistics_parser import setup_statistics_parser
from lvestats.lib.uidconverter import (
    username_to_uid_local,
    reseller_to_uid,
    uid_to_username,
)

NOT_AVAILABLE = 'N/A'


class CloudLinuxStatistics(object):
    def __init__(
            self,
            engine,  # type: sqlalchemy.engine.base.Engine
            show_columns,  # type: List[str]
            order_by='',  # type: str
            server_id=None,  # type: Optional[str]
            by_usage=None,  # type: Optional[str]
            by_usage_percentage=0.9,  # type: float
            by_fault=None,  # type: Optional[str]
            threshold=1,  # type: int
            limit=0,  # type: int
            uid=None,  # type: Optional[Union[int, List[int], Tuple[int]]]
            time_unit=60  # type: int
    ):
        """
        Initializes statistics object used by SPA UI. Note, that if we set order_by to mysql_*, and set
        by_usage or by_fault, end result will will not be ordered by mysql_ column, using uid for ordering instead
        """
        self.order_by_lve = None
        self.order_by_gov = None
        self.server_id = server_id
        self.engine = engine
        self.limit_gov = 0
        self.limit_lve = 0
        self.uid = uid
        self.time_unit = time_unit
        self.is_many_uids = self.uid is None or isinstance(self.uid, (list, tuple))
        self._lve_results = []
        self._gov_results = []
        self._resellers_results = []

        if 'all' in show_columns:
            show_columns = ['cpu', 'ep', 'vmem', 'pmem', 'nproc', 'io', 'iops', 'nproc', 'mysql']
        self.show_columns = [s.lower() for s in show_columns]
        self.by_fault = self.normalize_by_fault(by_fault)
        self.order_by = order_by
        self.init_order_by(order_by, self.by_fault)
        self._show_columns_lve = []
        self._show_columns_gov = []
        self._gov_key_index = None
        self._lve_key_index = None
        self.init_show_columns()
        if self.order_by_gov:
            self.limit_gov = limit
        else:
            self.limit_lve = limit
        self.limit = limit
        self.by_usage_lve = None
        self.by_usage_gov = None
        self.init_by_usage(by_usage)
        self.by_usage_percentage = by_usage_percentage
        self.threshold = threshold

        self.mysql_governor = MySQLGovernor()
        self.log = logging.getLogger(self.__class__.__name__)
        self.show_mysql = self._show_mysql()
        self._admins_uids = self._get_admin_uids() if getuser() == 'root' else set()

    @staticmethod
    def _get_admin_uids():
        try:
            return set(username_to_uid_local(admin) for admin in cpapi.admins()) - {None}
        except (cpapi.NotSupported, AttributeError):
            return set()

    def _show_mysql(self):
        # type: () -> bool
        return self.mysql_governor.is_governor_present() and 'mysql' in self.show_columns

    def _index_of(self, list_, key):
        # type: (List, Union[str,int]) -> Optional[int]
        if key in list_:
            return list_.index(key)

    def init_show_columns(self):
        show_columns_lve = []
        show_columns_gov = []
        show_columns_gov_singleuser = ['from', 'to']
        show_columns_lve_singleuser = ['id', 'from', 'to']
        columns = []
        for column in self.show_columns:
            if column != 'mysql':
                columns += ['a' + column, 'l' + column, column + 'f']
        if len(columns):
            show_columns_lve = ['id'] + columns
            show_columns_lve_singleuser += columns

        if 'mysql' in self.show_columns:
            show_columns_gov = ['id', 'cpu', 'lcpu', 'read', 'lread', 'write', 'lwrite']
            show_columns_gov_singleuser = show_columns_gov + ['from', 'to']

        self._show_columns_lve = show_columns_lve if self.is_many_uids else show_columns_lve_singleuser
        self._show_columns_gov = show_columns_gov if self.is_many_uids else show_columns_gov_singleuser
        if self.is_many_uids:
            self._gov_key_index = self._index_of(self._show_columns_gov, 'id')
            self._lve_key_index = self._index_of(self._show_columns_lve, 'id')
        else:
            self._gov_key_index = self._index_of(self._show_columns_gov, 'from')
            self._lve_key_index = self._index_of(self._show_columns_lve, 'from')

    def init_order_by(self, order_by, by_fault):
        # type: (str) -> None
        if order_by:
            order_by = order_by.lower()
            if order_by.endswith('faults'):
                self.order_by_lve = order_by.replace('_faults', '') + 'f'
            elif order_by.startswith('mysql'):
                self.order_by_gov = order_by[len('mysql_'):]
            else:
                self.order_by_lve = 'a' + order_by
        elif by_fault and by_fault[0] == 'any':
            self.order_by_lve = 'cpuf'
        else:
            self.order_by_lve = by_fault and by_fault[0]

    def get_username(self, uid):
        # type: (int) -> Tuple[str]
        username = uid_to_username(uid, self.server_id, self.server_id, self.engine)
        return username or NOT_AVAILABLE

    def _get_default_user_usage(self):
        """
        Return default info about user usage (mysql)
        :rtype: dict
        """
        usage_field = defaultdict(dict)
        for column in self.show_columns:
            if column != 'mysql':
                usage_field[column]['lve'] = 0.0
        if self.show_mysql:
            usage_field['cpu']['mysql'] = 0.0
            usage_field['io']['mysql'] = 0.0
        return usage_field

    def _get_default_user_limits(self, username):
        """
        Return default info about user limits (mysql)
        :rtype: dict
        """
        limit_field = defaultdict(dict)
        for column in self.show_columns:
            if column != 'mysql':
                limit_field[column]['lve'] = 0.0
        if self.show_mysql:
            cpu_limit, io_limit = self.get_user_limits_safe(username)
            limit_field['cpu']['mysql'] = cpu_limit
            limit_field['io']['mysql'] = io_limit
        return limit_field

    def _get_default_user_faults(self):
        fault_field = {}
        for column in self.show_columns:
            if column != 'mysql':
                fault_field[column] = {'lve': 0}
        return fault_field

    def get_item(self, ikey):
        if self.is_many_uids:
            _, is_reseller = deserialize_lve_id(ikey)
            username = self.get_username(ikey)
        else:
            _, is_reseller = deserialize_lve_id(self.uid)
            username = self.get_username(self.uid)
        item = {
            'usage': self._get_default_user_usage(),
            'limits': self._get_default_user_limits(username),
        }
        if self.show_columns != ['mysql']:
            item['faults'] = self._get_default_user_faults()
        if self.is_many_uids:
            if is_reseller:
                item.update({'id': ikey, 'name': username})
            else:
                item.update({
                    'id': ikey,
                    'username': username,
                    'domain': g_users_manager.get_domain(username, raise_exc=False) or NOT_AVAILABLE,
                    'reseller': g_users_manager.get_reseller(username, raise_exc=False) or NOT_AVAILABLE,
                })
        else:
            item.update({
                'from': int(ikey),
                'to': int(ikey) + self.time_unit})
        return item

    def fill_lve_item(self, item, row):
        if row is None:
            return item
        count = 1 if self.is_many_uids else 1 + 2
        for column in self.show_columns:
            if column == 'mysql':
                continue
            item['usage'][column]['lve'] = row[count]
            item['limits'][column]['lve'] = row[count + 1]
            item['faults'][column]['lve'] = row[count + 2]
            count += 3
        return item

    @staticmethod
    def __get_io_bytes(read, write):
        """
        Get io value by read and write values.
        :param int|float|None read: read speed in MEGAbytes
        :param int|float|None write: write speed in MEGAbytes
        :return int: io, BYTES
        """
        # convert None to zero
        read = read or 0
        write = write or 0

        # calculate io in bytes
        io = (read + write) * 1024 * 1024
        return int(io)

    def fill_gov_item(self, item, row):
        if row is None:
            return item
        item['usage']['cpu']['mysql'] = round(float(row[1] or 0), 1)
        item['limits']['cpu']['mysql'] = round(float(row[2] or 0), 1)
        item['usage']['io']['mysql'] = self.__get_io_bytes(row[3], row[5])
        item['limits']['io']['mysql'] = self.__get_io_bytes(row[4], row[6])
        return item

    def process_data(self, lve_results, gov_results, period_from, period_to):
        from_gov_to_lve_indexes, from_lve_to_gov_indexes,\
            gov_keys, gov_results,\
            lve_keys, lve_results,\
            not_in_gov_indexes, not_in_lve_indexes = self.process_indexes(gov_results, lve_results)
        if not self.show_mysql:
            lve_results = list(lve_results)
        elif self.order_by_gov:
            lve_results = self.sort_by_another_list(from_gov_to_lve_indexes, lve_results, not_in_gov_indexes)
        elif self.order_by_lve:
            gov_results = self.sort_by_another_list(from_lve_to_gov_indexes, gov_results, not_in_lve_indexes)
        else:
            gov_results, lve_results = self.sort_by_key(gov_keys, gov_results, lve_keys, lve_results)
        lendiff = (len(lve_results) - len(gov_results))
        if lendiff != 0:
            from_formatted = period_from.strftime('%Y-%m-%d %H:%M')
            to_formatted = period_to.strftime('%Y-%m-%d %H:%M')
            self.log.warning(
                'len(lve_results) != len(gov_results), it might cause duplicated timestamps '
                'or mix of lve and governor stats from different periods in aggregated data.\n'
                'Utility was called with the following parameters:\n'
                'cloudlinux-statistics --json --time-unit %ss --id %s '
                '--from %s --to %s '
                '--show %s --server_id %s\n'
                'engine=%s\norder_by=%s\nby_fault=%s\n',
                self.time_unit, self.uid, from_formatted, to_formatted,
                " ".join(self.show_columns), self.server_id,
                self.engine, self.order_by, self.by_fault
            )
        self._lve_results = lve_results
        self._gov_results = gov_results

    @staticmethod
    def sort_by_key(gov_keys, gov_results, lve_keys, lve_results):
        # type: (List[int], List, List[int], List) -> Tuple[List, List]
        lve_pos = 0
        gov_pos = 0
        lve_len = len(lve_results)
        gov_len = len(gov_results)
        _lve_results = []
        _gov_results = []
        # maximum # of iterations will be in case when
        # lve_keys like [1,2,7] and gov_keys like [4,5,6]
        for _ in range(lve_len + gov_len):
            if lve_pos == lve_len:  # became out of range
                _lve_results.extend([None] * (gov_len - gov_pos))
                # if also gov_pos == gov_len slice returns [] without error
                _gov_results.extend(gov_results[gov_pos:])
                break
            if gov_pos == gov_len:
                _gov_results.extend([None] * (lve_len - lve_pos))
                _lve_results.extend(lve_results[lve_pos:])
                break
            if lve_keys[lve_pos] == gov_keys[gov_pos]:
                _lve_results.append(lve_results[lve_pos])
                _gov_results.append(gov_results[gov_pos])
                lve_pos += 1
                gov_pos += 1
            elif lve_keys[lve_pos] > gov_keys[gov_pos]:
                _lve_results.append(None)
                _gov_results.append(gov_results[gov_pos])
                gov_pos += 1
            else:
                _gov_results.append(None)
                _lve_results.append(lve_results[lve_pos])
                lve_pos += 1
        return _gov_results, _lve_results

    @staticmethod
    def sort_by_another_list(indexes, results, not_in_indexes):
        # type: (List[int], List, List[int]) -> List
        _results = []
        for index in indexes:
            if index is None:
                _results.append(None)
            else:
                _results.append(results[index])
        for index in not_in_indexes:
            _results.append(results[index])
        results = _results
        return results

    def process_indexes(self, gov_results, lve_results):
        # type: (List, List) -> Tuple[List[int],List[int],List,List,List,List,List,List]
        if self.by_usage_gov:
            gov_uids = [_[0] for _ in gov_results]
            lve_results = [row for row in lve_results if row[0] in gov_uids]
        if self.by_usage_lve or self.by_fault:
            lve_uids = [_[0] for _ in lve_results]
            gov_results = [row for row in gov_results if row[0] in lve_uids]
        gov_keys = [_[self._gov_key_index] for _ in gov_results]
        lve_keys = [_[self._lve_key_index] for _ in lve_results]
        from_gov_to_lve_indexes = self.get_indexes(gov_keys, lve_keys)
        not_in_lve_indexes = [i for i, _ in enumerate(from_gov_to_lve_indexes) if _ is None]
        from_lve_to_gov_indexes = self.get_indexes(lve_keys, gov_keys)
        not_in_gov_indexes = [i for i, _ in enumerate(from_lve_to_gov_indexes) if _ is None]
        return from_gov_to_lve_indexes, from_lve_to_gov_indexes,\
            gov_keys, gov_results,\
            lve_keys, lve_results,\
            not_in_gov_indexes, not_in_lve_indexes

    def get_indexes(self, first_uids, second_uids):
        # type: (List[int], List[int]) -> List[int]
        return [self._index_of(second_uids, uid) for uid in first_uids]

    def get_value(self, key, resellers):
        # type: (int, bool) -> Dict[str, Union[int,float]]
        """
        :type resellers: bool
        :type key: int
        :rtype: dict
        """
        if resellers:
            item = self.get_item(self._resellers_results[key][self._lve_key_index])
            return self.fill_lve_item(item, self._resellers_results[key])
        elif not self.limit or key < self.limit:
            if not self.show_mysql:
                item = self.get_item(self._lve_results[key][self._lve_key_index])
                return self.fill_lve_item(item, self._lve_results[key])
            elif self.order_by_gov or self.limit_gov:
                row = self._gov_results[key]
                if row is not None:
                    item = self.get_item(row[self._gov_key_index])
                    self.fill_gov_item(item, row)
                else:
                    item = self.get_item(self._lve_results[key][self._lve_key_index])
                return self.fill_lve_item(item, self._lve_results[key])
            else:
                row = self._lve_results[key]
                if row is not None:
                    item = self.get_item(row[self._lve_key_index])
                    self.fill_lve_item(item, row)
                else:
                    item = self.get_item(self._gov_results[key][self._gov_key_index])
                return self.fill_gov_item(item, self._gov_results[key])
        else:
            raise IndexError

    def get_gov_row(self, key, indexes, result):
        if key > len(indexes) + 1 or indexes[key] is None:
            return None
        return result[indexes[key]]

    def get_user_limits_safe(self, username, ignore_watched=False):
        # type: (str, bool) -> Tuple[int,int]
        """
        Get cpu and io limit for user
        :return tuple: cpu limit in %, io limit in bytes/s
        """
        cpu_limit, io_limit = 0, 0
        if self.mysql_governor.is_governor_present():
            try:
                if ignore_watched or self.mysql_governor.get_governor_status_by_username(username) == 'watched':
                    cpu_limit, io_limit = self.mysql_governor.get_limits_by_user(username)
            except MySQLGovException as e:
                self.log.info('Cannot get mysql limits for user %s. Error raised: %s', username,
                              e.message % e.context)
        return cpu_limit, io_limit * 1024

    @staticmethod
    def normalize_by_fault(by_fault):
        # type: (Optional[str]) -> Optional[Tuple[str]]
        if by_fault:
            if by_fault == 'mem':
                by_fault = 'vmem'
            return by_fault + 'f',  # it is a tuple, so we need coma at the end

    def init_by_usage(self, by_usage):
        # type: (Optional[str]) -> None
        if by_usage:
            if by_usage.startswith("mysql"):
                self.by_usage_gov = by_usage[len('mysql_'):]
            else:
                self.by_usage_lve = 'a' + by_usage,  # it is a tuple, so we need coma at the end

    def generate_statistics(self, period_from, period_to):
        # type: (datetime, datetime) -> None
        """Generate statistics."""
        period_from_gm = dateutil.local_to_gm(period_from)
        period_to_gm = dateutil.local_to_gm(period_to)

        if self._show_columns_lve:
            history_show = HistoryShowUnion(
                dbengine=self.engine,
                period_from=period_from_gm, period_to=period_to_gm,
                uid=self.uid,
                time_unit=self.time_unit,
                show_columns=self._show_columns_lve,
                server_id=self.server_id,
                order_by=self.order_by_lve,
                by_usage=self.by_usage_lve,
                by_usage_percentage=self.by_usage_percentage,
                by_fault=self.by_fault,
                threshold=self.threshold,
                limit=self.limit_lve)
            history_show.set_normalised_output()
            lve_results = history_show.proceed()
        else:
            lve_results = []

        if self.show_mysql:
            gov_results = HistoryShowDBGov(
                dbengine=self.engine,
                period_from=period_from_gm,
                period_to=period_to_gm,
                time_unit=self.time_unit,
                server_id=self.server_id,
                show_columns=self._show_columns_gov,
                order_by=self.order_by_gov,
                by_usage=self.by_usage_gov,
                by_usage_percentage=self.by_usage_percentage,
                limit=self.limit_gov,
                reverse=True,
                uid=self.uid,
            ).history_dbgov_show()
        else:
            gov_results = []
        if self.is_many_uids:
            users = []
            for item in lve_results:
                uid, is_reseller = deserialize_lve_id(item[self._lve_key_index])
                if uid in self._admins_uids:
                    continue
                if is_reseller:
                    self._resellers_results.append(item)
                else:
                    users.append(item)
            lve_results = users
        self.process_data(lve_results, gov_results, period_from, period_to)


def _get_uid_for_select(for_reseller, uid):
    # type: (str, int) -> Union[List[int], int]
    """
    Get uid that should be selected from history table;
    """

    if getuser() != 'root':
        return os.getuid()

    g_users_manager.build_users_cache(for_reseller)
    if for_reseller:
        users_list = get_users_for_reseller(for_reseller or getuser())
        id_ = list(filter(bool, list(map(username_to_uid_local, users_list))))
        reseller_uid = reseller_to_uid(for_reseller or getuser())
        # LVESTATS-97
        # If uid not in users_list, we return empty list
        if uid:
            return uid if uid in id_ or uid == reseller_uid else []
        # select also information about reseller's containers (if exists)
        if reseller_uid is not None:
            id_.append(reseller_uid)
    else:
        id_ = uid
    return id_


def execute(engine, options, log=None):
    # type: (sqlalchemy.engine.base.Engine, object, Optional[logging.Logger]) -> Tuple[str, int]
    """
    Get command response and exitcode.
    :type log: None | logging.Logger
    :type engine: sqlalchemy.engine.base.Engine
    :rtype: (str, int)
    """
    if not options.json:
        return "Only JSON mode supported for now", -1

    id_ = _get_uid_for_select(options.for_reseller, options.id)
    statistics = get_statistics(engine, id_, log, options)
    json_data = dump_json(statistics)

    return json_data, 0


def dump_json(statistics):
    # type: (Dict) -> str
    json_data = json.dumps(statistics)
    return json_data


def get_statistics(
        engine,  # type: sqlalchemy.engine.base.Engine
        id_,  # type: Optional[Union[int, List[int], Tuple[int]]]
        log,  # type: Optional[logging.Logger]
        options,
        timestamp=dateutil.gm_datetime_to_unixtimestamp()  # type: int
):
    # type: (...) -> Dict[str, Union[LVEList, str, int]]
    result = {
        'result': 'success',
        'timestamp': timestamp
    }
    # LVESTATS-97
    # If both params (for_reseller and id) are used
    # and id_ list is empty, we return `Permission denied` message
    if options.for_reseller is not None and options.id is not None and id_ == []:
        error_msg = (
            'Permission denied. User with id '
            f'{options.id} does not belong reseller `{options.for_reseller}`'
        )
        result['result'] = error_msg
        return result
    try:
        stat = CloudLinuxStatistics(
            engine,
            show_columns=options.show,
            order_by=options.order_by,
            server_id=options.server_id,
            by_usage=options.by_usage,
            by_usage_percentage=options.percentage / 100.,
            by_fault=options.by_fault,
            threshold=options.threshold,
            limit=options.limit,
            uid=id_,
            time_unit=options.time_unit)
        stat.generate_statistics(getattr(options, 'from'), options.to)

        lvelist = LVEList(stat)
        lvelist_resellers = LVEList(stat, resellers=True)
        if stat.is_many_uids:
            result['users'] = lvelist
            result['resellers'] = lvelist_resellers
        else:
            result['user'] = lvelist
    except Exception as e:
        result["result"] = str(e)
        if log:
            log.error(str(e))

    # add mysql status info
    status, error = func.get_governor_status()
    result['mySqlGov'] = status
    if status == GovernorStatus.ERROR:
        result['warning'] = error.message
        result['context'] = error.context
    return result


@lvestats.lib.commons.decorators.no_sigpipe
def main(engine, argv_, server_id='localhost', log=None):
    # type: (sqlalchemy.engine.base.Engine, List[str], str, Optional[logging.Logger]) -> int
    """
    Execute command and return exitcode
    """
    if log is None:
        log = setup_logging(
            {},
            caller_name="CloudLinuxStatistics",
            file_level=logging.WARNING,
            console_level=logging.FATAL,
        )
    options = setup_statistics_parser(argv_, server_id)

    json_str, exit_code = execute(engine, options, log=log)
    print(json_str)
    return exit_code


def get_users_and_resellers_with_faults(period='1d'):
    """
    Return amount of users and resellers with faults
    """
    cnf = config.read_config()
    dbengine = make_db_engine(cnf)
    server_id = cnf.get('server_id', 'localhost')
    opt_list = ['--by-fault', 'any', '--json', '--period=' + period]
    options = setup_statistics_parser(opt_list, server_id)
    json_str, _ = execute(dbengine, options, log=None)
    json_data = json.loads(json_str)
    resellers = json_data['resellers']
    users = json_data['users']
    return len(users), len(resellers)


def get_max_memory_for_lve(lve_id, period_seconds):
    data = main_(
        config.read_config(),
        argv_=['--json', '--id', str(lve_id), '--show-columns',
               'mPMem', '--period', str(int(math.ceil(period_seconds / 60))) + 'm'],
    )
    if data:
        parsed_data = json.loads(data)
    try:
        max_pmem = max(x['mPMem'] for x in parsed_data['data'])
    except (ValueError, KeyError):
        max_pmem = None
    return max_pmem
