from typing import Tuple, Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.exp_lut_op import ExpLutOp
from hailo_model_optimization.acceleras.atomic_ops.finalize_softmax_op import FinalizeSoftmaxOp
from hailo_model_optimization.acceleras.atomic_ops.generic_native_op import GenericNativeOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.reduce_max_op import ReduceMaxOp
from hailo_model_optimization.acceleras.atomic_ops.reduce_sum_ppu_op import ReduceSumPPUOp
from hailo_model_optimization.acceleras.atomic_ops.shift_op import ShiftOp
from hailo_model_optimization.acceleras.atomic_ops.softmax_inp_op import SoftmaxInpOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import APUOutputSignedQuantElement
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationLayerConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple, TypeStats
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ACTIVATION_CLIP_BITS_HAILO_LAYER_NORM,
    EXP_NUME_BITS,
    EXP_OUT_BITS,
    LUT_IN_BITS,
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    NativeName,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import limvals_to_zp_scale


class HailoSoftmaxMars(BaseHailoLayer):
    """
    softmax on Mars PPU
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a16_w16,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w16_a16,
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }

    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.SOFTMAX

    def __init__(
        self,
        name: str,
        num_inputs: int,
        axis=-1,
        groups=1,
        logger=None,
        input_repeats=None,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        **kwargs,
    ):
        self.groups = groups
        self.input_op = SoftmaxInpOp(f"{name}/input_op", num_inputs=num_inputs, logger=logger)

        self.reduce_max_op = ReduceMaxOp(f"{name}/x_max", groups=groups, reduce_axes=axis, logger=logger)

        self.ew_sub_op = GenericNativeOp(
            f"{name}/elementwise_sub_op",
            NativeName.EW_SUB,
            num_inputs=2,
            config_params={"input_repeats": input_repeats},
            logger=logger,
        )

        self.exp_denominator = ExpLutOp(f"{name}/exp_denominator", logger=logger)  # exp that goes into denominator

        self.e_x_sum = ReduceSumPPUOp(
            f"{name}/e_x_sum",
            is_softmax=True,
            groups=groups,
            reduce_axes=axis,
            rms_norm=False,
            square=False,
            logger=logger,
        )

        self.shift_exp_numerator = ShiftOp(f"{name}/exp_numerator", logger=logger)  # exp that goes into nominator

        self.softmax_op = FinalizeSoftmaxOp(f"{name}/softmax_op", logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        super().__init__(name=name, logger=logger, **kwargs)

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a16_w16

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        axis = hn_element["params"]["logits_axis"] if "logits_axis" in hn_element["params"] else [-1]
        if not isinstance(axis, list):
            axis = [axis]
        groups = hn_element["params"]["groups"] if "groups" in hn_element["params"] else 1
        numer_of_repeats = hn_element["input_shapes"][0][-1] // groups
        input_repeats = [[1, 1, numer_of_repeats], [1, 1, 1]]
        layer = cls(
            name=lname,
            num_inputs=len(hn_element["input_shapes"]),
            axis=axis,
            groups=groups,
            input_repeats=input_repeats,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in_nodes = [layer_flow.add_input() for _ in range(self.input_op.num_inputs)]
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.input_op)
        for index, node in enumerate(in_nodes):
            layer_flow.add_edge(node, self.input_op, DataPath.LAYER_IN, input_index=index)

        layer_flow.add_node(self.reduce_max_op)
        layer_flow.add_node(self.ew_sub_op)

        layer_flow.add_node(self.exp_denominator)
        layer_flow.add_node(self.shift_exp_numerator)

        layer_flow.add_node(self.e_x_sum)

        layer_flow.add_node(self.softmax_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(self.input_op, self.reduce_max_op, DataPath.LAYER_IN_MASK)
        layer_flow.add_edge(self.reduce_max_op, self.ew_sub_op, DataPath.LAYER_IN, input_index=0)  # input_sub
        layer_flow.add_edge(self.input_op, self.ew_sub_op, DataPath.LAYER_IN_MASK, input_index=1)  # input_sub

        layer_flow.add_edge(self.ew_sub_op, self.exp_denominator, DataPath.LAYER_IN_MASK)
        layer_flow.add_edge(self.exp_denominator, self.shift_exp_numerator, DataPath.EXP_DENO)
        layer_flow.add_edge(self.exp_denominator, self.e_x_sum, DataPath.EXP_DENO)

        layer_flow.add_edge(self.shift_exp_numerator, self.softmax_op, DataPath.EXP_NUME, input_index=0)
        layer_flow.add_edge(
            self.e_x_sum, self.softmax_op, DataPath.LAYER_E_X_SUM, input_index=1
        )  # use LAYER_E_X_SUM because its jere as x_sum

        layer_flow.add_edge(self.softmax_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.PLUTO, OptimizationTarget.MARS}:
            return {PrecisionMode.a16_w16, PrecisionMode.a16_w16_a8, PrecisionMode.a16_w16_a16}
        else:
            return {
                PrecisionMode.a8_w8,
                PrecisionMode.a16_w16_a8,
                PrecisionMode.a16_w16,
                PrecisionMode.a8_w8_a8,
                PrecisionMode.a8_w8_a16,
                PrecisionMode.a16_w16_a16,
            }

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def homogeneous(self):
        return False

    def import_weights(self, layer_params: LayerParams):
        self.softmax_op.import_weights(layer_params)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        dict_params = self.softmax_op.export_weights()
        activation_params = self.act_op.export_weights()

        dict_params.update(activation_params)
        return dict_params

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    def _verify_and_set_io_shapes(self):
        # TODO: there is a bug in the broadcast of const_data
        # https://hailotech.atlassian.net/browse/SDK-39317
        return

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def _enforce_exp_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def _enforce_input_encoding(self):
        self.input_op.output_scale = self.input_op.input_scales[0]
        self.input_op.enforce_encoding()

        self.reduce_max_op.input_scales[0] = self.input_op.output_scale
        self.reduce_max_op.input_zero_points[0] = self.input_op.output_zero_point

        self.ew_sub_op.input_scales[1] = self.input_op.output_scale
        self.ew_sub_op.input_zero_points[1] = self.input_op.output_zero_point

        self.reduce_max_op.enforce_encoding()

    def create_hw_params(
        self, layer_cfg: ModelOptimizationLayerConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self.enable_lossy()
        shift = EXP_OUT_BITS - EXP_NUME_BITS

        for op in self.atomic_ops:
            if op.name == "act_op":
                op.create_hw_params(self.softmax_op.output_scale, optimization_target)
            elif op.name == "exp_numerator":
                op.create_hw_params(shift=shift)

            else:
                op.create_hw_params()
            self.enforce_internal_encoding()

    def enforce_internal_encoding(self, training=False, **kwargs):
        self._enforce_input_encoding()
        self._enforce_output_encoding()

        ## update ew_sub_op by reduce_max_op
        self.ew_sub_op.input_scales[0] = self.reduce_max_op.output_scale
        self.ew_sub_op.input_zero_points[0] = self.reduce_max_op.output_zero_point

        self.ew_sub_op.output_scale = self.ew_sub_op.input_scales[1]  # check that ew_sub enforce_encoding

        ## update exp_denominator by ew_sub_op
        self.exp_denominator.input_scales[0] = self.ew_sub_op.output_scale
        self.exp_denominator.input_zero_points[0] = self.ew_sub_op.output_zero_point
        self.exp_denominator.enforce_encoding()

        ## update shift_exp_numerator by exp_denominator
        self.shift_exp_numerator.input_scales[0] = self.exp_denominator.output_scale
        self.shift_exp_numerator.input_zero_points[0] = self.exp_denominator.output_zero_point
        self.shift_exp_numerator.enforce_encoding()

        ## update e_x_sum by exp_denominator
        self.e_x_sum.input_scales[0] = self.exp_denominator.output_scale
        self.e_x_sum.input_zero_points[0] = self.exp_denominator.output_zero_point
        self.e_x_sum.enforce_encoding()

        ## update softmax_op by shift_exp_numerator and e_x_sum
        self.softmax_op.input_scales[0] = self.shift_exp_numerator.output_scale
        self.softmax_op.input_zero_points[0] = self.shift_exp_numerator.output_zero_point
        self.softmax_op.input_scales[1] = self.e_x_sum.output_scale
        self.softmax_op.input_zero_points[1] = self.e_x_sum.output_zero_point
        self.softmax_op.enforce_encoding()

        ## update act_op by shift_exp_numerator and softmax_op

        self.act_op.input_scales[0] = self.softmax_op.output_scale
        self.act_op.input_zero_points[0] = self.softmax_op.output_zero_point
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def get_bias_mode(self):
        return self.get_default_bias_mode()

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget
    ):
        exp_nume = EXP_NUME_BITS

        exp_deno = EXP_OUT_BITS

        # x_reduce_sum_bits_inp = 32
        e_x_reduce_sum_bits = 56

        self.input_op.create_weight_quant_element()
        self.create_quant_element_by_data_path(DataPath.LAYER_IN, 15)
        self.create_quant_element_by_data_path(DataPath.LAYER_IN_MASK, 15)

        self.create_quant_element_by_data_path(DataPath.EXP_NUME, exp_nume)
        self.create_quant_element_by_data_path(DataPath.EXP_DENO, exp_deno)

        # self.ew_sub_op.set_output_lossy_element(APUOutputSignedQuantElement(bits=16), index=0)
        # self.exp_denominator.set_input_lossy_element(APUOutputSignedQuantElement(bits=16), index=0)
        # self.exp_numerator.set_input_lossy_element(APUOutputSignedQuantElement(bits=16), index=0)

        # self.exp_denominator.set_output_lossy_element(APUOutputSignedQuantElement(bits=exp_deno), index=0)
        # self.exp_numerator.set_output_lossy_element(APUOutputSignedQuantElement(bits=exp_nume), index=0)

        self.shift_exp_numerator.create_weight_quant_element(bits=exp_nume)

        self.exp_denominator.create_weight_quant_element(clip_bits=LUT_IN_BITS)

        self.create_quant_element_by_data_path(DataPath.LAYER_E_X_SUM, e_x_reduce_sum_bits)
        self.e_x_sum.create_weight_quant_element(e_x_reduce_sum_bits, e_x_reduce_sum_bits)

        clip_bits = ACTIVATION_CLIP_BITS_HAILO_LAYER_NORM
        self.softmax_op.create_weight_quant_element(clip_bits)

        pre_act_lossy_element = APUOutputSignedQuantElement(bits=32)
        self.softmax_op.set_output_lossy_element(pre_act_lossy_element, index=0)
        self.act_op.set_input_lossy_element(pre_act_lossy_element, index=0)
        self.act_op.create_weight_quant_element(optimization_target)

        self.act_op.FLOAT_TYPE_NP = self.softmax_op.FLOAT_TYPE_NP
        self.act_op.INT_TYPE_NP = self.softmax_op.INT_TYPE_NP

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        act_stats_cfg_out = stats_cfg
        if output_hist:
            act_stats_cfg_out = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)

        act_stats_cfg_preact = stats_cfg
        if preact_hist:
            act_stats_cfg_preact = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)

        self.exp_denominator.start_stats_collection(
            stats_cfg=act_stats_cfg_preact, collect_inputs=True, collect_output=True
        )
        self.shift_exp_numerator.start_stats_collection(
            stats_cfg=act_stats_cfg_preact, collect_inputs=True, collect_output=True
        )

        self.reduce_max_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

        self.ew_sub_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)
        self.e_x_sum.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)
        self.softmax_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        for op, input_index in self._input_stats_ops():
            op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=False)
        for op, output_index in self._output_stats_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_out, collect_inputs=False, collect_output=True)
        for op in self._iterate_act_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_preact, collect_inputs=True, collect_output=False)

    @staticmethod
    def _calc_new_min(min_of_reduce_max, bits, groups: int = 1):
        """
        min_of_reduce_max - a vector of the min  output stats by channel(groups) of reduce_max layer
        bits- number of bits in the output of matmul.
        """
        to_reduce = np.log(1 / 2 ** (bits + 1))
        return np.min(min_of_reduce_max.reshape(groups, -1), axis=1) + to_reduce

    def get_softmax_new_range(self):
        """
        Apply min on all maximums of softmax base on reduce_max layer
        for every vector x we reduce from it its max value and only then pass it to softmax.
        So actually the main value that is passed is the x- max(x).

        Instead of using the regular min max limvals which is for every x:
            limvals = min(min(x)),  max(max(x))

        We will use:
            limvals = min(max(x)) + np.log(1/2**(bits+1)),  max(max(x))

        """
        output_stats_reduce_max = self.reduce_max_op.get_output_stats(0)
        groups = self.reduce_max_op._groups
        bits = 16
        new_min = self._calc_new_min(output_stats_reduce_max.min, bits, groups)
        return new_min, groups

    def change_softmax_stats_layer(self, new_min, groups):
        # the is a vector of shap (1, channels = (channnel_per_group*groups))
        out_stats_of_input = self.input_op.get_input_stats(0)
        # Get mins based on groups
        stats_group_min = out_stats_of_input.min.reshape(groups, -1).min(axis=1)

        # if new_min<stats_group_min then we will use the stats_group_min
        best_min = np.maximum(new_min, stats_group_min)

        # print(f"old_min: {stats_group_min}")
        # print(f"new_min: {new_min}")
        # print(f"best_min: {best_min}")

        repates = out_stats_of_input.min.size / groups
        out_stats_of_input.min[...] = best_min.repeat(repates)

        inputs_stats_exp_denominator = self.exp_denominator.get_input_stats(0)
        reduce_max_stats = self.reduce_max_op.get_output_stats(0)

        # get new max of input to exp based on the new_min
        new_max_input_to_exp = self._new_max_of_ew_sub(
            reduce_max_stats, out_stats_of_input, inputs_stats_exp_denominator, repates
        )

        inputs_stats_exp_denominator.max[...] = new_max_input_to_exp
        output_stats_ew_sub = self.ew_sub_op.get_output_stats(0)
        output_stats_ew_sub.max[...] = new_max_input_to_exp

    def _new_max_of_ew_sub(self, r, x, r_minus_x, repates):
        ## we would like to calculate the new max of (r-x) --
        # r - x <= r.max - x.min
        # r - x <= (r - x).max ==>>

        # r-x < = min(r.max - x.min, (r-x).max) ==>>
        # max(r -x) <= min(r.max - x.min, (r-x).max)

        return np.minimum(np.repeat(r.max.copy(), repates) - x.min.copy(), r_minus_x.max)

    def _softmax_smart_clipping_hack(self):
        new_min, groups = self.get_softmax_new_range()
        self.change_softmax_stats_layer(new_min, groups)

    def _layer_dependent_hw_params_modifications(self, params):
        activation_ebias_mode = self.get_precision_mode() == PrecisionMode.a16_w16_a16
        params["activation_ebias_mode"] = np.array(activation_ebias_mode, np.uint8)
        return params

    def _groups_softmax_hack(self):
        input_scales, inp_zero_points = self.get_scales_inputs()
        output_scales, out_zero_points = self.get_scales_output()
        self.set_input_scale(input_scales, 0)
        self.set_input_zero_point(inp_zero_points, 0)

        self.set_output_scale(output_scales, 0)
        self.set_output_zero_point(out_zero_points, 0)

    def _helper_scale_zp(self, lossy_element, min_vals, max_vals):
        zp_scales_vector = np.array(
            [
                res
                for res in map(
                    lambda lin_vals: limvals_to_zp_scale(lin_vals, lossy_element)[:2], zip(min_vals, max_vals)
                )
            ]
        )
        zp = zp_scales_vector[:, 0].repeat(self.output_shape[-1] / self.groups)
        scales = zp_scales_vector[:, 1].repeat(self.output_shape[-1] / self.groups)
        return scales, zp

    def get_scales_output(self) -> Tuple[np.array, int]:
        lossy_element = self.get_output_lossy_elements()[0]
        min_vals, max_vals = self.get_group_output_limvals(self.groups)[0]
        return self._helper_scale_zp(lossy_element, min_vals, max_vals)

    def get_scales_inputs(self) -> Tuple[np.array, int]:
        lossy_element = self.get_input_lossy_elements()[0]
        min_vals, max_vals = self.get_group_input_limvals(self.groups)[0]
        return self._helper_scale_zp(lossy_element, min_vals, max_vals)
