from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from itertools import chain
from logging import Logger
from typing import List, Tuple, Type, TypeVar

import numpy as np
from pydantic import BaseModel

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import HailoPrecisionSplit
from hailo_model_optimization.acceleras.hailo_layers.hailo_shortcut import HailoShortcut
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    ActivationType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.opt_utils import calculate_shifts, get_scalar_vector
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm

T = TypeVar("T", bound=BaseModel)


class MMDecomNames816(BaseModel):
    """Layers names for the decomposiution of the matmul correction block"""

    mm_h: str
    mm_l: str
    ewa_end: str
    ps_int: str


class MMDecomNames1616(BaseModel):
    """Layers names for the decomposiution of the matmul correction block"""

    ps_int: str
    ps_uint: str

    mm_hh: str
    mm_hl: str
    mm_lh: str
    mm_ll: str

    moc_ah: str
    moc_al: str
    moc_wh: str
    moc_wl: str

    ewa_h: str
    ewa_l: str

    ewa_end: str


@dataclass
class Component:
    flow: ModelFlow
    layers: List[BaseHailoLayer]
    input_edges: List[Tuple[str, str, dict]]
    output_edges: List[Tuple[str, str, dict]]


class BaseDecomposeBlock:
    BLOCK_NAMES = BaseModel
    SUFFIX = "decompose"

    def __init__(self, layer_name: str, flow: ModelFlow, logger: Logger):
        self.logger = logger
        self.layer_name = layer_name
        self.block = self.get_block_names(layer_name, self.BLOCK_NAMES)
        self.comp_flow: ModelFlow = None
        self.flow = deepcopy(flow)

    @property
    def block_in_flow(self) -> bool:
        return self.check_if_block_exists(self.flow)

    @abstractmethod
    def _create_layers(self, block: T, model: HailoModel) -> List[BaseHailoLayer]: ...

    @abstractmethod
    def _create_flow(self, block: T) -> ModelFlow: ...

    @abstractmethod
    def create_input_edges(self, block: T) -> List[Tuple[str, str, dict]]: ...

    @abstractmethod
    def create_output_edges(self, block: T) -> List[Tuple[str, str, dict]]: ...

    @abstractmethod
    def update_mo_config(self, mo_config: ModelOptimizationConfig) -> None:
        """Here you have to implement the logic of updating the model optimization configuration"""

    @abstractmethod
    def fix_matmuls(self, model: HailoModel) -> None:
        """Here you have to implement the logic of fixing the matmul layers"""

    def check_if_block_exists(self, flow: ModelFlow) -> bool:
        return all(lname in flow.nodes for lname in self.block.model_dump().values())

    def collect_stats_layers(self) -> List[str]:
        return list(self.block.model_dump().values())

    def add_correction_block(self, model: HailoModel):
        """Here you have to implement the logic of adding a correction block,
        what layers should be add it"""
        if self.check_if_block_exists(model.flow):
            raise ValueError(f"Block {self.layer_name} already exists in the model")

        component = self.create_sub_flow(self.block, model)
        model._unlock_model()
        model.layers.pop(self.layer_name)
        model.layers.update({layer.full_name: layer for layer in component.layers})
        model._lock_model()
        self._stich_flow(model.flow, component)

    def create_sub_flow(self, block: T, model: HailoModel) -> Component:
        layers = self._create_layers(block, model)
        flow = self._create_flow(block)
        input_edges = self.create_input_edges(block)
        output_edges = self.create_output_edges(block)
        return Component(flow, layers, input_edges, output_edges)

    def _stich_flow(self, flow: ModelFlow, component: Component):
        flow.remove_node(self.layer_name)
        for u_node, v_node, index_info in chain(
            component.input_edges, component.output_edges, component.flow.edges(data=True)
        ):
            flow.add_edge(u_node, v_node, **index_info)

    def remove_correction_block(self):
        """Here you have to implement the logic of removing the correction block,
        what layers should be removed"""

    @abstractmethod
    def enforce_block_encodings(self, scales) -> Tuple[np.ndarray, np.ndarray]:
        """Here you have to implement the logic of enforcing all the layers on the block constrains,
        meaning setting scales and zero points for the block.
        """

    def get_block_names(self, lname: str, decompose_block: Type[T]) -> T:
        mm_scope = "/".join(lname.split("/")[:-1])
        lname = lname.split("/")[-1]
        mm_block_name, lname = OptimizationAlgorithm.get_block_and_layer_names(lname)
        template = (
            f"{mm_scope}/{mm_block_name}{{}}_{lname}_decompose"
            if mm_block_name
            else f"{mm_scope}/{{}}_{lname}_decompose"
        )
        vals = {key: template.format(key) for key in decompose_block.model_fields.keys()}
        return decompose_block(**vals)

    @staticmethod
    def filter_edges_by_attribute(edges, attr_name, attr_value):
        return [edge for edge in edges if edge[2].get(attr_name) == attr_value]


