from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.element_wise_mult_op import ElementwiseMultOp
from hailo_model_optimization.acceleras.atomic_ops.mock_conv_op import MockConvOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasNumerizationError,
    InvalidInputShape,
)
from hailo_model_optimization.acceleras.utils.export.layer_export_utils import add_sufix_to_keys
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_accumulator_bits_by_precision_mode,
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
    get_scalar_vector,
)


class HailoElementwiseMult(BaseHailoLayer):
    """
    Implement Hailo ew_mult layer,
        - takes two inputs,
        - the mac behaves as passthru + zp compensation
        - multiply the inputs in the APU
        - activation in the APU
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w8_a8,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        input_repeats=None,
        logger=None,
        **kwargs,
    ):
        self.conv_op1 = MockConvOp(f"{name}/conv_op_a", logger=logger)
        self.conv_op2 = MockConvOp(f"{name}/conv_op_b", logger=logger)
        self.bias_add_op1 = AddBiasOp.get_passthru_bias(f"{name}/bias_add_op_a", logger=logger)
        self.bias_add_op2 = AddBiasOp.get_passthru_bias(f"{name}/bias_add_op_b", logger=logger)
        self.ew_mult_op = ElementwiseMultOp(f"{name}/elementwise_mult_op", input_repeats, logger=logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        super().__init__(name=name, logger=logger, **kwargs)
        self.output_scale_scalar_dof = 1
        self.encoding_const = False
        self._mock_kernel_values = [2, 2]
        self.input_repeats = input_repeats if input_repeats else [[1, 1, 1], [1, 1, 1]]
        self._forced_output_scale_scalar_dof = None  # degree of freedom
        self._forced_shift_zp = None

    @property
    def pre_acc_shift(self):
        return self.conv_op1.pre_acc_shift

    @property
    def forced_output_scale_scalar_dof(self):
        return self._forced_output_scale_scalar_dof

    @forced_output_scale_scalar_dof.setter
    def forced_output_scale_scalar_dof(self, forced_output_scale_scalar_dof):
        self._forced_output_scale_scalar_dof = forced_output_scale_scalar_dof

    @property
    def forced_shift_zp(self):
        return self._forced_shift_zp

    @forced_shift_zp.setter
    def forced_shift_zp(self, forced_shift_zp):
        self._forced_shift_zp = forced_shift_zp

    def _export_ops_hw_params(self) -> dict:
        def fix_bias(bias_op: AddBiasOp):
            bias_params = bias_op.export_hw_params()
            bias_params["bias"] = get_scalar_vector(bias_params["bias"])
            bias_params["bias_q"] = get_scalar_vector(bias_params["bias_q"])
            return bias_params

        params = dict()

        params.update(self.act_op.export_hw_params())
        conv_keys = ("zp_kernel", "kernel")
        params.update(add_sufix_to_keys(self.conv_op1.export_hw_params(), conv_keys, "_a"))
        params.update(add_sufix_to_keys(self.conv_op2.export_hw_params(), conv_keys, "_b"))

        params.update(add_sufix_to_keys(fix_bias(self.bias_add_op1), ("bias",), "_in_a"))
        params.update(add_sufix_to_keys(fix_bias(self.bias_add_op2), ("bias",), "_in_b"))
        params.update(self.ew_mult_op.export_hw_params())
        return params

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        in2 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.conv_op1)
        layer_flow.add_node(self.conv_op2)
        layer_flow.add_node(self.bias_add_op1)
        layer_flow.add_node(self.bias_add_op2)
        layer_flow.add_node(self.ew_mult_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.conv_op1, DataPath.LAYER_IN)
        layer_flow.add_edge(in2, self.conv_op2, DataPath.LAYER_IN)
        layer_flow.add_edge(self.conv_op1, self.bias_add_op1, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.conv_op2, self.bias_add_op2, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op1, self.ew_mult_op, DataPath.DATA_MULT, input_index=0)
        layer_flow.add_edge(self.bias_add_op2, self.ew_mult_op, DataPath.DATA_MULT, input_index=1)
        layer_flow.add_edge(self.ew_mult_op, self.act_op, DataPath.POST_DATA_MULT)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)

        return layer_flow

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        Calls infer_encodings for underlying atomic ops.
        """
        # TODO: I don't like that the scales and zp are external properties, and I have to set them explicitly.
        #       Which affects the infer_encoding implicitly
        self._enforce_output_encoding()
        self.conv_op1.enforce_encoding(training=training)
        self.conv_op2.enforce_encoding(training=training)
        self.bias_add_op1.input_scales[0] = self.conv_op1.output_scale
        self.bias_add_op2.input_scales[0] = self.conv_op2.output_scale
        self.bias_add_op1.input_zero_points = [self.conv_op1.output_zero_point]
        self.bias_add_op2.input_zero_points = [self.conv_op2.output_zero_point]
        self.bias_add_op1.output_scale = self.bias_add_op1.input_scales[0]
        self.bias_add_op2.output_scale = self.bias_add_op2.input_scales[0]
        self.bias_add_op1.enforce_encoding()
        self.bias_add_op2.enforce_encoding()
        self.ew_mult_op.input_scales = [self.bias_add_op1.output_scale, self.bias_add_op2.output_scale]
        self.ew_mult_op.input_zero_points = [self.bias_add_op1.output_zero_point, self.bias_add_op2.output_zero_point]
        self.ew_mult_op.enforce_encoding()
        self.act_op.input_zero_points = [self.ew_mult_op.output_zero_point]
        self.act_op.input_scales = [self.ew_mult_op.output_scale]
        self.act_op.enforce_encoding(training=training, zp_factor=self.forced_shift_zp)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def import_weights(self, layer_params: LayerParams):
        """
        load parameters to the layer. currently, it doesn't to anything.

        Args:
            layer_params: layer's params from the npz

        """
        # TODO: Do we want to load kernel and bias values? (instead of the auto-generated values)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        return self.act_op.export_weights()

    def _export_layer_metadata(self):
        export_vals = super()._export_layer_metadata()
        if self.forced_shift_zp is not None:
            export_vals["forced_shift_zp"] = self.forced_shift_zp
        if self.forced_output_scale_scalar_dof is not None:
            export_vals["forced_output_scale_scalar_dof"] = self.forced_output_scale_scalar_dof
        return export_vals

    def _import_layer_metadata(self, npz):
        self.forced_shift_zp = npz.get("forced_shift_zp", None)
        self.forced_output_scale_scalar_dof = npz.get("forced_output_scale_scalar_dof", None)
        return super()._import_layer_metadata(npz)

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "activation": "linear",
            "input_repeats": [[1, 1, 1], [1, 1, 1]],
        }
        return dict(defaults)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        if params.get("reduce_sum_groups", None) is not None:
            raise NotImplementedError("reduce_sum_groups is not supported with ew_mult on apu")
        layer = cls(
            name=lname,
            activation=params["activation"],
            input_repeats=params["input_repeats"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    @property
    def mock_kernel_values(self):
        return self._mock_kernel_values

    @mock_kernel_values.setter
    def mock_kernel_values(self, mock_kernel_values):
        self._mock_kernel_values = mock_kernel_values

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_output_encoding()
        pre_acc_shift = hw_shifts[0] if hw_shifts is not None else hw_shifts
        self.conv_op1.create_hw_params(self.mock_kernel_values[0], pre_acc_shift=pre_acc_shift)
        self.conv_op2.create_hw_params(self.mock_kernel_values[1], pre_acc_shift=pre_acc_shift)
        self.enforce_internal_encoding()
        self.bias_add_op1.pre_acc_shift = self.conv_op1.pre_acc_shift
        self.bias_add_op2.pre_acc_shift = self.conv_op2.pre_acc_shift
        self.bias_add_op1.create_hw_params()
        self.bias_add_op2.create_hw_params()
        self.ew_mult_op.create_hw_params(optimization_target)
        self.enforce_internal_encoding()

        self.act_op.create_hw_params(self.ew_mult_op.output_scale, optimization_target, nudging=False)
        self.output_op.create_hw_params()

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget
    ):
        accumulator_bits = get_accumulator_bits_by_precision_mode(precision_config.precision_mode)
        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_config.precision_mode)
        data_mult_bits = (accumulator_bits // 2) + 1
        self.create_quant_element_by_data_path(DataPath.DATA_MULT, data_mult_bits)
        self.create_quant_element_by_data_path(DataPath.POST_DATA_MULT, accumulator_bits)
        num_decomposition = get_decomposition_count_by_bias_mode(precision_config.bias_mode)
        self.conv_op1.create_weight_quant_element(kernel_bits)
        self.conv_op2.create_weight_quant_element(kernel_bits)
        self.bias_add_op1.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.bias_add_op2.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def _force_output_scale(self):
        # set output scale to be the product of the (self.input_scales[0] * self.input_scales[1])*forced_output_scale_scalar_dof
        if self.forced_output_scale_scalar_dof is not None and self.output_scale.shape != 0:
            self.set_output_scale(
                tf.math.multiply(*self._get_repeated_input_scales()) * self.forced_output_scale_scalar_dof, 0
            )

    def _create_out_in_scale_ratio(self):
        """
        create the output_scale_scalar_dof
        """
        _out_in_scale_ratio = self.output_scale / tf.math.multiply(*self._get_repeated_input_scales())
        eps = 1e-6
        if _out_in_scale_ratio.shape != ():
            if eps < np.max(np.abs(_out_in_scale_ratio - _out_in_scale_ratio[0]) / _out_in_scale_ratio[0]):
                # Possible fail case: coming from concat, so input scale is scalar while output is vector..
                raise AccelerasNumerizationError(
                    f"output_scale - input_scale ratio of {self.full_name} should be a scalar"
                )
            # create attribute to be used in scales-training context should it come
            self.output_scale_scalar_dof = _out_in_scale_ratio[0]
        else:
            self.output_scale_scalar_dof = _out_in_scale_ratio

    def enforce_io_encoding(self, training=False, **kwargs):
        self.output_op.output_scale = (
            tf.math.multiply(*self._get_repeated_input_scales()) * self.output_scale_scalar_dof
        )
        if self.forced_shift_zp is not None:
            self.output_op.output_zero_point = np.float32(self.input_zero_points[0] * self.forced_shift_zp)

    def _get_repeated_input_scales(self):
        input_scale_a = tf.repeat(self.input_scales[0], self.input_repeats[0][-1], axis=-1)
        input_scale_b = tf.cast(
            tf.repeat(self.input_scales[1], self.input_repeats[1][-1], axis=-1), input_scale_a.dtype
        )
        return input_scale_a, input_scale_b

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def update_scale_scalar_dof(self, shift):
        output_factor = 2**shift
        self.output_scale_scalar_dof /= output_factor

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(f"{self.full_name}/output_scale_scalar_dof:0", EncodingType.Scale, scalar=False, shape=())
        flow.nodes[f"{self.act_op.full_name}/output_factor_by_group:0"]["encoding"].scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        if not (self.conv_op1.encoding_const and self.bias_add_op1.encoding_const):
            enc.identity(f"{self.conv_op1.full_name}/mac_shift:0", f"{self.bias_add_op1.full_name}/mac_shift:0")
        if not (self.conv_op2.encoding_const and self.bias_add_op2.encoding_const):
            enc.identity(f"{self.conv_op2.full_name}/mac_shift:0", f"{self.bias_add_op2.full_name}/mac_shift:0")

        if not (self.conv_op1.encoding_const and self.conv_op2.encoding_const):
            enc.identity(f"{self.conv_op1.full_name}/mac_shift:0", f"{self.conv_op2.full_name}/mac_shift:0")
            enc.identity(f"{self.conv_op1.full_name}/kernel_scale:0", f"{self.conv_op2.full_name}/kernel_scale:0")

        if self.conv_op1.encoding_const and self.conv_op2.encoding_const and self.output_op.encoding_const:
            enc.identity(f"{self.full_name}/output_scale_scalar_dof:0", self.output_scale_scalar_dof)
        else:
            enc.mul(
                enc.dummy(0), f"{self.conv_op1.full_name}/input_scale:0", f"{self.conv_op2.full_name}/input_scale:0"
            )
            enc.div(enc.dummy(1), f"{self.output_op.full_name}/output_scale:0", enc.dummy(0))
            enc.callback(
                f"{self.full_name}/output_scale_scalar_dof:0",
                enc.dummy(1),
                lambda x: x[0] if x.shape != () else x,
                outs_scalar=True,
                outs_shape=(),
            )

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.output_scale_scalar_dof = encodings[f"{self.full_name}/output_scale_scalar_dof:0"]

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return self.SUPPORTED_PRECISION_MODE - {PrecisionMode.a8_w8_a16}
        else:
            return super()._get_precision_mode_supported_in_hw(arch)

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization

    def verify_layer_inputs_shape(self, input_shapes):
        if len(input_shapes) == 2:
            # factorizes the input shapes according to the input repeats
            factored_shapes = []
            for input_shape, input_repeat in zip(input_shapes, self.input_repeats):
                factored_shapes += [[dim * ratio for dim, ratio in zip(input_shape[1:], input_repeat)]]
            if not all(x == y for x, y in zip(*factored_shapes)):
                raise InvalidInputShape(
                    f"Input shapes {input_shapes} doesn't match each other in {self.full_name}", self.full_name
                )
