from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.shift_add_op import ShiftAddOp
from hailo_model_optimization.acceleras.atomic_ops.split_precision_op import (
    SplitPrecisionHigh,
    SplitPrecisionLow,
)
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    AccumulatorQuantElement,
    APUOutputQuantElement,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    OptimizationTarget,
    PaddingType,
    PrecisionMode,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.acceleras.utils.export.export_utils import add_prefix
from hailo_model_optimization.acceleras.utils.hn_npz_utils import (
    LayerParams,
    get_hn_padding,
    set_hn_padding_stride_align,
)
from hailo_model_optimization.acceleras.utils.opt_utils import get_decomposition_count_by_bias_mode

WEIGHT_LOW_BITS = 8
WEIGHT_HIGH_BITS = 8

BIAS_LOW_BITS = 15
BIAS_HIGH_BITS = 16

REMOVE_LL = False
ALLOW_H_WEIGHT_SHIFT_DELTA = True
LOW_BITS = 8


class HailoConvDecomposePluto(BaseHailoLayer):
    """
    implement conv 16bits that used 8bits layers.
    split_precision_op: split the input to high and low
    4 conv_op HH, HL, LH, LL
    2 bias_op H, L
    APU 32->16 (mode the 16bits percision)
    shift_add_op: add all the percision components

    The layer is used only after we can calibraion data, so there is be an algorithm that swtich
    the regular conv with this conv after the creation of hw_params
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a16_w16,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
        BiasMode.double_scale_decomposition,
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False

    def __init__(
        self,
        name: str,
        filters,
        kernel_size,
        strides=(1, 1),
        padding: Union[str, PaddingType] = "SAME",
        stride_align: Union[str, StrideAlignType] = "NW",
        dilation_rate=(1, 1),
        groups=1,
        activation: Union[str, callable, ActivationType] = "linear",
        logger=None,
        **kwargs,
    ):
        self.input_op = PassthruOp(f"{name}/input_passthough_op", logger=logger)
        self.split_precision_low = SplitPrecisionLow(f"{name}/split_precision_low_op", logger=logger)
        self.split_precision_high = SplitPrecisionHigh(f"{name}/split_precision_high_op", logger=logger)
        self.conv_hh_op = ConvStrippedOp(
            f"{name}/conv_hh_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            trainable=False,
            logger=logger,
        )
        self.conv_hl_op = ConvStrippedOp(
            f"{name}/conv_hl_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            trainable=False,
            logger=logger,
        )
        self.conv_lh_op = ConvStrippedOp(
            f"{name}/conv_lh_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            trainable=False,
            logger=logger,
        )
        self.conv_ll_op = ConvStrippedOp(
            f"{name}/conv_ll_op",
            kernel_size=kernel_size,
            is_depthwise=False,
            filters=filters,
            groups=groups,
            strides=strides,
            dilation_rate=dilation_rate,
            padding=padding,
            stride_align=stride_align,
            trainable=False,
            force_zero_output_when_quant=REMOVE_LL,
            logger=logger,
        )
        # Bias Ops
        self.bias_add_hh_op = AddBiasOp(f"{name}/bias_add_hh_op", trainable=False, logger=logger)
        self.bias_add_hl_op = AddBiasOp.get_passthru_bias(f"{name}/bias_add_hl_op", logger=logger)
        self.bias_add_lh_op = AddBiasOp.get_passthru_bias(f"{name}/bias_add_lh_op", logger=logger)
        self.bias_add_ll_op = AddBiasOp(f"{name}/bias_add_ll_op", trainable=False, logger=logger)

        # End Part
        self.shift_add_op = ShiftAddOp(f"{name}/shift", logger)
        self.act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self.output_op = PassthruOp(f"{name}/passthru_op", logger=logger)
        self.all_convs = [self.conv_ll_op, self.conv_hl_op, self.conv_lh_op, self.conv_hh_op]
        self.all_biases = [self.bias_add_ll_op, self.bias_add_hl_op, self.bias_add_lh_op, self.bias_add_hh_op]

        super().__init__(name=name, logger=logger, **kwargs)

        for op in self.atomic_ops:
            op.fully_native = True
        self._kernel_high_shift = 0
        self._bias_high_scale = None
        self._forced_kernel_high_scale = None
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
        self._hw_shifts = [0, 1, 2, 4, 6]

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None, pluto_mode=True):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        padding, stride_align = get_hn_padding(params)

        cls._validate_elwa(params["elementwise_add"])

        if params.get("transpose_output_width_features", False):
            raise AccelerasImplementationError("transpose_output_width_features is not supported in acceleras yet")
        layer = cls(
            name=lname,
            filters=kshape[-1],
            kernel_size=kshape[0:2],
            padding=padding,
            stride_align=stride_align,
            strides=params["strides"][1:3],
            groups=params["groups"],
            activation=params["activation"],
            dilation_rate=params["dilations"][1:3],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def to_hn(self, out_degree=None):
        params = {}
        weights = self.export_weights()
        params["kernel_shape"] = list(weights["kernel"].shape)
        strides = self.conv_hh_op.strides
        params["strides"] = [1, strides[0], strides[1], 1]
        params["groups"] = self.conv_hh_op.groups
        params["activation"] = self.act_op.act_name.value
        dilation_rate = self.conv_hh_op.dilation_rate
        params["dilations"] = [1, dilation_rate[0], dilation_rate[1], 1]
        set_hn_padding_stride_align(params, self.conv_hh_op.padding, self.conv_hh_op.stride_align)
        params["elementwise_add"] = False
        params["decompose_weights"] = True
        self._hn_element["params"].update(params)
        return super().to_hn(out_degree=out_degree)

    def _build_flow(self) -> LayerFlow:
        """
              ┌──────────────────────────────────────────────────────────────────────────────────────┐
              │                                                                                      │
              │                                                                                      |
              |                                                                                      │
              │                              ┌────────┐   ┌────────┐                                 │
              │   ┌────────────────────┐ ┌───►conv_ll ├───►bias_ll ├────────┐                        │
        input─┼┬─►│ precision_low      ├─┤   ┌────────┤   ┌────────┤        |                        |
              ││  └────────────────────┤ └───►conv_lh │───►bias_lh │─────►──▼─────┐   ┌───────────┐  │
              ││   ┌───────────────────┤     ┌────────┼   ┌────────┼     |ShiftAdd├───►apu        ├──┼────►
              ├┴───►precision_high     ├─┬───►conv_hl │───►bias_hl │─────►──▲─────┘   └───────────┘  Output
              │    └───────────────────┘ │   ┌────────┤   ┌────────┤        │                        │
              │                          └───►conv_hh ├───►bias_hh ├────────┘                        │
              │                              └────────┘   └────────┘                                 │
              │                                                                                      │
              │                                                                                      │
              └──────────────────────────────────────────────────────────────────────────────────────┘
        """
        #! This data paths are Momentary, they are fixed on create_quant_element_custom_behavior

        layer_flow = LayerFlow()
        layer_flow = self._init_flow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_edge(in1, self.input_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_low, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_op, self.split_precision_high, DataPath.LAYER_IN, input_index=0)

        layer_flow.add_edge(
            self.split_precision_low, self.split_precision_high, DataPath.LAYER_SPLIT_INPUT, input_index=1
        )
        layer_flow.add_edge(self.split_precision_low, self.conv_lh_op, DataPath.LAYER_SPLIT_INPUT)
        layer_flow.add_edge(self.split_precision_low, self.conv_ll_op, DataPath.LAYER_SPLIT_INPUT)
        layer_flow.add_edge(self.split_precision_high, self.conv_hh_op, DataPath.LAYER_SPLIT_INPUT)
        layer_flow.add_edge(self.split_precision_high, self.conv_hl_op, DataPath.LAYER_SPLIT_INPUT)

        for conv, bias in zip(self.all_convs, self.all_biases):
            layer_flow.add_edge(conv, bias, DataPath.ACCUMULATOR)
        for index, bias in enumerate(self.all_biases):
            layer_flow.add_edge(bias, self.shift_add_op, DataPath.ACCUMULATOR, input_index=index)

        layer_flow.add_edge(self.shift_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)
        self.input_op.set_output_lossy_element(APUOutputQuantElement(bits=15))
        self.split_precision_low.set_input_lossy_element(APUOutputQuantElement(bits=15), index=0)
        self.split_precision_low.set_input_lossy_element(APUOutputQuantElement(bits=15), index=0)
        self.split_precision_low.set_output_lossy_element(APUOutputQuantElement(bits=8))
        self.split_precision_high.set_output_lossy_element(APUOutputQuantElement(bits=8))
        self.split_precision_high.set_input_lossy_element(APUOutputQuantElement(bits=8), index=1)

        for op in self.all_convs:
            op.set_input_lossy_element(APUOutputQuantElement(bits=8))
            op.set_output_lossy_element(AccumulatorQuantElement(bits=16))
            op.create_weight_quant_element(kernel_bits=8, signed=True)

        for op in self.all_biases:
            op.set_input_lossy_element(AccumulatorQuantElement(bits=16))
            op.set_output_lossy_element(AccumulatorQuantElement(bits=16))
            op.create_weight_quant_element(kernel_bits=16, signed=True, num_decomposition=num_decomposition)
        for ind in range(4):
            self.shift_add_op.set_input_lossy_element(AccumulatorQuantElement(bits=16), index=ind)
        self.shift_add_op.set_output_lossy_element(AccumulatorQuantElement(bits=32))
        self.act_op.set_input_lossy_element(AccumulatorQuantElement(bits=32))
        self.act_op.create_weight_quant_element(optimization_target)

    def import_weights(self, layer_params: LayerParams):
        self.split_precision_low.import_weights(LOW_BITS)
        self.split_precision_high.import_weights(LOW_BITS)

        self.kernel = layer_params["kernel"]
        self._kernel_high_shift = layer_params.get("kernel_high_shift", 0)
        self._forced_kernel_high_scale = layer_params.get("forced_kernel_high_scale", None)
        weights_high, weights_low, s_high, s_low = self._decompose_native_weights(
            layer_params["kernel"],
            low_bits=WEIGHT_LOW_BITS,
            high_bits=WEIGHT_HIGH_BITS - self._kernel_high_shift,
            high_scale=self._forced_kernel_high_scale,
        )
        self.weight_scale_high = s_high / 2 ** np.minimum(
            self._kernel_high_shift,
            np.min(self._hw_shifts),  #! this can be remove
        )  # for lossless shift 1
        self.weight_scale_low = s_low
        self.conv_hh_op.import_weights(weights_high)
        self.conv_lh_op.import_weights(weights_high)
        self.conv_hl_op.import_weights(weights_low)
        self.conv_ll_op.import_weights(weights_low)

        if "bias" in layer_params.keys():
            self._bias_high_scale = layer_params.get("bias_high_scale", None)
            bias_high, bias_low, s_high, s_low = self._decompose_native_weights(
                layer_params["bias"],
                low_bits=BIAS_LOW_BITS,
                high_bits=BIAS_HIGH_BITS,
                high_scale=self._bias_high_scale,
            )
            self.bias_add_hh_op.import_weights(bias_high)
            self.bias_add_ll_op.import_weights(bias_low)
        self.act_op.import_weights(layer_params)

    def _export_weights(self):
        kernel = self.export_native_kernel()
        bias = self.export_native_bias()
        dict_params = {
            "kernel": kernel,
            "bias": bias,
            "kernel_high_shift": self._kernel_high_shift,
            "padding_const_value": self.conv_hh_op.export_weights()["padding_const_value"],
            **({"forced_kernel_high_scale": self._forced_kernel_high_scale} if self._forced_kernel_high_scale else {}),
            **({"bias_high_scale": self._bias_high_scale} if self._bias_high_scale else {}),
        }
        dict_params.update(self.act_op.export_weights())
        return dict_params

    def _export_ops_hw_params(self):
        val = dict()
        val.update(self.act_op.export_hw_params())
        val.update(self.shift_add_op.export_hw_params())

        for conv in self.all_convs:
            val.update(add_prefix(conv.export_hw_params(), f"{conv.name}/"))

        for bias in [self.bias_add_ll_op, self.bias_add_hh_op]:
            val.update(add_prefix(bias.export_hw_params(), f"{bias.name}/"))

        # Removing duplicated Kernels
        val.pop("conv_hl_op/kernel")
        val.pop("conv_lh_op/kernel")

        # Removing bias_q never use and will be good to start cleaning
        val.pop("bias_add_hh_op/bias_q")
        val.pop("bias_add_ll_op/bias_q")
        return val

    def export_native_kernel(self):
        kernel_high = self.conv_hh_op.export_weights()["kernel"]
        kernel_low = self.conv_ll_op.export_weights()["kernel"]
        return kernel_high + kernel_low

    def export_native_bias(self):
        bias_high = self.bias_add_hh_op.export_weights()
        bias_low = self.bias_add_ll_op.export_weights()
        return bias_high + bias_low

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        super().start_stats_collection(stats_cfg, output_hist, preact_hist)
        self.split_precision_low.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        self.split_precision_high.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        for conv in self.all_convs:
            conv.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=True)
        for bias in self.all_biases:
            bias.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=False, collect_output=True)

    def create_splits(self):
        self.split_precision_low.trivial_split = False
        self.split_precision_low.create_input_encoding_candidates(0)
        self.split_precision_low.enforce_encoding()
        self.split_precision_high.input_scales[0] = self.split_precision_low.input_scale
        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scale
        self.split_precision_high.enforce_encoding()

    def _decompose_native_weights(self, weights, low_bits, high_bits, high_scale=None):
        """
        The function assumes the weights are signed
        """
        if not high_scale:
            weight_max = np.max(np.abs(weights))
            if weight_max > 0:
                high_scale = weight_max / (2 ** (high_bits - 1) - 1)
            else:
                high_scale = np.array(1 / (2 ** (high_bits - 1) - 1))
        weights_high_q = np.round(weights / high_scale) * high_scale
        weight_low = weights - weights_high_q
        weights_high = weights_high_q.astype(np.float32)
        weight_low = weight_low.astype(np.float32)
        high_scale = high_scale.astype(np.float32)
        low_scale = np.array(high_scale / 2**low_bits).astype(np.float32)
        return weights_high, weight_low, high_scale, low_scale

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "strides": [1, 1, 1, 1],
            "dilations": [1, 1, 1, 1],
            "padding": "SAME",
            "activation": "linear",
            "groups": 1,
            "elementwise_add": False,
        }
        return dict(defaults)

    def _create_conv_bias_hw_params(
        self,
        conv: ConvStrippedOp,
        bias: AddBiasOp,
        clipping_config,
        force_rounded_shift_delta: bool = True,
        *,
        optimization_target: OptimizationTarget = OptimizationTarget.PLUTO,
    ):
        """
        Calculate the shift and shift delta of the conv based on the stats
        """

        bias.create_output_encoding_candidates(0)
        # we get accumalot max to calculate the Mac Shift and Shift delta forward
        max_native_accumulator = np.maximum(bias.get_output_stats(0).max, np.abs(bias.get_output_stats(0).min))
        conv.force_rounded_shift_delta = force_rounded_shift_delta

        # Domy only for the shape
        dommy_kernel_compoent = conv.calc_kernel_scale(
            conv.input_scales,
            bias.output_scale,
            0,
        )

        # Casting to the sto the shape of the kernel scale matrix component
        kernel_scale_matrix_component = np.array(conv.kernel_scale) * np.ones_like(dommy_kernel_compoent)
        kernel_scale_matrix_component = kernel_scale_matrix_component.astype(np.float32)

        # This creates the quantize Kernel (kernel scale is force) and we get the shift for the MAC, and the shift delta
        conv.create_hw_params(
            max_native_accumulator=max_native_accumulator,
            weight_clip_cfg=clipping_config,
            kernel_scale_matrix_component=kernel_scale_matrix_component,
            hw_shifts=self._hw_shifts,
            force_scale=True,  # Scales are force
            optimization_target=optimization_target,
        )

        # Here we pass the shift delta forward, but this can be problematic it will overflow the Accumulator
        kernel_scale_with_shift_delta = np.mean(conv.kernel_scale * 2 ** np.array(conv.shift_delta))
        conv.output_scale = (
            np.ones(conv.kernel.shape[-1]) * conv.input_scale[0] * kernel_scale_with_shift_delta * 2**conv.pre_acc_shift
        )

        # I Update the output Scale but the kernel_q scale is the same
        conv.enforce_encoding()  # calculate the zp the scale should be the same

    def import_quant(self, params: dict):
        super().import_quant(params=params)
        for conv in self.all_convs:
            kernel_k = params[conv.full_name + "/quant_kernel"]
            kernel_scale = params[conv.full_name + "/kernel_scale"]
            kernel = kernel_k * kernel_scale
            conv.kernel = kernel

    def create_weight_split(
        self,
        weights_clipping: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
    ):
        """
        This is External method exclusive used by Algorithms.

        ensure that the hh conv will be lossless by decrease the needed shift
        from the number of bits for the weights and recalculate the weight split
        """
        self.enable_lossy()
        self.create_io_encoding_candidates()
        self.enforce_io_encoding()
        self._enforce_split_precision_encoding()
        self._set_kernel_scales()

        # Sets the weights scale on a way that wont be pre acc shift
        self._create_conv_bias_hw_params(
            self.conv_hh_op,
            self.bias_add_hh_op,
            weights_clipping,
            force_rounded_shift_delta=False,
            optimization_target=optimization_target,
        )
        if ALLOW_H_WEIGHT_SHIFT_DELTA:
            shift_delta = 0
        else:
            self._create_conv_bias_hw_params(
                self.conv_lh_op,
                self.bias_add_lh_op,
                weights_clipping,
                force_rounded_shift_delta=False,
                optimization_target=optimization_target,
            )
            shift_delta = np.max(self.conv_lh_op.shift_delta) + 0.1
        self._kernel_high_shift = np.maximum(self.conv_hh_op.desired_pre_acc_shift, shift_delta)
        self._bias_high_scale = self.conv_hh_op.output_scale[0]
        self.import_weights(self.export_weights())

    def _set_kernel_scales(self):
        for op in [self.conv_lh_op, self.conv_hh_op]:
            op.kernel_scale = self.weight_scale_high
        for op in [self.conv_ll_op, self.conv_hl_op]:
            op.kernel_scale = self.weight_scale_low

    def _calculate_activation_scale_canditate(
        self,
        weights_clipping: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
    ):
        self._enforce_output_encoding()
        self._enforce_split_precision_encoding()
        self._set_kernel_scales()

        for index, conv in enumerate(self.all_convs[::-1]):
            bias = self.all_biases[len(self.all_convs) - 1 - index]
            self._create_conv_bias_hw_params(conv, bias, weights_clipping, optimization_target=optimization_target)
            # conv.enforce_encoding()  # calculate the zp the scale should be the same

        for index, bias in enumerate(self.all_biases):
            bias.pre_acc_shift = self.all_convs[index].pre_acc_shift
            bias.input_scale = self.all_convs[index].output_scale
            bias.input_zero_point = self.all_convs[index].output_zero_point
            bias.output_scale = bias.input_scale
            bias.create_hw_params()
            self.shift_add_op.input_scales[index] = bias.output_scale
            self.shift_add_op.input_zero_points[index] = bias.output_zero_point

        self.shift_add_op.output_scale = self.shift_add_op.input_scales[0]
        self.shift_add_op.create_hw_params()

        self.act_op.input_scales = [self.shift_add_op.output_scale]
        self.act_op.input_zero_points = [self.shift_add_op.output_zero_point]
        self.act_op.create_hw_params(self.shift_add_op.output_scale, optimization_target)
        self.act_op.get_accumulator_scale()

        # change the quant element of the mantissa to avoid error in the rounding due to numerical instability
        self.act_op.create_hw_params(self.act_op.input_scale, optimization_target)
        self.act_op.enforce_encoding()

    def _back_propogate_scales(self):
        """
        back_propogate scales from the apu input to the conv output
        """

        self.shift_add_op.output_scales = [self.act_op.input_scale]
        self.shift_add_op.output_zero_points = [self.act_op.input_zero_point]
        self.shift_add_op.enforce_encoding(forward=False)

        for index, bias in enumerate(self.all_biases):
            bias.input_scale = self.shift_add_op.input_scales[index]
            bias.output_scale = self.shift_add_op.input_scales[index]
            bias.create_hw_params()
        for index, conv in enumerate(self.all_convs):
            conv.output_scale = self.all_biases[index].input_scale
            conv.enforce_encoding()

        self._forced_kernel_high_scale = np.array(
            self.conv_hh_op.kernel_scale[0][0] * 2 ** np.min(self._hw_shifts),
        )  #! cancel the fix correction

    def create_hw_params(
        self,
        weights_clipping: LayerWeightsClippingConfig,
        optimization_target: OptimizationTarget,
        hw_shifts=None,
    ):
        self._calculate_activation_scale_canditate(weights_clipping, optimization_target)
        self._back_propogate_scales()
        self.import_weights(self.export_weights())
        for op in self.atomic_ops:
            op.fully_native = False
        self._has_hw_params = True

    def _enforce_split_precision_encoding(self):
        self.input_op.enforce_encoding()
        self.split_precision_low.input_scales[0] = self.input_op.output_scales[0]
        self.split_precision_high.input_scales[0] = self.input_op.output_scales[0]
        self.split_precision_low.input_zero_points[0] = self.input_op.output_zero_points[0]
        self.split_precision_high.input_zero_points[0] = self.input_op.output_zero_points[0]
        self.split_precision_low.enforce_encoding()
        self.split_precision_high.enforce_encoding()
        self.split_precision_high.input_scales[1] = self.split_precision_low.output_scales[0]
        self.split_precision_high.input_zero_points[1] = self.split_precision_low.output_zero_points[0]
        for op in [self.conv_lh_op, self.conv_ll_op]:
            op.input_scales = [self.split_precision_low.output_scale]
            op.input_zero_points = [self.split_precision_low.output_zero_point]
        for op in [self.conv_hl_op, self.conv_hh_op]:
            op.input_scales = [self.split_precision_high.output_scale]
            op.input_zero_points = [self.split_precision_high.output_zero_point]

    def _enforce_output_encoding(self):
        self.output_op.backward_encoding()
        self.act_op.output_scale = self.output_op.input_scales[0]
        self.act_op.output_zero_point = self.output_op.input_zero_points[0]

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        This is a forward path computation of the encoding enforement.
        As we enforce that the output scale is equal to the input scale, we only need to sequentially enforce the encodings
        of the atomic ops in their natural order.
        Note that as we set the avgpool op weight to be an exact power of two, we don't have to set the apu encodings.
        """
        self._enforce_output_encoding()
        self._enforce_split_precision_encoding()

        for index, (conv, bias) in enumerate(zip(self.all_convs, self.all_biases)):
            conv.enforce_encoding()
            bias.input_scales = [conv.output_scale]
            bias.input_zero_points = [conv.output_zero_point]
            bias.output_scale = conv.output_scale
            bias.output_zero_point = np.array(0.0)
            bias.enforce_encoding()

            self.shift_add_op.input_scales[index] = conv.output_scale
            self.shift_add_op.input_zero_points[index] = np.array(0.0)

        self.shift_add_op.enforce_encoding()
        self.act_op.input_scale = self.shift_add_op.output_scale
        self.act_op.input_zero_point = self.shift_add_op.output_zero_point
        self.act_op.enforce_encoding()

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    @classmethod
    def _validate_elwa(cls, elwa_value):
        if elwa_value:
            raise ValueError(
                f"elementwise_add value was {elwa_value}, "
                f"but expected {not elwa_value} in {cls.__name__} initialization",
            )

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @property
    def consumer_input_scale(self):
        return False

    @property
    def homogeneous(self):
        return False

    def _is_precision_config_supported(self, precision_mode, bias_mode, arch):
        return True
