from typing import Tuple, Union

from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import get_hn_padding


class HailoDepthwise(BaseHailoConv):
    """Hailo's standard depthwise layer"""

    # PrecisionMode and BiasMode are the same as BaseHailoConv
    _hn_type = LayerType.DW
    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        kernel_size,
        strides=(1, 1),
        padding: Union[str, PaddingType] = "SAME",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate: Tuple[int, int] = (1, 1),
        activation: Union[str, callable, ActivationType] = "linear",
        transpose_output_width_features=False,
        logger=None,
        **kwargs,
    ):
        conv_op = ConvStrippedOp(
            f"{name}/conv_op",
            kernel_size=kernel_size,
            is_depthwise=True,
            groups=1,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            logger=logger,
        )
        super().__init__(
            name=name,
            conv_op=conv_op,
            activation=activation,
            transpose_output_width_features=transpose_output_width_features,
            logger=logger,
            **kwargs,
        )

        self.encoding_const = False

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)
        cls._validate_elwa(params["elementwise_add"])
        transpose_output_width_features = params.get("transpose_output_width_features", False)

        if params.get("groups", 1) != 1:
            raise ValueError("Depthwise doesn't support group param")

        layer = cls(
            name=lname,
            kernel_size=kshape[0:2],
            strides=params["strides"][1:3],
            padding=padding,
            stride_align=stride_align,
            dilation_rate=params["dilations"][1:3],
            activation=params["activation"],
            transpose_output_width_features=transpose_output_width_features,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, *args, **kargs):
        hn_element = super().to_hn(*args, **kargs)
        kernel_shape = list(self.conv_op.kernel.shape)
        params = hn_element.setdefault("params", self.get_default_params())
        params.setdefault("kernel_shape", kernel_shape)
        params.setdefault("batch_norm", False)
        params["activation"] = self.get_activation_name().value
        return hn_element

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return self.SUPPORTED_PRECISION_MODE
        else:
            return super()._get_precision_mode_supported_in_hw(arch)

    def _get_bias_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return self.SUPPORTED_BIAS_MODE - {BiasMode.double_scale_decomposition}
        else:
            return super()._get_bias_mode_supported_in_hw(arch)

    def _supported_quantization_groups_hw(self, quantization_groups, arch):
        if 1 <= quantization_groups <= 4:
            return True
        return False

    def get_equalization_handler_type(self, predecessor_index=None):
        if self.forced_output_scale_scalar_dof is not None:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        # TODO we dont handle yet dynamic weights
        if self.transpose_output_width_features:
            is_source = False
        else:
            is_source = True
        return EquivClassification(LayerHandlerType.consumer, is_source=is_source)

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization

    def get_macs(self) -> int:
        _, high, widht, chanels_in = self.input_shapes[0]
        kh, kw, _, _ = self.get_kernel().numpy().shape
        s_h, s_w = self.conv_op.strides
        macs = kh * kw * chanels_in * (high // s_h) * (widht // s_w) * 8  # Inneffiencity index

        return macs
