from typing import Dict, Optional, Tuple, Union

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer_decompose import BaseHailoLayerDecompose
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum import HailoReduceSum
from hailo_model_optimization.acceleras.hailo_layers.layer_decompose_flow import (
    LayerDecomposeFlow,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PaddingType,
    PrecisionMode,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasInitializationError,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import (
    LayerParams,
    get_hn_padding,
    set_hn_padding_stride_align,
)


class HailoConvQuantWeightGroup(BaseHailoLayerDecompose):
    """Conv with group quantization decompose layer
    When multiple quantization groups are used, this decompose create quantization_weight_groups * out_channel different scales to quantize the kernel
    The scales are later absorbed into the APU scale"""

    # PrecisionMode and BiasMode are the same as BaseHailoConv
    _hn_type = LayerType.CONV
    SUPPORTED_QUANTIZATION_GROUPS = False
    SUPPORTED_QUANTIZATION_WEIGHT_GROUPS = True
    SUPPORTED_BIAS_MODE = HailoConv.SUPPORTED_BIAS_MODE
    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w4,
        PrecisionMode.a8_w4_a8,
        PrecisionMode.a8_w4_a16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }

    def __init__(
        self,
        name: str,
        filters: int,
        kernel_size: Tuple[int, int],
        strides=(1, 1),
        padding: Union[str, PaddingType] = "SAME",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate: Tuple[int, int] = (1, 1),
        activation: Union[str, callable, ActivationType] = "linear",
        quantization_weight_groups: int = 4,
        is_symmetric: bool = True,
        logger=None,
        **kwargs,
    ):
        if kernel_size[0] != 1 or kernel_size[1] != 1:
            raise AccelerasImplementationError("Group weight quantization only supports 1x1 kernel size")
        self.filters = filters
        self.quantization_weight_groups: int = quantization_weight_groups
        self.is_symmetric = is_symmetric
        self.conv_sub_layer = HailoConv(
            f"{name}/conv_sub_layer",
            filters=filters * quantization_weight_groups,
            kernel_size=kernel_size,
            groups=self.quantization_weight_groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            feature_shuffle_interval=quantization_weight_groups,
            set_scale_by_kernel_only=False,
            logger=logger,
        )
        self.reduce_sum_sub_layer = HailoReduceSum(
            name=f"{name}/reduce_sum_sub_layer", groups=filters, activation=activation, logger=logger
        )
        self.act_op = self.reduce_sum_sub_layer.activation_atomic_op
        super().__init__(name=name, logger=logger, **kwargs)

    def _build_flow(self) -> LayerDecomposeFlow:
        """Creates a weight quant conv flow which passes the activation from conv to reduce sum in 15 bits"""
        layer_flow = self._init_flow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        layer_flow.add_edge(in1, self.conv_sub_layer, DataPath.LAYER_IN)
        layer_flow.add_edge(self.conv_sub_layer, self.reduce_sum_sub_layer, DataPath.INTER_BLOCK_16)
        layer_flow.add_edge(self.reduce_sum_sub_layer, out1, DataPath.LAYER_OUT)
        return layer_flow

    def _create_weight_constraints(self, layer_precision_mode: Dict[str, int]) -> Dict[str, int]:
        """Fill the reduce_sum weights to always be 16 bits"""
        layer_precision_mode[self.reduce_sum_sub_layer.name] = 16
        return layer_precision_mode

    def _apply_precision_config_constraints(
        self, layer_precision_config: Dict[str, LayerPrecisionConfig]
    ) -> Dict[str, LayerPrecisionConfig]:
        """Set the quantization groups for the conv layer, it has to be equal to the new out_channel, so each one will get a different scale"""

        layer_precision_config[self.conv_sub_layer.name].quantization_groups = (
            self.quantization_weight_groups * self.filters
        )
        return layer_precision_config

    def _get_activation_layer(self):
        """The activation of the group weight quant conv is being performed at the reduce sum"""
        return self.reduce_sum_sub_layer

    @property
    def _training_layer(self):
        return self.conv_sub_layer

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        # enforce backwrads propagation from reduce sum to conv
        self.reduce_sum_sub_layer.create_input_encoding_candidates()
        self._create_out_in_scale_ratio()  # factor*inputs_scale = output_scale
        self._reduce_sum_output_to_input()
        self._set_predecessor_output_encodings(
            self.conv_sub_layer, self.reduce_sum_sub_layer
        )  # copy the inputs_scale/zp to the conv output_scale/zp

        self.conv_sub_layer.create_hw_params(
            weights_clipping=weights_clipping, optimization_target=optimization_target, hw_shifts=hw_shifts
        )
        if self._handle_negative_exponent(self.conv_sub_layer):
            self._set_successors_inputs_encodings(self.conv_sub_layer, self.reduce_sum_sub_layer)
        self.reduce_sum_sub_layer.create_hw_params(
            weights_clipping=weights_clipping, optimization_target=optimization_target, hw_shifts=hw_shifts
        )

    def _update_out_in_scale_ratio(self, output_shift_fix):
        output_factor = 2**output_shift_fix
        self.scalar_factor_dof = self.scalar_factor_dof / output_factor

    # region encoding

    def _reduce_sum_output_to_input(self):
        _repeated_scales = np.repeat(
            self.reduce_sum_sub_layer.output_scale.reshape(-1, 1), self.quantization_weight_groups, axis=1
        ).flatten()
        scalar_factor = self.scalar_factor_dof
        scale_candidate = _repeated_scales / scalar_factor
        zp_candidate = -np.min(self.reduce_sum_sub_layer.get_input_limvals()) / np.max(scale_candidate)
        zp_candidate = np.round(zp_candidate)
        self.reduce_sum_sub_layer.set_input_scale(scale_candidate, index=0)
        self.reduce_sum_sub_layer.set_input_zero_point(zp_candidate, index=0)

    def _create_out_in_scale_ratio(self):
        input_range = np.max(self.reduce_sum_sub_layer.get_input_limvals()) - np.min(
            self.reduce_sum_sub_layer.get_input_limvals()
        )
        bins = self.reduce_sum_sub_layer.get_input_lossy_elements()[0].bins_count
        self.scalar_factor_dof = np.max(self.reduce_sum_sub_layer.output_scale) * bins / input_range

    def enforce_io_encoding(self, training=False, **kwargs):
        """
        Same as a conv, this layer represents a degree of freedom for encoding"""
        pass  # TODO should add force_range output

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        1) First copy the output scale to the reduce sum which should backward propagate it to i's inputs.
        2) Then copy it to conv input
        3) reduce_sum enforce_internel_encoding
        4) conv enforce_internel_encoding
        """
        if self.reduce_sum_sub_layer.output_scale.shape != ():
            self._reduce_sum_output_to_input()
        self._set_predecessor_output_encodings(self.conv_sub_layer, self.reduce_sum_sub_layer)

        self.conv_sub_layer.enforce_internal_encoding(training=training, **kwargs)
        self.reduce_sum_sub_layer.enforce_internal_encoding(training=training, **kwargs)

    # region import / export

    @classmethod
    def get_default_params(cls):
        # TODO
        defaults = {
            "strides": [1, 1, 1, 1],
            "dilations": [1, 1, 1, 1],
            "padding": "SAME",
            "activation": "linear",
            "groups": 1,
            "elementwise_add": False,
        }
        return dict(defaults)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        quantization_params = hn_element.get("quantization_params", dict())
        params.update(hn_element.get("params", dict()))
        cls._check_params_validity(params, quantization_params)
        quantization_weight_group = quantization_params.get("quantization_weight_groups", None)
        kshape = params.get("kernel_shape", None)
        padding, stride_align = get_hn_padding(params)
        # is_lora = params.get("is_lora", False)  # TODO
        layer = cls(
            name=lname,
            quantization_weight_groups=quantization_weight_group,
            filters=kshape[-1],
            kernel_size=kshape[0:2],
            padding=padding,
            stride_align=stride_align,
            strides=params["strides"][1:3],
            activation=params["activation"],
            dilation_rate=params["dilations"][1:3],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    @classmethod
    def _check_params_validity(cls, params, quantization_params):
        """
        Validates the parameters for the layer.

        Args:
            params (dict): Dictionary containing layer parameters.
            quantization_params (dict): Dictionary containing quantization parameters.

        Raises:
            AccelerasInitializationError: If required parameters are missing or invalid.
        """
        if quantization_params.get("quantization_weight_groups", None) is None:
            raise AccelerasInitializationError("quantization_weight_groups quantization param is missing")
        kshape = params.get("kernel_shape", None)
        if kshape is None or kshape[0:2] != [1, 1]:
            raise AccelerasInitializationError("kernel must be 1x1")

    def to_hn(self, out_degree=None):
        hn_element = super().to_hn(out_degree=out_degree)
        params = {}
        params["activation"] = self.reduce_sum_sub_layer.act_op.act_name.value
        params["kernel_shape"] = list(self.export_weights()["kernel"].shape)
        params["groups"] = self.groups
        params["group_sizes"] = [1]
        params["transpose_output_width_features"] = self.conv_sub_layer.transpose_output_width_features
        params["spatial_flatten_output"] = self.conv_sub_layer.conv_op.spatial_flatten_output
        strides = self.conv_op.strides
        params["strides"] = [1, strides[0], strides[1], 1]
        dilation_rate = self.conv_sub_layer.conv_op.dilation_rate
        params["dilations"] = [1, dilation_rate[0], dilation_rate[1], 1]
        set_hn_padding_stride_align(
            params, self.conv_sub_layer.conv_op.padding, self.conv_sub_layer.conv_op.stride_align
        )
        params["elementwise_add"] = False
        params["zp_comp_added"] = self.conv_sub_layer.zp_comp_add
        if "params" not in hn_element:
            hn_element["params"] = {}
            hn_element["params"]["batch_norm"] = False
        hn_element["params"].update(params)
        return hn_element

    def _get_kernel_bits(self) -> Optional[int]:
        return self.conv_sub_layer._get_kernel_bits()

    def _export_weights(self):
        weights = self.conv_sub_layer._export_weights()
        weights.update(self.reduce_sum_sub_layer.act_op.export_weights())
        weights = self._revert_weights_modifications(weights)
        return weights

    def import_weights(self, layer_params: LayerParams):
        layer_params = self._modify_conv_weights(layer_params, self.quantization_weight_groups)
        self.conv_sub_layer.import_weights(layer_params)
        layer_params.pop("kernel", None)
        layer_params.pop("bias", None)
        self.reduce_sum_sub_layer.import_weights(layer_params)

    def revert_kernel_shape(self, kernel):
        quantization_weight_groups = self.quantization_weight_groups
        w, h, group_size, old_c_out = kernel.shape
        new_c_out = old_c_out // quantization_weight_groups

        new_kernel = (
            kernel.reshape(w, h, group_size, quantization_weight_groups, new_c_out)
            .transpose(0, 1, 3, 2, 4)
            .reshape(w, h, group_size * quantization_weight_groups, new_c_out)
        )
        return new_kernel

    def _revert_weights_modifications(self, layer_params: LayerParams):
        kernel = layer_params.get("kernel")
        bias = layer_params.get("bias")
        new_kernel = self.revert_kernel_shape(kernel)
        new_c_out = kernel.shape[-1] // self.quantization_weight_groups
        new_bias = bias[0:new_c_out]
        layer_params["kernel"] = new_kernel
        layer_params["bias"] = new_bias
        return layer_params

    def _modify_conv_weights(self, layer_params: LayerParams, quantization_weight_groups: int) -> LayerParams:
        kernel = layer_params.get("kernel")
        bias = layer_params.get("bias")

        w, h, c_in, c_out = kernel.shape
        if kernel.shape[2] % quantization_weight_groups != 0:
            raise AccelerasImplementationError(
                "Group weight quantization only supports kernel size that is a multiple of the number of quantization groups"
            )
        group_size = c_in // quantization_weight_groups
        new_kernel = (
            kernel.reshape(w, h, quantization_weight_groups, group_size, c_out)
            .transpose(0, 1, 3, 2, 4)
            .reshape(w, h, group_size, c_out * quantization_weight_groups)
        )

        new_bias = np.concatenate([bias] + [np.zeros_like(bias)] * (quantization_weight_groups - 1))
        layer_params["kernel"] = new_kernel
        layer_params["bias"] = new_bias

        return layer_params

    @property
    def pre_acc_shift(self):
        return self.conv_sub_layer.pre_acc_shift

    def get_bias_ops(self):
        return self.conv_sub_layer.get_bias_ops()

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization

    @property
    def bias(self):
        return self.conv_sub_layer.bias

    @property
    def kernel(self):
        return self.conv_sub_layer.kernel

    @classmethod
    def from_conv(cls, conv_layer: HailoConv, quantization_weight_groups=16, is_symmetric=True, logger=None):
        conv_op = conv_layer.conv_op
        inst = cls(
            name=conv_layer.full_name,
            filters=conv_op.kernel.shape[-1],
            kernel_size=conv_op.kernel.shape[0:2],
            strides=conv_op.strides,
            padding=conv_op.padding,
            stride_align=conv_op.stride_align,
            dilation_rate=conv_op.dilation_rate,
            activation=conv_layer.act_op.act_name,
            quantization_weight_groups=quantization_weight_groups,
            is_symmetric=is_symmetric,
            logger=logger,
        )  # TODO Add errors if quantization not supported
        inst.import_weights(conv_layer.export_weights())
        return inst

    @classmethod
    def check_conv_validity():
        if False:
            raise AccelerasImplementationError("Group weight quantization only supports 1x1 kernel size")
        # TODO will update in a future PR

    # region algorithms
    def get_kernel_scale_matrix_component(self):
        return self.conv_sub_layer.get_kernel_scale_matrix_component()

    def get_kernel(self):
        # TODO consider making abstract in baseDecomposeLayer
        return self.conv_sub_layer.get_kernel()

    @property
    def groups(self):
        """We are using conv groups at the conv_sub_layer as part of the numeric scheme
        However, this is not a conv group layer, and as such, algorithms should not relate to it as such"""
        return 1

    @property
    def conv_op(self):
        return self.conv_sub_layer.conv_op

    @property
    def bias_add_op(self):
        return self.conv_sub_layer.bias_add_op

    def export_hw_params(self, include_shared_weights=True):
        """The compiler team asked for the kernel shape to match it's hn shape. And the hn shape has to stay the native shape, as the implemntation details should stay hidden"""
        params = super().export_hw_params(include_shared_weights)
        params["conv_sub_layer/kernel"] = self.revert_kernel_shape(params["conv_sub_layer/kernel"])
        return params

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def _return_scale_summary(self):
        """Useful for debugging"""
        ret = ""
        ret += "----------------\n"
        ret += "Conv Sub Layer:\n"
        ret += f"conv input scale: {self._get_scale_string(self.conv_sub_layer.input_scales)}\n"
        ret += f"conv input zp, len {len(self.conv_sub_layer.input_zero_points)}: {self.conv_sub_layer.input_zero_points}\n"
        ret += f"conv output scale: {self._get_scale_string(self.conv_sub_layer.output_scales)}\n"
        ret += f"conv output zp, len {len(self.conv_sub_layer.output_zero_points)}: {self.conv_sub_layer.output_zero_points}\n"
        ret += "----------------\n"
        ret += "Reduce Sum Layer\n"
        ret += f"reduce sum input scale: {self._get_scale_string(self.reduce_sum_sub_layer.input_scales)}"
        ret += f"reduce sum input zp, len {len(self.reduce_sum_sub_layer.input_zero_points)}: {self.reduce_sum_sub_layer.input_zero_points}\n"
        ret += f"reduce sum output scale: {self._get_scale_string(self.reduce_sum_sub_layer.output_scales)}\n"
        ret += f"reduce sum output zp, len {len(self.reduce_sum_sub_layer.output_zero_points)}: {self.reduce_sum_sub_layer.output_zero_points}\n"
        ret += "----------------\n"
        return ret

    def _get_scale_string(self, scale):
        ret = "\n"
        if isinstance(scale[0], (list, np.ndarray)):
            length_array = [len(s) if hasattr(s, "__len__") else 1 for s in scale]
            unique_array = [len(np.unique(s)) if hasattr(s, "__len__") else 1 for s in scale]
            ret += "length of scales: " + " ".join(str(v) for v in length_array) + "\n"
            ret += "number of unique scales: " + " ".join(str(v) for v in unique_array) + "\n"
        else:
            ret += "Single scale value: "
        ret += f"{scale}\n"
        return ret
