from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import get_hn_padding, set_hn_padding_stride_align


class HailoConv(BaseHailoConv):
    """Hailo's standard conv layer"""

    _hn_type = LayerType.CONV

    def __init__(
        self,
        name: str,
        filters,
        kernel_size,
        strides=(1, 1),
        padding: Union[str, PaddingType] = "SAME",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate=(1, 1),
        groups=1,
        group_sizes=None,
        activation: Union[str, callable, ActivationType] = "linear",
        transpose_output_width_features=False,
        spatial_flatten_output=False,
        logger=None,
        zp_comp_add=False,
        set_scale_by_kernel_only=False,
        **kwargs,
    ):
        stripped_conv_op = ConvStrippedOp(
            f"{name}/conv_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            group_sizes=group_sizes,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            spatial_flatten_output=spatial_flatten_output,
            set_scale_by_kernel_only=set_scale_by_kernel_only,
            logger=logger,
        )
        super().__init__(
            name=name,
            conv_op=stripped_conv_op,
            activation=activation,
            transpose_output_width_features=transpose_output_width_features,
            logger=logger,
            **kwargs,
        )

        self.encoding_const = False
        self.zp_comp_add = zp_comp_add

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)
        cls._validate_elwa(params["elementwise_add"])
        transpose_output_width_features = params.get("transpose_output_width_features", False)
        spatial_flatten_output = params.get("spatial_flatten_output", False)

        layer = cls(
            name=lname,
            filters=kshape[-1],
            kernel_size=kshape[0:2],
            padding=padding,
            stride_align=stride_align,
            strides=params["strides"][1:3],
            groups=params["groups"],
            group_sizes=params.get("group_sizes"),
            activation=params["activation"],
            transpose_output_width_features=transpose_output_width_features,
            spatial_flatten_output=spatial_flatten_output,
            dilation_rate=params["dilations"][1:3],
            zp_comp_add=params.get("zp_comp_added", False),
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        weights = self.export_weights()
        params = {}
        params["kernel_shape"] = list(weights["kernel"].shape)
        params["kernel_shape"][2] = int(
            params["kernel_shape"][2] * sum(self.conv_op.group_sizes) / max(self.conv_op.group_sizes),
        )
        strides = self.conv_op.strides
        params["strides"] = [1, strides[0], strides[1], 1]
        params["groups"] = self.conv_op.groups
        params["group_sizes"] = self.conv_op.group_sizes
        params["activation"] = self.act_op.act_name.value
        params["transpose_output_width_features"] = self.transpose_output_width_features
        params["spatial_flatten_output"] = self.conv_op.spatial_flatten_output
        dilation_rate = self.conv_op.dilation_rate
        params["dilations"] = [1, dilation_rate[0], dilation_rate[1], 1]
        set_hn_padding_stride_align(params, self.conv_op.padding, self.conv_op.stride_align)
        params["elementwise_add"] = False
        params["zp_comp_added"] = self.zp_comp_add
        if "params" not in self._hn_element:
            self._hn_element["params"] = {}
            self._hn_element["params"]["batch_norm"] = False
        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)

    def get_equalization_handler_type(self, predecessor_index=None):
        if self.forced_output_scale_scalar_dof is not None:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        # TODO we dont handle yet dynamic weights
        is_signed_output = self.get_output_lossy_elements()[0].signed
        if self.transpose_output_width_features or is_signed_output:
            is_source = False
        else:
            is_source = True

        return EquivClassification(LayerHandlerType.consumer, is_source=is_source)

    def get_quarot_handler_type(self, predecessor_index=None):
        if self.forced_output_scale_scalar_dof is not None:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        if self.transpose_output_width_features:
            is_source = False
        else:
            is_source = True

        return EquivClassification(LayerHandlerType.consumer, is_source=is_source)
