import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LayerHandlerType,
    LayerSupportStatus,
    MatmulCorrectionType,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, ActivationTypes, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class MatmulLayer(LayerWithActivation):
    """
    HN representation of a matrix multiplication layer, implementing a general matrix multiplication between
    two data inputs in a linear form y=A'B'+c (where A and B are the data inputs, and c is the bias component)
    """

    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.matmul
        self._number_of_inputs_supported = 2
        self._transpose_matmul_input = True
        self._zp_correction_type = MatmulCorrectionType.ZP_COMP_NONE
        self._zp_comp_rank: int = 0
        self._kernel_shape = None
        self._dynamic_weights = True
        self._groups = 1
        self._input_windows = [1, 1, 1]
        self._input_tiles = [[1, 1, 1], [1, 1, 1]]

    @property
    def input_tiles(self):
        return self._input_tiles

    @input_tiles.setter
    def input_tiles(self, input_tiles):
        self._input_tiles = input_tiles

    @property
    def macs(self):
        # Example: MatMul of A*B, when the ONNX shape of A is [GxMxL] and B is [GxLxN] and output is [GxMxN].
        # We will parse A     as [1x1xMx(L*G)]
        # and B               as [1x1xNx(L*G)]
        # and the "kernel"    as [1x1x(L*G)x(N*G)]
        # and the output will be [1x1xMx(N*G)]
        # Moreover, The compiler has the freedom to split the width into the rows dimension.
        # The number of multiple-and-accumulate operations is M*L*N * G.
        # To extract the number of MACs, we should extract G, M, L and N:
        #
        # self.groups is G
        # M is rows * width A
        # L is dim2 of the kernel, divided by groups
        # N is dim3 of the kernel, divided by groups
        data_w = self.input_shapes[0][1] * self.input_shapes[0][2]
        weight_w = self._kernel_shape[2] / self.groups
        weight_f = self._kernel_shape[3] / self.groups
        return data_w * weight_w * weight_f * self.groups

    @property
    def ops(self):
        return self.macs * 2

    @property
    def kernel_shape(self):
        if self._kernel_shape is None:
            return [1, 1, self.input_shapes[0][3], self._calc_output_shape()[-1]]
        return self._kernel_shape

    @kernel_shape.setter
    def kernel_shape(self, value):
        self._kernel_shape = value

    @property
    def zp_correction_type(self):
        return self._zp_correction_type

    @zp_correction_type.setter
    def zp_correction_type(self, zp_correction_type):
        self._zp_correction_type = zp_correction_type

    @property
    def zp_comp_rank(self):
        return self._zp_comp_rank

    @zp_comp_rank.setter
    def zp_comp_rank(self, zp_comp_rank):
        self._zp_comp_rank = zp_comp_rank

    @classmethod
    def create(cls, original_name, input_vertex_order, should_transpose_input=True, groups=1, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.transpose_matmul_input = should_transpose_input
        layer.groups = groups

        return layer

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "compliation_params" not in hn or "mixed_mem" not in hn["compilation_params"]:
            layer.set_compilation_params(mixed_mem=False)

        layer.transpose_matmul_input = hn["params"]["transpose_matmul_input"]
        layer.zp_comp_rank = hn["params"].get("zp_comp_rank", 1 if layer.zp_comp_added else 0)
        layer.activation = ActivationTypes[hn["params"].get("activation", ActivationType.linear.value)]
        layer.input_tiles = hn["params"].get("input_tiles", [[1, 1, 1], [1, 1, 1]])
        assert (
            len(layer.input_tiles) == 2 or len(layer.input_tiles[0]) == 3 or len(layer.input_tiles[1]) == 3
        ), "matmul input tile should always be with shape (2,3)"

        if "input_windows" in hn["params"]:
            layer.input_windows = hn["params"].get("input_windows", hn["params"].get("windows", [1, 1, 1]))
            if hn["params"].get("output_windows", layer.input_windows) != layer.input_windows:
                raise UnsupportedModelError(
                    f"Output windows must be the same as input windows for {layer.full_name_msg}"
                )
        layer.zp_correction_type = MatmulCorrectionType(
            hn["params"].get(
                "zp_correction_type",
                MatmulCorrectionType.ZP_COMP if layer.zp_comp_added else MatmulCorrectionType.ZP_COMP_NONE,
            ),
        )

        if "groups" in hn["params"]:
            layer.groups = hn["params"]["groups"]
        if "kernel_shape" in hn["params"]:
            layer.kernel_shape = hn["params"]["kernel_shape"]
        else:
            layer.kernel_shape = [1, 1, layer.input_shapes[0][3], layer._calc_output_shape()[-1]]

        return layer

    @property
    def input_windows(self):
        return self._input_windows

    @input_windows.setter
    def input_windows(self, input_windows):
        self._input_windows = input_windows

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["transpose_matmul_input"] = self._transpose_matmul_input
        result["params"]["kernel_shape"] = self.kernel_shape
        result["params"]["groups"] = self._groups
        result["params"]["input_windows"] = self.input_windows
        result["params"]["zp_correction_type"] = self._zp_correction_type.value
        result["params"]["zp_comp_rank"] = self._zp_comp_rank
        result["params"]["activation"] = self._activation.value
        if any(repeat != [1, 1, 1] for repeat in self._input_tiles):
            result["params"]["input_tiles"] = self._input_tiles
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_MATMUL
        node.kernel_shape.height, node.kernel_shape.width, _, node.kernel_shape.features = self.kernel_shape
        node.strides.height, node.strides.width = [1, 1]
        node.dilations.height, node.dilations.width = [1, 1]
        node.groups = self._groups
        for number_of_windows in self.input_windows:
            node.input_windows.append(number_of_windows)
        node.transpose_matmul_input = self._transpose_matmul_input
        node.zp_comp_rank = self._zp_comp_rank
        for tile in self._input_tiles:
            input_tile = node.input_tiles.add()
            input_tile.height, input_tile.width, input_tile.features = tile
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.transpose_matmul_input = pb.transpose_matmul_input
        layer.zp_comp_rank = pb.zp_comp_rank
        layer.kernel_shape = [
            pb.kernel_shape.height,
            pb.kernel_shape.width,
            layer.input_shapes[0][3],
            pb.kernel_shape.features,
        ]
        layer.groups = pb.groups
        layer.input_windows = pb.input_windows
        if pb.output_windows:
            if pb.output_windows != layer.input_windows:
                raise UnsupportedModelError(
                    f"Output windows must be the same as input windows for {layer.full_name_msg}"
                )

        layer.input_tiles = [[tile.height, tile.width, tile.features] for tile in pb.input_tiles]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.transpose_matmul_input = old_layer.transpose_matmul_input
        layer.kernel_shape = old_layer.kernel_shape.copy()
        layer.groups = old_layer.groups
        layer.input_windows = old_layer.input_windows
        layer.zp_comp_rank = old_layer.zp_comp_rank
        layer.input_windows = old_layer.input_windows
        layer.input_tiles = old_layer.input_tiles.copy()
        return layer

    @property
    def number_of_windows(self):
        return self._input_windows[0] * self._input_windows[1]

    def calc_output_shape(self):
        return self._calc_output_shape()

    def _calc_output_shape(self):
        input0_shape = [-1, *[dim * ratio for dim, ratio in zip(self.input_shapes[0][1:], self._input_tiles[0])]]
        input1_shape = [-1, *[dim * ratio for dim, ratio in zip(self.input_shapes[1][1:], self._input_tiles[1])]]
        if not self._kernel_shape:
            if self._transpose_matmul_input:
                kernel_features = input1_shape[1] * input1_shape[2] * self.groups // self.number_of_windows
            else:
                kernel_features = input1_shape[3]
        else:
            kernel_features = self._kernel_shape[3]

        return [
            input0_shape[0],
            input0_shape[1],
            input0_shape[2],
            kernel_features,
        ]

    def set_input_shapes(self, input_shapes, validate=True):
        if validate and self._kernel_shape:
            if len(input_shapes) != self.number_of_inputs_supported:
                raise UnsupportedModelError(
                    f"Unexpected number of inputs at {self.full_name_msg}, expected "
                    f"{self.number_of_inputs_supported}, found {len(input_shapes)}",
                )
            input0_shape, input1_shape = input_shapes
            data_tile = self.input_tiles[0][2]
            weight_tile = self.input_tiles[1][2]

            data_features = input0_shape[3] // (self.groups // data_tile)
            if self._transpose_matmul_input:
                weights_features = input1_shape[3] // (self.groups // weight_tile) - self.zp_comp_rank
            else:
                weights_features = input1_shape[1] * input1_shape[2] // (self.number_of_windows)

            if data_features != weights_features:
                raise UnsupportedModelError(
                    f"Unexpected input shapes at {self.full_name_msg}, "
                    f"input_shapes={input_shapes} (type={type(input_shapes)})",
                )

        super().set_input_shapes(input_shapes, validate)

    @property
    def zp_comp_rank(self):
        return self._zp_comp_rank

    @zp_comp_rank.setter
    def zp_comp_rank(self, zp_comp_rank):
        self._zp_comp_rank = zp_comp_rank

    @property
    def transpose_matmul_input(self):
        return self._transpose_matmul_input

    @transpose_matmul_input.setter
    def transpose_matmul_input(self, transpose_matmul_input):
        self._transpose_matmul_input = transpose_matmul_input

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @property
    def input_features(self):
        return self._get_shape_single_dim(self._input_shapes, 3, validate=False)

    @property
    def input_width(self):
        return self._get_shape_single_dim(self._input_shapes, 2, validate=False)

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