class MatmulDecompose1616(BaseDecomposeBlock):
    BLOCK_NAMES = MMDecomNames1616

    def _create_layers(self, block: MMDecomNames1616, model: HailoModel) -> List[BaseHailoLayer]:
        matmul: HailoMatmul = model.layers[self.layer_name]

        config = {
            "transpose_matmul_input": matmul.transpose_matmul_input,
            "groups": matmul.groups,
            "zp_comp_added": matmul.zp_comp_added,
            "input_windows": matmul.input_windows,
            "input_tiles": matmul.input_tiles,
            "logger": self.logger,
        }

        layers = [
            # Matmul
            HailoMatmul(block.mm_hh, **config),
            HailoMatmul(block.mm_ll, **config),
            HailoMatmul(block.mm_lh, **config),
            HailoMatmul(block.mm_ll, **config),
            # Shortcuts
            HailoShortcut(block.moc_ah, self.logger),
            HailoShortcut(block.moc_al, self.logger),
            HailoShortcut(block.moc_wh, self.logger),
            HailoShortcut(block.moc_wl, self.logger),
            # EWAdd
            HailoElementwiseAdd(block.ewa_h, logger=self.logger),
            HailoElementwiseAdd(block.ewa_l, logger=self.logger),
            HailoElementwiseAdd(block.ewa_end, activation=matmul.act_op.act_name, logger=self.logger),
            # PrecisionSplit
            HailoPrecisionSplit(block.ps_int, logger=self.logger),
            HailoPrecisionSplit(block.ps_uint, logger=self.logger),
        ]
        return layers

    def _create_flow(self, block: MMDecomNames1616) -> ModelFlow:
        flow = ModelFlow()

        # Left value of Matmul A where => (A(uint) @ B(int))
        flow.add_edge(block.ps_uint, block.moc_ah, output_index=1, input_index=0)
        flow.add_edge(block.ps_uint, block.moc_al, output_index=0, input_index=0)
        flow.add_edge(block.ps_int, block.moc_wh, output_index=1, input_index=0)
        flow.add_edge(block.ps_int, block.moc_wl, output_index=0, input_index=0)

        # Left value of Matmul A where => (A(uint) @ B(int))
        flow.add_edge(block.moc_ah, block.mm_hh, output_index=0, input_index=0)
        flow.add_edge(block.moc_ah, block.mm_hl, output_index=1, input_index=0)
        flow.add_edge(block.moc_al, block.mm_lh, output_index=0, input_index=0)
        flow.add_edge(block.moc_al, block.mm_ll, output_index=1, input_index=0)

        # Right value of Matmul B where => (A(uint) @ B(int))
        flow.add_edge(block.moc_wh, block.mm_hh, output_index=0, input_index=1)
        flow.add_edge(block.moc_wh, block.mm_hl, output_index=1, input_index=1)
        flow.add_edge(block.moc_wl, block.mm_lh, output_index=0, input_index=1)
        flow.add_edge(block.moc_wl, block.mm_ll, output_index=1, input_index=1)

        # Matmul to EWAdd
        flow.add_edge(block.mm_hh, block.ewa_h, output_index=0, input_index=0)
        flow.add_edge(block.mm_hl, block.ewa_h, output_index=0, input_index=1)
        flow.add_edge(block.mm_lh, block.ewa_l, output_index=0, input_index=0)
        flow.add_edge(block.mm_ll, block.ewa_l, output_index=0, input_index=1)

        # EWAdd to EWAdd end
        flow.add_edge(block.ewa_h, block.ewa_l, output_index=0, input_index=0)
        flow.add_edge(block.ewa_l, block.ewa_l, output_index=0, input_index=1)

        return flow

    def update_mo_config(self, mo_config: ModelOptimizationConfig):
        block: MMDecomNames1616 = self.block
        # Precision Config for the block
        configs = {
            # PresicionSplit
            block.ps_int: PrecisionMode.a16_w16_a8,
            block.ps_uint: PrecisionMode.a16_w16_a8,
            # Moc Convs
            block.moc_ah: PrecisionMode.a8_w8_a8,
            block.moc_al: PrecisionMode.a8_w8_a8,
            block.moc_wh: PrecisionMode.a8_w8_a8,
            block.moc_wl: PrecisionMode.a8_w8_a8,
            # Matmul
            block.mm_hh: PrecisionMode.a8_w8_a16,
            block.mm_hl: PrecisionMode.a8_w8_a16,
            block.mm_lh: PrecisionMode.a8_w8_a16,
            block.mm_ll: PrecisionMode.a8_w8_a16,
            # EWAdd
            block.ewa_h: PrecisionMode.a16_w16_a16,
            block.ewa_l: PrecisionMode.a16_w16_a16,
            block.ewa_end: PrecisionMode.a16_w16,
        }

        for layer, presicion in configs.items():
            ps = LayerPrecisionConfig()
            ps.precision = presicion
            mo_config.precision_config.layers[layer] = ps

        mo_config.remove_layer_from_all_configs(self.layer_name)

    def create_splits(self, model: HailoModel):
        block: MMDecomNames1616 = self.block
        for layer in [block.ps_int, block.ps_uint]:
            model.layers[layer].create_splits()

    def create_input_edges(self, block: MMDecomNames1616) -> List[Tuple[str, str, dict]]:
        orinal_edges = self.flow.in_edges(self.layer_name, data=True)
        a_input = self.filter_edges_by_attribute(orinal_edges, "input_index", 0)[0]
        b_input = self.filter_edges_by_attribute(orinal_edges, "input_index", 1)[0]
        out_index_a = a_input[2]["output_index"]
        out_index_b = b_input[2]["output_index"]

        conction_edges = [
            (a_input[0], block.ps_uint, {"input_index": 0, "output_index": out_index_a}),
            (b_input[0], block.ps_int, {"input_index": 0, "output_index": out_index_b}),
        ]
        return conction_edges

    def create_output_edges(self, block: T) -> List[Tuple[str, str, dict]]:
        orinal_edges = self.flow.out_edges(self.layer_name, data=True)
        out_edges = [
            (block.ewa_end, edge[1], {"input_index": edge[2]["input_index"], "output_index": 0})
            for edge in orinal_edges
        ]
        return out_edges

    def fix_matmuls(self, model: HailoModel) -> None:
        pass


