from typing import Tuple, Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv_add import BaseHailoConvAdd
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasNumerizationError,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import get_hn_padding


class HailoDepthwiseAdd(BaseHailoConvAdd):
    """Hailo's standard depthwise layer"""

    # PrecisionMode and BiasMode are the same as BaseHailoConv
    SUPPORTED_QUANTIZATION_GROUPS = False
    vector_elwa_factor_supported = False

    _hn_type = LayerType.DW

    def __init__(
        self,
        name: str,
        kernel_size,
        strides=(1, 1),
        padding: Union[str, PaddingType] = "SAME",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate: Tuple[int, int] = (1, 1),
        activation: Union[str, callable, ActivationType] = "linear",
        logger=None,
        **kwargs,
    ):
        conv_op = ConvStrippedOp(
            f"{name}/conv_op",
            kernel_size=kernel_size,
            is_depthwise=True,
            groups=1,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            logger=logger,
        )
        super().__init__(
            name=name,
            conv_op=conv_op,
            activation=activation,
            logger=logger,
            **kwargs,
        )
        self.output_scale_scalar_dof = 1

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)

        cls._validate_elwa(params["elementwise_add"])

        if params.get("transpose_output_width_features", False):
            raise AccelerasImplementationError("transpose_output_width_features is not supported in acceleras yet")

        if params.get("groups", 1) != 1:
            raise ValueError("Depthwise doesn't support group param")

        layer = cls(
            name=lname,
            kernel_size=kshape[0:2],
            strides=params["strides"][1:3],
            padding=padding,
            stride_align=stride_align,
            dilation_rate=params["dilations"][1:3],
            activation=params["activation"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def _create_out_in_scale_ratio(self):
        """
        create the output_scale_scalar_dof
        """
        _out_inp1ew_scale_ratio = self.output_scale / self.input_scales[1]
        eps = 1e-6
        if _out_inp1ew_scale_ratio.shape != ():
            if isinstance(_out_inp1ew_scale_ratio, np.ndarray) and eps < np.max(
                np.abs(_out_inp1ew_scale_ratio - _out_inp1ew_scale_ratio[0]) / _out_inp1ew_scale_ratio[0],
            ):
                # Possible fail case: coming from concat, so input scale is scalar while output is vector..
                raise AccelerasNumerizationError(
                    f"output_scale - input_scale ratio of {self.full_name} should be a scalar"
                )
            # create attribute to be used in scales-training context should it come
            self.output_scale_scalar_dof = _out_inp1ew_scale_ratio[0]
        else:
            self.output_scale_scalar_dof = _out_inp1ew_scale_ratio

    def enforce_io_encoding(self, training=False, **kwargs):
        """
        Special treatment of the (hopefully) temporary state of *scalar* elwa factor:
        The output scale vector is partially constrained, got a scalar degree-of-freedom (DOF),
         but got to stay proportional to the elwa-input scale vector.

        We verify this in offline context, and enforce it in scales training context -
         by creating a dummy scalar variable reflecting the scalar DOF.

        TODO verify or modify APU to have a single quantization group -
          otherwise above condition can't hold.. (currently create_hw_params doesn't do groups)
        """
        if self.vector_elwa_factor_supported:
            return  # nothing to do here - output scale is decoupled.
        else:
            self.output_op.output_scale = self.input_scales[1] * self.output_scale_scalar_dof

    def get_equalization_handler_type(self, predecessor_index=None):
        is_elwa_inputs = predecessor_index == 1  # indicates if this is the conv input(0) or elwa input(1)
        if self.transpose_output_width_features and is_elwa_inputs:
            handler_type = LayerHandlerType.unsupported
        elif is_elwa_inputs and not (self.transpose_output_width_features):
            handler_type = LayerHandlerType.ew_bouncer
        else:
            # self.transpose_output_width_features and not(is_elwa_inputs) or not self.transpose_output_width_features and
            handler_type = LayerHandlerType.consumer
        return EquivClassification(handler_type, is_source=False)

    def get_macs(self):
        _, high, widht, chanels_in = self.input_shapes[0]
        kh, kw, _, _ = self.get_kernel().numpy().shape
        s_h, s_w = self.conv_op.strides
        macs = (
            kh * kw * chanels_in * (high // s_h) * (widht // s_w) + high * widht * chanels_in
        ) * self._get_inefficiency_factor()
        return macs
