# -*- coding: utf-8 -*-
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2019 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENSE.TXT

import io
import logging
import os
import signal
import sys
import time
import traceback
import psutil

from lvestats.lib.commons.func import reboot_lock
from lvestats.lib import config, dbengine
from lvestats.lib.commons.func import LVEVersionError, get_lve_version
from lvestats.lib.commons.logsetup import setup_logging

DEV_NULL = '/dev/null'
PIDFILE = ""


class EnvironmentException(Exception):
    def __init__(self, message):
        super().__init__()
        self.message = message


def get_process_pid():
    """
    Check if lvestats already running
    :return int|None: None - if no process found;  pid - if some lvestats-server found
    """
    if PIDFILE and os.path.isfile(PIDFILE):
        try:
            with open(PIDFILE, 'r', encoding='utf-8') as f:
                pid = int(f.read().strip())
                os.kill(pid, 0)  # try to send some
                return pid
        except (IOError, OSError):
            return None  # No pidfile or no process found


def stop_server():
    exit_code = 0

    def kill_process(_pid: int, _signal: signal.Signals):
        """Kill process by pid by sending a specific signal to it"""
        try:
            os.kill(_pid, _signal)
        except (OSError, ProcessLookupError):
            log.info("Process with pid '%d' is already dead", _pid)

    def on_sigusr2(_proc: psutil.Process):
        """callback for psutil.wait_procs()"""
        log.info("Signal 'SIGUSR2' sent to child process: %s", _proc)

    def on_terminate(_proc: psutil.Process):
        """callback for psutil.wait_procs()"""
        log.info("Signal 'SIGTERM' sent to child process: %s", _proc)

    log = setup_logging({}, caller_name='stop_server')
    pid = get_process_pid()
    if pid is None:
        exit_code = 1
    else:
        process = psutil.Process(pid)
        childs = process.children(recursive=True)
        with reboot_lock(timeout=60 * 10):
            for child in childs:
                # There may be stored absent childs, so we need to kill
                # only existing processes
                if psutil.pid_exists(child.pid):
                    kill_process(child.pid, signal.SIGUSR2)
            _, alive = psutil.wait_procs(childs, timeout=3, callback=on_sigusr2)
            for p in alive:
                kill_process(p.pid, signal.SIGTERM)
            _, alive = psutil.wait_procs(childs, timeout=3, callback=on_terminate)
            kill_process(pid, signal.SIGTERM)
            time.sleep(0.15)
            for child in alive:
                kill_process(child.pid, signal.SIGKILL)
    return exit_code


def setup_default_exception_hook():
    def hook(_type, _ex, _trace):
        sio = io.StringIO()
        traceback.print_tb(_trace, file=sio)
        msg = f"Uncaught exception {_type}\nmessage='%s'\n:%s"
        logging.error(msg, str(_ex), sio.getvalue())
        sio.close()
        sys.__excepthook__(_type, _ex, _trace)

    sys.excepthook = hook


def run_main(cnf, singleprocess, plugins, profiling_log, times_):
    from lvestats import main  # pylint: disable=import-outside-toplevel,redefined-outer-name
    main.main(cnf, singleprocess, plugins, profiling_log, times_)


def sigterm_handler(signum, frame):
    log = logging.getLogger('sigterm_handler')
    log.info('SIGTERM handler. Shutting Down.')
    os._exit(0)


def _check_db_connection(cnf, log_):
    """
    Check whether database connection can be
    established and db schema is ok.
    Raises exception if database is broken.
    :type cnf: dict
    :type log_: logging.Logger
    :raises: EnvironmentException
    """
    log_.debug('Check for running SQL server')
    try:
        engine = dbengine.make_db_engine(cnf)
        engine.execute("SELECT 1;")
        validation = dbengine.validate_database(engine)
        if validation['column_error'] or validation['table_error']:
            sys.exit(1)
    except Exception as ex:
        msg = "Error occurred during connecting to SQL server:"
        log_.fatal(str(ex))
        log_.exception(ex)

        raise EnvironmentException(f"\n{msg}\n{ex}\n") from ex


def _check_valid_lve_version(log_):
    """
    Check for possible misconfiguration
    if cpu speed is reported as 0 with /proc/cpuinfo
    and lve version <= 4. Raises exception, if so.
    :type log_: logging.Logger
    :raises: EnvironmentException
    """
    log_.debug('Check for valid LVE version')

    lve_version = get_lve_version()
    if lve_version <= 4:
        msg = "LVE version <= 4"
        log_.fatal('LVE version <= 4. Please, update.')
        raise EnvironmentException(f"\n{msg}\n")


