import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class EWAddLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.base_ew_add
        self._number_of_inputs_supported = 2
        self._input_repeats = [[1, 1, 1], [1, 1, 1]]

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None):
        return super().create(original_name, input_vertex_order, output_shapes)

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_EW_ADD
        for repeat in self._input_repeats:
            input_repeat = node.input_repeats.add()
            input_repeat.height, input_repeat.width, input_repeat.features = repeat
        return node

    @property
    def macs(self):
        # The /2 is because we don't do multiply
        return self.ops / 2

    @property
    def ops(self):
        return float(np.abs(np.prod(np.array(self.output_shape))))

    @property
    def input_shape(self):
        input0_shape = self._input_shapes[0]
        if len(self._input_shapes) == 1:
            return input0_shape

        if len(input0_shape) == 2:
            return input0_shape

        input1_shape = self._input_shapes[1]
        return list(map(max, zip(input0_shape, input1_shape)))

    @property
    def input_repeats(self):
        return self._input_repeats

    @input_repeats.setter
    def input_repeats(self, input_repeats):
        self._input_repeats = input_repeats

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        if any(repeat != [1, 1, 1] for repeat in self._input_repeats):
            result["params"]["input_repeats"] = self._input_repeats
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.input_repeats = hn["params"].get("input_repeats", [[1, 1, 1], [1, 1, 1]])

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op not in [LayerType.equal, LayerType.lstm, LayerType.rnn, LayerType.conv, LayerType.dense]:
            layer.input_repeats = old_layer.input_repeats
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
