from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_layer import TensorInitializer
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import BaseQuantElement, MACDataQuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import MAX_NUM_REPEATS_ELTWISE
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasDecompositionError,
    AccelerasElementwiseDecompositionError,
    AccelerasImportParamConfigMismatch,
    AccelerasNumerizationError,
    AccelerasPrematureQuantOperation,
)
from hailo_model_optimization.acceleras.utils.opt_utils import int_smallnum_factorize, uint_smallnum_factorize


@dataclass
class ELWAWeightsLossy(BaseWeightLossyElements):
    factor: BaseLossyElement


class ElementwiseAddOp(BaseAtomicOp):
    """
    The elementwise-add part of Conv&Add:
    The 2nd (L3) input is multiplied by a constant, then added to the 1st (L1/L2) input,
    to arrive at a L1/L2 ("accumulator") result.
    """

    weight_lossy_elements: ELWAWeightsLossy

    num_inputs = 2
    num_outputs = 1

    # Debug Tensors, created after build
    _desired_factor: tf.Tensor
    _zp_change: tf.Tensor
    _total_factor: tf.Tensor

    def __init__(self, name: str, ew_add_factor=1, logger=None, fully_native=None, bit_exact=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, bit_exact=bit_exact, **kwargs)
        self.weight_lossy_elements = ELWAWeightsLossy(
            factor=IdentityElement(name=f"{self.full_name}/ie:element_wise_add")
        )
        self.pre_acc_shift = tf.constant(0, self.FLOAT_TYPE_TF)
        self.elwa_factor = None  # created in infer_encodings
        self.elwa_feed_repeat = 1
        # TODO we need to load this params from alls config
        self._max_feed_repeat = MAX_NUM_REPEATS_ELTWISE
        self.shift_delta = 0
        self._total_factor = 1
        self._ew_add_factor = ew_add_factor

    def create_weight_quant_element(self, factor_bits=8, signed=True):
        self.weight_lossy_elements = ELWAWeightsLossy(
            factor=MACDataQuantElement(bits=factor_bits, signed=signed, name=f"{self.full_name}/qe:factor"),
        )

    def create_hw_params(self, *args, **kwargs):
        self.elwa_feed_repeat = self._compute_feed_repeat()

    def compute_output_zp(self, zp_change=None):
        if zp_change is None:
            zp_change = self._zp_change
        return self.input_zero_points[0] + zp_change

    def enforce_encoding(self, training=False):
        """
        NOTE1: should seamlessly support vectorized elwa-factor,
            obtained by channelwise division of vectorized scales.

        NOTE2: the #feed-repeats is NOT updated here, computed once in the
                "create_numerization" stage by a non-differentiable method.
                This creates an assumption that the changes during fine-tuning algos
                won't be large enough to the current feed-repeat unworkable.
        """
        desired_factors = self._get_desired_factors()
        elwa_factor = self.weight_lossy_elements.factor(desired_factors / self.elwa_feed_repeat, training=training)
        zp_change = self.input_zero_points[1] * elwa_factor * self.elwa_feed_repeat / 2**self.pre_acc_shift
        self.output_zero_point = self.compute_output_zp(zp_change)

        total_factor = elwa_factor * self.elwa_feed_repeat / 2**self.pre_acc_shift
        self._total_factor = total_factor
        self.elwa_factor = elwa_factor
        self._zp_change = zp_change
        self.kernel_zero_point = np.float32(0)

    def set_max_feed_repeat(self, max_feed_repeat):
        if max_feed_repeat is not None:
            self._max_feed_repeat = max_feed_repeat

    @property
    def max_feed_repeat(self):
        return self._max_feed_repeat

    def _compute_feed_repeat(self):
        """
        decompose the total factors into integers multiplication of factor, and feed repeat.
        The inputs_factor is an int and the repeats are
        max_elementwise_feed_repeat or smaller.

        """
        desired_factors = tf.round(self._get_desired_factors())
        # TODO we may want to avoid rounding here, but for now we sync with legacy
        bits = self.weight_lossy_elements.factor.bits
        try:
            if bits == 15:
                _, feed_repeat_elwa = uint_smallnum_factorize(
                    desired_factors[0],
                    bits=bits,
                    maxsmallnum=self.max_feed_repeat,
                )
            else:
                _, feed_repeat_elwa = int_smallnum_factorize(
                    desired_factors[0],
                    bits=bits,
                    maxsmallnum=self.max_feed_repeat,
                )
        except AccelerasDecompositionError as e:
            raise AccelerasElementwiseDecompositionError(self.full_name, e)
        return feed_repeat_elwa

    def _get_desired_factors(self):
        """
        get the desired factors from input scales
        Returns:
        """
        desired_factor = self._ew_add_factor * (self.input_scales[1] / self.input_scales[0] * 2**self.pre_acc_shift)
        return tf.cast(desired_factor, tf.float32)

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    def call_hw_sim(self, inputs, training=False, **kwargs):
        accumulator_input, L3_input = inputs

        if not self.bit_exact:
            L3_input_total = L3_input * tf.cast(self._total_factor, L3_input.dtype)
        else:
            L3_input_total = (
                self.output_lossy_elements[0](
                    L3_input * tf.cast(self.elwa_factor, L3_input.dtype) / 2**self.pre_acc_shift,
                    training=training,
                )
                * self.elwa_feed_repeat
            )

        res = accumulator_input + L3_input_total
        return res

    def call_native(self, inputs, **kwargs):
        return inputs[0] + self._ew_add_factor * inputs[1]

    def _compute_output_shape(self, input_shape):
        return input_shape[0]

    def export_independent_params(self):
        return {
            "feed_repeat": np.array(self.elwa_feed_repeat, np.float32),
            "weight_bits": np.array(self.weight_lossy_elements.factor.bits),
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
        }

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.factor, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        kernel_bits = self.weight_lossy_elements.factor.bits
        imported_kernel_bits = params["weight_bits"]
        if kernel_bits != imported_kernel_bits:
            raise AccelerasImportParamConfigMismatch("factor_bits", kernel_bits, imported_kernel_bits, self.full_name)
        self.elwa_feed_repeat = params["feed_repeat"]
        self.pre_acc_shift = params["mac_shift"]

    def export_quant_weights(self):
        return {
            # TODO: why index 0?
            "factor": np.array(self.elwa_factor[0], np.float32),
        }

    def export_hw_params(self):
        return {
            "elementwise_addition/input_factor": np.array(self.elwa_factor[0], np.int16),
            "elementwise_addition/feed_repeat": np.array(self.elwa_feed_repeat, np.uint8),
        }

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(
            f"{self.full_name}/mac_shift:0",
            EncodingType.Scale,
            scalar=False,
            shape=(),
            quant=True,
            quant_min=tf.float32.min,
            quant_max=tf.float32.max,
        )
        flow.add_encoding(
            f"{self.full_name}/desired_factors:0",
            EncodingType.Scale,
            scalar=False,
            shape=(self.input_shapes[0][-1],),
            initializer=TensorInitializer(self._get_desired_factors(), eps=1e-5),
        )
        flow.add_encoding(
            f"{self.full_name}/factor:0", EncodingType.Scale, scalar=False, shape=(self.input_shapes[0][-1],)
        )
        flow.add_encoding(
            f"{self.full_name}/total_factor:0",
            EncodingType.Scale,
            scalar=False,
            shape=(self.input_shapes[0][-1],),
        )
        flow.add_encoding(
            f"{self.full_name}/zp_change:0",
            EncodingType.ZeroPoint,
            scalar=False,
            shape=(self.input_shapes[0][-1],),
        )

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # Compute factor (elwa_factor)
        enc.div(enc.dummy(0), f"{self.full_name}/input_scale:1", f"{self.full_name}/input_scale:0", inverse=True)
        enc.shift(f"{self.full_name}/desired_factors:0", enc.dummy(0), f"{self.full_name}/mac_shift:0")
        enc.div(enc.dummy(1), f"{self.full_name}/desired_factors:0", self.elwa_feed_repeat, inverse=True)
        enc.lossy_element(f"{self.full_name}/factor:0", enc.dummy(1), self.weight_lossy_elements.factor)

        # Compute total_factor
        enc.mul(enc.dummy(2), f"{self.full_name}/factor:0", self.elwa_feed_repeat)
        enc.shift(enc.dummy(2), f"{self.full_name}/total_factor:0", f"{self.full_name}/mac_shift:0")

        # Compute zp_change
        enc.mul(enc.dummy(3), f"{self.full_name}/input_zero_point:1", f"{self.full_name}/factor:0")
        enc.mul(enc.dummy(4), enc.dummy(3), self.elwa_feed_repeat)
        enc.shift(enc.dummy(4), f"{self.full_name}/zp_change:0", f"{self.full_name}/mac_shift:0")

        # Compute output_scale
        enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")

        # Compute output_zero_point
        enc.add(
            f"{self.full_name}/output_zero_point:0",
            f"{self.full_name}/input_zero_point:0",
            f"{self.full_name}/zp_change:0",
        )

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)
        enc.identity(f"{self.full_name}/desired_factors:0", self._get_desired_factors())
        enc.identity(f"{self.full_name}/factor:0", self.elwa_factor)
        enc.identity(f"{self.full_name}/total_factor:0", self._total_factor)
        enc.identity(f"{self.full_name}/zp_change:0", self._zp_change)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.pre_acc_shift = encodings[f"{self.full_name}/mac_shift:0"]
        self.elwa_factor = encodings[f"{self.full_name}/factor:0"]
        self._total_factor = encodings[f"{self.full_name}/total_factor:0"]
        self._zp_change = encodings[f"{self.full_name}/zp_change:0"]


