from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.matmul_op import MatmulOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    MatmulCorrectionType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import get_kernel_bits_and_sign_by_precision_mode


class HailoMatmul(BaseHailoLayer):
    """
    Implement Hailo matmul layer,
        - takes two inputs,
        - the mac behaves as passthru + zp compensation
        - multiply the inputs in the APU
        - activation in the APU
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.single_scale_decomposition,
        BiasMode.double_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.MATMUL

    def __init__(
        self,
        name: str,
        transpose_matmul_input: bool = True,
        groups: int = 1,
        zp_comp_added: bool = False,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        input_windows: list = None,
        input_tiles: list = None,
        logger=None,
        **kwargs,
    ):
        self.matmul_op = MatmulOp(
            f"{name}/matmul_op",
            transpose_matmul_input=transpose_matmul_input,
            groups=groups,
            zp_comp_added=zp_comp_added,
            input_windows=input_windows,
            input_tiles=input_tiles,
            logger=logger,
        )
        self.act_op = ActivationOp(
            f"{name}/act_op",
            activation=activation,
            logger=logger,
        )
        self.output_op = PassthruOp(
            f"{name}/passthru_op",
            logger=logger,
        )  # enabling output quantization even as activation is fully native...
        self._zp_correction_type = MatmulCorrectionType.ZP_COMP_NONE
        self.transpose_matmul_input = transpose_matmul_input
        self.zp_comp_added = zp_comp_added
        self.input_windows = input_windows
        super().__init__(name=name, logger=logger, **kwargs)

    @property
    def pre_acc_shift(self):
        return self.matmul_op._multiplier_shift

    @property
    def input_tiles(self):
        return self.matmul_op.input_tiles

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        in2 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.matmul_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.matmul_op, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(in2, self.matmul_op, DataPath.LAYER_IN_WEIGHTS, input_index=1)
        layer_flow.add_edge(self.matmul_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        Calls infer_encodings for underlying atomic ops.
        """
        # TODO: I don't like that the scales and zp are external properties, and I have to set them explicitly.
        #       Which affects the infer_encoding implicitly
        self._enforce_output_encoding()
        self.matmul_op.enforce_encoding()
        self.act_op.input_scales = [self.matmul_op.output_scale]
        self.act_op.input_zero_points = [self.matmul_op.output_zero_point]
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def import_weights(self, layer_params: LayerParams):
        """
        load parameters to the layer. currently, it doesn't to anything.

        Args:
            layer_params: layer's params from the npz

        """
        # TODO: Do we want to load kernel and bias values? (instead of the auto-generated values)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        return dict()

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "activation": "linear",
            "zp_comp_added": False,
            "groups": 1,
            "zp_correction_type": MatmulCorrectionType.ZP_COMP_NONE,
            "input_tiles": [[1, 1, 1], [1, 1, 1]],
        }
        return dict(defaults)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        zp_comp_added = params.get("zp_comp_added", False)
        input_windows = params.get("input_windows", params.get("windows", [1, 1, 1]))
        input_tiles = params.get("input_tiles")

        layer = cls(
            name=lname,
            activation=params["activation"],
            transpose_matmul_input=params["transpose_matmul_input"],
            groups=params["groups"],
            zp_comp_added=zp_comp_added,
            input_windows=input_windows,
            input_tiles=input_tiles,
            logger=logger,
        )
        layer.zp_correction_type = hn_element.get("params", {}).get(
            "zp_correction_type", MatmulCorrectionType.ZP_COMP if zp_comp_added else MatmulCorrectionType.ZP_COMP_NONE
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        # need to update since matmul correction is now part of optimization flow
        self._hn_element.setdefault("params", dict())
        self._hn_element["params"].update({"zp_comp_added": self.zp_comp_added})
        self._hn_element["params"].update({"transpose_matmul_input": self.transpose_matmul_input})

        # update MatmulCorrectionType:
        self._hn_element["params"].update({"zp_correction_type": self.zp_correction_type.value})
        self._hn_element["params"].update({"zp_comp_rank": self.matmul_op.zp_comp_rank})
        self._hn_element["params"].update({"groups": self.matmul_op.groups})
        self._hn_element["params"].update({"input_windows": self.input_windows})
        self._hn_element["params"].update({"input_tiles": self.matmul_op.input_tiles})
        self._hn_element["params"].update({"activation": self.act_op.act_name.value})
        return super().to_hn(out_degree=out_degree)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()
        self.matmul_op.create_hw_params(
            preact_limvals=self.act_op.get_group_input_limvals(0, self.groups), hw_shifts=hw_shifts
        )
        self.act_op.create_hw_params(
            accumulator_scale_candidate=self.matmul_op.output_scale,
            optimization_target=optimization_target,
            nudging=False,
        )
        self.enforce_internal_encoding()

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        mac_data_bits, _ = get_kernel_bits_and_sign_by_precision_mode(precision_config.precision_mode)
        self.create_quant_element_by_data_path(DataPath.MAC_DATA, mac_data_bits)
        self.act_op.create_weight_quant_element(optimization_target)
        self.matmul_op.create_weight_quant_element()

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        if self.input_tiles[0][-1] != 1 or self.input_tiles[1][-1] != 1:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        if self.transpose_matmul_input:
            return EquivClassification(LayerHandlerType.matmul_transpose, is_source=False)
        elif predecessor_index == 1:
            return EquivClassification(LayerHandlerType.matmul, is_source=False)
        else:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def _verify_hn_to_keras_input_shapes(self, keras_shapes, hn_shapes):
        if len(keras_shapes) != len(hn_shapes):
            return False
        for index, (keras_shape, hn_shape) in enumerate(zip(keras_shapes, hn_shapes)):
            if not (np.array(keras_shape) == np.array(hn_shape)).all():
                if index == 1:
                    if (np.array(keras_shape) == np.array(hn_shape) + np.array([0, 0, 0, self.groups])).all():
                        continue
                return False
        return True

    @property
    def zp_correction_type(self):
        return self._zp_correction_type

    @zp_correction_type.setter
    def zp_correction_type(self, zp_correction_type: MatmulCorrectionType):
        zp_correction_type = MatmulCorrectionType(zp_correction_type)
        self._zp_correction_type = zp_correction_type
        if zp_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_2:
            self.matmul_op.zp_comp_rank = 2
        elif zp_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_3:
            self.matmul_op.zp_comp_rank = 2
        elif zp_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK:
            self.matmul_op.zp_comp_rank = 1
        elif zp_correction_type == MatmulCorrectionType.ZP_COMP_WEIGHTS:
            self.matmul_op.zp_comp_rank = 1
        elif zp_correction_type == MatmulCorrectionType.ZP_COMP:
            self.matmul_op.zp_comp_rank = 1
        else:
            self.matmul_op.zp_comp_rank = 0

    def get_bias_mode(self):
        return BiasMode.double_scale_initialization

    @property
    def zp_comp_added(self):
        return self.matmul_op._zp_comp_added

    @property
    def consumer_input_scale(self):
        return self.zp_correction_type in [MatmulCorrectionType.ZP_COMP_BLOCK_2, MatmulCorrectionType.ZP_COMP_BLOCK_3]

    @property
    def homogeneous(self):
        homogeneous = super().homogeneous
        return (
            self.zp_correction_type not in [MatmulCorrectionType.ZP_COMP_BLOCK_2, MatmulCorrectionType.ZP_COMP_BLOCK_3]
            and homogeneous
        )

    @zp_comp_added.setter
    def zp_comp_added(self, zp):
        self.matmul_op._zp_comp_added = zp

    @property
    def transpose_matmul_input(self):
        return self.matmul_op._transpose_matmul_input

    @transpose_matmul_input.setter
    def transpose_matmul_input(self, transpose):
        self.matmul_op._transpose_matmul_input = transpose

    @property
    def groups(self):
        return self.matmul_op.groups
