# coding=utf-8
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2019 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENSE.TXT

import argparse
import logging
import datetime
import sys
import os
from bisect import bisect_left
from collections import defaultdict
from typing import Iterable, List

from lxml import etree
from sqlalchemy.exc import SQLAlchemyError

import lvestats
import lvestats.lib.commons.decorators
from clcommon.cpapi.pluginlib import getuser
from lvestats.lib.cloudlinux_statistics import _get_uid_for_select
from lvestats.lib import config, dbengine, lveinfolib, uidconverter
from lvestats.lib.chart.svggraph import SvgChart
from lvestats.lib.chart.rdp import ramerdouglas
from lvestats.lib.commons import dateutil
from lvestats.lib.commons.argparse_utils import period_type2, ParseDatetime
from lvestats.lib.commons.logsetup import setup_logging
from lvestats.lib.chart.polysimplify import VWSimplifier
from lvestats.lib.chart.svg2png import svg_to_png

__author__ = 'shaman'

DEFAULT_SERVER_ID = 'localhost'

format_index = 1


class Renderer(object):
    """
    Renders data to file
    """

    def __init__(self, max_points_on_graph: int = 800, fault_color: str = 'aquamarine'):
        self.log = logging.getLogger('Chart Renderer')
        self.svg_chart = SvgChart()
        self._max_points_on_graph = max_points_on_graph
        self._fault_color = fault_color

    @staticmethod
    def _nop(arg):
        return arg

    @staticmethod
    def _ts_to_str(ts):
        # Example: [0]="5 Feb", [1]="Aug 15, 2:35pm"
        formats = ['%b-%d', '%b-%d %I:%M%p']
        gm = dateutil.unixtimestamp_to_gm_datetime(ts)
        lo = dateutil.gm_to_local(gm)
        return lo.strftime(formats[format_index])

    def set_period_sec(self, period_sec):
        global format_index
        # 8 days = 8 * 24 * 3600 = 691200 sec
        if period_sec >= 691200:
            format_index = 0

    def _optimisation_ramerdouglas(self, points, epsilon_max=2.0, epsilon_min=0):
        epsilon_default = epsilon_max
        points_output = points
        points_len = 0
        for i in range(10):
            points_optimised = ramerdouglas(points, epsilon_default)
            points_len_previous = points_len
            points_len = len(points_optimised)
            if points_len == points_len_previous:
                # return result if changing epsilon not change points number
                return points_output
            if points_len >= self._max_points_on_graph:
                if i == 0:
                    # return result if in first iteration points more than max
                    return points_optimised
                epsilon_min = epsilon_default
            else:
                points_output = points_optimised
                epsilon_max = epsilon_default
            # correcting epsilon
            epsilon_default = (epsilon_max + epsilon_min) / 2.
        return points_output

    def _optimisation_polysimplify(self, points: Iterable[Iterable[float]]) -> List[List[float]]:
        simplifier = VWSimplifier(points)
        result = simplifier.from_number(self._max_points_on_graph)
        return result

    def optimise_points(self, line: Iterable[Iterable[float]]) -> List[List[float]]:
        if len(line) > self._max_points_on_graph:
            line = self._optimisation_ramerdouglas(line, epsilon_max=1.4)
            line = self._optimisation_polysimplify(line)
        return line

    def get_two_closest(self, line, x_value):
        """
        :type x_value: float
        :type line: list
        """
        pos = bisect_left(line, x_value)
        before = line[pos - 1]
        after = line[pos]
        return pos - 1, before, pos, after

    def add_graph(
            self,
            data,
            title,
            legend,
            x_values,
            min_y=None,
            max_y=None,
            x_labels=None,
            generate_n_xlabels=None,
            y_labels=None,
            y_legend_converter=lambda v: v,
            unit=None,
            message=None,
            faults=None):
        """
        :type message: None|str
        :type unit: None|str
        :type y_legend_converter: (float) -> float
        :type max_y: None|float
        :type min_y: None|float
        :type title: str
        :type legend: dict
        :type x_values: list
        :type faults: (str, str, str)
        :type data: defaultdict
        """
        colors, datasets, fault_lines, names = self._add_graph(data, faults, legend, x_values)

        self.svg_chart.add_graph(
            datasets,
            colors,
            title=title,
            minimum_y=min_y,
            maximum_y=max_y,
            x_legend=x_labels,
            x_legend_generate=generate_n_xlabels,
            y_legend=y_labels,
            y_legend_converter=y_legend_converter,
            x_legend_converter=self._ts_to_str,
            names=names,
            unit=unit,
            message=message,
            fault_lines=fault_lines,
            fault_color=self._fault_color)

    def _add_graph(self, data, faults, legend, x_values):
        """
        :type legend: dict
        :type x_values: list
        :type faults: (str, str, str)
        :type data: defaultdict
        :rtype (list, list, list, list)
        """
        datasets = []
        names = []
        colors = []
        datasets_dictionary = {}
        for (key, metainfo) in legend.items():
            try:
                legend_title, color, modifier = metainfo
            except ValueError:
                legend_title, color = metainfo
                modifier = self._nop
            y_values = [modifier(y) for y in data[key.lower()]]
            # x_values = range(0, len(y_values))
            line = list(zip(x_values, y_values))
            line_optimised = self.optimise_points(line)
            datasets.append(line_optimised)
            datasets_dictionary[key.lower()] = line_optimised
            names.append(legend_title)
            colors.append(color)
        fault_lines = self.get_faults_lines(data, datasets_dictionary, faults, x_values)

        # add legend
        if fault_lines:
            names.append('faults')
            colors.append(self._fault_color)
        del datasets_dictionary
        return colors, datasets, fault_lines, names

    def get_faults_lines(self, data, datasets_dictionary, faults, x_values):
        """
        :type x_values: list
        :type faults: (str, str, str)
        :type datasets_dictionary: dict
        :type data: defaultdict
        :rtype list
        """
        fault_lines = []
        if faults is not None:
            fault_name, data_name, limit_name = [x.lower() for x in faults]
            if limit_name in datasets_dictionary and data_name in datasets_dictionary:
                faults_x = [x for x, y in zip(x_values, data[fault_name]) if y > 0]
                average_times, average = list(zip(*datasets_dictionary[data_name]))
                limit_times, limit = list(zip(*datasets_dictionary[limit_name]))
                for fault in faults_x:
                    try:
                        average_dot = self.get_dot(fault, average, average_times)
                        limit_dot = self.get_dot(fault, limit, limit_times)
                        if average_dot[1] < limit_dot[1] and (average_dot, limit_dot) not in fault_lines:
                            fault_lines.append((average_dot, limit_dot))
                    except IndexError:
                        self.log.error("Can't get fault line: %s", str(fault))
        return fault_lines

    def get_dot(self, fault_time, line, line_times):
        """
        :type line_times: list
        :type line: list
        :type fault_time: float
        :rtype (float, float)
        """
        if fault_time >= line_times[-1]:
            return fault_time, line[-1]
        if fault_time <= line_times[0]:
            return fault_time, line[0]
        try:
            dot = (fault_time, line[line_times.index(fault_time)])
        except ValueError:
            before_index, before, after_index, after = self.get_two_closest(line_times, fault_time)
            dot_y = (
                (fault_time - before)
                / (after - before)
                * (line[after_index] - line[before_index])
                + line[before_index]
            )
            dot = (fault_time, dot_y)
        return dot

    def add_common_x_legend(self, x_values, n):
        self.svg_chart._add_x_legend(  # pylint: disable=protected-access
            x_values=x_values,
            number=n,
            x_legend_converter=self._ts_to_str,
        )

    def add_text_box(self, text, font_size=None):
        """
        add empty rectangle with text
        """
        self.svg_chart.add_text_box(text, font_size)

    def render(self):
        return self.svg_chart.dump()


