from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops._misc_internals import get_tf_same_padding
from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.feature_permute_op import FeaturePermuteOp
from hailo_model_optimization.acceleras.atomic_ops.format_conversion_op import TransposeWidthFeaturesOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    OptimizationTarget,
    PaddingType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    InvalidInputShape,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class BaseHailoConv(BaseHailoLayer):
    """
    BaseHailoConv is a base class for most conv based layers (e.g. conv, depthwise, avgpool)
    that use ConvStrippedOp as their core op, then (optionally) add bias and apply an activation function;
     in HW they are implemented by the following sequence of computational elements:
     L3-->Subcluster[<->L1/2]-->APU-->L3.

     See the datapath and quantization scheme overview thru the link below:
     https://hailotech.atlassian.net/wiki/spaces/ML/pages/986185817/H8+Core+numerics+101+-+full+Conv+add+datapath

     This class covers the common Quantiation and Emulation functionality.
     Child classes should implement some methods:
         1. from_hn
         2. __init__ with suiting parameter and initializing conv_op
         3. verify_precision_config (TODO explain)
         4. import_config
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w4,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w4_a8,
        PrecisionMode.a8_w4_a16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w8_a8,
        PrecisionMode.a16_w8_a16,
        PrecisionMode.a16_w16_a16,
        PrecisionMode.a16_w4_a8,
        PrecisionMode.a16_w4_a16,
    }

    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = True

    def __init__(
        self,
        name: str,
        conv_op: ConvStrippedOp,
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        bias_initializer=None,
        trainable=True,
        transpose_output_width_features=False,
        bias_add_op=None,
        feature_shuffle_interval=None,
        logger=None,
        **kwargs,
    ):
        self.conv_op = conv_op
        if bias_add_op is not None:
            self.bias_add_op = bias_add_op
        else:
            self.bias_add_op = AddBiasOp(
                f"{name}/bias_add_op",
                bias_initializer=bias_initializer,
                trainable=trainable,
                logger=logger,
            )
        self.act_op = ActivationOp(
            f"{name}/act_op",
            activation=activation,
            logger=logger,
        )
        self.transpose_output_width_features = transpose_output_width_features
        self.feature_shuffle_interval = feature_shuffle_interval
        if feature_shuffle_interval is not None and transpose_output_width_features:
            raise AccelerasImplementationError(
                "Feature shuffle is only supported with transpose_output_width_features=True"
            )
        if transpose_output_width_features:
            # enabling output quantization even as activation is fully native...
            output_op = TransposeWidthFeaturesOp(f"{name}/output_op", logger=logger)
        elif feature_shuffle_interval is not None:
            output_op = FeaturePermuteOp(f"{name}/output_op", logger=logger)
        else:
            # enabling output quantization even as activation is fully native...
            output_op = PassthruOp(f"{name}/output_op", logger=logger)

        self.output_op = output_op
        super().__init__(name=name, logger=logger, **kwargs)
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)

        self._forced_output_factor = None  # degree of freedom
        self._nudging = False
        self._precision_split_zp = False
        self._forced_output_scale_scalar_dof = None  # degree of freedom

    @property
    def forced_output_scale_scalar_dof(self):
        return self._forced_output_scale_scalar_dof

    @forced_output_scale_scalar_dof.setter
    def forced_output_scale_scalar_dof(self, forced_output_scale_scalar_dof):
        self._forced_output_scale_scalar_dof = forced_output_scale_scalar_dof

    @property
    def precision_split_zp(self):
        return self._precision_split_zp

    @precision_split_zp.setter
    def precision_split_zp(self, value: bool):
        self._precision_split_zp = value
        self.conv_op._precision_split_zp = value
        self.bias_add_op._precision_split_zp = value
        self.conv_op._kickback_residual_shift_delta = not value

    @property
    def pre_acc_shift(self):
        return self.conv_op.pre_acc_shift

    @property
    def forced_output_factor(self):
        return self._forced_output_factor

    @forced_output_factor.setter
    def forced_output_factor(self, forced_output_factor):
        self._forced_output_factor = forced_output_factor

    @property
    def groups(self):
        return self.conv_op.groups

    @property
    def group_sizes(self):
        return self.conv_op.group_sizes

    @property
    def transpose_width_features(self):
        return isinstance(self.output_op, TransposeWidthFeaturesOp)

    def get_activation_name(self):
        return self.act_op.act_name

    def get_weights_clipping(self):
        return self.conv_op.get_weights_clipping()

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()

        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.conv_op)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.conv_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.conv_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def _build(self, input_shape):
        if self.feature_shuffle_interval is not None:
            order = np.arange(self.conv_op.kernel.shape[-1])
            order_reshaped = order.reshape(self.feature_shuffle_interval, -1)
            order_t = order_reshaped.T
            order_flattened = order_t.flatten()
            self.output_op.feature_order = order_flattened

    def _change_native_kernel(self, kernel):
        return kernel

    def import_weights(self, layer_params: LayerParams, **kwargs):
        self.import_native_kernel(layer_params["kernel"], layer_params)
        self.import_native_bias(layer_params["bias"])
        self._load_activation(layer_params)

    def neg_weights(self):
        self.import_native_kernel(-self.export_native_kernel())
        self.import_native_bias(-self.export_native_bias())

    def _export_weights(self):
        kernel = self.export_native_kernel()
        bias = self.export_native_bias()
        activation_params = self._export_activation()
        dict_params = {
            "kernel": kernel,
            "bias": bias,
            "padding_const_value": self.conv_op.export_weights()["padding_const_value"],
        }
        dict_params.update(activation_params)
        return dict_params

    def _export_layer_metadata(self):
        export_vals = super()._export_layer_metadata()
        if self.forced_output_scale_scalar_dof is not None:
            export_vals["forced_output_scale_scalar_dof"] = self.forced_output_scale_scalar_dof
        if self.forced_output_factor is not None:
            export_vals["forced_output_factor"] = self.forced_output_factor
        return export_vals

    def _import_layer_metadata(self, npz):
        self.forced_output_scale_scalar_dof = npz.get("forced_output_scale_scalar_dof", None)
        self.forced_output_factor = npz.get("forced_output_factor", None)
        return super()._import_layer_metadata(npz)

    @property
    def is_changing_bias_supported(self):
        return True

    def export_native_kernel(self):
        kernel = self.conv_op.export_weights()["kernel"]
        return self._change_native_kernel(kernel)

    def export_native_bias(self):
        bias = self.bias_add_op.export_weights()
        return bias

    def _export_activation(self):
        return self.act_op.export_weights()

    def import_native_kernel(self, kernel, layer_params=None):
        self.conv_op.import_weights(kernel, layer_params)

    def import_native_bias(self, bias):
        self.bias_add_op.import_weights(bias)

    def _load_activation(self, layer_params):
        self.act_op.import_weights(layer_params)

    def get_numeric_kernel_np(self):
        # TODO remove this ins the future and use the export
        # shifts the kernel_scale that we calculate in acceleras if weights = 4-bits
        numeric_kernel = self.conv_op.final_numeric_kernel.numpy()
        return numeric_kernel * 2**self.conv_op.weight_placement_shift

    def export_bias(self):
        # TODO: needs to be removed.
        #   Currently used in toynet tests, and hn npz utils in old flow
        return self.bias_add_op._export_hw_params()

    def get_kernel_np(self):
        return self.conv_op.kernel.numpy()

    def get_bias_np(self):
        return self.bias_add_op.bias.numpy()

    def _set_accumulator_scale_into_ops(self, acc_scale):
        self.acc_scale = acc_scale
        self.conv_op.output_scale = acc_scale
        self.bias_add_op.input_scales[0] = acc_scale
        self.bias_add_op.output_scale = acc_scale

    def _accumulator_scale_from_apu(self):
        """
        Note - Accumulator scale is fully defined by output and APU params,
        we resolve it in Activation class and use for all earlier op scales.+
        """
        self.act_op.get_accumulator_scale()
        self._set_accumulator_scale_into_ops(self.act_op.input_scales[0])

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        See BaseHailoLayer.enforce_internal_encoding comment.
        Here the DOFs include in addition to I/O scales also the rescale factor in APU,
        from which the accumulator_scale and the (doubly-vectorial) kernel_scale are computed.

        This can be viewed as a "backward" calculation for "multiplicative" (scales) part;
         then followed by a normal forward calculation for the "additive" (zero-points) part.
        """
        if self.kernel_scale_forced_to_save:
            self.enforce_forward_encoding(training=training, **kwargs)
        else:
            self._enforce_output_encoding()
            self._accumulator_scale_from_apu()
            self.conv_op.enforce_encoding(training=training)

            self.bias_add_op.input_zero_points = [self.conv_op.output_zero_point]
            self.bias_add_op.enforce_encoding()
            self.act_op.input_zero_points = [self.bias_add_op.output_zero_point]
            self.act_op.enforce_encoding(training=training)

    def enforce_forward_encoding(self, training=False, **kwargs):
        self.conv_op.enforce_encoding(training=training)
        self.bias_add_op.input_scales = self.bias_add_op.output_scales = [self.conv_op.output_scale]
        self.bias_add_op.input_zero_points = [self.conv_op.output_zero_point]
        self.bias_add_op.enforce_encoding()

        self.act_op.input_scales = [self.bias_add_op.output_scale]
        self.act_op.input_zero_points = [self.bias_add_op.output_zero_point]
        self._enforce_output_encoding()  # for encoding
        self.act_op.enforce_encoding(training=training)  # this needs input and putput scale

    def fast_enforce_internal_encoding(self, training=False, **kwargs):
        if not self.kernel_scale_forced_to_save:
            self.conv_op.output_zero_point = self.conv_op.compute_output_zp(training=training)
            self.bias_add_op.input_zero_points = [self.conv_op.output_zero_point]

    def get_forced_accumulator_scale(self):
        return self.conv_op.get_forced_accumulator_scale()

    def _force_output_scale(self):
        """
        1. self.forced_output_factor : force the output_factor output scale to be the same as the accumulator scale/self.forced_output_factor
        2. if self._nudging is True, nudge the output scale
        """
        non_trivial_scales = (
            self.output_scales[0].shape != ()
            and not np.all(self.output_scales[0] == 1)
            and self.input_scales[0].shape != ()
            and not np.all(self.input_scales[0] == 1)
        )
        if self.forced_output_factor is not None and non_trivial_scales:
            accumulater_scale = self.get_forced_accumulator_scale()
            self.set_output_scale(accumulater_scale / self.forced_output_factor, 0)

    def update_scale_scalar_dof(self, shift):
        """
        this function keeps the output_scale the same even so the input_scale is changed - and then there is no degreation.
        """
        output_factor = 2**shift
        if self.forced_output_factor is not None:
            self.forced_output_factor *= output_factor

    def enforce_io_encoding(self, training=False, **kwargs):
        if self.forced_output_factor is not None:
            self.output_op.output_scale = self.get_forced_accumulator_scale() / self.forced_output_factor
        elif self.forced_output_scale_scalar_dof is not None and self.output_scale.shape != ():
            self.set_output_scale(self.input_scales[0] * self.forced_output_scale_scalar_dof, 0)
            self.set_output_zero_point(self.input_zero_points[0], 0)

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op._tracker.locked = False
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]
        self.act_op._tracker.locked = True

    @property
    def kernel_scale_forced_to_save(self):
        return self.conv_op.kernel_scale_forced_to_save

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        """
        WIP.
        Implementing basic scalar case.

        NOTE: the create_hw_params() methods of atomic_ops usually create "candidates" for
              the actual numerization params, only finalizing "independent" params,
               to later be used in finalization as performed in "infer_encodings".

              The comments below try to carefully specify what and where is finalized.
        """
        # TODO to think about here - we generally wanted to only use scale&zp “candidates” and not
        #  limvals at this point (after they been consumed for all layers&ops in model.create_io_encoding_candidates)

        self._enforce_output_encoding()
        pre_act_stats = self.get_preact_stats()[0]
        max_final_accumulator_by_channel = np.maximum(
            np.abs(pre_act_stats.min),
            np.abs(pre_act_stats.max),
            dtype=np.float32,
        )
        kernel_scale_matrix_component = self.get_kernel_scale_matrix_component()

        self.conv_op.create_hw_params(
            max_final_accumulator_by_channel,
            weights_clipping,
            optimization_target,
            kernel_scale_matrix_component=kernel_scale_matrix_component,
            hw_shifts=hw_shifts,
        )

        self.bias_add_op.pre_acc_shift = self.conv_op.pre_acc_shift

        # From accumulator scale candidate, create the "ideal" output factor (*finalized*).
        nudging = (
            not (self.kernel_scale_forced_to_save) and not self.conv_op.set_scale_by_kernel_only
        )  # dont nudge id the kernel_q_forced is True
        self.act_op.create_hw_params(self.conv_op.accumulator_scale_candidate, optimization_target, nudging=nudging)
        self.enforce_internal_encoding()
        # This MOSTLY finalizes the "independent" params, so rest of job can be done by infer_encodings():
        # The "ideal" output factor leads to the "numeric" (M/E decomposed, @slope=1) output factor,
        # and then the *finalized* (aka, "mantissa-nudged") accumulator scale, followed by kernel scales (L&R),
        # and then all intemediate zero-points, and then final APU parameters.
        # An exception is bias (& elwa, see subclas..) decompositions
        self._create_hw_params_finalize()
        self._has_hw_params = True

    def _create_hw_params_finalize(self):
        self.enforce_internal_encoding()
        # One last thing is bias decomposition (if any)
        self.bias_add_op.create_hw_params()

    @property
    def kernel(self):
        return self.conv_op.kernel

    @property
    def bias(self):
        return self.bias_add_op.bias

    @property
    def output_factors(self):
        return self.act_op.output_factor_by_group

    @classmethod
    def _validate_elwa(cls, elwa_value):
        if elwa_value:
            raise ValueError(
                f"elementwise_add value was {elwa_value}, "
                f"but expected {not elwa_value} in {cls.__name__} initialization",
            )

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "strides": [1, 1, 1, 1],
            "dilations": [1, 1, 1, 1],
            "padding": "SAME",
            "activation": "linear",
            "groups": 1,
            "elementwise_add": False,
        }
        return dict(defaults)

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups
        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        self.conv_op.create_weight_quant_element(kernel_bits, signed)
        self.bias_add_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

        # set quantization groups
        self.conv_op.quantization_groups_num = quant_groups
        self.act_op.set_quantization_groups(quant_groups)

    def get_kernel(self):
        return self.conv_op.kernel

    def get_kernel_scale_matrix_component(self):
        self._enforce_input_encoding()
        self._enforce_output_encoding()
        output_scale = self.act_op.output_scale / 2**self._negative_slope_exponent_fix_shift
        return self.conv_op.calc_kernel_scale(self.conv_op.input_scales, output_scale, 0)

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch):
        is_supported = super()._is_precision_config_supported(precision_mode, bias_mode, arch)
        if not is_supported:
            return False
        if (
            precision_mode in {PrecisionMode.a8_w4_a16, PrecisionMode.a8_w8_a16}
            and self.transpose_output_width_features
        ):
            return False
        return True

    def _supported_quantization_groups_hw(self, quantization_groups, arch):
        if 1 <= quantization_groups <= 4:
            return True
        if (quantization_groups == self.activation_atomic_op.num_of_channels) and (
            arch in {OptimizationTarget.MERCURY, OptimizationTarget.PLUTO, OptimizationTarget.MARS}
        ):
            return True
        return False

    def _get_kernel_bits(self):
        return self.conv_op.weight_lossy_elements.kernel.bits

    def get_encoding_flow(self):
        if self.transpose_width_features:
            self._logger.warning(
                f"get_encoding_flow for base_hailo_conv {self.full_name} with transpose_width_features is't supported yet.",
            )
            self.encoding_const = True
        return super().get_encoding_flow()

    def define_constraints(self, enc):
        super().define_constraints(enc)
        if not (self.conv_op.encoding_const and self.bias_add_op.encoding_const):
            enc.identity(f"{self.conv_op.full_name}/mac_shift:0", f"{self.bias_add_op.full_name}/mac_shift:0")

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self._tracker.locked = False
        self.acc_scale = encodings[f"{self.act_op.full_name}/input_scale:0"]
        self._tracker.locked = True

    def verify_layer_inputs_shape(self, input_shapes):
        # slicing for the case of conv3d
        kernel_size = self.conv_op.kernel_size[:2]
        strides = self.conv_op.strides[:2]
        dilations = self.conv_op.dilation_rate[:2]
        h_in = input_shapes[0][1]
        w_in = input_shapes[0][2]
        if self.conv_op.padding != PaddingType.VALID:
            pads = get_tf_same_padding(dilations, h_in, w_in, kernel_size[0], kernel_size[1], strides[0], strides[1])
            # pads = pad_beg_h, pad_end_h, pad_beg_w, pad_end_w
            h_in = h_in + pads[0] + pads[1]
            w_in = w_in + pads[2] + pads[3]

        h_out = (h_in - dilations[0] * (kernel_size[0] - 1) - 1) / strides[0] + 1
        w_out = (w_in - dilations[1] * (kernel_size[1] - 1) - 1) / strides[1] + 1
        if np.floor(h_out) <= 0 or np.floor(w_out) <= 0:
            raise InvalidInputShape(
                f"Input shapes {input_shapes} doesn't match layer's parameters in {self.full_name}",
                self.full_name,
            )

    def get_macs(self) -> int:
        """
        This is an aproximated number of the operations
        of the convolution: here is missing the pads
        """
        _, high, widht, _ = self.input_shapes[0]
        kh, kw, kc, n_kernels = self.get_kernel().numpy().shape
        s_h, s_w = self.conv_op.strides
        macs = kh * kw * kc * (high // s_h) * (widht // s_w) * n_kernels
        return macs

    def enable_force_pruning(self):
        self.conv_op.enable_force_pruning()

    def disable_force_pruning(self):
        self.conv_op.disable_force_pruning()

    @classmethod
    def get_default_bias_mode(cls):
        return BiasMode.double_scale_initialization

    def get_nudged_kernel(self, kernel):
        return self.conv_op.get_nudged_kernel(kernel)
