import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ExternalPadLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    # Quantize this layer for initialize zp_in
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.external_pad
        self._top = 0
        self._bottom = 0
        self._left = 0
        self._right = 0
        self._front = 0
        self._back = 0
        self._padding_const_value = 0

    @classmethod
    def create(cls, original_name, input_vertex_order, padding_const_value=0, output_shapes=None, padding_vals=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        if padding_vals is not None:
            layer.set_pad(padding_vals)
        layer.padding_const_value = padding_const_value
        return layer

    def set_pad(self, pad_array):
        if len(pad_array) == 4:
            self._top, self._bottom, self._left, self._right = pad_array
        else:
            self._top, self._bottom, self._left, self._right, self._front, self._back = pad_array

    @property
    def top(self):
        return self._top

    @property
    def bottom(self):
        return self._bottom

    @property
    def left(self):
        return self._left

    @property
    def right(self):
        return self._right

    @property
    def front(self):
        return self._front

    @property
    def back(self):
        return self._back

    @property
    def padding_const_value(self):
        return self._padding_const_value

    @padding_const_value.setter
    def padding_const_value(self, new_const_value):
        self._padding_const_value = new_const_value

    def _calc_output_shape(self):
        return [
            self.input_shape[0],
            self.input_shape[1] + self._top + self._bottom,
            self.input_shape[2] + self._left + self._right,
            self.input_shape[3] + self._front + self._back,
        ]

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_EXTERNAL_PAD
        node.kernel_shape.width, node.kernel_shape.features = self.input_shape[1], self.input_shape[1]
        node.external_pad_params.top = self._top
        node.external_pad_params.bottom = self._bottom
        node.external_pad_params.left = self._left
        node.external_pad_params.right = self._right
        node.external_pad_params.front = self._front
        node.external_pad_params.back = self._back
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._top = pb.external_pad_params.top
        layer._bottom = pb.external_pad_params.bottom
        layer._left = pb.external_pad_params.left
        layer._right = pb.external_pad_params.right
        layer._front = pb.external_pad_params.front
        layer._back = pb.external_pad_params.back
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer._top = old_layer.top
        layer._bottom = old_layer.bottom
        layer._left = old_layer.left
        layer._right = old_layer.right
        layer._front = old_layer.front
        layer._back = old_layer.back
        layer._padding_const_value = old_layer.padding_const_value
        return layer

    def move_params(self, layer):
        super(LayerWithParams, self).move_params(layer)
        if layer.op == LayerType.external_pad:
            self._padding_const_value = layer.padding_const_value

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "params" in hn and "external_pad_params" in hn["params"]:
            layer.set_pad(hn["params"]["external_pad_params"])

        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["external_pad_params"] = [
            self._top,
            self._bottom,
            self._left,
            self._right,
            self._front,
            self._back,
        ]

        return result

    def _get_equiv_handler(self):
        return LayerHandlerType.transparent if self._front == 0 and self._back == 0 else LayerHandlerType.unsupported

    def get_equalization_handler_type(self, predecessor=None):
        handler = self._get_equiv_handler()
        return EquivClassification(handler, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        handler = self._get_equiv_handler()
        return EquivClassification(handler, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        handler = self._get_equiv_handler()
        return EquivClassification(handler, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