class UserNotFound(Exception):
    def __init__(self, user_name, *args, **kwargs):
        super().__init__(f'User {user_name} not found', *args, **kwargs)


class ChartMain(object):
    def __init__(self, prog_name, prog_desc, cnf):
        self.prog_name = prog_name
        self.prog_desc = prog_desc
        self.log = setup_logging(cnf, prog_name, console_level=logging.ERROR)
        self.cfg = cnf
        # analog of previous parameter but get from config file
        self.is_normalized_user_cpu = config.is_normalized_user_cpu()

    @staticmethod
    def convert_dbdata_to_dict(data, show_columns):
        data_collected = defaultdict(list)
        for row in data:
            row_dict = dict(zip(show_columns, row))
            for (k, v) in row_dict.items():
                data_collected[k].append(float(v or 0))
        return data_collected

    @staticmethod
    def convert_lvedata_to_dict(data):
        # key -- field name, like 'cpu', 'id', 'created', etc.. value -- list of values
        by_key = defaultdict(list)
        for (k, values) in data.items():
            by_key[k] = [float(v or 0) for v in values]
        return by_key

    def make_parser(self):
        current_server_id = self.cfg.get('server_id', DEFAULT_SERVER_ID)
        datetime_now = datetime.datetime.now()
        parser = argparse.ArgumentParser(prog=self.prog_name,
                                         add_help=True,
                                         description=self.prog_desc)

        parser.add_argument('--version', version=lvestats.__version__,
                            help='Version number',
                            dest='version',
                            action='version')
        parser.add_argument('--period',
                            help='Time period\n'
                                 'specify minutes with m,  h - hours, days with d, and values: today, yesterday\n'
                                 '5m - last 5 minutes, 4h - last four hours, 2d - last 2 days, as well as today',
                            type=lambda value: period_type2(value, datetime_now),
                            default=None
                            )

        parser.add_argument('--from',
                            help='Run report from date and time in YYYY-MM-DD HH:MM format\n'
                                 'if not present last 10 minutes are assumed',
                            action=ParseDatetime,
                            nargs='+',
                            dest='ffrom')
        parser.add_argument('--to',
                            help='Run report up to date and time in YYYY-MM-DD HH:MM format\n'
                                 'if not present, reports results up to now',
                            action=ParseDatetime,
                            nargs='+',
                            dest='to')

        parser.add_argument('--server',
                            help='With username or LVE id show only record for that user at given server',
                            dest='server',
                            default=current_server_id)
        parser.add_argument('--output',
                            help='Filename to save chart as, if not present, output will be sent to STDOUT',
                            default='-', dest='output')
        parser.add_argument('--show-all',
                            help='Show all graphs (by default shows graphs for which limits are set)',
                            dest='show_all',
                            action='store_true',
                            default=False)
        parser.add_argument('--style',
                            help='Set graph style',
                            choices=['admin', 'user'],
                            dest='style',
                            default=None)
        parser.add_argument('--format',
                            help='Set graph output format',
                            choices=['svg', 'png'],
                            dest='format',
                            default='png')
        parser.add_argument('--for-reseller',
                            help='Show graph for specific reseller and user (used only with --id key)',
                            type=str,
                            dest='for_reseller',
                            default=None)

        # add opt --dpi, --width, --height for backward compatible with old lvechart
        parser.add_argument('--dpi', help=argparse.SUPPRESS)
        parser.add_argument('--width', help=argparse.SUPPRESS)
        parser.add_argument('--height', help=argparse.SUPPRESS)

        return parser

    def customize_parser(self, parser):
        """
        :type parser: argparse.ArgumentParser
        :rtype : argparse.ArgumentParser
        """
        raise NotImplementedError()

    def get_chart_data(self,
                       engine,
                       from_ts,
                       to_ts,
                       server,
                       user_id,
                       show_all=False):
        """

        Extracts data from database , in form of a tuple:
        (dict, list)
        where dict is

        {
            'cpu': [0, 5, 10 , 95.2, ...] -- values
            'io': [0, 5, 10 , 95.2, ...] -- values
            'foo': ,
        }

        and a list of timestamps for all that collected values:
        2) [1, 2 ,3 ...]



        :param show_all:
        :rtype : tuple
        :type engine: sqlalchemy.engine
        :type from_ts: datetime.datetime
        :type to_ts: datetime.datetime
        :type server: str
        :type user_id: int
        """
        raise NotImplementedError()

    def add_graphs(self, renderer, data_collected, times, lve_version, show_all, is_user=False):
        """
        :rtype : None
        :type renderer: Renderer
        :type data_collected: dict
        :type times: list
        :type lve_version: str
        :type show_all: bool
        :type is_user: bool. True for user, False - admin
        """
        raise NotImplementedError()

    @lvestats.lib.commons.decorators.no_sigpipe
    def main(self, args, debug_engine=None):
        parser = self.make_parser()
        parser = self.customize_parser(parser)
        if debug_engine is not None:
            engine = debug_engine
        else:
            try:
                engine = dbengine.make_db_engine(self.cfg)
            except dbengine.MakeDbException as e:
                sys.stderr.write(str(e) + '\n')
                return 1

        opts = parser.parse_args(list(args))

        if opts.period and any((opts.ffrom, opts.to)):
            print("--period and [--from, --to] are mutually exclusive")
            return 0

        if opts.period:
            f, t = opts.period
        else:
            datetime_now = datetime.datetime.now()
            f = opts.ffrom or datetime_now - datetime.timedelta(minutes=10)
            t = opts.to or datetime_now

        user_id = self._obtain_user_id(engine, opts) if self._has_user_arguments(opts) else os.getuid()
        if getuser() != 'root':
            if user_id and user_id != os.getuid():
                sys.stderr.write('Permission denied\n')
                return 1

        # LVESTATS-97
        # If both params (for_reseller and id or username) are used
        # and user_id not in resellers_users, we return `Permission denied` message
        if opts.for_reseller:
            reseller_name = opts.for_reseller
            reseller_users = _get_uid_for_select(reseller_name, user_id)
            # if reseller_users is just int, than everything is ok
            if isinstance(reseller_users, int):
                pass
            # if not user's id is in reseller's ids
            # neither user is reseller and user's id is reseller's id itself
            elif user_id not in reseller_users:
                error_msg = (
                    f'Permission denied. User with id {user_id} does not belong reseller `{reseller_name}`\n'
                )
                sys.stderr.write(error_msg)
                return 1
        try:
            data_collected, times, period_sec = self.get_chart_data(engine,
                                                                    f,
                                                                    t,
                                                                    opts.server,
                                                                    user_id=user_id)
        except (UserNotFound, SQLAlchemyError) as ex:
            sys.stderr.write(str(ex) + '\n')
            return 1

        rendered_graph = self._render(data_collected, engine, opts, period_sec, times)
        self._output(opts, rendered_graph)
        return 0

    def _render(self, data_collected, engine, opts, period_sec, times):
        show_all = opts.show_all
        try:
            style = opts.style
        except AttributeError:
            style = None
        if style is None:
            style = 'admin'
        renderer = Renderer()
        renderer.set_period_sec(period_sec)
        lve_version = lveinfolib.get_lve_version(engine, opts.server)
        self.add_graphs(renderer, data_collected, times, lve_version, show_all,
                        is_user=style == 'user')
        rendered_graph = renderer.render()
        return rendered_graph

    def _output(self, opts, rendered_graph):
        if opts.format == 'svg' and opts.output == '-':
            # output as SVG
            try:
                root_node = etree.fromstring(rendered_graph.encode('utf8'))
                rendered_graph = etree.tostring(root_node, pretty_print=True)
            except (ImportError, ValueError, UnicodeEncodeError) as e:
                self.log.debug('Can not use pretty print for svg xml formatting; %s', str(e))
        if opts.format == 'png':
            # output as PNG
            rendered_graph = svg_to_png(rendered_graph)
        if opts.output == '-':
            if isinstance(rendered_graph, bytes):
                sys.stdout.buffer.flush()  # needed to write directly to buffer
                sys.stdout.buffer.write(rendered_graph)
            else:
                sys.stdout.write(rendered_graph)
        else:
            try:
                if isinstance(rendered_graph, bytes):
                    write_mode = 'wb'
                else:
                    write_mode = 'w'
                with open(opts.output, write_mode) as output:
                    output.write(rendered_graph)
            except IOError:
                self.log.error("Unable to create file: %s", opts.output)

    def _obtain_user_id(self, engine, opts):
        # obtain user id
        try:
            user_id = opts.user_id
        except AttributeError:
            user_id = None
        try:
            user_name = opts.user_name
        except AttributeError:
            user_name = None
        if user_id is None:
            user_id = uidconverter.username_to_uid(
                username=user_name, local_server_id=self.cfg['server_id'],
                server_id=opts.server, db_engine=engine) or -1
        return user_id

    def _has_user_arguments(self, opts):
        return (hasattr(opts, 'user_id') and opts.user_id) or (hasattr(opts, 'user_name') and opts.user_name)
