import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class BiasAddLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True
    _IS_REAL_LAYER = True
    _IS_RANK3_SUPPORTED = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.bias_add
        self._bias = None

    @classmethod
    def create(cls, original_name, input_vertex_order, bias, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.bias = bias
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BIAS_ADD
        return node

    @property
    def bias(self):
        return self._bias if self._bias is not None else np.zeros(self.output_shape[-1])

    @bias.setter
    def bias(self, bias):
        self._bias = bias

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
