from typing import List

from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.slice_op import SliceOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_RANK2_SLICE,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasInitializationError


class HailoSlice(BaseHailoLayer):
    """Represents `slice` layer in the hn"""

    _hn_type = LayerType.SLICE
    OP_NAME = "slice_op"
    SUPPORTED_QUANTIZATION_GROUPS = False
    SUPPORTED_BIAS_MODE = {BiasMode.single_scale_decomposition}
    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a16_w16_a16,
    }

    def __init__(
        self,
        name: str,
        height_slice: tuple = None,
        width_slice: tuple = None,
        features_slice: tuple = None,
        logger=None,
        groups: int = 1,
        group_size: int = 0,
        **kwargs,
    ):
        self._input_pass_op = PassthruOp(name=f"{name}/in_passthru_op", logger=logger)
        self._slice_ops: List[SliceOp] = []
        for i in range(groups):
            f_slice_start = (i * group_size) + features_slice[0]
            f_slice_end = (i * group_size) + features_slice[1]
            slice_i = SliceOp(
                f"{name}/slice_op_{i}",
                height_slice=height_slice,
                width_slice=width_slice,
                features_slice=(f_slice_start, f_slice_end, features_slice[2]),
                logger=logger,
            )
            self._slice_ops.append(slice_i)
        self._concat_op = ConcatOp(f"{name}/concat_op", concat_elements=groups, logger=logger)
        super().__init__(name, logger, **kwargs)

        self.encoding_const = False
        self._groups = groups
        self._group_size = group_size

    @property
    def height_slice(self):
        return self._slice_ops[0].height_slice

    @property
    def width_slice(self):
        return self._slice_ops[0].width_slice

    @property
    def features_slice(self):
        orig_start = self._slice_ops[0].features_slice[0] * self._groups
        last_slice = self._slice_ops[-1]
        orig_end = (last_slice.features_slice[1] - (self._groups - 1) * self._group_size) * self._groups
        return [orig_start, orig_end, 1]

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        layer_flow.add_node(self._input_pass_op)
        layer_flow.add_node(self._concat_op)

        layer_flow.add_edge(in1, self._input_pass_op, DataPath.LAYER_IN)

        for idx, slice_op in enumerate(self._slice_ops):
            layer_flow.add_node(slice_op)
            layer_flow.add_edge(self._input_pass_op, slice_op, DataPath.LAYER_IN)
            layer_flow.add_edge(slice_op, self._concat_op, DataPath.LAYER_IN, input_index=idx)
        layer_flow.add_edge(self._concat_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = hn_element["params"]
        height_slice = params["height_slice"]
        width_slice = params["width_slice"]
        features_slice = params["features_slice"]
        if len(height_slice) == 2:
            height_slice += [1]
        if len(width_slice) == 2:
            width_slice += [1]
        if len(features_slice) == 2:
            features_slice += [1]

        input_shape = hn_element["input_shapes"][0]
        groups = params.get("groups", 1)
        group_size = input_shape[-1] // groups
        cls._validate_params_from_hn(height_slice, width_slice, features_slice, input_shape)
        layer = cls(
            height_slice=height_slice,
            width_slice=width_slice,
            features_slice=features_slice,
            name=lname,
            logger=logger,
            groups=groups,
            group_size=group_size,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    @classmethod
    def _validate_params_from_hn(cls, height_slice, width_slice, features_slice, input_shape):
        # this is a function only for the hn
        cls._validate_slice_params_hn(features_slice, input_shape[-1])
        if len(input_shape) == 4:
            cls._validate_slice_params_hn(height_slice, input_shape[1])
            cls._validate_slice_params_hn(width_slice, input_shape[2])
        elif height_slice != DEFAULT_RANK2_SLICE or width_slice != DEFAULT_RANK2_SLICE:
            # rank2 slice supports only features.
            raise AccelerasInitializationError("Unsupported slice values for layer slice")

    @staticmethod
    def _validate_slice_params_hn(slice_params, relevant_input_shape):
        """
        # TODO we may change the slice params to a smart named tuple
        To match kernel implementation we all slice (begin, send, stride) params will be as following:
            1. the begin is not negative
            2. the stop is not equal to the start
            3. the stride is 1 (for now)
        """
        if len(slice_params) != 3:
            raise AccelerasInitializationError("Unsupported slice values for layer slice")
        begin, end, stride = slice_params
        if begin < 0 or end == begin or stride != 1 or end > relevant_input_shape:
            raise AccelerasInitializationError("Unsupported slice values for layer slice")

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return {PrecisionMode.a8_w8_a8, PrecisionMode.a16_w16_a16, PrecisionMode.a8_w8, PrecisionMode.a16_w16}
        else:
            return super()._get_precision_mode_supported_in_hw(arch)

    def import_weights(self, layer_params):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        self.enforce_internal_encoding()

    def enforce_internal_encoding(self, training=False, **kwargs):
        self._input_pass_op.enforce_encoding()
        for i, slice_op in enumerate(self._slice_ops):
            slice_op.input_scale = self._input_pass_op.output_scale
            slice_op.input_zero_point = self._input_pass_op.output_zero_point
            slice_op.enforce_encoding()
            self._concat_op.input_scales[i] = slice_op.output_scale
            self._concat_op.input_zero_points[i] = slice_op.output_zero_point
        self._concat_op.enforce_encoding()

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        pass

    def _export_weights(self):
        return dict()
