import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationTypes, DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.activation_layer import ActivationLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FusedStandaloneEWAddLayer(LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.ew_add
        self._input_list = []
        self._number_of_inputs_supported = 2
        # [height, width, features] for each input
        self._input_repeats = [[1, 1, 1], [1, 1, 1]]

    @property
    def input_list(self):
        return self._input_list

    @input_list.setter
    def input_list(self, input_list):
        self._input_list = input_list

    @property
    def input_repeats(self):
        return self._input_repeats

    @input_repeats.setter
    def input_repeats(self, input_repeats):
        self._input_repeats = input_repeats

    @property
    def macs(self):
        # The /2 is because we don't do multiply
        return self.ops / 2

    @property
    def ops(self):
        return float(np.abs(np.prod(np.array(self.output_shape))))

    def append_to_input_list(self, inp):
        self.input_list.append(inp)

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["activation"] = self._activation.value
        if any([repeat != [1, 1, 1] for repeat in self._input_repeats]):
            result["params"]["input_repeats"] = self._input_repeats
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.activation = ActivationTypes[hn["params"]["activation"]]
        layer.input_repeats = hn["params"].get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_STANDALONE_EW_ADD
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]

        for repeat in self._input_repeats:
            input_repeat = node.input_repeats.add()
            input_repeat.height, input_repeat.width, input_repeat.features = repeat

        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        if len(pb.input_repeats) == 2:
            layer._input_repeats = [
                [input_repeat.height, input_repeat.width, input_repeat.features] for input_repeat in pb.input_repeats
            ]
        elif len(pb.input_repeats) == 0:
            layer._input_repeats = [[1, 1, 1], [1, 1, 1]]
        else:
            raise UnsupportedModelError(f"Invalid input_repeats length ({len(pb.input_repeats)}) for {layer.name}")

        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.input_repeats = old_layer.input_repeats
        if hasattr(old_layer, "activation"):
            layer.activation = old_layer.activation
        return layer

    def _calc_output_shape(self):
        input0 = self.input_list[0]
        input1 = self.input_list[1]

        input0_shape = self.pred_layer_output_shape(input0, True)
        input1_shape = self.pred_layer_output_shape(input1, True)

        input0_shape = [-1, *[dim * ratio for dim, ratio in zip(input0_shape[1:], self._input_repeats[0])]]
        input1_shape = [-1, *[dim * ratio for dim, ratio in zip(input1_shape[1:], self._input_repeats[1])]]

        if input0_shape != input1_shape and not (self.is_from_dense() and input0_shape[-1] == input1_shape[-1]):
            raise UnsupportedModelError(
                f"Unsupported dimensions: {input0.name} - {input0_shape}, {input1.name} - "
                f"{input1_shape} at {self.full_name_msg}",
            )
        if "defuse_features" in self.defuse_params and self.defuse_type not in [DefuseType.none, DefuseType.spatial_w]:
            output_f = self.defuse_features
        else:
            output_f = input0_shape[-1]

        if len(input0_shape) == 2:
            return [-1, output_f]
        else:
            if (
                self.defuse_params["defuse_types"] is not None
                and DefuseType.double_precision_conv in self.defuse_params["defuse_types"]
            ):
                output_f = int(output_f / 2)

            if (
                self.defuse_type is DefuseType.spatial_w
                and "defuse_input_width" in self.defuse_params
                and self.defuse_input_width != 0
            ):
                output_w = self.defuse_input_width
            else:
                output_w = input0_shape[-2]

            return [-1, input0_shape[1], output_w, output_f]

    @property
    def input_width(self):
        return 1 if len(self._input_shapes[0]) == 2 else self._input_shapes[0][2]

    @property
    def input_features(self):
        input_shapes = []
        for input_repeats, input_shape in zip(self._input_repeats, self._input_shapes):
            input_shapes.append([-1, *[dim * ratio for dim, ratio in zip(input_shape[1:], input_repeats)]])

        return self._get_shape_single_dim(input_shapes, -1)

    @property
    def output_features(self):
        output_shapes = []
        for input_repeats, input_shape in zip(self._input_repeats, self._input_shapes):
            shape = [-1, *[dim * ratio for dim, ratio in zip(input_shape[1:], input_repeats)]]
            if (
                self.defuse_params["defuse_types"] is not None
                and DefuseType.double_precision_conv in self.defuse_params["defuse_types"]
            ):
                shape[-1] = int(shape[-1] / 2)
            output_shapes.append(shape)

        return self._get_shape_single_dim(output_shapes, -1)

    @property
    def requires_native_weights(self):
        if self._activation not in ActivationLayer._REQUIRES_NATIVE_WEIGHTS:
            self._logger.warning(
                f"Layer {self.name} of activation type {self._activation.value} does not specify "
                "whether native weights are required. Assuming False.",
            )
            return False

        return ActivationLayer._REQUIRES_NATIVE_WEIGHTS[self._activation]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.multi_source, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
