import numpy as np

from hailo_sdk_common.hailo_nn.hn_definitions import LayerType, PaddingType
from hailo_sdk_common.hailo_nn.hn_layers.fused_conv2d import FusedConv2DLayer


class FusedEWAddLayer(FusedConv2DLayer):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        # Creating a stand-alone ew-add layer as a combination of dummy conv and add
        layer = super().from_pb(pb, pb_wrapper)
        layer._op = LayerType.conv
        layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        layer._strides = [1, 1, 1, 1]
        layer._dilations = [1, 1, 1, 1]
        layer._kernel_shape = [1, 1, layer.output_features, layer.output_features]
        layer._padding = PaddingType.valid
        layer._bias = np.zeros(layer.output_features, dtype=np.float32)
        layer._kernel = np.reshape(np.identity(layer.output_features, dtype=np.float32), layer._kernel_shape)
        layer._groups = 1
        return layer

    @property
    def macs(self):
        # The /2 is because we don't do multiply
        return self.ops / 2

    @property
    def ops(self):
        return float(np.abs(np.prod(np.array(self.output_shape))))

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.strides = [1, 1, 1, 1]
        layer.dilations = [1, 1, 1, 1]
        layer.kernel_shape = [1, 1, layer.output_features, layer.output_features]
        layer.padding = PaddingType.valid
        layer.bias = np.zeros(layer.output_features, dtype=np.float32)
        layer.kernel = np.reshape(np.identity(layer.output_features, dtype=np.float32), layer.kernel_shape)
        layer.groups = 1
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        return super().to_pb(pb_wrapper, is_multi_scope)
