from typing import List

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.slice_op import SliceOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)


class HailoFeatureSplitter(BaseHailoLayer):
    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.FEATURE_SPLITTER

    def __init__(self, name: str, split_sizes: List[int], groups: int = 1, logger=None, **kwargs):
        num_splits = len(split_sizes)
        split_indices = list(split_sizes)
        split_indices.insert(0, 0)
        split_indices = np.cumsum(split_indices, axis=0)
        self.passthru_op = PassthruOp(f"{name}/in_passthru_op", logger=logger)
        self._slice_ops: List[List[SliceOp]] = [list() for _ in range(num_splits)]
        self._concat_ops: List[ConcatOp] = []
        group_size = split_indices[-1] // groups
        for i in range(num_splits):
            for j in range(groups):
                slice_start = (j * group_size) + split_indices[i] // groups
                slice_end = (j * group_size) + split_indices[i + 1] // groups
                slice_j = SliceOp(
                    f"{name}/slice_op_{i}_{j}",
                    features_slice=(slice_start, slice_end, 1),
                    logger=logger,
                )
                self._slice_ops[i].append(slice_j)
            concat_i = ConcatOp(f"{name}/concat_op_{i}", concat_elements=groups, logger=logger)
            self._concat_ops.append(concat_i)
        super().__init__(name, logger, **kwargs)

        self.encoding_const = False
        self._groups = groups

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        outputs = [layer_flow.add_output() for _ in self._concat_ops]
        layer_flow.add_node(self.passthru_op)
        layer_flow.add_edge(in1, self.passthru_op, DataPath.LAYER_IN)

        for i, concat_op in enumerate(self._concat_ops):
            layer_flow.add_node(concat_op)
            for slice_index, slice_op in enumerate(self._slice_ops[i]):
                layer_flow.add_node(slice_op)
                layer_flow.add_edge(self.passthru_op, slice_op, DataPath.LAYER_IN)
                layer_flow.add_edge(slice_op, concat_op, data_path=DataPath.LAYER_IN, input_index=slice_index)
            layer_flow.add_edge(concat_op, outputs[i], DataPath.LAYER_OUT)

        return layer_flow

    def _export_weights(self):
        return dict()

    def create_hw_params(
        self,
        weights_clipping: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
        hw_shifts=None,
    ): ...

    def enforce_internal_encoding(self, training=False, **kwargs):
        self.passthru_op.forward_encoding()
        for i, concat_op in enumerate(self._concat_ops):
            for j, slice_op in enumerate(self._slice_ops[i]):
                slice_op.input_scale = self.passthru_op.output_scale
                slice_op.input_zero_point = self.passthru_op.output_zero_point
                slice_op.enforce_encoding()
                concat_op.input_scales[j] = slice_op.output_scale
                concat_op.input_zero_points[j] = slice_op.output_zero_point
            concat_op.enforce_encoding()

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        self.enforce_internal_encoding()

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def import_weights(self, layer_params):
        pass

    @property
    def is_precision_transparent(self) -> bool:
        return True

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        split_sizes = []
        for output_shape in hn_element["output_shapes"]:
            split_sizes.append(output_shape[-1])
        groups = hn_element.get("params", {}).get("groups", 1)
        layer = cls(lname, split_sizes, groups, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        params = self._hn_element.get("params", dict())
        params["groups"] = self._groups
        if "params" not in self._hn_element:
            self._hn_element["params"] = {}
        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)
