import numpy as np

from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.ew_add import EWAddLayer


class EWSubLayer(EWAddLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super(EWAddLayer, self).__init__()
        self._op = LayerType.base_ew_sub
        self._number_of_inputs_supported = 2
        self._input_repeats = [[1, 1, 1], [1, 1, 1]]

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None):
        return super().create(
            original_name,
            input_vertex_order=input_vertex_order,
            output_shapes=output_shapes,
        )

    @property
    def macs(self):
        # The /2 is because we don't do multiply
        return self.ops / 2

    @property
    def ops(self):
        return float(np.abs(np.prod(np.array(self.output_shape))))

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_EW_SUB
        return node

    def get_equalization_handler_type(self, predecessor=None):
        return super().get_equalization_handler_type(predecessor)

    def get_params_sorter_handler_type(self, predecessor=None):
        return super().get_params_sorter_handler_type(predecessor)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return super().get_dead_channels_removal_handler_type(predecessor)

    def ibc_supported(self):
        return super().ibc_supported()