class MMDecomNames168(BaseModel):
    """Layers names for the decomposiution of the matmul correction block"""

    mm_h: str
    mm_l: str
    linear_l: str
    ewa_end: str
    ps_uint: str


class MatmulDecompose168(BaseDecomposeBlock):
    BLOCK_NAMES = MMDecomNames168
    block: MMDecomNames168

    def _create_layers(self, block: MMDecomNames168, model: HailoModel) -> List[BaseHailoLayer]:
        matmul: HailoMatmul = model.layers[self.layer_name]
        config = {
            "transpose_matmul_input": matmul.transpose_matmul_input,
            "groups": matmul.groups,
            "zp_comp_added": matmul.zp_comp_added,
            "input_windows": matmul.input_windows,
            "input_tiles": matmul.input_tiles,
            "logger": self.logger,
        }

        layers = [
            HailoMatmul(block.mm_h, **config),
            HailoMatmul(block.mm_l, **config),
            HailoStandaloneActivation(block.linear_l, activation=ActivationType.LINEAR, logger=self.logger),
            HailoElementwiseAdd(block.ewa_end, activation=matmul.act_op.act_name, logger=self.logger),
            HailoPrecisionSplit(block.ps_uint, logger=self.logger),
        ]
        return layers

    def _create_flow(self, block: MMDecomNames168) -> ModelFlow:
        flow = ModelFlow()

        # Left value of Matmul A where => (A(uint) @ B(int))
        flow.add_edge(block.ps_uint, block.mm_h, output_index=1, input_index=0)
        flow.add_edge(block.ps_uint, block.linear_l, output_index=0, input_index=0)
        flow.add_edge(block.linear_l, block.mm_l, output_index=0, input_index=0)

        # EWAdd to EWAdd end
        flow.add_edge(block.mm_h, block.ewa_end, output_index=0, input_index=0)
        flow.add_edge(block.mm_l, block.ewa_end, output_index=0, input_index=1)

        return flow

    def update_mo_config(self, mo_config: ModelOptimizationConfig):
        block: MMDecomNames168 = self.block
        # Precision Config for the block
        configs = {
            # PresicionSplit
            block.ps_uint: PrecisionMode.a16_w16_a8,
            block.linear_l: PrecisionMode.a8_w8_a8,
            # Matmul
            block.mm_h: PrecisionMode.a8_w8_a16,
            block.mm_l: PrecisionMode.a8_w8_a16,
            # EWAdd
            block.ewa_end: PrecisionMode.a16_w16,
        }

        for layer, presicion in configs.items():
            ps = LayerPrecisionConfig()
            ps.precision_mode = presicion
            mo_config.precision_config.layers[layer] = ps

        mo_config.remove_layer_from_all_configs(self.layer_name)

    def create_splits(self, model: HailoModel):
        model.layers[self.block.ps_uint].create_splits()

    def create_input_edges(self, block: MMDecomNames168) -> List[Tuple[str, str, dict]]:
        orinal_edges = self.flow.in_edges(self.layer_name, data=True)
        a_input = self.filter_edges_by_attribute(orinal_edges, "input_index", 0)[0]
        b_input = self.filter_edges_by_attribute(orinal_edges, "input_index", 1)[0]
        out_index_a = a_input[2]["output_index"]
        out_index_b = b_input[2]["output_index"]

        conction_edges = [
            (a_input[0], block.ps_uint, {"input_index": 0, "output_index": out_index_a}),
            (b_input[0], block.mm_h, {"input_index": 1, "output_index": out_index_b}),
            (b_input[0], block.mm_l, {"input_index": 1, "output_index": out_index_b}),
        ]
        return conction_edges

    def create_output_edges(self, block: T) -> List[Tuple[str, str, dict]]:
        orinal_edges = self.flow.out_edges(self.layer_name, data=True)
        out_edges = [
            (block.ewa_end, edge[1], {"input_index": edge[2]["input_index"], "output_index": 0})
            for edge in orinal_edges
        ]
        return out_edges

    def fix_matmuls(self, model: HailoModel) -> None:
        # Fixing the matmul High layer is more complicate

        # Fixing the matmul Low layer
        matmul = model.layers[self.block.mm_l]
        _, shift_delta, _ = _get_matmul_shifts(matmul)
        scale = matmul.input_scales[0] * 2**shift_delta
        matmul.set_input_scale(scale, 0)
        model.layers[self.block.linear_l].set_output_scale(scale, 0)


