import math

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LayerEquivType,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.inner_layer import InnerLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class DenseLayer(InnerLayer):
    """
    HN representation of a fully connected layer, implementing a matrix multiplication between
    data input and weights kernel, in a linear form y=xA'+b (where A is the kernel, and b is the bias component)
    """

    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.base_dense
        self._should_transpose_kernel = False

    @classmethod
    def create(cls, original_name, input_vertex_order, bias, kernel, should_transpose_kernel=False, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.bias = bias
        layer.kernel = kernel
        layer.kernel_shape = kernel.shape if not should_transpose_kernel else [kernel.shape[1], kernel.shape[0]]
        layer.should_transpose_kernel = should_transpose_kernel
        return layer

    @property
    def should_transpose_kernel(self):
        return self._should_transpose_kernel

    @should_transpose_kernel.setter
    def should_transpose_kernel(self, should_transpose_kernel):
        self._should_transpose_kernel = should_transpose_kernel

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_DENSE
        node.kernel_shape.width, node.kernel_shape.features = self.kernel_shape
        node.strides.height, node.strides.width = 1, 1
        return node

    def reshape_input(self, input_shape):
        # edge case for onnx models with different flatten behavior pre-dense layers.
        if self.should_transpose_kernel:
            self.update_kernel_from_input_shape(input_shape)
            self.should_transpose_kernel = False
        return input_shape

    def update_kernel_from_input_shape(self, input_shape):
        rank4 = len(input_shape) == 4
        if rank4:
            h, w, f_in = input_shape[1:]
            f_out = self.kernel_shape[1]
            rs = np.reshape(self.kernel, [f_out, f_in, w, h])
            # reverse an edge case where input is given by [batch, f_out, h, w]
            tr = np.transpose(rs, axes=[0, 2, 3, 1])
            self.kernel = np.transpose(np.reshape(tr, [f_out, h * w * f_in]))
        else:
            self.kernel = np.transpose(self.kernel)

    def _calc_output_shape(self):
        return [-1, 1, 1, self._kernel_shape[1]]

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._kernel_shape = [pb.kernel_shape.width, pb.kernel_shape.features]
        return layer

    @property
    def input_features(self):
        input_features = math.prod(self.input_shape[1:])
        if input_features != self.kernel_shape[0]:
            raise UnsupportedModelError(
                f"Invalid kernel shape for {self.full_name_msg}. Kernel input features: "
                f"{self.kernel_shape[0]}, Input features: {input_features}",
            )
        return input_features

    @property
    def output_features(self):
        output_features = self.output_shape[-1]
        for dim in self.output_shape[1:-1]:
            if dim != 1:
                raise UnsupportedModelError(
                    f"Invalid output shape for {self.full_name_msg} dense layer. Output shape: {self.output_shape}",
                )
        if output_features != self.kernel_shape[1]:
            raise UnsupportedModelError(
                f"Invalid kernel shape for {self.full_name_msg} Kernel output features: "
                f"{self.kernel_shape[1]}, Output features: {output_features}",
            )
        return output_features

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected

    def get_axes_mask(self, type_of_layer=None):
        if type_of_layer == LayerEquivType.producer:
            return [True, False]
        else:
            return [False, True]

    @property
    def kernel_height(self):
        return 1

    @property
    def kernel_width(self):
        return 1
