from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.split_precision_op import (
    PrecisionAddOp,
    SplitPrecisionHigh,
    SplitPrecisionLow,
)
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    AccumulatorQuantElement,
    APUOutputQuantElement,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    HW_SHIFTS_PLUTO,
    ZP_LOW_SPLIT_PRECISION_PIXEL,
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PaddingType,
    PrecisionMode,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.acceleras.utils.hn_npz_utils import (
    LayerParams,
    get_hn_padding,
    set_hn_padding_stride_align,
)
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoConvA16PreActSum(BaseHailoLayer):
    """
    implement conv A16_W8 or A16_W4 on pluto arch where we have adder in the APU that can sum the high and low
    """

    _hn_type = LayerType.CONV

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a16_w8,
        PrecisionMode.a16_w4,
        PrecisionMode.a16_w8_a16,
        PrecisionMode.a16_w4_a16,
        PrecisionMode.a16_w8_a8,
        PrecisionMode.a16_w4_a8,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        filters,
        kernel_size,
        strides=(1, 1),
        padding: Union[str, PaddingType] = "SAME",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate=(1, 1),
        groups=1,
        activation: Union[str, callable, ActivationType] = "linear",
        transpose_output_width_features=False,
        logger=None,
        set_scale_by_kernel_only=False,
        **kwargs,
    ):
        self.transpose_output_width_features = transpose_output_width_features
        self.set_scale_by_kernel_only = set_scale_by_kernel_only
        self.input_op = PassthruOp(f"{name}/input_passthough_op", logger=logger)
        self.split_precision_low = SplitPrecisionLow(f"{name}/split_precision_low_op", logger=logger)
        self.split_precision_high = SplitPrecisionHigh(f"{name}/split_precision_high_op", logger=logger)
        self.conv_h_op = ConvStrippedOp(
            f"{name}/conv_h_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            trainable=False,
            logger=logger,
        )
        self.conv_l_op = ConvStrippedOp(
            f"{name}/conv_l_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            trainable=False,
            logger=logger,
        )
        self.bias_add_h_op = AddBiasOp(f"{name}/bias_add_h_op", trainable=False, logger=logger)
        self.bias_add_l_op = AddBiasOp(f"{name}/bias_add_l_op", trainable=False, logger=logger)

        self.concat_op = ConcatOp(f"{name}/concat_op", concat_elements=2, logger=logger)

        self.precision_add_op = PrecisionAddOp(f"{name}/precision_add_op", num_decompositions=2, logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        self.all_convs = [self.conv_l_op, self.conv_h_op]
        self.all_biases = [self.bias_add_l_op, self.bias_add_h_op]
        super().__init__(name=name, logger=logger, **kwargs)
        self.kernel_scale_forced_to_save = False
        for op in self.atomic_ops:
            op.fully_native = True
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
        self._hw_shifts = HW_SHIFTS_PLUTO

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)

        cls._validate_elwa(params["elementwise_add"])

        if params.get("transpose_output_width_features", False):
            raise AccelerasImplementationError("transpose_output_width_features is not supported in acceleras yet")
        layer = cls(
            name=lname,
            filters=kshape[-1],
            kernel_size=kshape[0:2],
            padding=padding,
            stride_align=stride_align,
            strides=params["strides"][1:3],
            groups=params["groups"],
            activation=params["activation"],
            dilation_rate=params["dilations"][1:3],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        params = self._hn_element.get("params", dict())
        weights = self.export_weights()
        params["kernel_shape"] = list(weights["kernel"].shape)
        strides = self.conv_h_op.strides
        params["strides"] = [1, strides[0], strides[1], 1]
        params["groups"] = self.conv_h_op.groups
        params["activation"] = self.act_op.act_name.value
        dilation_rate = self.conv_h_op.dilation_rate
        params["dilations"] = [1, dilation_rate[0], dilation_rate[1], 1]
        set_hn_padding_stride_align(params, self.conv_h_op.padding, self.conv_h_op.stride_align)
        params["elementwise_add"] = False
        params["decompose_weights"] = True
        self._hn_element["params"] = params
        return super().to_hn(out_degree=out_degree)

    def _get_kernel_bits(self):
        return self.conv_h_op.weight_lossy_elements.kernel.bits

    def _build_flow(self) -> LayerFlow:
        """
        Split precision -> conv -> bias -> concat -> activation -> output
        """
        layer_flow = LayerFlow()
        layer_flow = self._init_flow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        layer_flow.add_edge(in1, self.input_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_low, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_high, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(self.split_precision_low, self.split_precision_high, DataPath.LAYER_IN, input_index=1)
        layer_flow.add_edge(self.split_precision_low, self.conv_l_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.split_precision_high, self.conv_h_op, DataPath.LAYER_IN)
        for conv, bias in zip(self.all_convs, self.all_biases):
            layer_flow.add_edge(conv, bias, DataPath.ACCUMULATOR)
        for index, bias in enumerate(self.all_biases):
            layer_flow.add_edge(bias, self.concat_op, DataPath.ACCUMULATOR, input_index=index)

        layer_flow.add_edge(self.concat_op, self.precision_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.precision_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def import_weights(self, layer_params: LayerParams):
        self._forced_kernel_high_scale = layer_params.get("forced_kernel_high_scale", None)
        self.conv_h_op.import_weights(layer_params["kernel"])
        self.conv_l_op.import_weights(layer_params["kernel"])

        self.bias_add_h_op.import_weights(layer_params["bias"])
        self.bias_add_l_op.import_weights(np.zeros_like(layer_params["bias"]))
        self.act_op.import_weights(layer_params)
        self.precision_add_op.output_channels = layer_params["kernel"].shape[-1]

    def _export_weights(self):
        dict_params = {
            "kernel": self.conv_h_op.export_weights()["kernel"],
            "bias": self.bias_add_h_op.export_weights(),
            "padding_const_value": self.conv_h_op.export_weights()["padding_const_value"],
        }
        dict_params.update(self.act_op.export_weights())
        return dict_params

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        super().start_stats_collection(stats_cfg, output_hist, preact_hist)
        self.split_precision_low.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        self.split_precision_high.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        self.precision_add_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        for conv in self.all_convs:
            conv.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        for bias in self.all_biases:
            bias.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

        self.act_op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)

    def create_splits(self):
        self.split_precision_low.trivial_split = False
        self.split_precision_low.create_input_encoding_candidates(0, split_precision_zp=ZP_LOW_SPLIT_PRECISION_PIXEL)
        self.split_precision_low.enforce_encoding()
        self.split_precision_high.input_scales[0] = self.split_precision_low.input_scale
        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scale
        self.split_precision_high.enforce_encoding()

    @classmethod
    def from_conv(cls, conv_layer: HailoConv, logger=None):
        conv_op = conv_layer.conv_op
        inst = cls(
            name=conv_layer.name,  # TODO: check this
            filters=conv_op.kernel.shape[-1],
            kernel_size=conv_op.kernel.shape[0:2],
            strides=conv_op.strides,
            padding=conv_op.padding,
            stride_align=conv_op.stride_align,
            dilation_rate=conv_op.dilation_rate,
            activation=conv_layer.act_op.act_name,
            logger=logger,
        )
        inst.import_weights(conv_layer.export_weights())
        return inst

    def get_numeric_kernel_np(self):
        return self.avgpool_op.final_quantized_kernel.numpy() * (
            np.ones(self.avgpool_op.kernel_shape_prod * self.input_shape[-1])
        )

    def _export_ops_hw_params(self):
        params = super()._export_ops_hw_params()
        params.update(self.conv_h_op.export_hw_params())
        for key in ["bias", "bias_q"]:
            bias = np.concatenate([op.export_hw_params()[key] for op in self.all_biases], axis=-1)
            params[key] = bias
        params["pre_act_sum_shift"] = np.log2(self.precision_add_op.export_hw_params()["kernel"]).astype(np.uint8)
        return params

    @property
    def activation_atomic_op(self):
        return self.act_op

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "strides": [1, 1, 1, 1],
            "dilations": [1, 1, 1, 1],
            "padding": "SAME",
            "activation": "linear",
            "groups": 1,
            "elementwise_add": False,
        }
        return dict(defaults)

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        self.split_precision_low.set_input_lossy_element(APUOutputQuantElement(bits=16), index=0)
        self.split_precision_low.set_output_lossy_element(APUOutputQuantElement(bits=8))
        self.split_precision_high.set_output_lossy_element(APUOutputQuantElement(bits=8))
        self.split_precision_high.set_input_lossy_element(APUOutputQuantElement(bits=8), index=1)
        self.split_precision_high.set_input_lossy_element(APUOutputQuantElement(bits=16), index=0)
        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_config.precision_mode)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        for op in self.all_convs:
            op.set_input_lossy_element(APUOutputQuantElement(bits=8))
            op.create_weight_quant_element(kernel_bits=kernel_bits, signed=signed)
        for op in self.all_biases:
            op.create_weight_quant_element(kernel_bits=kernel_bits, signed=signed, num_decomposition=num_decomposition)
        self.precision_add_op.set_output_lossy_element(AccumulatorQuantElement(bits=32))
        self.act_op.set_input_lossy_element(AccumulatorQuantElement(bits=32))
        self.act_op.create_weight_quant_element(optimization_target)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()

        kernel_scale_matrix_component = self.get_kernel_scale_matrix_component()
        accumulator_stats_high = self.bias_add_h_op.get_output_stats(0)
        accumulator_stats_low = self.bias_add_l_op.get_output_stats(0)
        max_final_accumulator_by_channel_h = np.maximum(
            np.abs(accumulator_stats_high.min),
            np.abs(accumulator_stats_high.max),
            dtype=np.float32,
        )
        max_final_accumulator_by_channel_l = np.maximum(
            np.abs(accumulator_stats_low.min),
            np.abs(accumulator_stats_low.max),
            dtype=np.float32,
        )

        self.conv_h_op.create_hw_params(
            max_final_accumulator_by_channel_h,
            weights_clipping,
            optimization_target,
            kernel_scale_matrix_component=kernel_scale_matrix_component,
            hw_shifts=hw_shifts,
        )

        self.conv_l_op.create_hw_params(
            max_final_accumulator_by_channel_l,
            weights_clipping,
            optimization_target,
            kernel_scale_matrix_component=kernel_scale_matrix_component,
            hw_shifts=hw_shifts,
        )
        for conv in self.all_convs:
            conv.output_scale = conv.accumulator_scale_candidate
            self._propagate_encoding_forward(conv, enforce_encoding=False)
        for op in self.all_biases + [self.concat_op]:
            self._propagate_encoding_forward(op, enforce_encoding=True)
        self.bias_add_h_op.pre_acc_shift = self.conv_h_op.pre_acc_shift
        self.bias_add_l_op.pre_acc_shift = self.conv_l_op.pre_acc_shift
        self.precision_add_op.create_hw_params()
        self._propagate_encoding_forward(self.precision_add_op, enforce_encoding=True)
        nudging = (
            not (self.kernel_scale_forced_to_save) and not self.conv_h_op.set_scale_by_kernel_only
        )  # dont nudge id the kernel_q_forced is True
        self.act_op.create_hw_params(self.act_op.input_scale, optimization_target, nudging=nudging)
        self.enforce_internal_encoding()
        for bias in self.all_biases:
            bias.create_hw_params()

    def get_kernel_scale_matrix_component(self):
        self._enforce_input_encoding()
        self._enforce_output_encoding()
        output_scale = self.act_op.output_scale / 2**self._negative_slope_exponent_fix_shift
        return self.conv_h_op.calc_kernel_scale(self.conv_h_op.input_scales, output_scale, 0)

    def _enforce_split_precision_encoding(self):
        self.input_op.enforce_encoding()
        self.split_precision_low.input_scales[0] = self.input_op.output_scales[0]
        self.split_precision_high.input_scales[0] = self.input_op.output_scales[0]
        self.split_precision_low.input_zero_points[0] = self.input_op.output_zero_points[0]
        self.split_precision_high.input_zero_points[0] = self.input_op.output_zero_points[0]
        self.split_precision_low.enforce_encoding()
        self.split_precision_high.enforce_encoding()
        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scales[0]
        self.split_precision_high.input_zero_points[1] = self.split_precision_low.output_zero_points[0]
        self.conv_h_op.input_scale = self.split_precision_high.output_scale
        self.conv_h_op.input_zero_point = self.split_precision_high.output_zero_point
        self.conv_l_op.input_scale = self.split_precision_low.output_scale
        self.conv_l_op.input_zero_point = self.split_precision_low.output_zero_point

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self._propagate_encoding_backward(self.output_op)

    def _enforce_input_encoding(self):
        self._propagate_encoding_forward(self.split_precision_low, enforce_encoding=True)
        self._propagate_encoding_forward(self.split_precision_high, enforce_encoding=True)

    def _accumulator_scale_from_apu(self):
        """
        Note - Accumulator scale is fully defined by output and APU params,
        we resolve it in Activation class and use for all earlier op scales.+
        """
        self.act_op.get_accumulator_scale()
        self.act_op.enforce_encoding()
        self._propogate_apu_into_ops()

    def _propogate_apu_into_ops(self):
        self._propagate_encoding_backward(self.act_op)
        for op in [self.precision_add_op, self.concat_op] + self.all_biases:
            self._propagate_encoding_backward(op, enforce_encoding=True)

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        This is a forward path computation of the encoding enforement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        Note that as we set the avgpool op weight to be an exact power of two, we don't have to set the apu encodings.
        """
        self._enforce_output_encoding()
        self._enforce_input_encoding()
        self._accumulator_scale_from_apu()
        self._enforce_split_precision_encoding()
        for conv in self.all_convs:
            self._propagate_encoding_forward(conv, enforce_encoding=True)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    @classmethod
    def _validate_elwa(cls, elwa_value):
        if elwa_value:
            raise ValueError(
                f"elementwise_add value was {elwa_value}, "
                f"but expected {not elwa_value} in {cls.__name__} initialization",
            )

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def consumer_input_scale(self):
        return False

    @property
    def homogeneous(self):
        return False

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch):
        return True
