from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.reduce_sum_op import ReduceSumOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoReduceSum(BaseHailoLayer):
    """
    Implement Hailo reduce_sum layer,
        - takes one inputs,
        - multiply the input by the weight
        - sums along the given axis in the acc
        - activation in the APU
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
        BiasMode.double_scale_initialization,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.REDUCE_SUM

    def __init__(
        self,
        name: str,
        groups: int = 1,
        reduce_axes: tuple = (3,),
        height_groups: int = 1,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        logger=None,
        **kwargs,
    ):
        self.reduce_sum_op = ReduceSumOp(
            f"{name}/reduce_sum_op",
            groups=groups,
            reduce_axes=reduce_axes,
            height_groups=height_groups,
            logger=logger,
        )
        self.bias_op = AddBiasOp.get_passthru_bias(f"{name}/bias_add_op", logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(
            f"{name}/passthru_op",
            logger=logger,
        )  # enabling output quantization even as activation is fully native...

        super().__init__(name=name, logger=logger, **kwargs)
        self._groups = groups
        self._reduce_axes = reduce_axes
        self._height_groups = height_groups

    @property
    def pre_acc_shift(self):
        return self.reduce_sum_op.pre_acc_shift

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        layer_flow.add_node(self.reduce_sum_op)
        layer_flow.add_node(self.bias_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.reduce_sum_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.reduce_sum_op, self.bias_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def neg_weights(self):
        self.reduce_sum_op.import_weights({"kernel": -np.array(self.reduce_sum_op.kernel)})
        self.bias_op.import_weights(-self.bias_op.export_weights())

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        Calls infer_encodings for underlying atomic ops.
        """
        # TODO: I don't like that the scales and zp are external properties, and I have to set them explicitly.
        #       Which affects the infer_encoding implicitly
        self._enforce_output_encoding()
        self.reduce_sum_op.enforce_encoding()
        self.bias_op.input_scales = [self.reduce_sum_op.output_scale]
        self.bias_op.input_zero_points = [self.reduce_sum_op.output_zero_point]
        self.bias_op.output_scale = self.bias_op.input_scales[0]
        self.bias_op.enforce_encoding()
        self.act_op.input_scales = [self.bias_op.output_scale]
        self.act_op.input_zero_points = [self.bias_op.output_zero_point]
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def import_weights(self, layer_params: LayerParams):
        """
        load parameters to the layer. currently, it doesn't to anything.

        Args:
            layer_params: layer's params from the npz

        """
        # TODO: Do we want to load kernel and bias values? (instead of the auto-generated values)
        self.reduce_sum_op.import_weights(layer_params)
        self.bias_op.import_weights(layer_params.get("bias", np.array(0.0)))
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        weights = dict()
        weights.update(self.reduce_sum_op.export_weights())
        weights.update(self.act_op.export_weights())
        weights.update({"bias": self.bias_op.export_weights()})
        return weights

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "reduce_axes": [3],
            "groups": 1,
            "activation": "linear",
            "height_groups": 1,
        }
        return dict(defaults)

    def to_hn(self, out_degree=None):
        params = self._hn_element.get("params", dict())
        params["groups"] = self.reduce_sum_op._groups
        params["activation"] = self.act_op.act_name.value
        params["reduce_axes"] = self.reduce_sum_op._reduce_axes
        params["height_groups"] = self.reduce_sum_op._height_groups

        if "params" not in self._hn_element:
            self._hn_element["params"] = {}
        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        input_shapes = hn_element.get("input_shapes", [])
        if len(input_shapes) == 1 and len(input_shapes[0]) == 2 and len(params["reduce_axes"]) == 1:
            params["reduce_axes"] = [params["reduce_axes"][0] + 2]
        layer = cls(
            name=lname,
            groups=params["groups"],
            activation=params["activation"],
            reduce_axes=params["reduce_axes"],
            height_groups=params["height_groups"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()
        pre_act_stats = self.act_op.get_input_stats(0)
        max_final_accumulator_by_channel = np.maximum(np.abs(pre_act_stats.min), np.abs(pre_act_stats.max))
        utilize_wraparound = self.act_op.is_negative_input()
        self.reduce_sum_op.create_hw_params(
            max_output_per_channel=max_final_accumulator_by_channel,
            utilize_wraparound=utilize_wraparound,
            hw_shifts=hw_shifts,
        )
        self.bias_op.pre_acc_shift = self.reduce_sum_op.pre_acc_shift
        if utilize_wraparound:  # fixes an edge case where resulted accumulator is exactly -2**15
            self.bias_op.output_zero_points = [np.float32(1)]
        self.act_op.create_hw_params(
            self.reduce_sum_op.output_scale, optimization_target, nudging=False, utilize_wraparound=utilize_wraparound
        )
        self.enforce_internal_encoding()
        self.bias_op.create_hw_params()

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups

        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        if kernel_bits == 15 and not signed:
            kernel_bits, signed = 16, True
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)
        self.reduce_sum_op.create_weight_quant_element(kernel_bits, signed)
        self.bias_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

        # set quantization groups - we now dont support it but maybe in the future?
        self.act_op.set_quantization_groups(quant_groups)

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def enforce_io_encoding(self, training=False, **kwargs):
        self.enforce_internal_encoding()

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return self.SUPPORTED_PRECISION_MODE - {PrecisionMode.a16_w16_a8}
        else:
            return super()._get_precision_mode_supported_in_hw(arch)

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch):
        if (
            precision_mode in {PrecisionMode.a8_w4_a8, PrecisionMode.a8_w4_a16, PrecisionMode.a8_w4}
            and bias_mode != BiasMode.double_scale_initialization
        ):
            return False
        return True
