from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv_add import BaseHailoConvAdd
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    EquivClassification,
    LayerHandlerType,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasNumerizationError
from hailo_model_optimization.acceleras.utils.hn_npz_utils import get_hn_padding


class HailoConvAdd(BaseHailoConvAdd):
    """
    Hailo's standard conv layer with elementwise addition.

    TODO do we really need the separation between this and the base?  maybe at least put in same module?
    """

    vector_elwa_factor_supported = False
    _output_scale_scalar_dof: float

    def __init__(
        self,
        name: str,
        filters,
        kernel_size,
        strides=(1, 1),
        padding: Union[str, PaddingType] = "SAME",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate=(1, 1),
        ew_add_factor=1,
        groups=1,
        activation: Union[str, callable, ActivationType] = "linear",
        transpose_output_width_features=False,
        logger=None,
        **kwargs,
    ):
        stripped_conv_op = ConvStrippedOp(
            f"{name}/conv_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            logger=logger,
        )
        super().__init__(
            name=name,
            conv_op=stripped_conv_op,
            activation=activation,
            ew_add_factor=ew_add_factor,
            logger=logger,
            transpose_output_width_features=transpose_output_width_features,
            **kwargs,
        )

        self.output_scale_scalar_dof = 1
        # !! Set to True when LCU implementation supports per-channel multiplier.

        self.encoding_const = False

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)
        transpose_output_width_features = params.get("transpose_output_width_features", False)
        cls._validate_elwa(params["elementwise_add"])

        layer = cls(
            name=lname,
            filters=kshape[-1],
            kernel_size=kshape[0:2],
            padding=padding,
            stride_align=stride_align,
            strides=params["strides"][1:3],
            groups=params["groups"],
            activation=params["activation"],
            transpose_output_width_features=transpose_output_width_features,
            dilation_rate=params["dilations"][1:3],
            ew_add_factor=params.get("elementwise_add_factor", 1),
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def _create_out_in_scale_ratio(self):
        """
        create the output_scale_scalar_dof
        """
        _out_inp1ew_scale_ratio = self.output_scale / self.input_scales[1]
        eps = 1e-5
        if _out_inp1ew_scale_ratio.shape != ():
            if isinstance(_out_inp1ew_scale_ratio, np.ndarray) and eps < np.max(
                np.abs(_out_inp1ew_scale_ratio - _out_inp1ew_scale_ratio[0]) / _out_inp1ew_scale_ratio[0],
            ):
                ratio_max = np.max(
                    np.abs(_out_inp1ew_scale_ratio - _out_inp1ew_scale_ratio[0]) / _out_inp1ew_scale_ratio[0]
                )
                # Possible fail case: coming from concat, so input scale is scalar while output is vector..
                raise AccelerasNumerizationError(
                    f"output_scale - input_scale ratio of {self.full_name} should be a scalar {ratio_max}"
                )
            # create attribute to be used in scales-training context should it come
            self.output_scale_scalar_dof = _out_inp1ew_scale_ratio[0]
        else:
            self.output_scale_scalar_dof = _out_inp1ew_scale_ratio

    def enforce_io_encoding(self, training=False, **kwargs):
        """
        Special treatment of the (hopefully) temporary state of *scalar* elwa factor:
        The output scale vector is partially constrained, got a scalar degree-of-freedom (DOF),
         but got to stay proportional to the elwa-input scale vector.

        We verify this in offline context, and enforce it in scales training context -
         by creating a dummy scalar variable reflecting the scalar DOF.

        TODO verify or modify APU to have a single quantization group -
          otherwise above condition can't hold.. (currently create_hw_params doesn't do groups)
        """
        if self.vector_elwa_factor_supported:
            return  # nothing to do here - output scale is decoupled.
        else:
            self.output_op.output_scale = self.input_scales[1] * self.output_scale_scalar_dof

    def get_equalization_handler_type(self, predecessor_index=None):
        is_elwa_inputs = predecessor_index == 1  # indicates if this is the conv input(0) or elwa input(1)
        if self.transpose_output_width_features and is_elwa_inputs:
            handler_type = LayerHandlerType.unsupported
        elif is_elwa_inputs and not (self.transpose_output_width_features):
            handler_type = LayerHandlerType.ew_bouncer
        else:
            # self.transpose_output_width_features and not(is_elwa_inputs) or not self.transpose_output_width_features and
            handler_type = LayerHandlerType.consumer
        return EquivClassification(handler_type, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        is_elwa_inputs = predecessor_index == 1
        if self.transpose_output_width_features and is_elwa_inputs:
            handler_type = LayerHandlerType.unsupported
        elif is_elwa_inputs and not (self.transpose_output_width_features):
            handler_type = LayerHandlerType.transparent
        else:
            handler_type = LayerHandlerType.consumer
        return EquivClassification(handler_type, is_source=True)

    def define_encodings(self, flow):
        super().define_encodings(flow)
        if not self.vector_elwa_factor_supported:
            flow.add_encoding(f"{self.full_name}/output_scale_scalar_dof:0", EncodingType.Scale, scalar=False, shape=())
            flow.nodes[f"{self.act_op.full_name}/output_factor_by_group:0"]["encoding"].scalar = True
            flow.nodes[f"{self.elwa_op.full_name}/desired_factors:0"]["encoding"].scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        if not self.vector_elwa_factor_supported:
            if self.elwa_op.encoding_const and self.output_op.encoding_const:
                enc.identity(f"{self.full_name}/output_scale_scalar_dof:0", self.output_scale_scalar_dof)
            else:
                enc.div(
                    enc.dummy(0),
                    f"{self.output_op.full_name}/output_scale:0",
                    f"{self.elwa_op.full_name}/input_scale:1",
                )
                enc.callback(
                    f"{self.full_name}/output_scale_scalar_dof:0",
                    enc.dummy(0),
                    lambda x: x[0] if x.shape != () else x,
                    outs_scalar=True,
                    outs_shape=(),
                )

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        if not self.vector_elwa_factor_supported:
            self.output_scale_scalar_dof = encodings[f"{self.full_name}/output_scale_scalar_dof:0"]

    def get_macs(self):
        conv_macs = super().get_macs()
        _, high, widht, chanels_in = self.input_shapes[1]
        add_macs = high * widht * chanels_in * self._get_inefficiency_factor()
        return conv_macs + add_macs
