from typing import Optional, Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.reduce_sum_op import (
    ReduceSumOp,
)
from hailo_model_optimization.acceleras.atomic_ops.split_precision_op import (
    PrecisionAddOp,
    SplitPrecisionHigh,
    SplitPrecisionLow,
)
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    AccumulatorQuantElement,
    APUOutputQuantElement,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    HW_SHIFTS_PLUTO,
    ZP_LOW_SPLIT_PRECISION_PIXEL,
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import (
    LayerParams,
)
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoReduceSumA16PreActSum(BaseHailoLayer):
    """
    implement reduce sum A16_W8 or A16_W4 on pluto arch where we have adder in the APU that can sum the high and low
    """

    _hn_type = LayerType.REDUCE_SUM

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a16_w8,
        PrecisionMode.a16_w8_a16,
        PrecisionMode.a16_w8_a8,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        groups=1,
        reduce_axes: tuple = (3,),
        height_groups: int = 1,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        logger=None,
        **kwargs,
    ):
        self.input_op = PassthruOp(f"{name}/input_passthough_op", logger=logger)
        self.split_precision_low = SplitPrecisionLow(f"{name}/split_precision_low_op", logger=logger)
        self.split_precision_high = SplitPrecisionHigh(f"{name}/split_precision_high_op", logger=logger)
        self.reduce_sum_h_op = ReduceSumOp(
            f"{name}/reduce_sum_h_op",
            groups=groups,
            reduce_axes=reduce_axes,
            height_groups=height_groups,
            logger=logger,
        )
        self.reduce_sum_l_op = ReduceSumOp(
            f"{name}/reduce_sum_l_op",
            groups=groups,
            reduce_axes=reduce_axes,
            height_groups=height_groups,
            logger=logger,
        )

        self.bias_add_h_op = AddBiasOp.get_passthru_bias(f"{name}/bias_add_h_op", logger=logger)
        self.bias_add_l_op = AddBiasOp.get_passthru_bias(f"{name}/bias_add_l_op", logger=logger)

        self.concat_op = ConcatOp(f"{name}/concat_op", concat_elements=2, logger=logger)

        self.precision_add_op = PrecisionAddOp(f"{name}/precision_add_op", num_decompositions=2, logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        self.all_reduce_sums = [self.reduce_sum_l_op, self.reduce_sum_h_op]
        self.all_biases = [self.bias_add_l_op, self.bias_add_h_op]
        super().__init__(name=name, logger=logger, **kwargs)

        self._groups = groups
        self._reduce_axes = reduce_axes
        self._height_groups = height_groups
        self._hw_shifts = HW_SHIFTS_PLUTO

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        input_shapes = hn_element.get("input_shapes", [])
        if len(input_shapes) == 1 and len(input_shapes[0]) == 2 and len(params["reduce_axes"]) == 1:
            params["reduce_axes"] = [params["reduce_axes"][0] + 2]
        layer = cls(
            name=lname,
            groups=params["groups"],
            activation=params["activation"],
            reduce_axes=params["reduce_axes"],
            height_groups=params["height_groups"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    @classmethod
    def from_reduce_sum_op(cls, reduce_sum_op: ReduceSumOp, logger=None):
        layer = cls(
            name=reduce_sum_op.name,
            groups=reduce_sum_op.groups,
            activation=ActivationType.LINEAR,
            reduce_axes=reduce_sum_op.reduce_axes,
            height_groups=reduce_sum_op.height_groups,
            logger=logger,
        )
        return layer

    def to_hn(self, out_degree: Optional[int] = None) -> dict:
        params = self._hn_element.get("params", dict())
        params["groups"] = self.reduce_sum_l_op.groups
        params["activation"] = self.act_op.act_name.value
        params["reduce_axes"] = self.reduce_sum_l_op.reduce_axes
        params["height_groups"] = self.reduce_sum_l_op.height_groups
        params["decompose_weights"] = True

        if "params" not in self._hn_element:
            self._hn_element["params"] = {}

        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)

    def _build_flow(self) -> LayerFlow:
        """
        Split precision -> reduce_sum -> zero bias -> concat -> activation -> output
        """
        layer_flow = LayerFlow()
        layer_flow = self._init_flow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        layer_flow.add_edge(in1, self.input_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_low, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_high, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(self.split_precision_low, self.split_precision_high, DataPath.LAYER_IN, input_index=1)
        layer_flow.add_edge(self.split_precision_low, self.reduce_sum_l_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.split_precision_high, self.reduce_sum_h_op, DataPath.LAYER_IN)
        for index, (reduce_sum, bias) in enumerate(zip(self.all_reduce_sums, self.all_biases)):
            layer_flow.add_edge(reduce_sum, bias, DataPath.ACCUMULATOR)
            layer_flow.add_edge(bias, self.concat_op, DataPath.ACCUMULATOR, input_index=index)

        layer_flow.add_edge(self.concat_op, self.precision_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.precision_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def import_weights(self, layer_params: LayerParams):
        self.reduce_sum_l_op.import_weights(layer_params)
        self.reduce_sum_h_op.import_weights(layer_params)

        self.bias_add_h_op.import_weights(layer_params.get("bias", np.array(0.0)))
        self.bias_add_l_op.import_weights(layer_params.get("bias", np.array(0.0)))
        self.act_op.import_weights(layer_params)

    def _export_weights(self) -> dict:
        weights = dict()
        weights.update(self.reduce_sum_h_op.export_weights())
        weights.update(self.act_op.export_weights())
        weights.update({"bias": self.bias_add_h_op.export_weights()})
        return weights

    def neg_weights(self):
        self.reduce_sum_h_op.import_weights({"kernel": -np.array(self.reduce_sum_h_op.kernel)})
        self.reduce_sum_l_op.import_weights({"kernel": -np.array(self.reduce_sum_l_op.kernel)})
        self.bias_add_l_op.import_weights(-self.bias_add_l_op.export_weights())
        self.bias_add_h_op.import_weights(-self.bias_add_h_op.export_weights())

    def start_stats_collection(
        self, stats_cfg: tuple = BasicTypeTuple, output_hist: bool = False, preact_hist: bool = False
    ):
        super().start_stats_collection(stats_cfg, output_hist, preact_hist)
        self.split_precision_low.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        self.split_precision_high.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        self.precision_add_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        for reduce_sum, bias in zip(self.all_reduce_sums, self.all_biases):
            reduce_sum.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
            bias.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

        self.act_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)

    def create_splits(self):
        self.split_precision_low.trivial_split = False
        self.split_precision_low.create_input_encoding_candidates(0, split_precision_zp=ZP_LOW_SPLIT_PRECISION_PIXEL)
        self.split_precision_low.enforce_encoding()
        self.split_precision_high.input_scales[0] = self.split_precision_low.input_scale
        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scale
        self.split_precision_high.enforce_encoding()

    def _export_ops_hw_params(self) -> dict:
        params = super()._export_ops_hw_params()
        params.update(self.reduce_sum_h_op.export_hw_params())
        for key in ["bias", "bias_q"]:
            params[key] = np.concatenate([op.export_hw_params()[key] for op in self.all_biases], axis=-1)
        params["kernel"] = np.concatenate([op.export_hw_params()["kernel"] for op in self.all_reduce_sums], axis=-1)
        params["pre_act_sum_shift"] = np.log2(self.precision_add_op.export_hw_params()["kernel"]).astype(np.uint8)
        params["output_stage/mult_shift"] = np.array(
            [op.export_hw_params()["output_stage/mult_shift"] for op in self.all_reduce_sums], dtype=np.uint8
        )
        return params

    @classmethod
    def get_default_params(cls):
        defaults = {
            "reduce_axes": [3],
            "groups": 1,
            "activation": "linear",
            "height_groups": 1,
        }
        return dict(defaults)

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        self.split_precision_low.set_input_lossy_element(APUOutputQuantElement(bits=16), index=0)
        self.split_precision_low.set_output_lossy_element(APUOutputQuantElement(bits=8))
        self.split_precision_high.set_output_lossy_element(APUOutputQuantElement(bits=8))
        self.split_precision_high.set_input_lossy_element(APUOutputQuantElement(bits=8), index=1)
        self.split_precision_high.set_input_lossy_element(APUOutputQuantElement(bits=16), index=0)
        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_config.precision_mode)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        for reduce_sum, bias in zip(self.all_reduce_sums, self.all_biases):
            reduce_sum.set_input_lossy_element(APUOutputQuantElement(bits=8))
            reduce_sum.create_weight_quant_element(kernel_bits, signed=signed)
            bias.create_weight_quant_element(
                kernel_bits=kernel_bits, signed=signed, num_decomposition=num_decomposition
            )
        self.precision_add_op.set_output_lossy_element(AccumulatorQuantElement(bits=32))
        self.act_op.set_input_lossy_element(AccumulatorQuantElement(bits=32))
        self.act_op.create_weight_quant_element(optimization_target)

    def _get_kernel_bits(self) -> int:
        return self.reduce_sum_h_op.weight_lossy_elements.factor.bits

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()

        accumulator_stats_high = self.bias_add_h_op.get_output_stats(0)
        accumulator_stats_low = self.bias_add_l_op.get_output_stats(0)
        max_final_accumulator_by_channel_h = np.maximum(
            np.abs(accumulator_stats_high.min),
            np.abs(accumulator_stats_high.max),
            dtype=np.float32,
        )
        max_final_accumulator_by_channel_l = np.maximum(
            np.abs(accumulator_stats_low.min),
            np.abs(accumulator_stats_low.max),
            dtype=np.float32,
        )
        if hw_shifts is None:
            hw_shifts = self._hw_shifts

        self.reduce_sum_h_op.create_hw_params(
            max_final_accumulator_by_channel_h,
            hw_shifts=hw_shifts,
        )

        self.reduce_sum_l_op.create_hw_params(
            max_final_accumulator_by_channel_l,
            hw_shifts=hw_shifts,
        )
        for reduce_sum in self.all_reduce_sums:
            self._propagate_encoding_forward(reduce_sum, enforce_encoding=False)
        for op in self.all_biases + [self.concat_op]:
            self._propagate_encoding_forward(op, enforce_encoding=True)

        self.bias_add_h_op.pre_acc_shift = self.reduce_sum_h_op.pre_acc_shift
        self.bias_add_l_op.pre_acc_shift = self.reduce_sum_l_op.pre_acc_shift
        self.precision_add_op.create_hw_params()
        self._propagate_encoding_forward(self.precision_add_op, enforce_encoding=True)

        self.act_op.create_hw_params(self.act_op.input_scale, optimization_target)
        self.enforce_internal_encoding()
        for bias in self.all_biases:
            bias.create_hw_params()

    def _enforce_split_precision_encoding(self):
        self.input_op.enforce_encoding()
        self.split_precision_low.input_scales[0] = self.input_op.output_scales[0]
        self.split_precision_high.input_scales[0] = self.input_op.output_scales[0]

        self.split_precision_low.input_zero_points[0] = self.input_op.output_zero_points[0]
        self.split_precision_high.input_zero_points[0] = self.input_op.output_zero_points[0]

        self.split_precision_low.enforce_encoding()

        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scales[0]
        self.split_precision_high.input_zero_points[1] = self.split_precision_low.output_zero_points[0]
        self.split_precision_high.enforce_encoding()

        self.reduce_sum_h_op.input_scale = self.split_precision_high.output_scale
        self.reduce_sum_h_op.input_zero_point = self.split_precision_high.output_zero_point

        self.reduce_sum_l_op.input_scale = self.split_precision_low.output_scale
        self.reduce_sum_l_op.input_zero_point = self.split_precision_low.output_zero_point

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self._propagate_encoding_backward(self.output_op)

    def _enforce_input_encoding(self):
        self._propagate_encoding_forward(self.split_precision_low, enforce_encoding=True)
        self._propagate_encoding_forward(self.split_precision_high, enforce_encoding=True)

    def _accumulator_scale_from_apu(self):
        """
        Note - Accumulator scale is fully defined by output and APU params,
        we resolve it in Activation class and use for all earlier op scales.+
        """
        self.act_op.get_accumulator_scale()
        self.act_op.enforce_encoding()
        self._propagate_apu_into_ops()

    def _propagate_apu_into_ops(self):
        self._propagate_encoding_backward(self.act_op)
        for op in [self.precision_add_op, self.concat_op] + self.all_biases:
            self._propagate_encoding_backward(op, enforce_encoding=True)

    def enforce_internal_encoding(self, training: bool = False, **kwargs):
        """
        This is a forward path computation of the encoding enforement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        Note that as we set the avgpool op weight to be an exact power of two, we don't have to set the apu encodings.
        """
        self._enforce_output_encoding()
        self.input_op.enforce_encoding()

        self._enforce_split_precision_encoding()

        self.reduce_sum_l_op.enforce_encoding()
        self.reduce_sum_h_op.enforce_encoding()

        self.bias_add_l_op.input_scales = [self.reduce_sum_l_op.output_scale]
        self.bias_add_h_op.input_scales = [self.reduce_sum_h_op.output_scale]
        self.bias_add_l_op.input_zero_points = [self.reduce_sum_l_op.output_zero_point]
        self.bias_add_h_op.input_zero_points = [self.reduce_sum_h_op.output_zero_point]
        self.bias_add_l_op.output_scale = self.bias_add_l_op.input_scales[0]
        self.bias_add_h_op.output_scale = self.bias_add_h_op.input_scales[0]
        self._propagate_encoding_forward(self.bias_add_l_op, enforce_encoding=True)
        self._propagate_encoding_forward(self.bias_add_h_op, enforce_encoding=True)
        self._propagate_encoding_forward(self.concat_op, enforce_encoding=True)
        self._propagate_encoding_forward(self.precision_add_op, enforce_encoding=True)
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    def get_equalization_handler_type(self, predecessor_index: Optional[int] = None) -> LayerHandlerType:
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch) -> bool:
        return True
