import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import InnerLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class RNNLayer(InnerLayer):
    def __init__(self):
        super().__init__()
        self._op = LayerType.rnn
        self._recurrent_kernel = None
        self._recurrent_bias = None
        self._initial_h = None

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        kernel,
        bias,
        recurrent_kernel,
        recurrent_bias,
        initial_h,
        output_shapes=None,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.kernel = kernel
        layer.kernel_shape = kernel.shape
        layer.bias = bias
        layer.recurrent_kernel = recurrent_kernel
        layer.recurrent_bias = recurrent_bias
        layer.initial_h = initial_h
        return layer

    @property
    def recurrent_kernel(self):
        return self._recurrent_kernel

    @recurrent_kernel.setter
    def recurrent_kernel(self, recurrent_kernel):
        self._recurrent_kernel = recurrent_kernel

    @property
    def recurrent_bias(self):
        return self._recurrent_bias

    @recurrent_bias.setter
    def recurrent_bias(self, recurrent_bias):
        self._recurrent_bias = recurrent_bias

    @property
    def initial_h(self):
        return self._initial_h

    @initial_h.setter
    def initial_h(self, initial_h):
        self._initial_h = initial_h

    def _calc_output_shape(self):
        output_shape = copy.deepcopy(self.input_shape)
        output_shape[-1] = self.kernel_shape[-1]
        return output_shape

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
