from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.finalize_norm_op import FinalizeNormOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.reduce_sum_ppu_op import ReduceSumPPUOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import APUOutputSignedQuantElement
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple, TypeStats
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams


class HailoLayerNorm(BaseHailoLayer):
    """
    layer norm of Pluto PPU
    Currently degenerate, implemented as fully native activation.
    TODO simulate the PPU- (and/or the upcoming core-) implementation
        - will include multiple AtomicOps.
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a16_w16,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w16_a16,
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }

    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.LAYER_NORM

    def __init__(
        self,
        name: str,
        rms_norm=False,
        logger=None,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        **kwargs,
    ):
        self.input_op = AddBiasOp.get_passthru_bias(f"{name}/input_op", logger=logger)
        self.x_sum = ReduceSumPPUOp(f"{name}/x_sum", rms_norm=rms_norm, square=False, logger=logger)
        self.x2_sum = ReduceSumPPUOp(f"{name}/x2_sum", square=True, logger=logger)
        self.norm_op = FinalizeNormOp(f"{name}/norm_op", logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        super().__init__(name=name, logger=logger, **kwargs)

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a16_w16

    @property
    def rms_norm(self):
        return self.x_sum.rms_norm

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        layer_flow.add_node(self.input_op)
        layer_flow.add_node(self.x_sum)
        layer_flow.add_node(self.x2_sum)

        layer_flow.add_node(self.norm_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)
        layer_flow.add_edge(in1, self.input_op, DataPath.LAYER_IN)

        layer_flow.add_edge(self.input_op, self.x_sum, DataPath.LAYER_IN_WEIGHTS_16)
        layer_flow.add_edge(self.input_op, self.x2_sum, DataPath.LAYER_IN_WEIGHTS_16)
        layer_flow.add_edge(self.input_op, self.norm_op, DataPath.LAYER_IN_WEIGHTS_16, input_index=0)

        layer_flow.add_edge(self.x_sum, self.norm_op, DataPath.LAYER_X_SUM, input_index=1)
        layer_flow.add_edge(self.x2_sum, self.norm_op, DataPath.LAYER_X2_SUM, input_index=2)

        layer_flow.add_edge(self.norm_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    @property
    def epsilon(self):
        return self.norm_op.epsilon

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        rms_norm = hn_element.get("params", {}).get("rms_norm", False)
        layer = cls(name=lname, rms_norm=rms_norm, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.PLUTO}:
            return {PrecisionMode.a16_w16, PrecisionMode.a16_w16_a8, PrecisionMode.a16_w16_a16}
        else:
            return {
                PrecisionMode.a8_w8,
                PrecisionMode.a16_w16_a8,
                PrecisionMode.a16_w16,
                PrecisionMode.a8_w8_a8,
                PrecisionMode.a8_w8_a16,
                PrecisionMode.a16_w16_a16,
            }

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        if self.rms_norm:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        else:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def homogeneous(self):
        return False

    def import_weights(self, layer_params: LayerParams):
        self.norm_op.import_weights(layer_params)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        dict_params = self.norm_op.export_weights()
        activation_params = self.act_op.export_weights()

        dict_params.update(activation_params)
        return dict_params

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    def _verify_and_set_io_shapes(self):
        # TODO: there is a bug in the broadcast of const_data
        # https://hailotech.atlassian.net/browse/SDK-39317
        return

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def _enforce_input_encoding(self):
        self.input_op.output_scale = self.input_op.input_scales[0]

        self.input_op.enforce_encoding()
        self.norm_op.input_scales[0] = self.input_op.output_scale
        self.norm_op.input_zero_points[0] = self.input_op.output_zero_point

        self.x_sum.input_scales[0] = self.input_op.output_scale
        self.x_sum.input_zero_points[0] = self.input_op.output_zero_point
        self.x_sum.enforce_encoding()

        self.x2_sum.input_scales[0] = self.input_op.output_scale
        self.x2_sum.input_zero_points[0] = self.input_op.output_zero_point
        self.x2_sum.enforce_encoding()

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self.enforce_internal_encoding()
        self.input_op.create_hw_params()
        self.x_sum.create_hw_params()
        self.x_sum.enforce_encoding()
        if self.rms_norm:
            shift = 0
        else:
            shift = 2 * self.x_sum._shift_cfg
        self.x2_sum.create_hw_params(force_shift=shift)
        self.x2_sum.enforce_encoding()

        self.enforce_internal_encoding()
        self.norm_op.create_hw_params(rms_norm=self.rms_norm)
        self.norm_op.enforce_encoding()

        self.enforce_internal_encoding()
        self.act_op.create_hw_params(self.norm_op.output_scale, optimization_target)
        self.act_op.enforce_encoding()
        self.enforce_internal_encoding()

    def enforce_internal_encoding(self, training=False, **kwargs):
        self._enforce_input_encoding()
        self._enforce_output_encoding()

        self.norm_op.input_scales[1] = self.x_sum.output_scale
        self.norm_op.input_zero_points[1] = self.x_sum.output_zero_point

        self.norm_op.input_scales[2] = self.x2_sum.output_scale
        self.norm_op.input_zero_points[2] = self.x2_sum.output_zero_point
        self.norm_op.enforce_encoding()

        self.act_op.input_scales[0] = self.norm_op.output_scale
        self.act_op.input_zero_points[0] = self.norm_op.output_zero_point

        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def get_bias_mode(self):
        return self.get_default_bias_mode()

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        mu_bits = 16
        x2_reduce_sum_bits = 56
        x_reduce_sum_bits = 28

        x2_reduce_sum_bits_after_clip = 56
        x_reduce_sum_bits_after_clip = 20

        self.input_op.create_weight_quant_element(16, True, 0, 16)
        self.create_quant_element_by_data_path(DataPath.LAYER_IN_WEIGHTS_16, 16)
        self.create_quant_element_by_data_path(DataPath.LAYER_X_SUM, x_reduce_sum_bits_after_clip)
        self.create_quant_element_by_data_path(DataPath.LAYER_X2_SUM, x2_reduce_sum_bits_after_clip)
        self.x_sum.create_weight_quant_element(x_reduce_sum_bits, x_reduce_sum_bits_after_clip)
        self.x2_sum.create_weight_quant_element(x2_reduce_sum_bits, x2_reduce_sum_bits_after_clip)

        pre_act_element = APUOutputSignedQuantElement(bits=32)
        self.norm_op.set_output_lossy_element(pre_act_element, index=0)
        self.act_op.set_input_lossy_element(pre_act_element, index=0)

        self.norm_op.create_weight_quant_element(mu_bits)
        self.act_op.create_weight_quant_element(optimization_target)

        self.act_op.FLOAT_TYPE_NP = self.norm_op.FLOAT_TYPE_NP
        self.act_op.INT_TYPE_NP = self.norm_op.INT_TYPE_NP

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        act_stats_cfg_out = stats_cfg
        if output_hist:
            act_stats_cfg_out = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)

        act_stats_cfg_preact = stats_cfg
        if preact_hist:
            act_stats_cfg_preact = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)

        self.x2_sum.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)
        self.x_sum.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

        self.norm_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        for op, input_index in self._input_stats_ops():
            op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=False)
        for op, output_index in self._output_stats_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_out, collect_inputs=False, collect_output=True)
        for op in self._iterate_act_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_preact, collect_inputs=True, collect_output=False)

    def _layer_dependent_hw_params_modifications(self, params):
        activation_ebias_mode = self.get_precision_mode() == PrecisionMode.a16_w16_a16
        params["vector_size"] = np.array(self.x2_sum.f_out, np.uint16)  # in x_2_sum
        params["rms_mode_enable"] = np.array(self.rms_norm, np.uint8)
        params["activation_ebias_mode"] = np.array(activation_ebias_mode, np.uint8)
        return params
