from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.element_wise_add_op import ElementwiseAddDirectOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasNumerizationError,
    InvalidInputShape,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoElementwiseAdd(BaseHailoLayer):
    """
    Represents a layer that performs elementwise addition between two L3 inputs,
    and a (normally trivialized) activation to arrive at L3 output.

    Example:
            >>> ew_add_layer = HailoElementwiseAdd()
        >>> rand_data1 = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> rand_data2 = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> ew_add_result = ew_add_layer(rand_data1, rand_data2)

    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w4,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w4_a8,
        PrecisionMode.a8_w4_a16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.single_scale_decomposition,
        BiasMode.double_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.ELEMENTWISE_ADD

    def __init__(
        self,
        name: str,
        activation: Union[str, callable, ActivationType] = "linear",
        input_repeats=None,
        logger=None,
        **kwargs,
    ):
        self.ew_add_op = ElementwiseAddDirectOp(f"{name}/elementwise_add_op", input_repeats, logger=logger)

        self.bias_op = AddBiasOp.get_passthru_bias(f"{name}/bias_add_op", logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation=activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        self.input_repeats = input_repeats if input_repeats else [[1, 1, 1], [1, 1, 1]]
        super().__init__(name=name, logger=logger, **kwargs)
        self.output_scale_scalar_dof = 1
        self.transparent = False

        self.encoding_const = False

    @property
    def pre_acc_shift(self):
        return self.ew_add_op.pre_acc_shift

    @property
    def groups(self):
        return 1

    @property
    def consumer_input_scale(self):
        return True

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        in2 = layer_flow.add_input()
        out1 = layer_flow.add_output()
        layer_flow.add_node(self.ew_add_op)
        layer_flow.add_node(self.bias_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.ew_add_op, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(in2, self.ew_add_op, DataPath.LAYER_IN, input_index=1)
        layer_flow.add_edge(self.ew_add_op, self.bias_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        """
        Class method to create HailoElementwiseAdd from hn file.
        as elementwise add doesn't require params, the hn_element is
        used solely for the activation type.

        Args:
            hn_element: `file object` created from opening an hn file.

        Example:
            >>> ew_add = HailoElementwiseAdd()
        >>> with open('hn_file.hn') as hn_element:
        ...     ew_add_layer = ew_add.from_hn(hn_element)
        >>> rand_data1 = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> rand_data2 = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> ew_add_result = ew_add_layer(rand_data1, rand_data2)

        """
        params = hn_element.get("params", dict())
        activation = params["activation"]
        input_repeats = params.get("input_repeats", [[1, 1, 1], [1, 1, 1]])

        layer = cls(
            name=lname,
            activation=activation,
            input_repeats=input_repeats,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        hn_element = super().to_hn(out_degree)
        _ = hn_element.setdefault("params", dict())
        hn_element["params"].update(
            {
                "activation": self.act_op.act_name.value,
                "input_repeats": self.input_repeats,
            }
        )

        return hn_element

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()
        accumulator_scale_candidate = self.act_op.get_input_scale_candidate()
        target_max_kernel = self._get_target_max_kernel_base_on_activation(self.act_op.act_name)
        self.ew_add_op.create_hw_params(
            output_scale_candidate=accumulator_scale_candidate, target_max_kernel=target_max_kernel, hw_shifts=hw_shifts
        )
        self.ew_add_op.enforce_encoding()
        self.act_op.create_hw_params(self.ew_add_op.output_scale, optimization_target, nudging=False)
        self.enforce_internal_encoding()
        self.bias_op.pre_acc_shift = self.ew_add_op.pre_acc_shift
        self.bias_op.create_hw_params()

    def _get_target_max_kernel_base_on_activation(self, activation_name):
        if (
            self.get_precision_mode().input_precision_mode() == PrecisionMode.a16_w16_a16
            and activation_name == ActivationType.SHIFT
        ):
            return 2**14
        else:
            return None  # Default value determined by the op

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def _set_accumulator_scale_into_ops(self, acc_scale):
        self.acc_scale = acc_scale
        self.ew_add_op.output_scale = acc_scale
        self.bias_op.input_scales[0] = acc_scale
        self.bias_op.output_scale = acc_scale

    def _accumulator_scale_from_apu(self):
        """
        Note - Accumulator scale is fully defined by output and APU params,
        we resolve it in Activation class and use for all earlier op scales.+
        """
        self.act_op.get_accumulator_scale()
        self._set_accumulator_scale_into_ops(self.act_op.input_scales[0])

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        we first calculate the scale of the accumulator accoring to the output scale and the number of groups.
        Then the ew_add_op gets 3 scales (2 for the input and one for the output) and calculate the ideal kernel
        """
        self._enforce_output_encoding()
        self._accumulator_scale_from_apu()
        self.ew_add_op.enforce_encoding(training=training)
        self.bias_op.input_zero_points = [self.ew_add_op.output_zero_point]
        self.bias_op.enforce_encoding()
        self.act_op.input_scales = [self.bias_op.output_scale]
        self.act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def _create_out_in_scale_ratio(self):
        """
        create the output_scale_scalar_dof
        """
        if not self.transparent:
            return
        _out_in_scale_ratio = self.output_scale / np.repeat(self.input_scales[0], self.input_repeats[0][-1], axis=-1)
        eps = 1e-6
        if _out_in_scale_ratio.shape != ():
            if eps < np.max(np.abs(_out_in_scale_ratio - _out_in_scale_ratio[0]) / _out_in_scale_ratio[0]):
                # Possible fail case: coming from concat, so input scale is scalar while output is vector..
                raise AccelerasNumerizationError(
                    f"output_scale - input_scale ratio of {self.full_name} should be a scalar"
                )
            # create attribute to be used in scales-training context should it come
            self.output_scale_scalar_dof = _out_in_scale_ratio[0]
        else:
            self.output_scale_scalar_dof = _out_in_scale_ratio

    def enforce_io_encoding(self, training=False, **kwargs):
        if self.transparent:
            self.output_op.output_scale = (
                np.repeat(self.input_scales[0], self.input_repeats[0][-1], axis=-1) * self.output_scale_scalar_dof
            )

    def update_scale_scalar_dof(self, shift):
        output_factor = 2**shift
        if self.transparent:
            self.output_scale_scalar_dof *= output_factor

    def export_hn(cls):
        raise AccelerasImplementationError("acceleras export hn is not supported yet")

    def import_weights(self, layer_params: LayerParams):
        """

        ew_add needs to load a bias term, write this function
        """
        kernel = layer_params.get("kernel", None)
        if kernel is not None:
            self.ew_add_op.import_weights(kernel)
        bias = layer_params.get("bias", None)
        if bias is not None:
            self.import_native_bias(bias)
        self.act_op.import_weights(layer_params)

    @property
    def is_changing_bias_supported(self):
        return True

    def export_native_bias(self):
        bias = self.bias_op.export_weights()
        return bias

    def import_native_bias(self, bias):
        self.bias_op.import_weights(bias)

    def _export_weights(self):
        weights = dict()
        weights.update(self.ew_add_op.export_weights())
        weights["bias"] = self.export_native_bias()
        weights.update(self.act_op.export_weights())
        return weights

    def _export_layer_metadata(self):
        export_vals = super()._export_layer_metadata()
        if self.transparent:
            export_vals["transparent"] = self.transparent
        return export_vals

    def _import_layer_metadata(self, npz):
        self.transparent = npz.get("transparent", False)
        return super()._import_layer_metadata(npz)

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups

        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        self.ew_add_op.create_weight_quant_element(kernel_bits, signed)
        self.bias_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

        # set quantization groups
        self.act_op.set_quantization_groups(quant_groups)

    def get_equalization_handler_type(self, predecessor_index=None):
        is_source = not self.transparent
        if predecessor_index == 0 and self.transparent:
            return EquivClassification(LayerHandlerType.transparent, is_source=is_source)
        return EquivClassification(LayerHandlerType.consumer, is_source=is_source)

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.multi_source, is_source=False)

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return {
                PrecisionMode.a8_w8_a8,
                PrecisionMode.a8_w8_a16,
                PrecisionMode.a8_w8,
                PrecisionMode.a16_w16_a16,
                PrecisionMode.a16_w16,
            }
        else:
            return super()._get_precision_mode_supported_in_hw(arch)

    def _get_bias_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return {BiasMode.double_scale_initialization}
        else:
            return super()._get_bias_mode_supported_in_hw(arch)

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch):
        return True  # currently no limitations, assuming bias_mode only supports double scale initialization

    def _get_kernel_bits(self):
        return self.ew_add_op.weight_lossy_elements.factor.bits

    def get_kernel_scale_matrix_component(self):
        # TODO if we want to do equalization for
        return self.ew_add_op.calc_kernel_scale_from_io()

    def get_kernel(self):
        return self.ew_add_op.kernel

    def define_encodings(self, flow):
        super().define_encodings(flow)

    def define_constraints(self, enc):
        super().define_constraints(enc)
        if not (self.ew_add_op.encoding_const and self.bias_op.encoding_const):
            enc.identity(f"{self.ew_add_op.full_name}/mac_shift:0", f"{self.bias_op.full_name}/mac_shift:0")

    def get_macs(self):
        _, high, width, chanel_in = self.input_shapes[0]
        macs = high * width * chanel_in * 8  # inefficiency factor
        return macs

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization

    def verify_layer_inputs_shape(self, input_shapes):
        if len(input_shapes) == 2:
            # factorizes the input shapes according to the input repeats
            factored_shapes = []
            for input_shape, input_repeat in zip(input_shapes, self.input_repeats):
                factored_shapes += [[dim * ratio for dim, ratio in zip(input_shape[1:], input_repeat)]]
            if not all(x == y for x, y in zip(*factored_shapes)):
                raise InvalidInputShape(
                    f"Input shapes {input_shapes} doesn't match each other in {self.full_name}", self.full_name
                )
