from dataclasses import dataclass
from typing import Union

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import InputSpec
from tensorflow.python.eager import context
from tensorflow.python.keras.utils import conv_utils

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_layer import TensorInitializer
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.prune_element import PruneElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    AdaRoundQuantElement,
    BaseQuantElement,
    MACDataQuantElement,
    QuantElement,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    SHIFT_OPTIONS_4BIT,
    WEIGHTS_PLACEMENT_SHIFT,
    ZP_LOW_SPLIT_PRECISION_PIXEL,
    CalcKernelMode,
    OptimizationTarget,
    PaddingType,
    StrideAlignType,
    WeightsClippingMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasImportParamConfigMismatch,
    AccelerasInitializationError,
    AccelerasNumerizationError,
    AccelerasPrematureQuantOperation,
)
from hailo_model_optimization.acceleras.utils.opt_utils import (
    bankers_round_int_shift,
    calculate_shifts,
    limvals_to_zp_scale,
    mmse,
)
from hailo_model_optimization.acceleras.utils.padding_utils import get_deconv_padding, handle_padding
from hailo_model_optimization.acceleras.utils.quantization_group_utils import get_quantization_groups_info

dict_map = {CalcKernelMode.kernel_vals: 1, CalcKernelMode.limvals: 2}
reversted_dict_map = {v: k for k, v in dict_map.items()}


@dataclass
class ConvWeightsLossy(BaseWeightLossyElements):
    kernel: BaseLossyElement
    kernel_prune: BaseLossyElement


