from typing import List

from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.slice_op import SliceOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)


class HailoRowSplitter(BaseHailoLayer):
    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.ROW_SPLITTER

    def __init__(self, name: str, num_groups: int, logger=None, **kwargs):
        self._slice_ops: List[SliceOp] = []
        self.passthru_op = PassthruOp(f"{name}/in_passthru_op", logger=logger)
        for i in range(num_groups):
            op_i = SliceOp(f"{name}/slice_op_{i}", height_slice=(i, None, num_groups), logger=logger)
            self._slice_ops.append(op_i)
        super().__init__(name, logger, **kwargs)

        self.encoding_const = False

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        outputs = [layer_flow.add_output() for _ in self._slice_ops]

        layer_flow.add_node(self.passthru_op)
        for slice_op in self._slice_ops:
            layer_flow.add_node(slice_op)

        layer_flow.add_edge(in1, self.passthru_op, DataPath.LAYER_IN)
        for i, slice_op in enumerate(self._slice_ops):
            layer_flow.add_edge(self.passthru_op, slice_op, DataPath.LAYER_IN)
            layer_flow.add_edge(slice_op, outputs[i], DataPath.LAYER_OUT)

        return layer_flow

    def enforce_internal_encoding(self, training=False, **kwargs):
        self.passthru_op.forward_encoding()
        for slice_op in self._slice_ops:
            slice_op.input_scale = self.passthru_op.output_scale
            slice_op.input_zero_point = self.passthru_op.output_zero_point
            slice_op.enforce_encoding()

    def create_hw_params(
        self,
        weights_clipping: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
        hw_shifts=None,
    ): ...

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        self.enforce_internal_encoding(training=training, **kwargs)

    def _export_weights(self):
        return dict()

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def import_weights(self, layer_params):
        pass

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        num_groups = len(hn_element["output_shapes"])
        layer = cls(lname, num_groups, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer
