from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.element_wise_add_op import ElementwiseAddOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    DataPath,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import InvalidInputShape
from hailo_model_optimization.acceleras.utils.opt_utils import get_kernel_bits_and_sign_by_precision_mode


class BaseHailoConvAdd(BaseHailoConv):
    """
    This class extends BaseHailoConv with elementwise-addition functionality,
     to arrive at a "Conv & Add" implementation:
    additional L3 input is multiplied by a constant then added to the accumulator
    (along with the usual conv and bias addends).

    See the datapath and quantization scheme overview thru the link below:
    https://hailotech.atlassian.net/wiki/spaces/ML/pages/986185817/H8+Core+numerics+101+-+full+Conv+add+datapath

    This abstract class covers the (extra) Quantization and Emulation functionality,
     while the single concrete HailoConvAdd subclass fills in the methods
     left to be implemented by BaseConv (see comment there).
    """

    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        conv_op: ConvStrippedOp,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        ew_add_factor=1,
        logger=None,
        **kwargs,
    ):
        self.elwa_op = ElementwiseAddOp(f"{name}/elementwise_add_op", ew_add_factor, logger=logger)
        super().__init__(
            name=name,
            conv_op=conv_op,
            activation=activation,
            logger=logger,
            **kwargs,
        )
        self.input_spec = [tf.keras.layers.InputSpec(ndim=4) for _ in range(2)]

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()

        in1 = layer_flow.add_input()
        in2 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.conv_op)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.elwa_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.conv_op, DataPath.LAYER_IN)
        layer_flow.add_edge(in2, self.elwa_op, DataPath.LAYER_IN, input_index=1)

        layer_flow.add_edge(self.conv_op, self.elwa_op, DataPath.ACCUMULATOR, input_index=0)
        layer_flow.add_edge(self.elwa_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)

        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def _export_bias_kernel_params(self):
        bias_kernel_info = super()._export_bias_kernel_params()
        output_offset_elwa = self.elwa_op._zp_change.numpy() * 2**self.pre_acc_shift
        residue = bias_kernel_info["residue"] - output_offset_elwa
        bias_kernel_info["residue"] = residue
        bias_kernel_info["residue_0"] = residue
        return bias_kernel_info

    def _accumulator_scale_from_apu(self):
        super()._accumulator_scale_from_apu()
        self.elwa_op.input_scales[0] = self.acc_scale
        self.elwa_op.output_scale = self.acc_scale

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        Following conv implementation (see docstring in BaseHailoConv.enforce_internal_encoding),
        injecting appropriate calculation for elementwise op.
        """
        self._enforce_output_encoding()
        self._accumulator_scale_from_apu()
        self.conv_op.enforce_encoding(training=training)

        self.elwa_op.input_zero_points[0] = self.conv_op.output_zero_point
        self.elwa_op.enforce_encoding(training=training)

        self.bias_add_op.input_zero_points = [self.elwa_op.output_zero_point]
        self.bias_add_op.enforce_encoding()

        self.act_op.input_zero_points = [self.bias_add_op.output_zero_point]
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, training=False, **kwargs):
        self.conv_op.output_zero_point = self.conv_op.compute_output_zp(training=training)
        self.elwa_op.input_zero_points[0] = self.conv_op.output_zero_point
        self.elwa_op.output_zero_point = self.elwa_op.compute_output_zp()
        self.bias_add_op.input_zero_points = [self.elwa_op.output_zero_point]

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        super().create_hw_params(weights_clipping, optimization_target, hw_shifts=hw_shifts)
        self.elwa_op.pre_acc_shift = self.pre_acc_shift
        bits = self.elwa_op.weight_lossy_elements.factor.bits
        max_int = 2 ** (bits - 1) - 1.0

        input_factor_elwa_candidate_not_rounded = (
            self.input_scales[1] * 2**self.pre_acc_shift / self.conv_op.accumulator_scale_candidate
        )

        if input_factor_elwa_candidate_not_rounded.shape != ():
            eps = 1e-3  # TODO what is the correct scale?
            mean_input_factor = np.mean(input_factor_elwa_candidate_not_rounded)
            diff = np.max(np.abs(input_factor_elwa_candidate_not_rounded - mean_input_factor) / mean_input_factor)
            if eps < diff:
                self._logger.warning(f"input_factor_elwa_candidate should be scalar, there is a diff: {diff}")

        input_factor_elwa_candidate_rounded = np.max(np.round(input_factor_elwa_candidate_not_rounded))

        feedrepeats_fixfactor = input_factor_elwa_candidate_rounded / (max_int * self.elwa_op.max_feed_repeat)
        if feedrepeats_fixfactor > 1:
            # In order to meet the feed_repeat bound we do the following:
            # accumulator scale grows → numeric accumulator values shrink →
            # less repeats per same elwa numeric input (as needed).
            # as a side effect, kernel scale grows so it’s effectively, Wthe same as the “shift_delta” effect.
            self.conv_op.accumulator_scale_candidate = self.conv_op.accumulator_scale_candidate * feedrepeats_fixfactor
            self.act_op.create_hw_params(self.conv_op.accumulator_scale_candidate, optimization_target)
            self.enforce_internal_encoding()

        self.elwa_op.create_hw_params()
        # NOTE: the pre-bias accumulator ZP changed by elwa-add, need to recalc "additive" stuff,
        #      so probably self.elwa_op.enforce_encoding(); so self.bias_add_op.enforce_encoding() could suffice
        #  but just to be on the safe side, let's re-run the whole enforce_encoding (+bias decompose..)
        self._create_hw_params_finalize()

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget
    ):
        super().create_quant_element_custom_behavior(precision_config, optimization_target)
        precision_mode = precision_config.precision_mode

        factor_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        factor_bits = max(factor_bits, 8)

        self.elwa_op.create_weight_quant_element(factor_bits, signed)

    @classmethod
    def _validate_elwa(cls, elwa_value):
        if not elwa_value:
            raise ValueError(
                f"elementwise_add value was {elwa_value}, "
                f"but expected {not elwa_value} in {cls.__name__} initialization",
            )

    def define_constraints(self, enc):
        super().define_constraints(enc)
        if not ((self.conv_op.encoding_const or self.bias_add_op.encoding_const) and self.elwa_op.encoding_const):
            enc.identity(f"{self.conv_op.full_name}/mac_shift:0", f"{self.elwa_op.full_name}/mac_shift:0")

    def _get_inefficiency_factor(self):
        """Get the inefficiency factor of the layer. based on the precision mode"""
        precision = self.get_precision_mode()
        if precision in [PrecisionMode.a8_w4, PrecisionMode.a8_w4_a16, PrecisionMode.a8_w4_a8]:
            return 1.0
        elif precision in [PrecisionMode.a8_w8, PrecisionMode.a8_w8_a16, PrecisionMode.a8_w8_a8] or precision in [
            PrecisionMode.a16_w16,
            PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w16_a8,
            PrecisionMode.a16_w16_non_zero,
        ]:
            return 8
        else:
            raise ValueError(f"precision mode {precision} is not supported")

    def _supported_quantization_groups_hw(self, quantization_groups, arch):
        return False

    def verify_layer_inputs_shape(self, input_shapes):
        super().verify_layer_inputs_shape(input_shapes)
        conv_out_shape = self.conv_op.compute_output_shape(input_shapes[0])
        if not all(x == y for x, y in zip(conv_out_shape, input_shapes[1])):
            raise InvalidInputShape(
                f"Input shapes {input_shapes} doesn't match each other in {self.full_name}", self.full_name
            )