def _check_running_process(log_):
    """
    Check if another instance of lve-stats is
    already running. Raises exception, if so.
    :param log_: logging.Logger
    :raises: EnvironmentException
    """
    log_.debug('Check for running lvestats-server')
    rc = get_process_pid()
    if rc:
        msg = f"Lvestats-server already running with pid {rc}. Exiting"
        log_.warning(msg)

        raise EnvironmentException(msg)


def _is_environment_ok(cnf, log_):
    """
    Checks whether system environment works fine.
    Return True if ok, False otherwise.
    :type cnf: dict
    :type log_: logging.Logger
    :rtype: bool
    """
    try:
        _check_running_process(log_)
        _check_db_connection(cnf, log_)
        _check_valid_lve_version(log_)
    except EnvironmentException as ex:
        sys.stderr.write(ex.message)
        sys.stderr.flush()
        return False

    return True


def daemonize(cnf, singleprocess, plugins, profiling_log, times):
    def fork():
        try:
            return os.fork()
        except OSError as e:
            raise RuntimeError(f"{e.strerror} [{e.errno}]") from e

    setup_logging(cnf, console_level=logging.CRITICAL)
    setup_default_exception_hook()

    log_ = logging.getLogger('server')
    # check for issues with environment
    if not _is_environment_ok(cnf, log_):
        sys.exit(1)

    log_.debug('Starting server')

    pid = fork()
    if pid:
        log_.debug('First fork, pid=%d', pid)
        time.sleep(0.2)
        os._exit(0)

    os.setsid()
    signal.signal(signal.SIGTERM, sigterm_handler)

    pid = fork()
    if pid:
        log_.debug('Second fork, pid=%d', pid)
        if PIDFILE:
            log_.debug('Writing pid to file %s', PIDFILE)
            with open(PIDFILE, 'w', encoding='utf-8') as pidfile:
                pidfile.write(str(pid))
        # exit parent process
        log_.debug('Child daemon fork ok')

        os._exit(0)

    os.nice(10)
    os.setpgrp()
    os.chdir('/')
    previous_umask = os.umask(0)

    sys.stdout.flush()
    sys.stderr.flush()
    # pylint: disable=consider-using-with,unspecified-encoding
    si = open(DEV_NULL, 'r')
    so = open(DEV_NULL, 'a+')  # read + write allowed
    se = open(DEV_NULL, 'a+')  # read + write allowed
    # pylint: enable=consider-using-with,unspecified-encoding
    os.dup2(si.fileno(), sys.stdin.fileno())
    os.dup2(so.fileno(), sys.stdout.fileno())
    os.dup2(se.fileno(), sys.stderr.fileno())

    log_.info('Starting main() in daemon')

    os.umask(previous_umask)
    run_main(cnf, singleprocess, plugins, profiling_log, times)


def process_opts(cnf, _opts, _times):
    if _opts.action == 'stop':
        sys.exit(stop_server())
    elif _opts.action == 'restart':  # stop and start
        stop_server()
        daemonize(cnf, _opts.singleprocess, _opts.plugins, _opts.profiling_log, _times)
    elif _opts.action == 'start':
        if _opts.nodaemon:
            debug_mode = logging.DEBUG if _opts.debug_mode else logging.INFO
            setup_logging(cnf, console_level=debug_mode, file_level=debug_mode)
            setup_default_exception_hook()
            run_main(cnf, _opts.singleprocess, _opts.plugins, _opts.profiling_log, _times)
        else:
            daemonize(cnf, _opts.singleprocess, _opts.plugins, _opts.profiling_log, _times)


def main(opts):
    if not opts.pidfile:
        print("--pidfile should be specified")
        sys.exit(1)
    else:
        global PIDFILE
        PIDFILE = opts.pidfile

    times = None
    if opts.times:
        times = int(opts.times)

    cfg = None
    try:
        cfg = config.read_config()
    except config.ConfigError as ce:
        ce.log_and_exit()

    try:
        process_opts(cfg, opts, times)
    except LVEVersionError as lve_error:
        config.log.error(str(lve_error))
        sys.exit(1)
