import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EWMultType,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationTypes, DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.activation_layer import ActivationLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification
from hailo_sdk_common.tools.models_translator_helper import is_feature_repeats, is_spatial_broadcast


class EWMultLayer(LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.ew_mult
        self._input_list = []
        self._number_of_inputs_supported = 2
        self._is_softmax_mask = False
        self._dynamic_weights = False
        self._reduce_sum_groups = None
        self._ew_mult_type = EWMultType.on_apu
        # [height, width, features] for each input
        self._input_repeats = [[1, 1, 1], [1, 1, 1]]

    @property
    def input_repeats(self):
        return self._input_repeats

    @input_repeats.setter
    def input_repeats(self, input_repeats):
        self._input_repeats = input_repeats

    @property
    def macs(self):
        # The /2 is because we don't do accumulate
        return self.ops / 2

    @property
    def ops(self):
        return float(np.abs(np.prod(np.array(self.output_shape))))

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None):
        return super().create(original_name, input_vertex_order, output_shapes)

    @property
    def input_list(self):
        return self._input_list

    @input_list.setter
    def input_list(self, input_list):
        self._input_list = input_list

    @property
    def is_softmax_mask(self):
        return self._is_softmax_mask

    @is_softmax_mask.setter
    def is_softmax_mask(self, is_softmax_mask):
        self._is_softmax_mask = is_softmax_mask

    def append_to_input_list(self, inp):
        self.input_list.append(inp)

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["activation"] = self._activation.value
        result["params"]["is_softmax_mask"] = self._is_softmax_mask
        result["params"]["ew_mult_type"] = self._ew_mult_type.value
        if self._reduce_sum_groups is not None:
            result["params"]["reduce_sum_groups"] = self._reduce_sum_groups
        if any(repeat != [1, 1, 1] for repeat in self._input_repeats):
            result["params"]["input_repeats"] = self._input_repeats

        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.activation = ActivationTypes[hn["params"]["activation"]]
        layer.is_softmax_mask = hn["params"].get("is_softmax_mask", False)
        layer.input_repeats = hn["params"].get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        layer._ew_mult_type = EWMultType(hn["params"].get("ew_mult_type", EWMultType.on_apu))
        layer._reduce_sum_groups = hn["params"].get(
            "reduce_sum_groups",
            None
            if layer._ew_mult_type == EWMultType.on_apu
            else layer._input_shapes[0][-1] * layer._input_repeats[0][-1],
        )
        layer._dynamic_weights = False if layer._ew_mult_type == EWMultType.on_apu else True
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_EW_MULT
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        if self._reduce_sum_groups is not None:
            node.reduce_sum_groups = self._reduce_sum_groups
        else:
            node.reduce_sum_groups = self._input_shapes[0][-1] * self._input_repeats[0][-1]

        for repeat in self._input_repeats:
            input_repeat = node.input_repeats.add()
            input_repeat.height, input_repeat.width, input_repeat.features = repeat

        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        # Creating a stand-alone ew-add layer as a combination of dummy conv and add
        layer = super().from_pb(pb, pb_wrapper)
        layer.activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        layer.input_repeats = [
            [input_repeat.height, input_repeat.width, input_repeat.features] for input_repeat in pb.input_repeats
        ]
        layer._reduce_sum_groups = pb.reduce_sum_groups
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op != LayerType.lstm:
            layer.activation = old_layer.activation
            layer.input_repeats = old_layer.input_repeats

        return layer

    def update_output_shapes(self, **kwargs):
        if len(self.input_list) != 2:
            raise UnsupportedModelError(f"{self.full_name_msg} expects 2 inputs but found {len(self._input_list)}")
        input0_shape = self.pred_layer_output_shape(self.input_list[0], True)
        input1_shape = self.pred_layer_output_shape(self.input_list[1], True)

        input0_shape = [-1, *[dim * ratio for dim, ratio in zip(input0_shape[1:], self._input_repeats[0])]]
        input1_shape = [-1, *[dim * ratio for dim, ratio in zip(input1_shape[1:], self._input_repeats[1])]]

        if input0_shape != input1_shape:
            spatial_cond = is_spatial_broadcast(
                input0_shape,
                input1_shape,
                is_two_sided=True,
            ) or is_feature_repeats(input0_shape, input1_shape)
            features_cond = input0_shape[-1] == 1 or input1_shape[-1] == 1
            if not (spatial_cond or features_cond):
                raise UnsupportedModelError(
                    f"{self.full_name_msg} expects both inputs to have the same dimensions, but"
                    f" got {input0_shape} and {input1_shape} instead.",
                )

        super().update_output_shapes(**kwargs)

    def _calc_output_shape(self):
        input0_shape = self.pred_layer_output_shape(self.input_list[0], True)
        input1_shape = self.pred_layer_output_shape(self.input_list[1], True)

        input0_shape = [-1, *[dim * ratio for dim, ratio in zip(input0_shape[1:], self._input_repeats[0])]]
        input1_shape = [-1, *[dim * ratio for dim, ratio in zip(input1_shape[1:], self._input_repeats[1])]]

        result = list(map(max, zip(input0_shape, input1_shape)))
        if self._reduce_sum_groups is not None:
            result[-1] = self._reduce_sum_groups
        if "defuse_features" in self.defuse_params and self.defuse_type not in [
            DefuseType.none,
            DefuseType.spatial_w,
            DefuseType.ew_mult_on_mac,
        ]:
            result[-1] = self.defuse_features

        if (
            self.defuse_type is DefuseType.spatial_w
            and "defuse_input_width" in self.defuse_params
            and self.defuse_input_width != 0
        ):
            result[-2] = self.defuse_input_width

        if self.defuse_types is not None and DefuseType.ew_mult_on_mac in self.defuse_types:
            result[-2] *= 2  # in b0 should be 4

        return result

    def _calculate_input_shapes(self):
        input_shapes = []
        for input_repeats, input_shape in zip(self._input_repeats, self._input_shapes):
            input_shapes.append([-1, *[dim * ratio for dim, ratio in zip(input_shape[1:], input_repeats)]])
        return input_shapes

    @property
    def input_features(self):
        return self._get_shape_single_dim(self._calculate_input_shapes(), -1)

    @property
    def requires_native_weights(self):
        if self._activation not in ActivationLayer._REQUIRES_NATIVE_WEIGHTS:
            self._logger.warning(
                f"Layer {self.name} of activation type {self._activation.value} does not specify whether native weights are "
                "required. Assuming False.",
            )
            return False

        return ActivationLayer._REQUIRES_NATIVE_WEIGHTS[self._activation]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