def search_matmul_decomp_blocks(model: HailoModel) -> List[BaseDecomposeBlock]:
    supported_blocks = (MatmulDecompose1616, MatmulDecompose168)
    decompose_blocks = {}
    for lname, layer in model.layers.items():
        if isinstance(layer, HailoMatmul):
            if BaseDecomposeBlock.SUFFIX in lname:
                # GETS the matmul layer name, then check for each block if all the layers are there
                mm_lname = lname.split("_")[-2]
                full_name = "/".join([lname.split("/")[0], mm_lname])
                for block in supported_blocks:
                    compose_block = block(full_name, model.flow, model._logger)
                    if compose_block.block_in_flow:
                        decompose_blocks[full_name] = compose_block
    return decompose_blocks


def _get_matmul_shifts(matmul: HailoMatmul):
    """Calculate the shifts for the matmul OP -> Layer"""
    NUMERICAL_ERROR_FACTOR = 1.0001
    scale_a = np.array(get_scalar_vector(matmul.input_scales[0]))
    scale_b = np.array(get_scalar_vector(matmul.input_scales[1]))
    pre_act_candidate = scale_a * scale_b

    limval_min, limval_max = matmul.act_op.get_input_limvals(0)

    lossy_element = matmul.act_op.input_lossy_element
    limvals = np.maximum(np.abs(limval_max), np.abs(limval_min))
    expected_max_output = limvals / pre_act_candidate
    mac_shift, shift_delta, need_shift = calculate_shifts(
        expected_max_output,
        lossy_element.bits,
        SHIFT_CALCULATE_BUFFER * NUMERICAL_ERROR_FACTOR,
        return_needed_shift=True,
    )
    return mac_shift, shift_delta, need_shift
