from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.bias_add_3d_op import AddBias3DOp
from hailo_model_optimization.acceleras.atomic_ops.conv3d_op import Conv3DOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    EquivClassification,
    LayerHandlerType,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import get_hn_padding, set_hn_padding_stride_align


class HailoConv3D(BaseHailoConv):
    """Hailo's conv3d layer"""

    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        filters,
        kernel_size,
        input_features,
        strides=(1, 1, 1),
        padding: Union[str, PaddingType] = "VALID",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate=(1, 1, 1),
        groups=1,
        activation: Union[str, callable, ActivationType] = "linear",
        transpose_output_width_features=False,
        disparity=1,
        input_disparity=1,
        logger=None,
        **kwargs,
    ):
        conv3d_op = Conv3DOp(
            f"{name}/conv_op",
            kernel_size=kernel_size,
            filters=filters,
            input_features=input_features,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            disparity=disparity,
            input_disparity=input_disparity,
            logger=logger,
        )
        bias_add_op = AddBias3DOp(
            f"{name}/bias_add_op",
            output_disparity=conv3d_op.output_disparity,
            trainable=False,
            is_correctable=False,
            logger=logger,
        )
        super().__init__(
            name=name,
            conv_op=conv3d_op,
            bias_add_op=bias_add_op,
            activation=activation,
            transpose_output_width_features=transpose_output_width_features,
            logger=logger,
            **kwargs,
        )

        self.encoding_const = False

    @classmethod
    def get_default_params(cls):
        defaults = {
            "strides": [1, 1, 1, 1],
            "dilations": [1, 1, 1, 1],
            "padding": "VALID",
            "activation": "linear",
            "groups": 1,
            "elementwise_add": False,
        }
        return dict(defaults)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)
        cls._validate_elwa(params["elementwise_add"])
        transpose_output_width_features = params.get("transpose_output_width_features", False)
        input_disparity = params.get("input_disparity", params.get("disparity", params.get("layer_disparity", 1)))
        layer = cls(
            name=lname,
            filters=kshape[-1],
            kernel_size=kshape[0:3],
            input_features=hn_element["input_shapes"][0][-1] // input_disparity,
            strides=params["strides"][1:],
            padding=padding,
            stride_align=stride_align,
            groups=params["groups"],
            activation=params["activation"],
            transpose_output_width_features=transpose_output_width_features,
            dilation_rate=params["dilations"][1:],
            disparity=params.get("disparity", params.get("layer_disparity", 1)),
            input_disparity=input_disparity,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        weights = self.export_weights()
        strides = self.conv_op.strides
        dilation_rate = self.conv_op.dilation_rate

        params = {
            "kernel_shape": list(weights["kernel"].shape),
            "strides": [1, strides[0], strides[1], strides[2]],
            "groups": self.conv_op.groups,
            "activation": self.act_op.act_name.value,
            "transpose_output_width_features": self.transpose_output_width_features,
            "dilations": [1, dilation_rate[0], dilation_rate[1], 1],
            "elementwise_add": False,
            "disparity": self.conv_op.disparity,
        }

        set_hn_padding_stride_align(params, self.conv_op.padding, self.conv_op.stride_align)

        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def homogeneous(self):
        return False

    def _layer_dependent_hw_params_modifications(self, params: dict) -> dict:
        """This layer has a short bias that on the opp is repeated"""
        bias_leng = self.bias_add_op.short_bias.shape[0]
        params["bias"] = params["bias"][:bias_leng]
        params["bias_q"] = params["bias"]
        if "bias_q_int8_vec_a" in params:
            params["bias_q_int8_vec_a"] = np.array(params["bias"], dtype=np.int8)
        if "bias_q_int8_vec_b" in params:
            params["bias_q_int8_vec_b"] = np.array(params["bias"], dtype=np.int8)
        return params

    def get_macs(self):
        """
        This is an approximated number of the operations
        of the convolution: here is missing the pads
        """
        _, high, width, _ = self.input_shapes[0]
        kernel_size = self.get_kernel().numpy().size
        disparity = self.conv_op.disparity
        s_h, s_w, s_d = self.conv_op.strides
        stride_disp = s_d // self.conv_op.input_features
        macs = kernel_size * ((high // s_h) * (width // s_w) * (disparity // stride_disp))
        return macs

    def enable_force_pruning(self):
        pass

    def disable_force_pruning(self):
        pass

    def _supported_quantization_groups_hw(self, quantization_groups, arch):
        return False

    @staticmethod
    def _add_qnpz_conv_relations(quant_qnpz_relations): ...

    @property
    def kernel_scale_forced_to_save(self):
        # kernel q forced is not implemented for 3d conv
        return False
