from dataclasses import dataclass
from typing import Union

import numpy as np
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.keras.utils import conv_utils

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_layer import TensorInitializer
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    AdaRoundQuantElement,
    BaseQuantElement,
    MACDataQuantElement,
    QuantElement,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    WEIGHTS_PLACEMENT_SHIFT,
    OptimizationTarget,
    PaddingType,
    StrideAlignType,
    WeightsClippingMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasImportParamConfigMismatch,
    AccelerasInitializationError,
    AccelerasNumerizationError,
    AccelerasPrematureQuantOperation,
)
from hailo_model_optimization.acceleras.utils.opt_utils import (
    calculate_shifts,
    get_scalar_vector,
    limvals_to_zp_scale,
    mmse,
)


@dataclass
class ConvWeightsLossy(BaseWeightLossyElements):
    kernel: BaseLossyElement


class Conv3DOp(BaseAtomicOp):
    """3D convolution operator built from HN 2D convolution with disparity > 1"""

    supported_paddings = [PaddingType.VALID, PaddingType.SAME]
    weight_lossy_elements: ConvWeightsLossy

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        kernel_size,
        filters=None,
        input_features=None,
        stride_align: Union[str, StrideAlignType] = "NW",
        strides=(1, 1, 1),
        groups=1,
        padding: Union[str, PaddingType] = "VALID",
        dilation_rate=(1, 1, 1),
        disparity=1,
        input_disparity=1,
        kernel_initializer=None,
        trainable=True,
        logger=None,
        fully_native=None,
        spatial_flatten_output=False,
        **kwargs,
    ):
        """
        Args:
            filters, kernel_size, strides, groups, padding, dilation_rate, disparity] : arguments forwarded to
                conv3d()
            kwargs: arguments forwarded to HailoAtomicOp constructor: name, logger
            kernel_initializer: use in the (future) case of training from scratch with acceleras. e.g.
            keras.glorot_uniform etc.

        """
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.weight_lossy_elements = ConvWeightsLossy(
            kernel=IdentityElement(name=f"{self.full_name}/ie:conv_weight_lossy")
        )
        self.padding = PaddingType(padding)
        if self.padding not in self.supported_paddings:
            raise AccelerasInitializationError(f"Padding type {padding} is not supported in {type(self)}")
        self.padding_const_value = 0

        self.stride_align = StrideAlignType(stride_align)
        self.pre_acc_shift = 0
        self.kernel_initializer = kernel_initializer
        self.trainable = trainable
        self.groups = groups

        self.filters = filters
        self.input_features = input_features
        spatial_dims = 3
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, spatial_dims, "kernel_size")
        self.strides = conv_utils.normalize_tuple(strides, spatial_dims, "strides")
        self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, spatial_dims, "dilation_rate")
        if list(self.dilation_rate) != [1, 1, 1]:
            raise AccelerasInitializationError(f"Dilation rate {self.dilation_rate} is not supported in {type(self)}")
        self.disparity = disparity
        self.input_disparity = input_disparity
        self.kernel_scale = 1
        self.kernel = None
        self.shift_delta = 0
        self.zp_kernel = 0
        self.out_zp_comp_groups = None
        self.feed_repeat = 1
        self.force_rounded_shift_delta = False
        self.spatial_flatten_output = spatial_flatten_output
        self.axes2reduce = (0, 1, 2)  # might change if the scales and zeros are to be vectors
        self.quantization_groups_num = 1
        self._weight_limvals = None

    @property
    def output_disparity(self):
        if not hasattr(self, "disparity"):
            raise AccelerasInitializationError(
                f"Op {self.full_name}: you first need set a disparity for this instance.",
            )
        output_disparity = self.input_disparity
        k_d = self.kernel_size[2] // self.input_features
        disparity_strides = self.strides[2] // self.input_features
        if self.padding == PaddingType.VALID:
            output_disparity -= k_d - 1
        return int(np.ceil(output_disparity / float(disparity_strides)))

    @property
    def padding_const_value_q(self):
        input_scale = self.input_scale if self.input_scale.shape == () else self.input_scale[0]
        zp = tf.reduce_mean(self.input_zero_points[0])
        quantized_val = self.padding_const_value / input_scale + zp
        return self.input_lossy_element(quantized_val)

    def _build(self, input_shape):
        if self.kernel_size[2] % self.groups != 0 or input_shape[-1] % self.groups != 0:
            raise AccelerasImplementationError(
                f"Layer {self.full_name}: input_features and output feature must be divisible by groups",
            )
        kernel_shape = (
            *self.kernel_size[:2],
            self.kernel_size[2] // self.groups,
            self.filters,
        )

        if self.kernel not in self.weights:
            if self.kernel_initializer is None:
                kernel_initializer = tf.keras.initializers.Constant(self.kernel)
            else:
                kernel_initializer = self.kernel_initializer
            self.kernel = self.add_weight(
                name="kernel",
                shape=kernel_shape,
                initializer=kernel_initializer,
                trainable=self.trainable,
            )

    @property
    def weight_placement_shift(self):
        kernel_lossy = self.weight_lossy_elements.kernel
        if isinstance(kernel_lossy, (QuantElement, AdaRoundQuantElement)) and kernel_lossy.bits == 4:
            return WEIGHTS_PLACEMENT_SHIFT
        return 0

    @property
    def total_rshift(self):
        """
        Kernel is shifted left by weight_placement_shift before MAC,
        then Multiplication result is shifted right by pre_acc_shift before Add.
        """
        return self.pre_acc_shift - self.weight_placement_shift

    def calc_kernel_scale(self, input_scales=None, output_scale=None, total_rshift=None):
        """
        Args:
            input_scales: input_scales - (input scales of layer)
            output_scale: output_scale - (aka accumulator scale)

        Returns: the kernel_scale up to scalar

        """
        input_scales = self.input_scales if input_scales is None else input_scales
        output_scale = self.output_scale if output_scale is None else output_scale
        kernel_shape = self.kernel.shape
        total_rshift = self.total_rshift if total_rshift is None else total_rshift

        # NOTE: only scalar scales are currently supported for conv3d
        input_scales = self.reduce_to_scalar(input_scales)
        output_scale = self.reduce_to_scalar(output_scale)

        return self.calc_kernel_scale_external(input_scales, output_scale, total_rshift, kernel_shape)

    @staticmethod
    def reduce_to_scalar(elem, rtol=1e-3, atol=1e-3):
        """
        Reduce an input element into a list of len one, as in [T];
        T is a tensor of size 1.

        Args:
            elem (_type_): Can be a list, a numpy array or a tensor.

        Raises:
            AccelerasNumerizationError: raises an error if the input elem
                contains more than one distinct value

        """
        if isinstance(elem, list):
            if len(elem) > 1:
                raise AccelerasNumerizationError(
                    "A list longer than one (i.e., len(elem) > 1) is currently" + " not supported.",
                )
            elem = elem[0]
        elem_tensor = tf.convert_to_tensor(elem)
        elem_flatten_tensor = tf.reshape(elem_tensor, [-1])

        # Verify that all values in elem_flatten_tensor are the same.
        try:
            assert_op = tf.debugging.Assert(
                tf.experimental.numpy.allclose(elem_flatten_tensor[0], elem_flatten_tensor, rtol=rtol, atol=atol),
                [""],
            )
            with tf.control_dependencies([assert_op]):
                # Truncate vector/tensor to scalar
                elem_scalar_tensor = [elem_flatten_tensor[0]]
        except tf.errors.InvalidArgumentError:
            raise AccelerasNumerizationError(
                "Cannot reduce the vector elem into a scalar because it" + " contains more than one distinct value.",
            )
        return elem_scalar_tensor

    def calc_kernel_scale_external(self, input_scales, output_scale, total_rshift, kernel_shape):
        """
        The kernel scale is a dependent tensor that is a ratio of the op input scales to output and the
        accumulator scales.

        Args:
            input_scales: input_scales - (input scales of layer)
            output_scale: output_scale - (aka accumulator scale)

        Returns: the kernel_scale up to scalar

        Notes:
        (1) As a sanity check (at least for the scalar case),
                1/self.output_scale == 1/kernel_scale_right * 1/kernel_scale_left * 1/self.input_scale / 2**self.pre_acc_shift
                --> kernel_scale_left * kernel_scale_right =  output_scale / input_scale / 2**self.pre_acc_shift
        (2) Kernel scale is defined as a scalar.

        """
        kernel_scale_left = tf.cast(1 / input_scales[0], tf.float64)

        output_scale = output_scale[0]  # a kernel scale is supported for scalar
        kernel_scale_right = tf.cast(output_scale / 2**total_rshift, tf.float64)
        kernel_scale = kernel_scale_left * kernel_scale_right
        return tf.cast(kernel_scale, tf.float32)

    @property
    def final_numeric_kernel(self):
        return self.get_quant_kernel(training=False)

    def get_quant_kernel(self, training=False):
        numerized_kernel = self.kernel / self.kernel_scale + self.zp_kernel
        return self.weight_lossy_elements.kernel(numerized_kernel, training=training)

    def compute_output_zp(self, training=False):
        """
        The output's encoding asymmetry (Formerly called 'residue'). This feature
        can be computed at compile time from the input's encoding asymmetry.
        """
        kernel = self.get_quant_kernel(training=training) - self.zp_kernel
        numeric_kernel_summed = tf.reduce_sum(kernel, axis=self.axes2reduce)
        zp = get_scalar_vector(self.input_zero_point, name=self.full_name)

        _output_zp = tf.tile(
            zp * numeric_kernel_summed / 2**self.total_rshift,
            [self.output_disparity],
        )
        return _output_zp

    def get_weights_clipping(self):
        return self._weight_limvals

    def make_optimal_kernel_scale(
        self,
        kernel: tf.Tensor,  # TODO: we might want to change it
        clip_cfg: LayerWeightsClippingConfig,
    ):
        """
        Find the kernel_scale based on given configuration. This function assumes the layer is in eager mode.

        Args:
            kernel [H x W x (D x C_in) x C_out] : the folded kernel (rank=4) for conv3d.
            clip_cfg : weight clipping configuration

        Returns:
            kernel_scale: calculation is based on the bits configured in the kernel lossy element

        """
        kernel = kernel.numpy()
        kernel_lossy = self.weight_lossy_elements.kernel
        mode = clip_cfg.mode
        clipping_values = clip_cfg.clipping_values
        if mode == WeightsClippingMode.mmse_if4b:
            if kernel_lossy == MACDataQuantElement(4):
                mode = WeightsClippingMode.mmse
            else:
                mode = WeightsClippingMode.disabled

        if mode == WeightsClippingMode.mmse:
            # TODO need to change it to vectorized
            kernel_max = mmse(kernel)  # TODO other options ("clipping API")
            if tf.math.is_nan(kernel_max):
                kernel_max = tf.math.reduce_(np.abs(kernel))
            kernel_min = -kernel_max
        elif mode == WeightsClippingMode.manual:
            kernel_min, kernel_max = clipping_values
        elif mode == WeightsClippingMode.disabled:
            kernel_max = tf.math.reduce_max(kernel)
            kernel_min = tf.math.reduce_min(kernel)
        elif mode == WeightsClippingMode.percentile:
            # kernel_min, kernel_max = np.percentile(kernel, clipping_values)
            kernel_min, kernel_max = tf.contrib.distributions.percentile(kernel, clipping_values)
        else:
            raise AccelerasImplementationError(
                f"Layer {self.full_name}: unsupported scale mode (clip_cfg.mode={clip_cfg.mode})",
            )

        self._weight_limvals = tf.convert_to_tensor((kernel_min, kernel_max))
        zp, kernel_scale, _ = limvals_to_zp_scale(self._weight_limvals, kernel_lossy, self.full_name, self._logger)
        return kernel_scale, zp

    @staticmethod
    def _verify_acc_scale(acc_mtx_part, eps=2e-3):
        acc_mtx_part_flatten = np.array(acc_mtx_part).flatten()
        if not np.allclose(acc_mtx_part_flatten, acc_mtx_part_flatten[0], rtol=0.0, atol=eps):
            raise AccelerasNumerizationError(
                "Cannot reduce the vector x into a scalar because it contains more than one distinct value.",
            )

    def create_hw_params(
        self,
        max_native_accumulator,
        weight_clip_cfg: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
        kernel_scale_matrix_component=1,
        shift_calculate_buffer=None,
        hw_shifts=None,
        force_scale=False,
    ):
        """
        (A) Find an optimal kernel scale, minding the classic tradeoff of precision vs. dynamic-range (clipping).
            (B) Calculate the post-multiply/pre-accumulate shift-right needed to avoid accumulator wraparound
               If necessary, make numeric kernel smaller, increasing the scale (& range) from optimum found in (A).

        Args:
            max_native_accumulator (_type_):
            scale_method (str, optional): _description_. Defaults to 'auto'.
            force_kernel_scale_value (_type_, optional): _description_. Defaults to None.
            kernel_scale_matrix_component () :

        """
        if not context.executing_eagerly():
            raise RuntimeError(f"Layer {self.full_name}: create_hw_params is only supported for Eager mode.")

        if force_scale:
            raise AccelerasImplementationError(f"Layer {self.full_name}: conv3d does not support force_scale=True")

        kernel_numeric_up_to_scalar = self.kernel / kernel_scale_matrix_component
        kernel_scale_candidate_scalar, zero_points = self.make_optimal_kernel_scale(
            kernel=kernel_numeric_up_to_scalar,
            clip_cfg=weight_clip_cfg,
        )

        kernel_scale_matrix = kernel_scale_candidate_scalar * kernel_scale_matrix_component
        acc_before_mtx = tf.cast(tf.expand_dims(self.input_scale, 1), dtype=tf.float32) * kernel_scale_matrix

        # Verify all rows are the same.
        self._verify_acc_scale(acc_before_mtx)

        acc_before = acc_before_mtx[0]
        acc_scale_before_shift = acc_before / 2**self.weight_placement_shift
        expected_max_accumulator = np.max(max_native_accumulator / acc_scale_before_shift)
        accumultor_size = self.output_lossy_element.bits  # get accumulator
        buffer = shift_calculate_buffer if shift_calculate_buffer else SHIFT_CALCULATE_BUFFER
        pre_acc_shift, shift_delta, desired_pre_acc_shift = calculate_shifts(
            expected_max_accumulator,
            accumultor_size,
            buffer,
            hw_shifts=hw_shifts,
            return_needed_shift=True,
        )

        if shift_delta > 0:
            # HW can't provide a shift large enough to avoid final accumulator overflow,
            #  we need smaller numeric values by making kernel range wider
            self._logger.info(
                f"No shifts available for layer {self.full_name}, using max shift instead. delta={shift_delta:.04f}"
            )
            acc_scale_before_shift *= 2**shift_delta
        self.desired_pre_acc_shift = desired_pre_acc_shift
        accumulator_scale = np.squeeze(acc_scale_before_shift)
        self.zp_kernel = zero_points
        self.shift_delta = shift_delta
        self.pre_acc_shift = pre_acc_shift

        # Update the accumulator scale candidate after the final shift
        self.accumulator_scale_candidate = tf.concat(accumulator_scale, axis=-1) * 2**pre_acc_shift

    def enforce_encoding(self, training=False):
        """
        Here we finalize the inward scales propagation by resolving kernel scale,
        and begin the forward ZP propagation
        """
        self.kernel_scale = self.calc_kernel_scale()
        self.output_zero_point = self.compute_output_zp(training=training)
        # NOTE: final numeric kernel computation is in self.final_numeric_kernel(),
        #        but conceptually its also part of offline comp.

    def create_weight_quant_element(self, kernel_bits=8, signed=True):
        self.weight_lossy_elements = ConvWeightsLossy(
            kernel=MACDataQuantElement(bits=kernel_bits, signed=signed, name=f"{self.full_name}/qe:kernel"),
        )

    def import_weights(self, kernel, layer_params=None, **kwargs):
        if self.built:
            self.kernel.assign(kernel)
        else:
            self.kernel = tf.constant(kernel)
        if layer_params is not None:
            padding_const_value = layer_params.get("padding_const_value", self.padding_const_value)
            self.padding_const_value = np.float32(padding_const_value)

    def call_native(self, inputs, kernel=None, padding_const_value=None, **kwargs):
        """
        Calculates a full-precision conv3d using the inputs and the kernel.

        Args:
            inputs (tf.Tensor): [B x H x W x (D x C_in)]
            kernel (tf.Tensor): [Kh x Kw x (Kd x C_in) x C_out]. Defaults to None.
            padding_const_value (int, optional): padding const value. Defaults to None.

        Raises:
            ValueError: _description_

        Returns:
            rank4_result: [B x H x W x (D x C_out)] the folded result for the conv3d operation.

        """
        padding_const_value = padding_const_value if padding_const_value is not None else self.padding_const_value
        kernel = self.kernel if kernel is None else kernel
        # kernel: [Kh x Kw x (Kd x C_in // groups) x C_out]
        k_d = kernel.shape[2] // (self.input_features // self.groups)

        # conv3d_kernel : [Dk x Hk x Wk x C_in x C_out]
        conv3d_kernel = tf.concat(tf.split(tf.expand_dims(kernel, axis=0), k_d, axis=3), axis=0)

        # strides: [1, Sd, Sh, Sw, 1]
        strides_3d = [self.strides[2] // self.input_features, self.strides[0], self.strides[1]]
        strides_3d_extended = [1, *strides_3d, 1]

        # Expand from 2D to 3D (from rank 4 to rank 5 tensors)
        splits = tf.split(inputs[0], self.input_disparity, axis=-1)
        rank5_splits = [tf.expand_dims(split, axis=1) for split in splits]
        # rank5_input: [B x H x W x (D x C_in)] -> [B x D x H x W x C_in]
        rank5_input = tf.concat(rank5_splits, axis=1)

        if self.padding == PaddingType.SAME:
            rank5_input = self.handle_padding_3d(
                rank5_input,
                self.padding.value,
                conv3d_kernel.shape,
                strides_3d,
                padding_const_value,
                self.stride_align.value,
                self.dilation_rate,
            )

        result = self.grouped_conv3d(rank5_input, conv3d_kernel, strides_3d_extended)

        # Back to 2 dims (rank 4 tensor)
        splits = tf.split(result, result.shape[1], axis=1)
        rank4_splits = [tf.squeeze(split, axis=1) for split in splits]
        return tf.concat(rank4_splits, axis=-1)  # [B x H x W x (D x C_in)]

    def grouped_conv3d(self, rank5_input, conv3d_kernel, strides):
        """
        Perform grouped 3D convolution using tf.nn.conv3d.

        """
        input_groups = tf.split(rank5_input, self.groups, axis=-1)
        kernel_groups = tf.split(conv3d_kernel, self.groups, axis=-1)

        output_groups = [
            tf.nn.conv3d(input_group, kernel_group, strides=strides, padding="VALID")
            for input_group, kernel_group in zip(input_groups, kernel_groups)
        ]

        return tf.concat(output_groups, axis=-1)

    def call_hw_sim(self, inputs, training=False, **kwargs):
        """
        Simulating the core numeric functionality of Hailo MAC.
        NOTE: Applying both shifts:
          (A) the pre_acc_shift between mult and add
          (B) the "placement shift" for 4b weights into 8b field.

        as it's what Hailo MAC actually does
           (in contrast to (de)numerization ops which are vitrual)
        """
        kernel = self.get_quant_kernel(training=training) - self.zp_kernel
        kernel_shifted = kernel * 2**self.weight_placement_shift
        # If input_zero_point is a vector, and padding is not VALID, all of it's value should be the same.
        _pre_shift_out = self.call_native(inputs, kernel_shifted, self.padding_const_value_q)
        ret = _pre_shift_out / 2**self.pre_acc_shift

        return ret

    def export_independent_params(self):
        return {
            "mac_shift": np.array(self.pre_acc_shift, np.float32),
            "shift_delta": np.array(self.shift_delta, np.float32),
            "kernel_scale": np.array(self.kernel_scale, np.float32),
            "kernel_zero_point": np.array(self.zp_kernel, np.float32),
            "weight_bits": np.array(self.weight_lossy_elements.kernel.bits, np.float32),
        }

    def import_independent_params(self, params):
        if not isinstance(self.weight_lossy_elements.kernel, BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_independent_params", self.full_name)
        kernel_bits = self.weight_lossy_elements.kernel.bits
        imported_kernel_bits = params["weight_bits"]
        if kernel_bits != imported_kernel_bits:
            raise AccelerasImportParamConfigMismatch("kernel_bits", kernel_bits, imported_kernel_bits, self.full_name)
        self.pre_acc_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.zp_kernel = params["kernel_zero_point"]

    def export_quant_weights(self):
        kernel_q = self.final_numeric_kernel.numpy() * 2**self.weight_placement_shift  # shifted kernel
        return {"quant_kernel": np.float32(kernel_q), "padding_const_value": self.padding_const_value_q.numpy()}

    def export_weights(self):
        return {"kernel": self.kernel.numpy(), "padding_const_value": self.padding_const_value}

    def export_hw_params(self):
        if self.weight_lossy_elements.kernel.bits <= 8:
            kernel = np.array(self.final_numeric_kernel.numpy() * 2**self.weight_placement_shift, np.int8)
        else:
            kernel = np.array(self.final_numeric_kernel.numpy(), np.int16)
        return {
            "kernel": kernel,
            "zp_kernel": np.array(self.zp_kernel, np.int32),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
            "padding_const_value": np.uint16(self.padding_const_value_q.numpy()),
        }

    @staticmethod
    def get_conv_same_padding(input_size, kernel_size, strides):
        """
        Calculates the padding for the beginning and ending of a given
            input_size (1 dimension).

        Args:
            input_size (int), kernel_size (int), strides (int): input
            size, kernel size, and stride size along a given tensor
            dimension.

        Returns:
            pad_beg_size (1x1), pad_end_size (1x1)

        """
        adjusted_strides = strides if input_size % strides == 0 else (input_size % strides)
        total_pad_size = max(kernel_size - adjusted_strides, 0)
        pad_beg_size = total_pad_size // 2
        pad_end_size = total_pad_size - pad_beg_size
        return pad_beg_size, pad_end_size

    @staticmethod
    def handle_padding_3d(inputs, padding, kernel_size, strides, padding_const_value, stride_align, dilation_rate):
        """
        _summary_

        Args:
            inputs (_type_): [B, D, H, W, Cin]
            padding (str): {'VALID', 'SAME'}
            kernel_size (_type_): [Dk, Hk, Wk, Cin, Cout]
            strides (_type_): [Ds, Hs, Ws]
            zp (int): zero padding value
            stride_align (str): {'NW', 'SE'}
            dilation_rate (list): supports only [1,1,1]

        Raises:
            AccelerasInitializationError: _description_

        Returns:
            _type_: _description_

        """
        if list(dilation_rate) != [1, 1, 1]:
            raise AccelerasInitializationError(
                f"dilation rate = {dilation_rate} is not supported for conv3d.",
            )

        if padding.upper() == "VALID":
            padded_input = inputs

        elif padding.upper() == "SAME":  # DIY padding
            input_shape = inputs.shape  # [B, D, H, W, Cin]
            pad_beg_d, pad_end_d = Conv3DOp.get_conv_same_padding(input_shape[1], kernel_size[0], strides[0])
            pad_beg_h, pad_end_h = Conv3DOp.get_conv_same_padding(input_shape[2], kernel_size[1], strides[1])
            pad_beg_w, pad_end_w = Conv3DOp.get_conv_same_padding(input_shape[3], kernel_size[2], strides[2])

            # all set, this is the default tensorflow "SAME" pad
            if stride_align.upper() == "SE":
                pass
            # Reverse the padding order (i.e., end<->begin)
            elif stride_align.upper() == "NW":
                pad_beg_d, pad_end_d = pad_end_d, pad_beg_d
                pad_beg_h, pad_end_h = pad_end_h, pad_beg_h
                pad_beg_w, pad_end_w = pad_end_w, pad_beg_w
            else:
                raise AccelerasImplementationError(
                    f"Atomic op conv3d_op does not support stride_align = {stride_align}",
                )
            # padding_tensor: [5 x (pad_beg, pad_end)]
            padding_tensor = [
                [0, 0],
                [pad_beg_d, pad_end_d],
                [pad_beg_h, pad_end_h],
                [pad_beg_w, pad_end_w],
                [0, 0],
            ]
            padded_input = tf.pad(inputs, padding_tensor, mode="CONSTANT", constant_values=padding_const_value)
        else:
            raise AccelerasImplementationError(
                f"Atomic op conv3d_op does not support padding = {padding}",
            )

        return padded_input

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.add_encoding(
            f"{self.full_name}/kernel_scale:0",
            EncodingType.Scale,
            scalar=False,
            shape=self.kernel_scale.shape,
        )
        flow.add_encoding(
            f"{self.full_name}/kernel_zero_point:0",
            EncodingType.ZeroPoint,
            scalar=False,
            shape=(),
            quant=True,
            quant_min=tf.float32.min,
            quant_max=tf.float32.max,
            initializer=TensorInitializer(self.zp_kernel),
        )
        flow.add_encoding(
            f"{self.full_name}/mac_shift:0",
            EncodingType.Scale,
            scalar=False,
            shape=(),
            initializer=TensorInitializer(self.pre_acc_shift),
            quant=True,
            quant_min=1.0,
            quant_max=4.0,
        )
        flow.add_encoding(f"{self.full_name}/total_rshift:0", EncodingType.Scale, scalar=False, shape=())

        if self.padding != PaddingType.VALID:
            flow.get_encoding(f"{self.full_name}/input_zero_point:0").scalar = True

        flow.get_encoding(f"{self.full_name}/input_scale:0").scalar = True
        flow.get_encoding(f"{self.full_name}/output_scale:0").scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        # 16 bit quantization doesn't support activation shift and in that case it's set to zero.
        if self.output_lossy_element.bits == 32:
            enc.identity(f"{self.full_name}/mac_shift:0", 0.0)

        enc.sub(f"{self.full_name}/total_rshift:0", f"{self.full_name}/mac_shift:0", self.weight_placement_shift)

        # compute kernel_scale
        enc.callback(
            f"{self.full_name}/kernel_scale:0",
            [f"{self.full_name}/input_scale:0", f"{self.full_name}/output_scale:0", f"{self.full_name}/total_rshift:0"],
            self.calc_kernel_scale_external,
            callback_name="mat_mul",
            outs_shape=self.kernel_scale.shape,
            outs_scalar=True,
            kernel_shape=self.kernel.shape,
        )

        # compute output_zero_point
        def output_zp_callback(kernel_scale, kernel_zero_point, input_zero_point, total_rshift, training=False):
            numerized_kernel = self.kernel / kernel_scale + kernel_zero_point
            final_numeric_kernel = self.weight_lossy_elements.kernel(numerized_kernel, training=training)
            kernel = final_numeric_kernel - kernel_zero_point
            numeric_kernel_summed = tf.reduce_sum(kernel, axis=self.axes2reduce)
            return tf.tile(input_zero_point * numeric_kernel_summed / 2**total_rshift, [self.output_disparity])

        enc.callback(
            f"{self.full_name}/output_zero_point:0",
            [
                f"{self.full_name}/kernel_scale:0",
                f"{self.full_name}/kernel_zero_point:0",
                f"{self.full_name}/input_zero_point:0",
                f"{self.full_name}/total_rshift:0",
            ],
            output_zp_callback,
            callback_name="output_zp_callback",
            outs_shape=(self.output_shape[-1],),
        )

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/kernel_scale:0", self.kernel_scale)
        enc.identity(f"{self.full_name}/kernel_zero_point:0", self.zp_kernel)
        enc.identity(f"{self.full_name}/mac_shift:0", self.pre_acc_shift)
        enc.identity(f"{self.full_name}/total_rshift:0", self.total_rshift)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.kernel_scale = encodings[f"{self.full_name}/kernel_scale:0"]
        self.zp_kernel = encodings[f"{self.full_name}/kernel_zero_point:0"]
        self.pre_acc_shift = encodings[f"{self.full_name}/mac_shift:0"]

    def _compute_output_shape(self, input_shape):
        n, h, w = input_shape[:3]
        if list(self.dilation_rate) != [1, 1, 1]:
            return super()._compute_output_shape(input_shape)

        if self.padding == PaddingType.SAME:
            strides = list(self.strides[:-1]) + [self.strides[-1] // self.input_features]
            h, w, d = (
                conv_utils.conv_output_length(
                    length,
                    self.kernel_size[i],
                    padding="same",
                    stride=strides[i],
                    dilation=self.dilation_rate[i],
                )
                for i, length in enumerate([h, w, self.input_disparity])
            )
            output_shape = (n, h, w, d * self.kernel.shape[-1])
        elif self.padding == PaddingType.VALID:
            output_shape = super()._compute_output_shape(input_shape)
        else:
            output_shape = super()._compute_output_shape(input_shape)

        return output_shape

    @property
    def set_scale_by_kernel_only(self):
        # set_scale_by_kernel_only not implemented for 3d conv
        return False
