import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.inner_layer import InnerLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class GRULayer(InnerLayer):
    def __init__(self):
        super().__init__()
        self._op = LayerType.gru
        self._hidden_size = None
        self._linear_before_reset = None
        self._kernel = None
        self._kernel_shape = None
        self._recurrence_kernel = None
        self._bias = None
        self._sequence_lens = None
        self._initial_h = None
        self._number_of_inputs_supported = 2

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        hidden_size,
        linear_before_reset,
        kernel,
        recurrence_kernel,
        bias,
        sequence_lens,
        initial_h,
        output_shapes=None,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.hidden_size = hidden_size
        layer.linear_before_reset = linear_before_reset
        layer.kernel = np.transpose(kernel[None, ...], [0, 1, 3, 2])
        layer.kernel_shape = kernel.shape
        layer.recurrence_kernel = np.transpose(recurrence_kernel[None, ...], [0, 1, 3, 2])
        # the bias is shared with the reset gate and the update gate that's the reason for the multiplication by 2
        layer.bias = bias if bias is not None else np.zeros(kernel.shape[-2] * 2)
        layer.sequence_lens = sequence_lens
        layer.initial_h = initial_h if initial_h is not None else np.zeros((1, 1, hidden_size))

        return layer

    @property
    def hidden_size(self):
        return self._hidden_size

    @hidden_size.setter
    def hidden_size(self, hidden_size):
        self._hidden_size = hidden_size

    @property
    def linear_before_reset(self):
        return self._linear_before_reset

    @linear_before_reset.setter
    def linear_before_reset(self, linear_before_reset):
        self._linear_before_reset = linear_before_reset

    @property
    def kernel_shape(self):
        return self._kernel_shape

    @kernel_shape.setter
    def kernel_shape(self, kernel_shape):
        self._kernel_shape = kernel_shape

    @property
    def recurrence_kernel(self):
        return self._recurrence_kernel

    @recurrence_kernel.setter
    def recurrence_kernel(self, recurrence_kernel):
        self._recurrence_kernel = recurrence_kernel

    @property
    def bias(self):
        return self._bias

    @bias.setter
    def bias(self, bias):
        self._bias = bias

    @property
    def sequence_lens(self):
        return self._sequence_lens

    @sequence_lens.setter
    def sequence_lens(self, sequence_lens):
        self._sequence_lens = sequence_lens

    @property
    def initial_h(self):
        return self._initial_h

    @initial_h.setter
    def initial_h(self, initial_h):
        self._initial_h = initial_h

    def _calc_output_shape(self):
        return self.input_shapes[0]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
