# coding=utf-8
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2023 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENSE.TXT
import contextlib
import time
import queue
from queue import Queue
from datetime import timedelta
from threading import Event
from typing import Generator, Callable, Sequence

import sqlalchemy as sa
import sqlalchemy.exc

from lvestats.orm import bursting_events_table

from .._logs import logger
from .base import InBurstingEventRow, thread_running


@contextlib.contextmanager
def events_saver_running(
    engine: sa.engine.Engine,
    server_id: str,
    dump_interval: timedelta,
    run_period: timedelta = timedelta(seconds=5),
    fail_fast: bool = True,
) -> Generator[Callable[[InBurstingEventRow], None], None, None]:
    messages = Queue()

    def main(terminate: Event):
        # TODO(vlebedev): Implement some kind of buffer size monitoring.
        # FIXME(vlebedev): It will take  ~`dump_period` in the worst case for thread to respond to termination request.
        #                  Loop more frequently?
        prev_db_write_time, events = 0.0, []
        while not terminate.is_set():
            now = time.time()
            events.extend(_pull_events(messages))
            if (now - prev_db_write_time) > dump_interval.total_seconds():
                try:
                    save_events_to_db(engine, server_id, events)
                except sqlalchemy.exc.DBAPIError as e:
                    if fail_fast:
                        raise e
                    logger.error('Failed to save events to DB!', exc_info=e)
                else:
                    events.clear()
                    prev_db_write_time = now
            time.sleep(run_period.total_seconds())

        # NOTE(vlebedev): Write events remaining in the queue.
        save_events_to_db(engine, server_id, _pull_events(messages))
        logger.debug('Stopping events saving thread.')

    with thread_running('bursting-saver', main):
        yield messages.put_nowait


def _pull_events(messages: Queue) -> list[InBurstingEventRow]:
    result = []
    try:
        while True:
            item = messages.get_nowait()
            result.append(item)
    except queue.Empty:
        pass
    return result


def save_events_to_db(
    engine: sa.engine.Engine,
    server_id: str,
    events: Sequence[InBurstingEventRow],
) -> None:
    if len(events) == 0:
        return
    logger.debug('Saving %d events to DB', len(events))
    with engine.begin() as conn:
        stmt = sa.insert(bursting_events_table).values([{
            'server_id': server_id,
            **e,
        } for e in events])
        conn.execute(stmt)