class ElementwiseAddDirectOp(BaseAtomicOp):
    """
    The core part of "standalone elwa-add" layer (standalone as in "NOT conv & add").
    Represents an elementwise-addition of two L3 inputs,
    after multiplying each by a separate constant, to arrive at a L1/L2 ("accumulator") result
    """

    weight_lossy_elements: ELWAWeightsLossy

    num_inputs = 2
    num_outputs = 1

    def __init__(self, name: str, input_repeats=None, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = ELWAWeightsLossy(
            factor=IdentityElement(name=f"{self.full_name}/ie:element_wise_add")
        )
        self._kernel = tf.ones((2, 1), name="ew_factors")
        self.kernel_scale = 1
        self.kernel_zero_point = 0
        self.pre_acc_shift = tf.constant(0, dtype=self.FLOAT_TYPE_TF)
        self.input_repeats = input_repeats if input_repeats else [[1, 1, 1], [1, 1, 1]]
        self.preload_kernel = False

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "input_repeats": self.input_repeats,
                "preload_kernel": self.preload_kernel,
                "kernel_scale": self.kernel_scale,
                "kernel_zero_point": self.kernel_zero_point,
                "pre_acc_shift": self.pre_acc_shift,
                "name": self.full_name,
                "fully_native": self.fully_native,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        valid_kwargs = {
            "name": config.pop("name"),
            "input_repeats": config.pop("input_repeats"),
            "fully_native": config.pop("fully_native"),
        }
        instance = cls(**valid_kwargs)

        for key, value in config.items():
            if key in cls.__dict__:
                setattr(instance, key, value)

        return instance

    def create_weight_quant_element(self, factor_bits=8, signed=True):
        self.weight_lossy_elements = ELWAWeightsLossy(
            factor=MACDataQuantElement(bits=factor_bits, signed=signed, name=f"{self.full_name}/qe:factor"),
        )

    def _set_mac_shift(self, hw_shifts=None):
        weight_bits = self.weight_lossy_elements.factor.bits
        if weight_bits == 15:
            self.pre_acc_shift = tf.constant(0, dtype=self.FLOAT_TYPE_TF)
        else:
            self.pre_acc_shift = tf.constant(1, dtype=self.FLOAT_TYPE_TF)
        if hw_shifts is not None:
            self.pre_acc_shift = hw_shifts[0]
        self.shift_delta = 0

    def create_hw_params(
        self, output_scale_candidate=None, forced_ratio=None, target_max_kernel=None, hw_shifts=None, *args
    ):
        """
        if forced_ratio is not None it will set the ratio output_scale/input_scale = forced_ratio
        In that case, both input scales must be the same
        """
        # 16 bit quantization doesn't support activation shift and in that case it's set to zero.

        self._set_mac_shift(hw_shifts=hw_shifts)
        if forced_ratio is not None:
            if not np.allclose(self.input_scales[0], self.input_scales[1]):
                raise AccelerasNumerizationError(
                    f"for setting forced_ratio in {self.full_name}, both input scales must be the same",
                )
            self.output_scale = self.input_scales[0] * forced_ratio

            return
        factor_quant_element = self.weight_lossy_elements.factor
        if isinstance(factor_quant_element, IdentityElement):
            raise AccelerasNumerizationError(
                f"Can't quantize op {self.full_name} becuase the weight lossy element is IdentityElement",
            )

        kernel_scale_candidate = self._calc_kernel_scale(
            self.input_scale_matrix,
            tf.reshape(tf.cast(output_scale_candidate, tf.float32), [1, -1]),
            self.pre_acc_shift,
            self.kernel,
        )

        kernel_q_candidate = tf.abs(self.kernel) / kernel_scale_candidate

        target_max_kernel = factor_quant_element.max_value - 1 if target_max_kernel is None else target_max_kernel
        # we cant change the propostion beetwen channles
        propotions = tf.reduce_max(kernel_q_candidate) / target_max_kernel

        self.output_scale = output_scale_candidate * propotions

    @property
    def input_scale_matrix(self):
        if self.input_scale_is_scalar(0) and not self.input_scale_is_scalar(1):
            input_scale0 = tf.repeat(self.input_scales[0], len(self.input_scales[1]))
            input_scale_mat = tf.stack([input_scale0, self.input_scales[1]], axis=0)
        elif self.input_scale_is_scalar(1) and not self.input_scale_is_scalar(0):
            input_scale1 = tf.repeat(self.input_scales[1], len(self.input_scales[0]))
            input_scale_mat = tf.stack([self.input_scales[0], input_scale1], axis=0)
        elif self.input_scale_is_scalar(0) and self.input_scale_is_scalar(1):
            input_scale_mat = tf.stack([self.input_scales[0], self.input_scales[1]], axis=0)
        else:
            input_scales = [
                tf.repeat(input_scale, repeat[-1], axis=-1)
                for input_scale, repeat in zip(self.input_scales, self.input_repeats)
            ]
            input_scale_mat = tf.stack(input_scales, axis=0)
        return tf.reshape(tf.cast(input_scale_mat, tf.float32), [2, -1])

    @staticmethod
    def _calc_kernel_scale(input_scales, output_scale, shift, kernel):
        # enforce the max kernel to be the target value
        """
        We have to force that:
            S_in0*S_w0*2**shift=S_out
            S_in1*S_w1*2**shift=S_out

        """

        kernel_scale_candidate = tf.reshape(
            tf.cast(output_scale / tf.cast(2**shift, output_scale.dtype), kernel.dtype), [1, -1]
        ) / tf.cast(input_scales, kernel.dtype)

        return kernel_scale_candidate

    def calc_kernel_scale_from_io(self):
        return self._calc_kernel_scale(self.input_scale_matrix, self.output_scale, self.pre_acc_shift, self.kernel)

    @staticmethod
    def _calc_output_zero_point(input_zero_point0, input_zero_point1, kernel_q, pre_acc_shift):
        zp_comp0 = input_zero_point0 * kernel_q[0, :]
        zp_comp1 = input_zero_point1 * kernel_q[1, :]
        return (zp_comp0 + zp_comp1) / tf.cast(2**pre_acc_shift, zp_comp0.dtype)

    def get_vectorize_zp(self, index) -> np.ndarray:
        if len(tf.convert_to_tensor(self.input_zero_points[index]).shape) == 0:
            zp = self.input_zero_points[index]
        else:
            zp = tf.cast(
                tf.repeat(self.input_zero_points[index], self.input_repeats[index][-1]), dtype=self.FLOAT_TYPE_TF
            )
        return zp

    def enforce_encoding(self, training=False):
        # calc kernel sclae base on output_scale and input_scale
        self.kernel_scale = self.calc_kernel_scale_from_io()
        # update  ZP
        zp_0 = self.get_vectorize_zp(0)
        zp_1 = self.get_vectorize_zp(1)
        self.output_zero_point = self._calc_output_zero_point(
            zp_0,
            zp_1,
            self.get_quant_kernel(training=training),
            self.pre_acc_shift,
        )

    @property
    def kernel(self):
        features = self.input_shapes[0][-1] * self.input_repeats[0][-1]
        if self._kernel.shape[1] == features:
            return self._kernel
        return tf.repeat(self._kernel, features, axis=1)

    def import_weights(self, kernel, **kwargs):
        self.preload_kernel = True
        self._kernel = tf.cast(kernel, self.FLOAT_TYPE_TF)

    def export_weights(self):
        if self.preload_kernel:
            return {"kernel": self._kernel}
        return {}

    def get_quant_kernel(self, training=False):
        weight_bit_reducer = self.weight_lossy_elements.factor
        return tf.cast(weight_bit_reducer(self.kernel / self.kernel_scale, training=training), tf.float32)

    def _compute_output_shape(self, input_shape):
        batch = input_shape[0][0]
        return [batch, *[dim * ratio for dim, ratio in zip(input_shape[0][1:], self.input_repeats[0])]]

    def export_hw_params(self):
        kernel = self.export_quant_weights()["quant_kernel"]
        kernel = kernel.astype(np.int8) if self.weight_lossy_elements.factor.bits <= 8 else kernel.astype(np.int16)

        return {
            "kernel": kernel,
            "zp_kernel": np.array(self.kernel_zero_point, np.int32),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
        }

    def repeat_inputs(self, inputs):
        for i, repeats in enumerate(self.input_repeats):
            for dim, repeat in enumerate(repeats):
                inputs[i] = tf.repeat(inputs[i], repeat, axis=dim + 1)

    def call_native(self, inputs, **kwargs):
        self.repeat_inputs(inputs)
        if self.preload_kernel:
            return inputs[0] * self.kernel[0] + inputs[1] * self.kernel[1]
        return inputs[0] + inputs[1]

    def call_hw_sim(self, inputs, training=False, **kwargs):
        self.repeat_inputs(inputs)

        kernel_q = self.get_quant_kernel(training=training)
        shift_dtype = (inputs[0] * kernel_q[0]).dtype
        shift_val = tf.cast(
            tf.pow(tf.constant(2.0, dtype=tf.float32), tf.cast(-self.pre_acc_shift, dtype=tf.float32)), shift_dtype
        )
        in0 = inputs[0] * kernel_q[0] * shift_val
        in1 = inputs[1] * kernel_q[1] * shift_val
        if self.bit_exact:
            res0 = self.output_lossy_element(in0)
            res1 = self.output_lossy_element(in1)
        else:
            res0 = in0
            res1 = in1
        return res0 + res1

    def call_bit_exact(self, inputs, training=False, **kwargs):
        self.repeat_inputs(inputs)
        kernel_q = tf.cast(self.get_quant_kernel(training=training), self.INT_TYPE_TF)
        in0 = self.bankers_round_with_shift(
            inputs[0] * kernel_q[0], tf.cast(self.pre_acc_shift, inputs[0].dtype), signed=True
        )
        in1 = self.bankers_round_with_shift(
            inputs[1] * kernel_q[1], tf.cast(self.pre_acc_shift, inputs[1].dtype), signed=True
        )
        res0 = self.hw_simulation_by_lossy_element(in0, self.output_lossy_element)
        res1 = self.hw_simulation_by_lossy_element(in1, self.output_lossy_element)
        return res0 + res1

    def export_independent_params(self):
        weight_bits = self.weight_lossy_elements.factor.bits
        return {
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
            "shift_delta": np.array(self.shift_delta, np.float32),
            "kernel_zero_point": np.array(self.kernel_zero_point, np.float32),
            "kernel_scale": np.array(self.kernel_scale, np.float32),
            "weight_bits": np.array(weight_bits, np.float32),
        }

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.factor, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        kernel_bits = self.weight_lossy_elements.factor.bits
        imported_kernel_bits = params["weight_bits"]
        if kernel_bits != imported_kernel_bits:
            raise AccelerasImportParamConfigMismatch("factor_bits", kernel_bits, imported_kernel_bits, self.full_name)
        self.pre_acc_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.kernel_scale = params["kernel_scale"]
        self.kernel_zero_point = params["kernel_zero_point"]

    def export_quant_weights(self):
        output_features = self.output_shape[-1]
        quant_kernel = np.squeeze(np.array([np.float32(self.get_quant_kernel()).flatten()]))
        kernel_a, kernel_b = quant_kernel[:output_features], quant_kernel[output_features:]

        input_features_a = output_features // self.input_repeats[0][-1]
        input_features_b = output_features // self.input_repeats[1][-1]

        kernel_a = kernel_a.reshape([-1, input_features_a])
        kernel_b = kernel_b.reshape([-1, input_features_b])

        # #! IMOPORTTANT Verify kernel is repeated
        # assert np.all(kernel_a[0] == kernel_a) or np.all(np.isnan(kernel_a))
        # assert np.all(kernel_b[0] == kernel_b) or np.all(np.isnan(kernel_b))

        return {
            # TODO: wth is going on?
            "quant_kernel": quant_kernel,
        }

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(f"{self.full_name}/mac_shift:0", EncodingType.Scale, scalar=False, shape=())
        flow.add_encoding(
            f"{self.full_name}/kernel_scale:0",
            EncodingType.Scale,
            scalar=False,
            shape=(2, self.input_shapes[0][-1]),
        )

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)

        def _safe_stack(x, y):
            # Ensure that x and y are not KerasVariables, as tf.stack does not support them
            return tf.stack([tf.convert_to_tensor(x), tf.convert_to_tensor(y)], axis=0)

        # Compute kernel_scale
        enc.callback(
            enc.dummy("input_scale_mat"),
            [f"{self.full_name}/input_scale:0", f"{self.full_name}/input_scale:1"],
            _safe_stack,
            callback_name="tf.stack",
            outs_scalar=False,
            outs_shape=(2, self.input_shapes[0][-1]),
        )
        enc.callback(
            f"{self.full_name}/kernel_scale:0",
            [enc.dummy("input_scale_mat"), f"{self.full_name}/output_scale:0", f"{self.full_name}/mac_shift:0"],
            self._calc_kernel_scale,
            callback_name="calc_kernel_scale",
            kernel=self.kernel,
        )

        # Compute kernel_q
        enc.div(enc.dummy(1), self.kernel, f"{self.full_name}/kernel_scale:0")
        enc.lossy_element(enc.dummy(2), enc.dummy(1), self.weight_lossy_elements.factor)
        enc.cast(enc.dummy("kernel_q:0"), enc.dummy(2))

        # Compute output_zero_point
        enc.callback(
            f"{self.full_name}/output_zero_point:0",
            [
                f"{self.full_name}/input_zero_point:0",
                f"{self.full_name}/input_zero_point:1",
                enc.dummy("kernel_q:0"),
                f"{self.full_name}/mac_shift:0",
            ],
            self._calc_output_zero_point,
            callback_name="calc_output_zero_point",
        )

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)
        enc.identity(f"{self.full_name}/kernel_scale:0", self.kernel_scale)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.pre_acc_shift = encodings[f"{self.full_name}/mac_shift:0"]
        self.kernel_scale = encodings[f"{self.full_name}/kernel_scale:0"]

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
