import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import InnerLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class LSTMLayer(InnerLayer):
    def __init__(self):
        super().__init__()
        self._op = LayerType.lstm
        self._recurrent_kernel = None
        self._recurrent_bias = None
        self._initial_h = None
        self._initial_c = None
        self._bw_kernel = None
        self._bw_bias = None
        self._bw_recurrent_kernel = None
        self._bw_recurrent_bias = None
        self._bw_initial_h = None
        self._bw_initial_c = None
        self._direction = "fw"

    @classmethod
    def create(cls, original_name, input_vertex_order, fw_params, bw_params, direction, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        kernel, recurrent_kernel, bias, recurrent_bias, initial_h, initial_c = fw_params
        bw_kernel, bw_recurrent_kernel, bw_bias, bw_recurrent_bias, bw_initial_h, bw_initial_c = bw_params
        layer.kernel = kernel
        layer.kernel_shape = kernel.shape
        layer.bias = bias
        layer.recurrent_kernel = recurrent_kernel
        layer.recurrent_bias = recurrent_bias
        layer.initial_h = initial_h
        layer.initial_c = initial_c
        layer.bw_kernel = bw_kernel
        layer.bw_bias = bw_bias
        layer.bw_recurrent_kernel = bw_recurrent_kernel
        layer.bw_recurrent_bias = bw_recurrent_bias
        layer.bw_initial_h = bw_initial_h
        layer.bw_initial_c = bw_initial_c
        layer.direction = direction
        return layer

    @property
    def num_directions(self):
        if self._direction == "bidirectional":
            return 2

        return 1

    @property
    def bias(self):
        return self._bias if self._bias is not None else np.zeros(self.kernel_shape[-1])

    @bias.setter
    def bias(self, bias):
        self._bias = bias

    @property
    def recurrent_kernel(self):
        return self._recurrent_kernel

    @recurrent_kernel.setter
    def recurrent_kernel(self, recurrent_kernel):
        self._recurrent_kernel = recurrent_kernel

    @property
    def recurrent_bias(self):
        return self._recurrent_bias

    @recurrent_bias.setter
    def recurrent_bias(self, recurrent_bias):
        self._recurrent_bias = recurrent_bias

    @property
    def initial_h(self):
        return self._initial_h

    @initial_h.setter
    def initial_h(self, initial_h):
        self._initial_h = initial_h

    @property
    def initial_c(self):
        return self._initial_c

    @initial_c.setter
    def initial_c(self, initial_c):
        self._initial_c = initial_c

    @property
    def bw_kernel(self):
        return self._bw_kernel

    @bw_kernel.setter
    def bw_kernel(self, bw_kernel):
        self._bw_kernel = bw_kernel

    @property
    def bw_bias(self):
        return self._bw_bias if self._bw_bias is not None else np.zeros(self.kernel_shape[-1])

    @bw_bias.setter
    def bw_bias(self, bw_bias):
        self._bw_bias = bw_bias

    @property
    def bw_recurrent_kernel(self):
        return self._bw_recurrent_kernel

    @bw_recurrent_kernel.setter
    def bw_recurrent_kernel(self, bw_recurrent_kernel):
        self._bw_recurrent_kernel = bw_recurrent_kernel

    @property
    def bw_recurrent_bias(self):
        return self._bw_recurrent_bias

    @bw_recurrent_bias.setter
    def bw_recurrent_bias(self, bw_recurrent_bias):
        self._bw_recurrent_bias = bw_recurrent_bias

    @property
    def bw_initial_h(self):
        return self._bw_initial_h

    @bw_initial_h.setter
    def bw_initial_h(self, bw_initial_h):
        self._bw_initial_h = bw_initial_h

    @property
    def bw_initial_c(self):
        return self._bw_initial_c

    @bw_initial_c.setter
    def bw_initial_c(self, bw_initial_c):
        self._bw_initial_c = bw_initial_c

    @property
    def direction(self):
        return self._direction

    @direction.setter
    def direction(self, direction):
        self._direction = direction

    def _calc_output_shape(self):
        output_shape = copy.deepcopy(self.input_shape)
        output_shape[-1] = (self.kernel_shape[-1] // 4) * self.num_directions

        return output_shape

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
