# coding=utf-8
#
# Copyright © Cloud Linux GmbH & Cloud Linux Software, Inc 2010-2023 All Rights Reserved
#
# Licensed under CLOUD LINUX LICENSE AGREEMENT
# http://cloudlinux.com/docs/LICENSE.TXT
from typing import Callable, TypedDict

import sqlalchemy as sa

from lvestats.orm import BurstingEventType, bursting_events_table

from ..common import LveId, Timestamp
from ..history import LveHistory, IntervalType, get_interval_type_after
from ..utils import bootstrap_gen
from .._logs import logger


class _OutBurstingEventRow(TypedDict):
    lve_id: LveId
    timestamp: Timestamp
    event_type: BurstingEventType


def load_bursting_enabled_intervals_from_db(
    engine: sa.engine.Engine,
    cutoff: Timestamp,
    server_id: str,
) -> dict[LveId, LveHistory]:
    lve_to_asm = {}
    with engine.begin() as conn:
        first_events_stmt = (
            sa.select([
                bursting_events_table.c.lve_id,
                sa.func.max(bursting_events_table.c.timestamp).label("timestamp"),
            ])
            .where(sa.and_(
                bursting_events_table.c.timestamp <= cutoff,
                bursting_events_table.c.server_id == server_id,
            ))
            .group_by(
                bursting_events_table.c.lve_id
            )
            .alias('first_events')
        )

        all_events_stmt = (
            sa.select([
                bursting_events_table.c.server_id,
                bursting_events_table.c.lve_id,
                bursting_events_table.c.timestamp,
                bursting_events_table.c.event_type,
            ])
            .select_from(
                bursting_events_table.outerjoin(
                    first_events_stmt,
                    bursting_events_table.c.lve_id == first_events_stmt.c.lve_id,
                )
            )
            .where(sa.and_(
                bursting_events_table.c.server_id == server_id,
                bursting_events_table.c.timestamp >= sa.func.coalesce(
                    first_events_stmt.c.timestamp,
                    cutoff,
                )
            ))
            .order_by(
                bursting_events_table.c.timestamp.asc(),
            )
        )

        result = conn.execute(all_events_stmt)
        assert isinstance(result, sa.engine.ResultProxy)
        rows = result.fetchall()
        assert rows is not None
    for row in rows:
        row: _OutBurstingEventRow
        lve_id = row['lve_id']
        try:
            asm = lve_to_asm[lve_id]
        except KeyError:
            asm = lve_to_asm[lve_id] = _LveHistoryAssembler(lve_id, row)
        asm.process_row(row)
    return {lve_id: asm.create() for lve_id, asm in lve_to_asm.items()}


class _LveHistoryAssembler:
    def __init__(self, lve_id: LveId, first_event: _OutBurstingEventRow) -> None:
        self._lve_id = lve_id
        self._first_interval_type = (
            IntervalType.OVERUSING if first_event['event_type'] == BurstingEventType.STARTED else IntervalType.NORMAL
        )

        @bootstrap_gen
        def assembler_gen():
            prev_event = first_event
            while True:
                event: _OutBurstingEventRow = yield
                if event['event_type'] == prev_event['event_type']:
                    logger.debug(
                        'Skipping consecutive "%s" events for lve_id=%s',
                        event['event_type'], lve_id,
                    )
                    continue
                self._timestamps.append(event['timestamp'])
                prev_event = event

        self.process_row: Callable[[_OutBurstingEventRow], None] = assembler_gen().send
        self._timestamps: list[Timestamp] = [first_event['timestamp']]

    def create(self) -> LveHistory:
        timestamps = self._timestamps
        last_interval_type = get_interval_type_after(self._first_interval_type, len(timestamps))
        if last_interval_type == IntervalType.OVERUSING:
            timestamps = timestamps[:-1]
            logger.warning("Last interval for lve_id=%s is not closed!", self._lve_id)
        if len(timestamps) == 0:
            return LveHistory()
        return LveHistory(self._first_interval_type, tuple(timestamps))
