from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.loot_square_op import LootSquareOp
from hailo_model_optimization.acceleras.atomic_ops.norm_final_op import NormFinalOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.reduce_mean_norm_op import ReduceMeanNormOp
from hailo_model_optimization.acceleras.atomic_ops.root_variance_op import RootvarOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple, TypeStats
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import get_input_bits_by_precision_mode, limvals_to_zp_scale

# LAYER_INV1_bits = 16 # x_2mean out
# LAYER_INV2_bits = 10 # mu out
# LAYER_MULT1_bits = 8 # x out
# LAYER_MULT2_bits = 10 # mu out
# LAYER_MULT3_bits = 9 # inv_root out
# ACTIVITION_bits = 20 # pre_activation out


MU_REDUCE_SUM_BITS = 20
MU_OUT_BITS = 10

X2_REDUCE_SUM_BITS = 26
X2_OUT_BITS = 16

REDUCE_SUM_BITS = 10


class HailoLayerNormMercury(BaseHailoLayer):
    """
    Currently degenerate, implemented as fully native activation.
    TODO simulate the PPU- (and/or the upcoming core-) implementation -
        - will include multiple AtomicOps.
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
    }

    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        logger=None,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        **kwargs,
    ):
        self.x = PassthruOp(f"{name}/passthru_op_in", logger=logger)
        self.x2_mean = ReduceMeanNormOp(f"{name}/x2_mean", square=True, logger=logger)
        self.mu = ReduceMeanNormOp(f"{name}/mu", square=False, logger=logger)
        self.mu_square = LootSquareOp(f"{name}/mu_square", logger=logger)
        self.inv_root = RootvarOp(f"{name}/inv_root", logger=logger)
        self.mult = NormFinalOp(f"{name}/norm_final", logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        super().__init__(name=name, logger=logger, **kwargs)

        self._mult_shift = 2

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.x)
        layer_flow.add_node(self.mu)
        layer_flow.add_node(self.mu_square)
        layer_flow.add_node(self.x2_mean)
        layer_flow.add_node(self.inv_root)
        layer_flow.add_node(self.mult)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.x, DataPath.LAYER_IN)

        layer_flow.add_edge(self.x, self.x2_mean, DataPath.LAYER_IN)
        layer_flow.add_edge(self.x, self.mu, DataPath.LAYER_IN)

        layer_flow.add_edge(self.mu, self.mu_square, DataPath.LAYER_MU)

        layer_flow.add_edge(self.x2_mean, self.inv_root, DataPath.LAYER_IN_INV, input_index=0)
        layer_flow.add_edge(self.mu_square, self.inv_root, DataPath.LAYER_IN_INV, input_index=1)

        layer_flow.add_edge(self.x, self.mult, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(self.mu, self.mult, DataPath.LAYER_MU, input_index=1)
        layer_flow.add_edge(self.inv_root, self.mult, DataPath.LAYER_ROOT, input_index=2)

        layer_flow.add_edge(self.mult, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        layer = cls(name=lname, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def homogeneous(self):
        return False

    @property
    def epsilon(self):
        return self.inv_root.epsilon

    def import_weights(self, layer_params: LayerParams):
        self.inv_root.import_weights(layer_params)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        dict_params = self.inv_root.export_weights()
        activation_params = self.act_op.export_weights()

        dict_params.update(activation_params)
        return dict_params

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self.enforce_internal_encoding()

        scale_1 = self.x.input_scale[0]
        scale_2 = self.x2_mean.output_scale[0]
        scale_3 = self.inv_root.output_scale[0]

        zp_1 = self.x.input_zero_point

        max_native_mu = np.abs(self.mu.get_output_limvals(0))
        expected_max_accumulato_mu = np.max(max_native_mu / scale_1)

        self.x2_mean.create_hw_params(scale_1**2, scale_2, shift=0)
        self.mu.create_hw_params(shift=self._mult_shift)

        self.mu_square.create_hw_params(scale_1, scale_2, zp_1, expected_max_accumulato_mu)

        self.inv_root.create_hw_params(scale_2, scale_3)
        self.mult.create_hw_params(shift=self._mult_shift)

        self.enforce_internal_encoding()
        self.act_op.create_hw_params(self.mult.output_scale, optimization_target)

        self.enforce_internal_encoding()

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def enforce_internal_encoding(self, training=False, **kwargs):
        self._enforce_output_encoding()
        self.x.enforce_encoding()

        ############################
        self.mu.input_scales[0] = self.x.output_scale
        self.mu.input_zero_points[0] = self.x.output_zero_point
        self.mu.enforce_encoding()

        ##############
        self.mu_square.input_scales[0] = self.mu.output_scale
        self.mu_square.input_zero_points[0] = self.mu.output_zero_point
        self.mu_square.enforce_encoding()

        ############################

        self.x2_mean.input_scales[0] = self.x.output_scale
        self.x2_mean.input_zero_points[0] = self.x.output_zero_point
        self.x2_mean.enforce_encoding()

        ##############
        self.inv_root.input_scales[0] = self.x2_mean.output_scale
        self.inv_root.input_zero_points[0] = self.x2_mean.output_zero_point

        self.inv_root.input_scales[1] = self.mu_square.output_scale
        self.inv_root.input_zero_points[1] = self.mu_square.output_zero_point
        self.inv_root.enforce_encoding()

        ##############
        self.mult.input_scales[0] = self.x.output_scale
        self.mult.input_zero_points[0] = self.x.output_zero_point

        self.mult.input_scales[1] = self.mu.output_scale
        self.mult.input_zero_points[1] = self.mu.output_zero_point

        self.mult.input_scales[2] = self.inv_root.output_scale
        self.mult.input_zero_points[2] = self.inv_root.output_zero_point
        self.mult.enforce_encoding()

        ############################

        self.act_op.input_scales[0] = self.mult.output_scale
        self.act_op.input_zero_points[0] = self.mult.output_zero_point

        self.act_op.enforce_encoding()

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        act_stats_cfg_out = stats_cfg
        if output_hist:
            act_stats_cfg_out = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)

        act_stats_cfg_preact = stats_cfg
        if preact_hist:
            act_stats_cfg_preact = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)
        x2_mean_cfg_out = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)
        self.inv_root.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)
        self.x2_mean.start_stats_collection(stats_cfg=x2_mean_cfg_out, collect_inputs=False, collect_output=True)
        self.mu.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)
        self.mu_square.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

        self.mult.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

        for op, input_index in self._input_stats_ops():
            op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=False)
        for op, output_index in self._output_stats_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_out, collect_inputs=False, collect_output=True)
        for op in self._iterate_act_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_preact, collect_inputs=True, collect_output=False)

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        precision_mode = precision_config.precision_mode
        _ = get_input_bits_by_precision_mode(precision_mode)

        output_quant_norm = QuantElement(signed=True, bits=32, wraparound=False)
        self.mult.set_output_lossy_element(output_quant_norm, index=0)
        self.act_op.set_input_lossy_element(output_quant_norm, index=0)

        self._mult_shift = 2
        self.mu_square._square_shift = self._mult_shift

        # input_bits = 8
        mu_reduce_sum_bits = 20
        mu_reduce_sum_bits_after_clip = 10

        x2_mean_reduce_sum_bits = 26
        x2_mean_reduce_sum_bits_after_clip = 16

        inv_root_out_bits = 9  # inv_root out

        self.create_quant_element_by_data_path(DataPath.LAYER_IN_INV, x2_mean_reduce_sum_bits_after_clip)
        self.create_quant_element_by_data_path(DataPath.LAYER_MU, mu_reduce_sum_bits_after_clip)
        self.create_quant_element_by_data_path(DataPath.LAYER_ROOT, inv_root_out_bits)

        self.inv_root.create_weight_quant_element()
        self.mu.create_weight_quant_element(mu_reduce_sum_bits, mu_reduce_sum_bits_after_clip)
        self.x2_mean.create_weight_quant_element(x2_mean_reduce_sum_bits, x2_mean_reduce_sum_bits_after_clip)

        output_quant_norm = QuantElement(signed=True, bits=20, wraparound=False)
        self.mult.set_output_lossy_element(output_quant_norm, index=0)
        self.act_op.set_input_lossy_element(output_quant_norm, index=0)

        self.act_op.create_weight_quant_element(optimization_target)
        self.mult.create_weight_quant_element()

    def create_output_encoding_candidates(self, forced_range=None, translation_config=None):
        super().create_output_encoding_candidates(forced_range)
        self.x2_mean.create_output_encoding_candidates(0, translation_config=translation_config)
        # self.mu_square.create_output_encoding_candidates(0)
        self.inv_root.create_output_encoding_candidates(0, translation_config=translation_config)
        self._scale_match()

    def _scale_match(self):
        output_limvals_x2_mean = self.x2_mean.get_output_limvals(0)
        mu2_limvals = self.mu_square.get_output_limvals(0)

        output_limvals = (
            np.min([output_limvals_x2_mean[0], mu2_limvals[0]]),
            np.max([output_limvals_x2_mean[1], mu2_limvals[1]]),
        )
        zp, scale_2, _ = limvals_to_zp_scale(output_limvals, self.x2_mean.output_lossy_element, self.full_name)

        if scale_2.shape == ():
            # get the shapes from the statistics
            # TODO: make sure this property behaves as expected
            output_channels = self.x2_mean.output_shape[-1]
            scale_2 = np.repeat(scale_2, output_channels)
        self.x2_mean.output_scale = scale_2
        self.mu_square.output_scale = scale_2