class ConvStrippedOp(BaseAtomicOp):
    """
    Stripped down convolution - w.o. activation, elw-add (but with bias) as implemented by Hailo MAC.
    NOTE 1: The output of this op is accumulator contents, with the appropriate scale
            (taking into account Mr.Acc. shift)
    NOTE 2: The output's zero-point (aka, "residue") is thus a vector of channel-dependent numbers.

    """

    supported_paddings = [PaddingType.VALID, PaddingType.SAME, PaddingType.DECONV]
    weight_lossy_elements: ConvWeightsLossy

    num_inputs = 1
    num_outputs = 1

    # Debug tensors
    _padded_input: tf.Tensor
    _pre_shift_out: tf.Tensor

    BIT_EXACT_USE_HSIM = True

    # NOTE - alternative "SAME" flavor(s) reflected in stride_align and converted into SAME, externally to this class

    def __init__(
        self,
        name: str,
        kernel_size,
        filters=None,
        is_depthwise=False,
        stride_align: Union[str, StrideAlignType] = "NW",
        strides=(1, 1),
        groups=1,
        group_sizes=None,
        padding: Union[str, PaddingType] = "VALID",
        dilation_rate=(1, 1),
        kernel_initializer=None,
        trainable=True,
        vector_zp=False,
        logger=None,
        fully_native=None,
        bit_exact=None,
        force_zero_output_when_quant=False,
        spatial_flatten_output=False,
        set_scale_by_kernel_only=False,
        scale_calc_mode=CalcKernelMode.kernel_vals,
        **kwargs,
    ):
        """
        Args:
            is_depthwise: controls usage of Conv2D or Conv2DDepthwise underneath
            filters, kernel_size, strides, groups, padding, dilation_rate] : arguments forwarded to conv2d()
            kwargs: arguments forwarded to HailoAtomicOp constructor: name, logger
            kernel_initializer: use in the (future) case of training from scratch with acceleras. e.g.
            keras.glorot_uniform etc.

        """
        super().__init__(name, logger=logger, fully_native=fully_native, bit_exact=bit_exact, **kwargs)
        self.weight_lossy_elements = ConvWeightsLossy(
            kernel=IdentityElement(name=f"{self.full_name}/ie:mantissa"), kernel_prune=IdentityElement()
        )

        self.is_depthwise = is_depthwise
        self.axes2reduce = (0, 1, 3) if self.is_depthwise else (0, 1, 2)
        self.padding = PaddingType(padding)
        if self.padding not in self.supported_paddings:
            raise AccelerasInitializationError(f"Padding type {padding} is not supported in {type(self)}")
        self.padding_const_value = 0

        self.stride_align = StrideAlignType(stride_align)  # for our DIY padding

        self.pre_acc_shift = 0
        self.kernel_initializer = kernel_initializer
        self.trainable = trainable

        self.groups = groups
        self.group_sizes = group_sizes if group_sizes else [1] * groups
        if self.groups != len(self.group_sizes):
            AccelerasImplementationError(
                f"In {self.full_name}: number of groups (={self.groups}) does not match "
                f"group_sizes (={self.group_sizes})",
            )
        self.filters = filters
        """ ..some rip-off from keras/layers/convolutional.py .. """
        spatial_dims = 2
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, spatial_dims, "kernel_size")
        self.strides = conv_utils.normalize_tuple(strides, spatial_dims, "strides")  # , allow_zero=True)
        self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, spatial_dims, "dilation_rate")
        self.input_spec = InputSpec(min_ndim=spatial_dims + 2)
        self.kernel_scale = 1  # TODO reconsider to avoid masking errors.
        self.kernel = None
        self.shift_delta = 0  # TODO
        self.zp_kernel = np.float32(0)
        self.vector_zp = vector_zp
        self.out_zp_comp_groups = None
        self.feed_repeat = 1
        self.force_rounded_shift_delta = False
        self.spatial_flatten_output = spatial_flatten_output

        self.quantization_groups_num = 1
        # when the activation is for a dense layer, we need to validate the shapes of quantization groups
        self.validate_shapes = False
        self._force_zero_output_when_quant = force_zero_output_when_quant
        self._weight_limvals = None

        self.kernel_q_forced = 1
        self.kernel_scale_forced = 1
        self.kernel_scale_forced_to_save = False
        self.set_scale_by_kernel_only = set_scale_by_kernel_only
        self.scale_calc_mode = scale_calc_mode
        self._precision_split_zp = False
        self._kickback_residual_shift_delta = True

    def create_weight_quant_element(self, kernel_bits=8, signed=True):
        self.weight_lossy_elements = ConvWeightsLossy(
            kernel=MACDataQuantElement(
                bits=kernel_bits, signed=signed, symmetric=False, name=f"{self.full_name}/qe:kernel"
            ),
            kernel_prune=self.weight_lossy_elements.kernel_prune,
        )

    def _build(self, input_shape):
        """
        .. follows keras/layers/convolutional.py ..
        """
        input_channels = int(input_shape[-1])
        if self.is_depthwise:
            kernel_shape = self.kernel_size + (input_channels, 1)
        else:
            group_sizes_sum = sum(self.group_sizes)
            if input_channels % group_sizes_sum != 0:
                raise ValueError(
                    f"The number of input channels must be evenly divisible by the sum "
                    f"of group sizes. Received group_sizes={self.group_sizes},  input_shape={input_shape}",
                )
            # For asymmetric group convolution, the kernel is concatenated on the output features dimension of all its
            # convolutions. The input features dim is the maximum of input features per group - the other kernels are
            # padded with zeros.
            kernel_shape = self.kernel_size + (input_channels // group_sizes_sum * max(self.group_sizes), self.filters)
        if self.kernel not in self.weights:
            if self.kernel_initializer is None:
                if self.kernel is None:
                    kernel_initializer = tf.keras.initializers.zeros()
                else:
                    kernel_initializer = tf.keras.initializers.Constant(self.kernel)
            else:
                kernel_initializer = self.kernel_initializer
            self.kernel = self.add_weight(
                name="kernel",
                shape=kernel_shape,
                initializer=kernel_initializer,
                trainable=self.trainable,
            )

    def compute_output_zp(self, training=False):
        """
        The output's encoding asymmetry (Formerly called 'residue'),
        can be computed in compile time from the input's encoding asymmetry.
        """
        quant_kernel = self.get_quant_kernel(training=training)
        kernel = quant_kernel - tf.cast(self.zp_kernel, quant_kernel.dtype)
        if self.input_zero_point.shape == () or self.groups == 1:
            # a case when input_zp is scalar
            zp_inp_by_group = self.input_zero_point
            zp_tensor = tf.reshape(zp_inp_by_group, [1, 1, -1, 1])
        else:
            # assuming the input zero points are all the same per group
            zp_inp_by_group = tf.repeat(
                self.input_zero_point[:: self.kernel.shape[-2]],
                self.kernel.shape[-1] // self.groups,
            )
            zp_tensor = tf.reshape(zp_inp_by_group, [1, 1, 1, -1])
        numeric_kernel_summed = tf.reduce_sum(kernel * zp_tensor, axis=self.axes2reduce)
        if self._precision_split_zp:
            zp_tensor_low = tf.reshape(ZP_LOW_SPLIT_PRECISION_PIXEL, [1, 1, 1, -1])
            numeric_kernel_summed_low = tf.reduce_sum(kernel * zp_tensor_low, axis=self.axes2reduce)
            numeric_kernel_summed = tf.stack([numeric_kernel_summed_low, numeric_kernel_summed], axis=0)

        return numeric_kernel_summed / 2**self.total_rshift

    def import_weights(self, kernel, layer_params=None, **kwargs):
        # TODO rename to load_init_params, load initializer or smth like that..
        if len(kernel.shape) == 3:
            kernel = np.expand_dims(kernel, axis=-1)
        if any(self.kernel is w for w in self.weights):
            self.kernel.assign(kernel)
        else:
            self.kernel = tf.constant(kernel)
        self.kernel = tf.cast(self.kernel, tf.float32)
        if layer_params is not None:
            padding_const_value = layer_params.get("padding_const_value", self.padding_const_value)
            self.padding_const_value = np.float32(padding_const_value)

    @property
    def kernel_pruned(self):
        return self.weight_lossy_elements.kernel_prune(self.kernel)

    def get_quant_kernel(self, training=False):
        scaled_kernel = self.kernel_pruned / self.kernel_scale
        zp_dtype = scaled_kernel.dtype if hasattr(scaled_kernel, "dtype") else np.dtype(type(scaled_kernel))
        numerized_kernel = scaled_kernel + tf.cast(self.zp_kernel, zp_dtype)
        return self.weight_lossy_elements.kernel(numerized_kernel, training=training)

    @property
    def kernel_q(self):
        """Duplication of final_numeric_kernel to get more readable code"""
        return self.get_quant_kernel(training=False)

    @property
    def final_numeric_kernel(self):
        return self.get_quant_kernel(training=False)

    def split_feed_repeat(self, tensor_sum):
        return tensor_sum / self.feed_repeat

    @property
    def weight_placement_shift(self):
        kernel_lossy = self.weight_lossy_elements.kernel
        is_4bit_kernel = isinstance(kernel_lossy, (QuantElement, AdaRoundQuantElement)) and (kernel_lossy.bits == 4)
        if is_4bit_kernel:
            return WEIGHTS_PLACEMENT_SHIFT
        return 0

    @property
    def total_rshift(self):
        """
        Kernel is shifted left by weight_placement_shift before MAC,
        then Multiplication result is shifted right by pre_acc_shift before Add.
        """
        return self.pre_acc_shift - self.weight_placement_shift

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    @property
    def padding_const_value_q(self):
        input_scale = self.input_scale if self.input_scale.shape == () else self.input_scale[0]
        zp = tf.reduce_mean(self.input_zero_point)
        quantized_val = self.padding_const_value / input_scale + zp
        return self.input_lossy_element(quantized_val)

    def _get_kernel_scalar_vector_by_groups(self, kernel_numeric_up_to_scalar, weight_clip_cfg):
        """
        for each group quantization calculate the optimizal kernel scale.

        Args:
            kernel_numeric_up_to_scalar (_type_): the kernel
            weight_clip_cfg (_type_): _description_

        Returns:
            _type_: _description_

        """
        split_dim = -2 if self.is_depthwise else -1
        base_group_size, split_points, num_of_channels = get_quantization_groups_info(
            self.quantization_groups_num,
            self.kernel.shape[split_dim],
            self.validate_shapes,
        )
        kernels_scale = []
        kernels_zp = []
        for i in range(self.quantization_groups_num):
            kernel_part = tf.gather(
                kernel_numeric_up_to_scalar,
                indices=np.arange(split_points[i], split_points[i + 1]).astype("int"),
                axis=split_dim,
            )
            kernel_scale_candidate_scalar, zp_kernel = self.make_optimal_kernel_scale(
                kernel=kernel_part,
                clip_cfg=weight_clip_cfg,
            )
            kernels_scale.append(kernel_scale_candidate_scalar)
            kernels_zp.append(zp_kernel)

        scale = tf.cast(tf.repeat(kernels_scale, base_group_size), tf.float32)
        zp = kernels_zp[0]  # TODO check all the zp are the same?
        return tf.gather(scale, indices=np.arange(num_of_channels).astype("int")), zp

    def create_hw_params(
        self,
        max_native_accumulator,
        weight_clip_cfg: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
        kernel_scale_matrix_component=1,
        shift_calculate_buffer=None,
        hw_shifts=None,
        force_scale=False,
    ):
        """
        (A) Find an optimal kernel scale, minding the classic tradeoff of precision vs. dynamic-range (clipping).
            (B) Calculate the post-multiply/pre-accumulate shift-right needed to avoid accumulator wraparound
               If necessary, make numeric kernel smaller, increasing the scale (& range) from optimum found in (A).

        Args:
            max_native_accumulator (_type_):
            scale_method (str, optional): _description_. Defaults to 'auto'.
            force_kernel_scale_value (_type_, optional): _description_. Defaults to None.
            kernel_scale_matrix_component () :

        """
        # ignore the weight placement shift if 4bit. (when the shift is ignored, the shift options are 0-3)
        # It allows a simplified calculation for the 4bit enhanced in mercury
        is_4bit_kernel = self.weight_lossy_elements.kernel.bits == 4
        if is_4bit_kernel and hw_shifts is None:
            hw_shifts = SHIFT_OPTIONS_4BIT

        if self.set_scale_by_kernel_only:
            self._create_hw_params_by_kernel_only(
                shift_calculate_buffer, max_native_accumulator, hw_shifts, kernel_scale_matrix_component
            )

            return

        if self.kernel_scale_forced_to_save:
            kernel_scale_matrix_component = np.ones_like(kernel_scale_matrix_component) * self.kernel_scale_forced
            force_scale = True

        split_dim = -2 if self.is_depthwise else -1
        kernel_numeric_up_to_scalar = self.kernel / kernel_scale_matrix_component

        _, split_points, _ = get_quantization_groups_info(
            self.quantization_groups_num,
            self.kernel.shape[split_dim],
            self.validate_shapes,
        )
        accumulator_scale = []
        pre_act_shifts = []
        shift_deltas = []
        zero_points = []

        if self.quantization_groups_num == self.kernel.shape[split_dim]:
            accumulator_scale, pre_act_shifts, shift_deltas, zero_points = self._create_channel_wise_hw_params(
                max_native_accumulator,
                kernel_numeric_up_to_scalar,
                kernel_scale_matrix_component,
                weight_clip_cfg,
                optimization_target,
                shift_calculate_buffer=shift_calculate_buffer,
                hw_shifts=hw_shifts,
                force_scale=force_scale,
            )
        else:
            for i in range(self.quantization_groups_num):
                max_native_accumulator_part = max_native_accumulator[split_points[i] : split_points[i + 1]]
                kernel_numeric_up_to_scalar_part = tf.gather(
                    kernel_numeric_up_to_scalar,
                    indices=np.arange(split_points[i], split_points[i + 1]).astype("int"),
                    axis=split_dim,
                )
                kernel_scale_matrix_component_part = tf.gather(
                    kernel_scale_matrix_component,
                    indices=np.arange(split_points[i], split_points[i + 1]).astype("int"),
                    axis=split_dim,
                )
                accumulator_scale_before_shift, pre_acc_shift, shift_delta, zp = self._create_group_hw_params(
                    max_native_accumulator_part,
                    kernel_numeric_up_to_scalar_part,
                    kernel_scale_matrix_component_part,
                    weight_clip_cfg,
                    optimization_target,
                    shift_calculate_buffer=shift_calculate_buffer,
                    hw_shifts=hw_shifts,
                    force_scale=force_scale,
                )
                accumulator_scale.append(accumulator_scale_before_shift)
                pre_act_shifts.append(pre_acc_shift)
                shift_deltas.append(shift_delta)
                zero_points.append(zp)
            accumulator_scale = tf.concat(accumulator_scale, axis=-1)
        shift = np.max(pre_act_shifts)

        # (TODO) we dont support quantization groups with 16 bit kernel yet
        self.zp_kernel = zero_points[0]
        # no more use but will be exported to qnpz debug info (TODO)
        self.shift_delta = shift_deltas

        # Update the accumulator scale candidate after final shifts
        self.accumulator_scale_candidate = accumulator_scale * 2**shift

        if is_4bit_kernel:
            shift += WEIGHTS_PLACEMENT_SHIFT
        self.pre_acc_shift = shift

    def _create_hw_params_by_kernel_only(
        self, shift_calculate_buffer, max_native_accumulator, hw_shifts, kernel_scale_matrix_component
    ):
        kernel_lossy = self.weight_lossy_elements.kernel
        accumultor_size = self.output_lossy_element.bits  # get accumulator
        buffer = shift_calculate_buffer if shift_calculate_buffer else SHIFT_CALCULATE_BUFFER
        split_dim = -2 if self.is_depthwise else -1
        _, zp, kernel_scale, pre_acc_shift, shift_delta, desired_shift = [], [], [], [], [], []

        _, split_points, _ = get_quantization_groups_info(
            self.quantization_groups_num,
            self.kernel.shape[split_dim],
            self.validate_shapes,
        )

        max_native_accumulator_vector = []
        self._logger.debug(
            f"{self.full_name} max values per channel : {np.max([len(np.unique(self.kernel[0, 0, :, i])) for i in range(self.kernel.shape[-1])])}"
        )
        for i in range(self.quantization_groups_num):
            max_native_accumulator_part = np.max(max_native_accumulator[split_points[i] : split_points[i + 1]])
            kernel_part = tf.gather(
                self.kernel,
                indices=np.arange(split_points[i], split_points[i + 1]).astype("int"),
                axis=split_dim,
            )
            if self.scale_calc_mode == CalcKernelMode.limvals:
                _weight_limvals_i = np.array((np.min(kernel_part), np.max(kernel_part)))
                zp_i, kernel_scale_i, _ = limvals_to_zp_scale(
                    _weight_limvals_i, kernel_lossy, self.full_name, self._logger
                )
            elif (
                self.scale_calc_mode == CalcKernelMode.kernel_vals
            ):  # calculate scale from kernel values (currently default)
                kernel_scale_i = self.calc_scale_from_quantized_native_kernel(kernel_part, kernel_lossy.bits)
                zp_i = 0
            zp.append(zp_i)
            kernel_scale.append(kernel_scale_i)
            max_native_accumulator_vector.append(max_native_accumulator_part)

        kernel_shape = np.ones(len(kernel_scale_matrix_component.shape))
        kernel_shape[split_dim] = -1
        kernel_scale = np.array(kernel_scale).reshape(*kernel_shape.astype(np.int8))

        self.kernel_scale = np.ones_like(kernel_scale_matrix_component) * kernel_scale
        if isinstance(self.input_scale, tf.Tensor):
            self.input_scale = self.input_scale.numpy()
        acc_before_mtx = self._cacluclate_accumulator_mtx_by_groups(self.kernel_scale)
        if self.is_depthwise:
            # in the case of depthwise we need to transpose the accumulator scale from (channels, 1)
            # to (1, channels)
            acc_before_mtx = tf.transpose(acc_before_mtx)

        # Verify all rows same.
        self._verify_acc_scale(acc_before_mtx)

        acc_scale_before_shift = acc_before_mtx[0]
        expected_max_accumulator = max_native_accumulator_vector / acc_scale_before_shift

        pre_acc_shift, shift_delta, desired_shift = calculate_shifts(
            expected_max_accumulator,
            accumultor_size,
            buffer,
            force_rounded_shift_delta=self.force_rounded_shift_delta,
            hw_shifts=hw_shifts,
            return_needed_shift=True,
        )
        self.shift_delta = np.max(shift_delta)
        self._logger.debug(
            f"{self.full_name} np.unique(pre_acc_shift): {np.unique(pre_acc_shift)} max values per channel : {np.max([len(np.unique(self.kernel[0, 0, :, i])) for i in range(self.kernel.shape[-1])])}"
        )

        output_scale = acc_scale_before_shift * 2**pre_acc_shift
        self.output_scale = output_scale

        pre_acc_shift += self.weight_placement_shift
        self.pre_acc_shift = pre_acc_shift
        self.accumulator_scale_candidate = self.output_scale
        self.kernel_scale_forced_to_save = True

        kernel_q = self.kernel / kernel_scale_matrix_component
        self.kernel_q_forced = kernel_q
        self.kernel_scale_forced = self.kernel_scale

    def calc_scale_from_quantized_native_kernel(self, kernel_part, bits=None):
        unique_vals = np.unique(kernel_part)
        unique_vals = np.unique(np.append(unique_vals, 0))
        if unique_vals.size == 1:
            scale = 1.0
        else:
            scale = np.min(np.diff(unique_vals))
        if bits is not None:
            assert np.abs(unique_vals).max() / scale <= (
                2 ** (bits - 1) + 0.001
            ), f"max quant it wrong = {np.abs(unique_vals).max()/scale}"
        return scale

    def _create_group_hw_params(
        self,
        max_native_accumulator,
        kernel_numeric_up_to_scalar,
        kernel_scale_matrix_component,
        weight_clip_cfg: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
        shift_calculate_buffer,
        hw_shifts=None,
        force_scale=False,
    ):
        """
        create the hw_params for a single quantization group.

        Args:
            max_native_accumulator (_type_): _description_
            kernel_numeric_up_to_scalar (_type_): _description_
            kernel_scale_matrix_component (_type_): _description_
            weight_clip_cfg (LayerWeightsClippingConfig): _description_

        Raises:
            AccelerasNumerizationError: _description_

        Returns:
            _type_: _description_

        """
        if force_scale:
            kernel_scale_candidate_scalar = 1
            if self.weight_lossy_elements.kernel.signed:
                zp = 0
            else:
                if np.all(kernel_numeric_up_to_scalar >= 0):
                    zp = 0
                elif np.all(kernel_numeric_up_to_scalar <= 0):
                    zp = 2 ** (self.weight_lossy_elements.kernel.bits) - 1
                else:
                    raise ValueError(
                        "forced scale requires only positive or only negative kernel when 16bit mode is enabled"
                    )

        else:
            kernel_scale_candidate_scalar, zp = self.make_optimal_kernel_scale(
                kernel=kernel_numeric_up_to_scalar,
                clip_cfg=weight_clip_cfg,
            )

        kernel_scale_matrix = kernel_scale_candidate_scalar * kernel_scale_matrix_component

        # Verify the kernel_scale_matrix has the number of rows(input_features) of the kernel.
        if len(kernel_scale_matrix) != self.kernel.shape[-2]:
            raise AccelerasNumerizationError("the kernel_scale_matrix must have the same number of rows as the kernel")

        acc_before_mtx = self._cacluclate_accumulator_mtx_by_groups(kernel_scale_matrix)

        if self.is_depthwise:
            # in the case of depthwise we need to transpose the accumulator scale from (channels, 1)
            # to (1, channels)
            acc_before_mtx = tf.transpose(acc_before_mtx)

        # Verify all rows same.
        self._verify_acc_scale(acc_before_mtx)

        acc_before = acc_before_mtx[0]
        acc_scale_before_shift = acc_before
        expected_max_accumulator = np.max(max_native_accumulator / acc_scale_before_shift)
        accumultor_size = self.output_lossy_element.bits  # get accumulator
        buffer = shift_calculate_buffer if shift_calculate_buffer else SHIFT_CALCULATE_BUFFER
        pre_acc_shift, shift_delta, desired_shift = calculate_shifts(
            expected_max_accumulator,
            accumultor_size,
            buffer,
            force_rounded_shift_delta=self.force_rounded_shift_delta,
            hw_shifts=hw_shifts,
            return_needed_shift=True,
        )
        is_4bit_kernel = self.weight_lossy_elements.kernel.bits == 4
        if (
            is_4bit_kernel
            and pre_acc_shift > 0
            and desired_shift < 2
            and optimization_target not in {OptimizationTarget.SAGE}
        ):
            shift_delta = desired_shift
            pre_acc_shift = 0
            self._logger.info(
                f"Sacrificing {shift_delta:.3f} bits from the input data of layer {self.full_name}, "
                f"to enhance the kernel using 4bit weights",
            )
        elif shift_delta > 0:
            # HW can't provide a shift large enough to avoid final accumulator overflow,
            #  we need smaller numeric values by making kernel range wider
            self._logger.info(
                f"No shifts available for layer {self.full_name}, using max shift instead. delta={shift_delta:.04f}"
            )
        acc_scale_before_shift *= 2**shift_delta
        acc_scale_before_shift = np.reshape(acc_scale_before_shift, (acc_scale_before_shift.shape[-1]))
        self.desired_pre_acc_shift = desired_shift
        return acc_scale_before_shift, pre_acc_shift, shift_delta, zp

    def _create_channel_wise_hw_params(
        self,
        max_native_accumulator,
        kernel_numeric_up_to_scalar,
        kernel_scale_matrix_component,
        weight_clip_cfg: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
        shift_calculate_buffer,
        hw_shifts=None,
        force_scale=False,
    ):
        split_dim = -2 if self.is_depthwise else -1
        if force_scale:
            kernel_scale_candidate_scalar = 1
            if self.weight_lossy_elements.kernel.signed:
                zp = [0]
            else:
                if np.all(kernel_numeric_up_to_scalar >= 0):
                    zp = [0]
                elif np.all(kernel_numeric_up_to_scalar <= 0):
                    zp = [2 ** (self.weight_lossy_elements.kernel.bits) - 1]
                else:
                    raise ValueError(
                        "forced scale requires only positive or only negative kernel when 16bit mode is enabled"
                    )
        else:
            kernel_scale_candidate_scalar, zp = zip(
                *[
                    self.make_optimal_kernel_scale(
                        kernel=tf.gather(kernel_numeric_up_to_scalar, indices=[i], axis=split_dim),
                        clip_cfg=weight_clip_cfg,
                    )
                    for i in range(self.kernel.shape[split_dim])
                ],
            )
            kernel_scale_candidate_scalar = tf.expand_dims(kernel_scale_candidate_scalar, -3 - split_dim)
        kernel_scale_matrix = kernel_scale_candidate_scalar * kernel_scale_matrix_component

        # Verify the kernel_scale_matrix has the number of rows(input_features) of the kernel.
        if len(kernel_scale_matrix) != self.kernel.shape[-2]:
            raise AccelerasNumerizationError("the kernel_scale_matrix must have the same number of rows as the kernel")

        # group-convolution special treatment ():
        # we need to join the blockes of the kernel_scale ti a diagognal :
        acc_before_mtx = self._cacluclate_accumulator_mtx_by_groups(kernel_scale_matrix)

        if self.is_depthwise:
            # in the case of depthwise we need to transpose the accumulator scale from (channels, 1)
            # to (1, channels)
            acc_before_mtx = tf.transpose(acc_before_mtx)

        # Verify all rows same.
        self._verify_acc_scale(acc_before_mtx)

        acc_before = acc_before_mtx[0]
        acc_scale_before_shift = acc_before
        expected_max_accumulator = max_native_accumulator / acc_scale_before_shift
        accumultor_size = self.output_lossy_element.bits  # get accumulator
        buffer = shift_calculate_buffer if shift_calculate_buffer else SHIFT_CALCULATE_BUFFER
        pre_acc_shift, shift_delta, desired_shift = calculate_shifts(
            expected_max_accumulator,
            accumultor_size,
            buffer,
            force_rounded_shift_delta=self.force_rounded_shift_delta,
            hw_shifts=hw_shifts,
            return_needed_shift=True,
        )

        is_4bit_kernel = self.weight_lossy_elements.kernel.bits == 4
        if (
            self._kickback_residual_shift_delta
            and is_4bit_kernel
            and optimization_target not in {OptimizationTarget.SAGE}
        ):
            shift_delta = np.where(np.logical_and(pre_acc_shift > 0, desired_shift < 2), desired_shift, shift_delta)
            pre_acc_shift = np.where(np.logical_and(pre_acc_shift > 0, desired_shift < 2), 0, pre_acc_shift)
        acc_scale_before_shift *= 2**shift_delta
        acc_scale_before_shift = np.reshape(acc_scale_before_shift, (acc_scale_before_shift.shape[-1]))
        self.desired_pre_acc_shift = desired_shift
        return acc_scale_before_shift, pre_acc_shift, shift_delta, zp

    def _cacluclate_accumulator_mtx_by_groups(self, kernel_scale_matrix):
        """
        group-convolution special treatment ():
        we need to join the blockes of the kernel_scale ti a diagognal :
        """
        group_conv = self.groups
        if group_conv > 1:
            input_indices, output_indices = self.group_kernel_indices()
            blockes = []
            for g in range(group_conv):
                # kernel_scale_matrix: the first dim is of the size of the group with maximum input channels. The second
                # dim concatenates the output channels of all groups with respect to their sizes (as in output_indices).
                block = kernel_scale_matrix[:, output_indices[g] : output_indices[g + 1]]

                input_scales = tf.expand_dims(self.input_scales[0][input_indices[g] : input_indices[g + 1]], 1)
                # Using tile instead of pad to keep the kernel values the same (and avoid assertion errors)
                num_tiling_input = max(self.group_sizes) - self.group_sizes[g] + 1
                input_scales_tiled = tf.tile(input_scales, [num_tiling_input, 1])
                acc_before_block = input_scales_tiled * block
                blockes.append(acc_before_block)

            acc_before_mtx = tf.concat(blockes, axis=-1)
        else:
            acc_before_mtx = tf.cast(tf.expand_dims(self.input_scales[0], 1), dtype=tf.float32) * kernel_scale_matrix
        return acc_before_mtx

    @staticmethod
    def _verify_acc_scale(acc_mtx_part, eps=2e-3):
        mean_rows = np.mean(acc_mtx_part, axis=0, keepdims=True)
        diff = np.max(np.abs(acc_mtx_part - mean_rows) / mean_rows)
        if eps < diff:
            raise AccelerasNumerizationError(
                f"not all the rows of the kernel_scale are the same. the kernel_scale is not well defined with max diff {diff}",
            )

    def get_weights_clipping(self):
        return self._weight_limvals

    def get_forced_accumulator_scale(self):
        # this function return the forced accumulator scale that is calculated by forced kernel scale and input_scale
        if self.groups == 1 or self.input_scale.shape == ():
            input_scale = self.input_scale
        else:
            input_scale = self.input_scale[:: self.kernel.shape[-2]]
        if hasattr(self.kernel_scale_forced, "shape") and self.kernel_scale_forced.shape != ():
            output_scale = np.mean(np.expand_dims(input_scale, -1) * self.kernel_scale_forced, axis=0) * (
                2**self.pre_acc_shift
            )
        else:
            output_scale = input_scale * self.kernel_scale_forced * (2**self.pre_acc_shift)
        return output_scale

    def verification_of_vector(self, vector_to_check):
        #### verification the scales are the same for each group
        input_groups = self.kernel.shape[-2]
        for g in range(self.groups):
            vector_by_group = vector_to_check[input_groups * g : input_groups * (g + 1)]
            if not np.all(vector_by_group == vector_by_group[0]):
                raise AccelerasImplementationError(f"input_scale_by_group is not the same {vector_by_group}")

    def make_optimal_kernel_scale(
        self,
        kernel: tf.Tensor,  # TODO: we might want to change it
        clip_cfg: LayerWeightsClippingConfig,
    ):
        """
        Fine the kernel_scale based on given configuration, this function assumes the layer is in eager mode...

        Args:
            kernel (_type_, optional): _description_. Defaults to None.
            clip_cfg: weight clipping configuration

        Returns:
            kernel scale, based on the bits configured in the kernel lossy element

        """
        if not context.executing_eagerly():
            # TODO: this check should be in the create_hw_params
            raise RuntimeError(f"Tried to calculate kernel scale in graph mode in op {self.full_name}")
        kernel = kernel.numpy()
        kernel_lossy = self.weight_lossy_elements.kernel
        mode = clip_cfg.mode
        clipping_values = clip_cfg.clipping_values
        if mode == WeightsClippingMode.mmse_if4b:
            if kernel_lossy == MACDataQuantElement(4):
                mode = WeightsClippingMode.mmse
            else:
                mode = WeightsClippingMode.disabled

        if mode == WeightsClippingMode.mmse:
            # TODO need to change it to vectorized
            kernel_max = mmse(kernel, kernel_lossy.bits)  # TODO other options ("clipping API")
            if np.isnan(kernel_max):
                kernel_max = np.max(np.abs(kernel))
            kernel_min = -kernel_max
        elif mode == WeightsClippingMode.manual:
            kernel_min, kernel_max = clipping_values
        elif mode == WeightsClippingMode.disabled:
            kernel_max = np.max(kernel)
            kernel_min = np.min(kernel)
        elif mode == WeightsClippingMode.percentile:
            kernel_min, kernel_max = np.percentile(kernel, clipping_values)
        else:
            raise AccelerasImplementationError("unsupported scale mode")

        self._weight_limvals = np.array((kernel_min, kernel_max))
        zp, kernel_scale, _ = limvals_to_zp_scale(self._weight_limvals, kernel_lossy, self.full_name, self._logger)
        return kernel_scale, zp

    def _smart_floor(self, kernel_q_candidate):
        kernel_q_floor = np.floor(kernel_q_candidate)
        return np.where(kernel_q_floor == 0, 1, kernel_q_floor)

    def get_nudged_kernel(self, kernel_candidate):
        layer_clip_cfg = LayerWeightsClippingConfig.get_default()
        kernel_scale_candidate, _ = self.make_optimal_kernel_scale(kernel_candidate, layer_clip_cfg)
        kernel_q_candidate = kernel_candidate / kernel_scale_candidate
        kernel_q = self._smart_floor(kernel_q_candidate)
        kernel_nudged = kernel_q * kernel_scale_candidate
        return kernel_nudged

    def _kernel_scale_from_io_scales(self):
        """
        The kernel scale is a dependent tensor,
                    it's computed from input scale and output (aka, accumulator for this aop) scales,
                    which are imposed from the upper (composite) level.

        Returns

        """
        return self.calc_kernel_scale(self.input_scales, self.output_scale)

    def calc_kernel_scale(self, input_scales, output_scale, total_rshift=None):
        """
        Args:
            input_scales: input_scales - (input scales of layer)
            output_scale: output_scale - (aka accumulator scale)

        Returns: the kernel_scale up to scalar

        """
        if self.kernel is None:
            if self.filters is None:
                filters = self.output_shape[-1]
            else:
                filters = self.filters
            kernel_shape = [*self.kernel_size, self.input_shape[-1], filters]
            self.kernel = np.random.rand(*kernel_shape).astype("float32")
        else:
            kernel_shape = self.kernel.shape
        if total_rshift is None:
            total_rshift = self.total_rshift

        return self.calc_kernel_scale_external(
            input_scales,
            output_scale,
            total_rshift,
            kernel_shape,
            self.is_depthwise,
            self.group_sizes,
            self.group_kernel_indices(),
        )

    def group_kernel_indices(self):
        """
        Creates a list of the input and output indices of the convolution groups. This function works for the
            regular group-convolution case (i.e., all groups are of equal size) and for the asymmetric group-convolution,
            in which groups may have different sizes. Input and output group sizes must divide by the input and output base
            sizes, respectively. The function returns None if groups are not used.

            For example, assume that group_sizes=[1, 2] (2 groups) and base_output_group_size=3, then
            output_indices=[0, 3, 9]. The same principle holds for input_indices with a respective base size.

        Returns
            Tuple[List, List] or None: two lists for the input and output indices, respectively. If group-convoltion
                are not used, returns None.

        """
        if self.groups == 1:
            return None

        input_channels, output_channels = self.kernel.shape[-2:]
        group_sizes = self.group_sizes
        cumulative_group_sizes = np.cumsum(np.r_[0, group_sizes])

        if input_channels % max(group_sizes) > 0:
            raise AccelerasNumerizationError(
                "input_channels must divide by input_group_size "
                + f"(input_channels %  max(group_sizes) = {input_channels %  max(group_sizes)})",
            )
        base_input_group_size = input_channels // max(group_sizes)
        input_indices = base_input_group_size * cumulative_group_sizes

        if output_channels % sum(group_sizes) > 0:
            raise AccelerasNumerizationError(
                "output_channels must divide by sum(group_sizes) "
                + f"(output_channels % sum(group_sizes) = {output_channels % sum(group_sizes)})",
            )
        base_output_group_size = output_channels // sum(group_sizes)
        output_indices = base_output_group_size * cumulative_group_sizes

        return list(input_indices), list(output_indices)

    @staticmethod
    def calc_kernel_scale_external(
        input_scales,
        output_scale,
        total_rshift,
        kernel_shape,
        is_depthwise,
        group_sizes,
        group_kernel_indices=None,
    ):
        """
        Args:
            input_scales: input_scales - (input scales of layer)
            output_scale: output_scale - (aka accumulator scale)

        Returns: the kernel_scale up to scalar

        """
        kernel_scale_left = tf.cast(1 / input_scales[0], tf.float64)
        kernel_scale_right = tf.cast(output_scale / 2**total_rshift, tf.float64)
        if kernel_scale_left.shape == (1,):
            kernel_scale_left = kernel_scale_left[0]

        # NOTE: sanity check (at least for scalar case) how the constraint checks out:
        # 1/self.output_scale == 1/kernel_scale_right * 1/kernel_scale_left * 1/self.input_scale / 2**self.pre_acc_shift
        #  --> kernel_scale_left * kernel_scale_right =  output_scale / input_scale / 2**self.pre_acc_shift
        if kernel_scale_left.shape == ():
            # scalar, let's upgrade to vector
            kernel_scale_left = tf.repeat(kernel_scale_left, kernel_shape[-2])

        if is_depthwise:
            # (!) Feature-preserving, "vector"/channelwise kernel.
            # Channels are the same, kernel scale to be a vectorial L x R (elementwise)
            if kernel_scale_right.shape == ():
                kernel_scale_right = tf.repeat(kernel_scale_right, kernel_shape[-2])
            kernel_scale = tf.expand_dims(kernel_scale_left * kernel_scale_right, 1)
        else:
            # (!) Feature-mixing, "matrix" kernel.
            # Kernel scale is matrix: Scale[i][j] = Left[i] * Right[j]
            if kernel_scale_right.shape == ():
                kernel_scale_right = tf.repeat(kernel_scale_right, kernel_shape[-1])

            _mat_left = tf.expand_dims(kernel_scale_left, 1)
            _mat_right = tf.expand_dims(kernel_scale_right, 0)
            kernel_scale = tf.matmul(_mat_left, _mat_right)

            # Note:
            # (1) group-convolution special treatment ():
            #     we need to squash the diagognal tiles as done for kernel itself.
            # (2) Asymmetric group-convolution: groups can have different sizes.
            #     This is an expansion over the regular group-convolution case, in which
            #     all groups have the same size.
            groups = len(group_sizes)
            if groups > 1 and kernel_scale.shape != kernel_shape[-2:]:
                blocks = list()
                input_indices, output_indices = group_kernel_indices
                for g in range(groups):
                    # Add redundant values along the input channels (rank 0) for kernel_scale to enable concatenation.
                    matrix_block = kernel_scale[
                        input_indices[g] : input_indices[g + 1],
                        output_indices[g] : output_indices[g + 1],
                    ]

                    # Doing tf.tile and not tf.pad; this avoids failing over AccelerasImplementationError
                    # at other parts of the code.
                    num_tiling_input = max(group_sizes) - group_sizes[g] + 1
                    matrix_block_tiled = tf.tile(matrix_block, [num_tiling_input, 1])

                    blocks.append(matrix_block_tiled)
                kernel_scale = tf.concat(blocks, axis=-1)

        return tf.cast(kernel_scale, tf.float32)

    def enforce_encoding(self, training=False):
        """
        Here we finalize the inward scales propagation by resolving kernel scale,
        and begin the forward ZP propagation
        """
        if self.kernel_scale_forced_to_save:
            # calculate kernel scale  and accumulator from the forced value and
            filters = self.filters if self.filters is not None else self.output_shape[-1]
            kernel_shape = (
                self.kernel.shape if self.kernel is not None else [*self.kernel_size, self.input_shape[-1], filters]
            )
            self.kernel_scale = tf.cast(np.broadcast_to(self.kernel_scale_forced, kernel_shape), dtype=tf.float32)
            if not self.set_scale_by_kernel_only:
                self.output_scale = self.get_forced_accumulator_scale()
            self.output_scale = tf.cast(self.output_scale, dtype=tf.float32)
        else:
            self.kernel_scale = self._kernel_scale_from_io_scales()
        self.output_zero_point = self.compute_output_zp(training=training)
        # NOTE: final numeric kernel computation is in self.kernel_q(),
        #        but conceptually its also part of offline comp.

    def call_native(self, inputs, **kwargs):
        return self._call_conv_internal(inputs[0], self.kernel, self.padding_const_value)

    def call_hw_sim(self, inputs, training=False, **kwargs):
        """
        Simulating the core numeric functionality of Hailo MAC.
        NOTE: Applying the padding separately, the backend invocation is always with "valid"
              might be considered for moving into separate atomic op..
        NOTE: Applying both shifts:
          A.) the pre_acc_shift between mult and add
          B.) the "placement shift" for 4b weights into 8b field.

        as it's what Hailo MAC actually does
           (in contrast to (de)numerization ops which are virtual)
        """
        quant_kernel = self.get_quant_kernel(training=training)
        kernel = quant_kernel - tf.cast(self.zp_kernel, quant_kernel.dtype)
        kernel_shifted = kernel * 2**self.weight_placement_shift

        ret = self._call_conv_internal(inputs[0], kernel_shifted, self.padding_const_value_q)
        ret = ret / tf.cast(2**self.pre_acc_shift, ret.dtype)

        # For hailo_conv_decompose
        if not self.output_lossy_elements[0].is_lossless and self._force_zero_output_when_quant:
            ret = ret * 0.0

        return ret

    def call_bit_exact(self, inputs, training=False, **kwargs):
        """ """
        quant_kernel = self.get_quant_kernel(training=training)
        kernel = quant_kernel - tf.cast(self.zp_kernel, quant_kernel.dtype)
        kernel = kernel * 2**self.weight_placement_shift  # kernel_shifted

        inputs = tf.cast(inputs[0], self.INT_TYPE_TF)
        kernel = tf.cast(kernel, self.INT_TYPE_TF)
        padding_const_value_q = tf.cast(self.padding_const_value_q, self.INT_TYPE_TF)
        pre_acc_shift = tf.cast(self.pre_acc_shift, inputs.dtype)

        padded_input = handle_padding(
            inputs,
            self.padding,
            self.kernel_size,
            self.strides,
            padding_const_value_q,
            self.stride_align,
            self.dilation_rate,
        )
        if self.spatial_flatten_output:
            new_shape = [inputs.shape[0], 1, inputs.shape[1] * inputs.shape[2], inputs.shape[3]]
            padded_input = tf.reshape(padded_input, shape=new_shape)
        # depthwise_conv2d in tf is only supported with equal strides
        if self.is_depthwise and self.strides[0] == self.strides[1]:
            strides = (1, self.strides[0], self.strides[1], 1)
            output = self.lossy_depthwise_conv2d(
                padded_input,
                kernel,
                strides=strides,
                padding="VALID",
                dilations=self.dilation_rate,
                pre_acc_shift=pre_acc_shift,
                hsim=self._hsim if self.BIT_EXACT_USE_HSIM else None,
                accumulator_size=self.output_lossy_element.bits,
            )
        elif self.groups == 1 and not self.is_depthwise:
            output = self.lossy_conv2d(
                padded_input,
                kernel,
                strides=self.strides,
                dilation=self.dilation_rate,
                padding="VALID",
                pre_acc_shift=pre_acc_shift,
                hsim=self._hsim if self.BIT_EXACT_USE_HSIM else None,
                accumulator_size=self.output_lossy_element.bits,
            )
        else:
            if self.is_depthwise:
                grsize_inp = 1
                grsize_out = 1
                groups_sizes = [1] * self.kernel.shape[-2]
                kernel = tf.transpose(kernel, (0, 1, 3, 2))
            else:
                grsize_inp = self.kernel.shape[-2] // max(self.group_sizes)
                grsize_out = self.kernel.shape[-1] // sum(self.group_sizes)
                groups_sizes = self.group_sizes
            all_groups = list()
            for g, size in enumerate(groups_sizes):
                prev_group_sizes_sum = sum(groups_sizes[:g])
                group_input = padded_input[
                    ...,
                    grsize_inp * prev_group_sizes_sum : grsize_inp * (prev_group_sizes_sum + size),
                ]
                # For asymmetric group convolution, the kernel is concatenated on the output features dimension of all
                # its convolutions. The input features dim is the maximum of input features per group - the other
                # kernels are padded with zeros.
                group_kernel = kernel[
                    :,
                    :,
                    : grsize_inp * size,
                    grsize_out * prev_group_sizes_sum : grsize_out * (prev_group_sizes_sum + size),
                ]
                conv_group = self.lossy_conv2d(
                    group_input,
                    group_kernel,
                    strides=self.strides,
                    dilation=self.dilation_rate,
                    padding="VALID",
                    pre_acc_shift=pre_acc_shift,
                    hsim=self._hsim if self.BIT_EXACT_USE_HSIM else None,
                    accumulator_size=self.output_lossy_element.bits,
                )
                all_groups.append(conv_group)

            output = tf.concat(all_groups, axis=-1)

        if self.debug_mode:
            self._padded_input = padded_input
            self._output = output

        # For hailo_conv_decompose
        if not self.output_lossy_elements[0].is_lossless and self._force_zero_output_when_quant:
            output = output * 0

        return output

    @staticmethod
    def lossy_conv2d(inputs, kernel, strides, dilation, padding, pre_acc_shift=0, hsim=None, accumulator_size=None):
        """
        A fast implementation of a lossy Conv2d for the Hailo8 hardware.
            Note that both patches and pre_acc_shift must be of the same dtype.

        Args:
            inputs (_type_): _description_
            kernel (_type_): _description_
            strides (list): 2d strides, e.g., [1, 1]
            dilation (_type_): 2d dilations, e.g., [1, 1]
            padding (str): can be 'SAME' or 'VALID'
            pre_acc_shift (int, optional): pre accumulator shift (the lossy element in Hailo8). Defaults to 0.

        Returns:
            _type_: _description_

        """
        strides = [1, strides[0], strides[1], 1]
        dilation = [1, dilation[0], dilation[1], 1]
        k_height, k_width, ch_in, ch_out = kernel.shape
        if inputs.shape[-1] != ch_in:
            AccelerasImplementationError(f"num_channels ({inputs.shape[-1]}) != ch_in ({ch_in})")
        inputs_dtype = inputs.dtype
        kernel = tf.cast(kernel, inputs_dtype)

        if hsim is not None:
            hsim_op = hsim.h_conv2d_dilation if dilation[1] > 1 or dilation[2] > 1 else hsim.h_conv2d
            inputs = tf.cast(inputs, dtype=tf.int32)
            kernel = tf.cast(kernel, dtype=tf.int32)
            return hsim_op(
                inputs,
                kernel,
                tf.zeros([1], dtype=tf.int32),
                tf.cast(pre_acc_shift, dtype=tf.int8),
                strides,
                dilation,
                "VALID",
                accumulator_size=accumulator_size,
                use_fp16_acc=False,
                name="op",
            )

        kernel_flat = tf.reshape(kernel, [k_height * k_width * ch_in, ch_out])
        patches = tf.image.extract_patches(
            inputs,
            sizes=[1, k_height, k_width, 1],
            strides=strides,
            rates=dilation,
            padding=padding,
        )

        @tf.function
        def run_over_patches(patches, kernel_flat, pre_acc_shift, ch_out):
            ta = tf.TensorArray(dtype=patches.dtype, size=0, dynamic_size=True)
            for ii in tf.range(ch_out):
                mul = patches * kernel_flat[:, ii]
                mul_lossy = bankers_round_int_shift(mul, pre_acc_shift)
                ta = ta.write(ii, tf.math.reduce_sum(mul_lossy, axis=3))
            return ta.stack()

        res = run_over_patches(patches, kernel_flat, pre_acc_shift, ch_out)
        res = tf.transpose(res, perm=[1, 2, 3, 0])
        return tf.ensure_shape(res, [*res.shape[:-1], ch_out])

    @staticmethod
    def lossy_depthwise_conv2d(
        inputs, kernel, strides, padding, dilations, pre_acc_shift=0, hsim=None, accumulator_size=None
    ):
        dilation = [1, dilations[0], dilations[1], 1]
        k_height, k_width, ch_in, _ = kernel.shape
        inputs_dtype = inputs.dtype
        kernel = tf.cast(kernel, inputs_dtype)

        if hsim is not None:
            hsim_op = hsim.h_depth_wise
            kernel = tf.reshape(kernel, [k_height, k_width, 1, ch_in])
            inputs = tf.cast(inputs, dtype=tf.int32)
            kernel = tf.cast(kernel, dtype=tf.int32)
            return hsim_op(
                inputs,
                kernel,
                tf.zeros([1], dtype=tf.int32),
                tf.cast(pre_acc_shift, dtype=tf.int8),
                strides,
                dilation,
                "VALID",
                accumulator_size=accumulator_size,
                use_fp16_acc=False,
                name="op",
            )

        kernel_flat = tf.reshape(kernel, [k_height * k_width, ch_in])

        @tf.function
        def run_over_patches(inputs, kernel_flat, pre_acc_shift):
            ta = tf.TensorArray(dtype=inputs.dtype, size=0, dynamic_size=True)
            for i in range(ch_in):
                patches = tf.image.extract_patches(
                    inputs[:, :, :, i : i + 1],
                    sizes=[1, k_height, k_width, 1],
                    strides=strides,
                    rates=dilation,
                    padding=padding,
                )
                mul = patches * kernel_flat[:, i]
                mul_lossy = bankers_round_int_shift(mul, pre_acc_shift)
                ch_res = tf.reduce_sum(mul_lossy, axis=-1)
                ta = ta.write(i, ch_res)
            return ta.stack()

        res = run_over_patches(inputs, kernel_flat, pre_acc_shift)
        return tf.transpose(res, perm=(1, 2, 3, 0))

    def _call_conv_internal(self, inputs, kernel, padding_const_value):
        kernel = tf.cast(kernel, inputs.dtype)
        padded_input = handle_padding(
            inputs,
            self.padding,
            self.kernel_size,
            self.strides,
            padding_const_value,
            self.stride_align,
            self.dilation_rate,
        )
        if self.spatial_flatten_output:
            new_shape = [inputs.shape[0], 1, inputs.shape[1] * inputs.shape[2], inputs.shape[3]]
            padded_input = tf.reshape(padded_input, shape=new_shape)
        # depthwise_conv2d in tf is only supported with equal strides
        if self.is_depthwise and self.strides[0] == self.strides[1]:
            strides = (1, self.strides[0], self.strides[1], 1)
            output = tf.nn.depthwise_conv2d(
                padded_input,
                kernel,
                strides=strides,
                padding="VALID",
                dilations=self.dilation_rate,
            )
        elif self.groups == 1 and not self.is_depthwise:
            output = tf.nn.conv2d(
                padded_input,
                kernel,
                strides=self.strides,
                padding="VALID",
                dilations=self.dilation_rate,
            )
        else:
            if self.is_depthwise:
                grsize_inp = 1
                grsize_out = 1
                groups_sizes = [1] * self.kernel.shape[-2]
                kernel = tf.transpose(kernel, (0, 1, 3, 2))
            else:
                grsize_inp = self.kernel.shape[-2] // max(self.group_sizes)
                grsize_out = self.kernel.shape[-1] // sum(self.group_sizes)
                groups_sizes = self.group_sizes
            all_groups = list()
            for g, size in enumerate(groups_sizes):
                prev_group_sizes_sum = sum(groups_sizes[:g])
                group_input = padded_input[
                    ...,
                    grsize_inp * prev_group_sizes_sum : grsize_inp * (prev_group_sizes_sum + size),
                ]
                # For asymmetric group convolution, the kernel is concatenated on the output features dimension of all
                # its convolutions. The input features dim is the maximum of input features per group - the other
                # kernels are padded with zeros.
                group_kernel = kernel[
                    :,
                    :,
                    : grsize_inp * size,
                    grsize_out * prev_group_sizes_sum : grsize_out * (prev_group_sizes_sum + size),
                ]
                conv_group = tf.nn.conv2d(
                    group_input,
                    group_kernel,
                    strides=self.strides,
                    padding="VALID",
                    dilations=self.dilation_rate,
                )
                all_groups.append(conv_group)
            output = tf.concat(all_groups, axis=-1)

        if self.debug_mode:
            self._padded_input = padded_input
            self._output = output

        return output

    def _compute_output_shape(self, input_shape):
        if self.spatial_flatten_output:
            input_shape = [input_shape[0], 1, input_shape[1] * input_shape[2], input_shape[3]]
        if self.padding in {PaddingType.SAME, PaddingType.VALID}:
            h_out, w_out = self._spatial_output_shape(input_shape[1:3])
        elif self.padding == PaddingType.DECONV:
            pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = get_deconv_padding(
                self.dilation_rate,
                input_shape,
                self.kernel_size,
                self.strides,
            )
            h_inc = pad_beg_h + pad_end_h
            w_inc = pad_beg_w + pad_end_w
            h_out, w_out = self._spatial_output_shape([input_shape[1] + h_inc, input_shape[2] + w_inc], padding="valid")
        else:
            return super()._compute_output_shape(input_shape)
        f_out = input_shape[3] if self.is_depthwise else self.filters
        b_out = input_shape[0]
        shape = (b_out, h_out, w_out, f_out)
        return shape

    def _spatial_output_shape(self, spatial_input_shape, padding=None):
        if padding is None:
            padding = self.padding.value.lower()
        return [
            conv_utils.conv_output_length(
                length,
                self.kernel_size[i],
                padding=padding,
                stride=self.strides[i],
                dilation=self.dilation_rate[i],
            )
            for i, length in enumerate(spatial_input_shape)
        ]

    def export_hw_params(self):
        if self.weight_lossy_elements.kernel.bits <= 8:
            kernel = np.array(self.kernel_q * 2**self.weight_placement_shift, np.int8)
        else:
            kernel = np.array(self.kernel_q, np.int16)
        return {
            "kernel": kernel,
            "zp_kernel": np.array(self.zp_kernel, np.int32),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
            "padding_const_value": np.uint16(self.padding_const_value_q.numpy()),
        }

    def export_independent_params(self):
        return {
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
            "shift_delta": np.array(self.shift_delta, np.float32),
            "kernel_scale": np.array(self.kernel_scale, np.float32),
            "kernel_zero_point": np.array(self.zp_kernel, np.float32),
            "kernel_q_forced": np.array(self.kernel_q_forced, np.float32),
            "kernel_scale_forced": np.array(self.kernel_scale_forced, np.float32),
            "kernel_scale_forced_to_save": np.array(self.kernel_scale_forced_to_save, bool),
            "weight_bits": np.array(self.weight_lossy_elements.kernel.bits, np.float32),
            "set_scale_by_kernel_only": np.array(self.set_scale_by_kernel_only, bool),
            "scale_calc_mode": np.array(dict_map[self.scale_calc_mode], np.float32),
            "precision_split_zp": np.array(self._precision_split_zp, bool),
            "kickback_residual_shift_delta": np.array(self._kickback_residual_shift_delta, bool),
        }

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.kernel, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        kernel_bits = self.weight_lossy_elements.kernel.bits
        imported_kernel_bits = params["weight_bits"]
        if kernel_bits != imported_kernel_bits:
            raise AccelerasImportParamConfigMismatch("kernel_bits", kernel_bits, imported_kernel_bits, self.full_name)
        self.pre_acc_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.zp_kernel = params["kernel_zero_point"]
        self.kernel_q_forced = params.get("kernel_q_forced", 1)
        self.kernel_scale_forced = params.get("kernel_scale_forced", 1)
        self.kernel_scale_forced_to_save = bool(params.get("kernel_scale_forced_to_save", False))
        self.set_scale_by_kernel_only = bool(params.get("set_scale_by_kernel_only", False))
        self._precision_split_zp = bool(params.get("precision_split_zp", False))
        self._kickback_residual_shift_delta = bool(params.get("kickback_residual_shift_delta", True))
        self.scale_calc_mode = reversted_dict_map[float(params.get("scale_calc_mode", 1))]

    def export_quant_weights(self):
        kernel_q = self.final_numeric_kernel.numpy() * 2**self.weight_placement_shift  # shifted kernel
        return {
            "quant_kernel": np.float32(kernel_q),
            "padding_const_value": self.padding_const_value_q.numpy(),
        }

    def _remove_shared_params(self, params):
        params.pop("quant_kernel", None)
        params.pop("kernel_zero_point", None)
        params.pop("kernel_scale", None)
        params.pop("kernel_q_forced", None)
        params.pop("kernel_scale_forced", None)
        params.pop("kernel_scale_forced_to_save", None)
        return params

    def export_weights(self):
        return {"kernel": self.kernel.numpy(), "padding_const_value": self.padding_const_value}

    def enable_force_pruning(self):
        mask = self.kernel.numpy() != 0
        self.weight_lossy_elements.kernel_prune = PruneElement(mask)
        self.weight_lossy_elements.kernel_prune.enable()

    def disable_force_pruning(self):
        self.weight_lossy_elements.kernel_prune.disable()

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(
            f"{self.full_name}/kernel_scale:0",
            EncodingType.Scale,
            scalar=False,
            shape=self.kernel_scale.shape,
        )
        flow.add_encoding(
            f"{self.full_name}/kernel_zero_point:0",
            EncodingType.ZeroPoint,
            scalar=False,
            shape=(),
            initializer=TensorInitializer(self.zp_kernel),
            quant=True,
            quant_min=tf.float32.min,
            quant_max=tf.float32.max,
        )
        flow.add_encoding(
            f"{self.full_name}/mac_shift:0",
            EncodingType.Scale,
            scalar=False,
            shape=(),
            initializer=TensorInitializer(self.pre_acc_shift),
            quant=True,
            quant_min=1.0,
            quant_max=4.0,
        )
        flow.add_encoding(f"{self.full_name}/total_rshift:0", EncodingType.Scale, scalar=False, shape=())

        if self.padding != PaddingType.VALID:
            flow.get_encoding(f"{self.full_name}/input_zero_point:0").scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        # 16 bit quantization doesn't support activation shift and in that case it's set to zero.
        if self.output_lossy_element.bits == 32:
            enc.identity(f"{self.full_name}/mac_shift:0", np.float32(0.0))
        elif self.weight_lossy_elements.kernel.bits == 4:
            # Currently as self.weight_placement_shift depend on pre_acc_shift, we can't train pre_acc_shift
            # TODO: remove this line once we set weight_placement_shift constant. (SDK-43794)
            enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)

        enc.sub(f"{self.full_name}/total_rshift:0", f"{self.full_name}/mac_shift:0", self.weight_placement_shift)

        # compute kernel_scale
        def kernel_scale_callback(input_scale, output_scale, total_rshift, *args, **kwargs):
            return self.calc_kernel_scale_external([input_scale], output_scale, total_rshift, *args, **kwargs)

        enc.callback(
            f"{self.full_name}/kernel_scale:0",
            [f"{self.full_name}/input_scale:0", f"{self.full_name}/output_scale:0", f"{self.full_name}/total_rshift:0"],
            kernel_scale_callback,
            callback_name="mat_mul",
            outs_shape=self.kernel_scale.shape,
            kernel_shape=self.kernel.shape,
            is_depthwise=self.is_depthwise,
            group_sizes=self.group_sizes,
            group_kernel_indices=self.group_kernel_indices(),
        )

        # compute output_zero_point
        def output_zp_callback(kernel_scale, kernel_zero_point, input_zero_point, total_rshift, training=False):
            numerized_kernel = self.kernel / kernel_scale + kernel_zero_point
            final_numeric_kernel = self.weight_lossy_elements.kernel(numerized_kernel, training=training)
            kernel = final_numeric_kernel - kernel_zero_point
            zp_tensor = tf.reshape(input_zero_point, [1, 1, -1, 1])
            numeric_kernel_summed = tf.reduce_sum(kernel * zp_tensor, axis=self.axes2reduce)
            return numeric_kernel_summed / 2**total_rshift

        enc.callback(
            f"{self.full_name}/output_zero_point:0",
            [
                f"{self.full_name}/kernel_scale:0",
                f"{self.full_name}/kernel_zero_point:0",
                f"{self.full_name}/input_zero_point:0",
                f"{self.full_name}/total_rshift:0",
            ],
            output_zp_callback,
            callback_name="output_zp_callback",
            outs_shape=(self.output_shape[-1],),
        )

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/kernel_scale:0", self.kernel_scale)
        enc.identity(f"{self.full_name}/kernel_zero_point:0", self.zp_kernel)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)
        enc.identity(f"{self.full_name}/total_rshift:0", self.total_rshift)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.kernel_scale = encodings[f"{self.full_name}/kernel_scale:0"]
        self.zp_kernel = encodings[f"{self.full_name}/kernel_zero_point:0"]
        self.pre_acc_shift = encodings[f"{self.full_name}/mac_shift:0"]

    def launch_online_bias_correction(self):
        """EXPERIMENTAL - fully "in-graph", implicitly layer-by-layer,
        "batchnorm-inspired" approach to bias correction,
        compatible and conductive to training - always ensuring that quantized layer output is
        mean-corrected w.r.t to teacher, so that gradient is not dominated by first moment.

        another way to understand -
        See Krishnamoorthi for their approach to BN in QAT, drop the multiplicative part,
         replace the alignment to "beta" with alignment to dynamic reference.

        NOTE: Empirically found to be necessary for scales training.
        """

        def align_mean_to_reference(inputs, outputs, training=False):
            if not self.internal_encoding_enabled:
                in_dec = self._decode_inputs(inputs)
            else:
                in_dec = inputs
            out_dec = [self.call_native(in_dec)]

            # Don't trigger extreme wraparound by fix (comply with desired clip)
            out_enc = self._encode_outputs(out_dec)[0]
            # align_to_range either clips or applies wraparound.
            # The behavior of the reference should be the same as quant_element during traninig.
            elem: QuantElement = self.output_lossy_element
            out_enc = elem.align_to_range(out_enc, training)

            if not self.internal_decoding_enabled:
                ref_out = out_enc
            else:
                ref_out = self._decode_output([out_enc])[0]

            diff = ref_out - outputs[0]
            bias_fix = tf.reduce_mean(diff, axis=(0, 1, 2))
            return [outputs[0] + bias_fix]

        self.post_action = align_mean_to_reference

    def finalize_online_bias_correction(self):
        """EXPERIMENTAL (see launch_inline_bias_correction above)"""
        self.post_action = None

    def _encode_inputs(self, inputs):
        if not self._precision_split_zp:
            return super()._encode_inputs(inputs)

        res_even = inputs[0][:, :, ::2, :] / self.input_scale + ZP_LOW_SPLIT_PRECISION_PIXEL
        res_odd = inputs[0][:, :, 1::2, :] / self.input_scale + self.input_zero_point

        res = tf.stack([res_even, res_odd], axis=-2)
        res = tf.reshape(res, [res.shape[0], res.shape[1], 2 * res.shape[2], res.shape[4]])
        return [res]

    def _decode_output(self, outputs):
        if not self._precision_split_zp:
            return super()._decode_output(outputs)

        res_even = (outputs[0][:, :, ::2, :] - self.output_zero_point[0]) * self.output_scale
        res_odd = (outputs[0][:, :, 1::2, :] - self.output_zero_point[1]) * self.output_scale
        res = tf.stack([res_even, res_odd], axis=-2)
        res = tf.reshape(res, [res.shape[0], res.shape[1], 2 * res.shape[2], res.shape[4]])
        return [res]
