import itertools
from typing import Union

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.slice_op import SliceOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasInitializationError
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    calculate_shifts,
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
    limvals_to_zp_scale,
)

POINTS_PER_BOX = 4


class HailoBBoxDecoder(BaseHailoLayer):
    """
    OPS:
    Slice       : slice each of the input tensor to the relevant anchor values.
    Concat(list): (optional) if there are two inputs we concat them to a single input
    Conv (list) : transformation from x,y,w,h to x_min, y_min, x_max, y_max (order TBD)
                    there is one conv of each anchor (can also be implemented with group conv)
    Bias (list) : Add the anchors centers. Note that this bias is per special index not only
                    the featuer index
    concat      : concat all the anchors to one tensors
    *Note: the non arithmetic ops (Slice, Concat) are only for the SW implementation and
            not implemented on the HW

    Activation  : layer activation, should be linear but the HW Support other activation functions
    The order of the channels for for one input is [center0, center1, h, w]Xnum_enchoers
    The order of the channels for two inputs is [[center0, center1]Xnum_enchoers, [h, w]Xnum_enchoers


    Args:
        num_of_anchors - an integer, describing the number of anchors of the layer architecture
        num_inputs - there can be 1 or 2 input tensors as described above.

        If the user wants to create a new layer (w/o loading a trained layer),
        he will need to supply the following args:

        anchors_heights, anchors_widths - vector in the size of num_of_anchors with the wanted anchor sizes.
        x_centers, y_centers - matrix in the size of HxW (of the input sizes) with the offsets values.

    Reference:
    https://hailotech.atlassian.net/wiki/spaces/ML/pages/939687998/bbox+decoder


    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w16_a16,
    }
    SUPPORTED_BIAS_MODE = {
        BiasMode.double_scale_initialization,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False
    _hn_type = LayerType.BBOX_DECODER

    def __init__(
        self,
        name: str,
        num_of_anchors: int = 3,
        num_inputs: int = 2,
        anchors_heights: np.ndarray = None,
        anchors_widths: np.ndarray = None,
        anchors_heights_div_2: np.ndarray = None,
        anchors_widths_div_2: np.ndarray = None,
        anchors_heights_minus_div_2: np.ndarray = None,
        anchors_widths_minus_div_2: np.ndarray = None,
        x_centers: np.ndarray = None,
        y_centers: np.ndarray = None,
        activation: Union[str, callable, ActivationType] = ActivationType.RELU,
        logger=None,
        **kwargs,
    ):
        if num_inputs not in {1, 2}:
            raise AccelerasInitializationError(
                f"Number of inputs to bbox decoder {self.full_name} must be 1 or 2 but got {num_inputs}",
            )

        self.anchors_heights = anchors_heights
        self.anchors_widths = anchors_widths
        self.anchors_heights_div_2 = anchors_heights_div_2
        self.anchors_widths_div_2 = anchors_widths_div_2
        self.anchors_heights_minus_div_2 = anchors_heights_minus_div_2
        self.anchors_widths_minus_div_2 = anchors_widths_minus_div_2

        self.x_centers = x_centers
        self.y_centers = y_centers

        self._num_of_anchors = num_of_anchors
        self._num_inputs = num_inputs
        self._passthru_ops = []
        for input_index in range(num_inputs):
            self._passthru_ops.append(PassthruOp(f"{name}/passthru_op_in_{input_index}", logger=logger))
        self._slice_ops = []
        self._pre_concat_ops = []
        self._conv_ops = []
        self._bias_ops = []
        for anchor_index in range(num_of_anchors):
            # create slice ops
            for input_index in range(num_inputs):
                slice_start = POINTS_PER_BOX // num_inputs * anchor_index
                slice_end = POINTS_PER_BOX // num_inputs * (anchor_index + 1)
                slice_op = SliceOp(
                    f"{name}/slice_op_{input_index}_{anchor_index}",
                    features_slice=(slice_start, slice_end, 1),
                    logger=logger,
                )
                self._slice_ops.append(slice_op)

            # create pre concat op (used for 1 or 2 input tensors but if there is 1 input, it will just output it's
            # input as it is)
            pre_concat_op = ConcatOp(
                f"{name}/pre_concat_op_{anchor_index}",
                concat_elements=num_inputs,
                logger=logger,
                vector_zp=True,
            )
            self._pre_concat_ops.append(pre_concat_op)

            # create pre conv op
            conv_op = ConvStrippedOp(
                f"{name}/conv_op_{anchor_index}",
                kernel_size=(1, 1),
                filters=POINTS_PER_BOX,
                is_depthwise=False,
                stride_align="NW",
                strides=(1, 1),
                groups=1,
                padding="VALID",
                dilation_rate=(1, 1),
                trainable=False,
                vector_zp=True,
                logger=logger,
            )
            self._conv_ops.append(conv_op)

            # create pre bias op
            bias_op = AddBiasOp(
                f"{name}/bias_op_{anchor_index}",
                trainable=False,
                axis=(1, 2, 3),
                logger=logger,
            )
            self._bias_ops.append(bias_op)

        self._concat_op = ConcatOp(f"{name}/concat_op", concat_elements=num_of_anchors, logger=logger)
        self._act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        super().__init__(name=name, logger=logger, **kwargs)

    @property
    def pre_acc_shift(self):
        return self._conv_ops[0].pre_acc_shift

    # region Expor Hw Params
    def _export_ops_hw_params(self) -> dict:
        """
        #! This is the exception not the rule, we need to refactor this. Dont Copy!
        We dont have a extandar way to export ops hw params for layers that have multiple ops with the same class
        All the decompose"""
        params = {}
        for anchors in range(self._num_of_anchors):
            temp_conv = self._conv_ops[anchors].export_hw_params()
            params[f"kernel_anchor_{anchors}"] = temp_conv["kernel"]
            params[f"padding_const_value_{anchors}"] = temp_conv["padding_const_value"]

            temp_bias = self._bias_ops[anchors].export_hw_params()
            params[f"bias_anchor_{anchors}"] = temp_bias["bias"]
            params[f"bias_factor_{anchors}"] = temp_bias["bias_factor"]
            params[f"bias_feed_repeat_{anchors}"] = temp_bias["bias_feed_repeat"]

        params["zp_kernel"] = temp_conv["zp_kernel"]
        params["output_stage/mult_shift"] = temp_conv["output_stage/mult_shift"]
        params.update(self._act_op.export_hw_params())
        return params

    def _layer_dependent_hw_params_modifications(self, params: dict) -> dict:
        anchors = self.kernels_to_anchors(params)
        centers = self.bias_to_centers(params)
        params.update(anchors)
        params.update(centers)
        return params

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        # Add io nodes
        in_nodes = [layer_flow.add_input() for _ in range(self._num_inputs)]
        out1 = layer_flow.add_output()

        # Add op as nodes
        for inp in range(self._num_inputs):
            layer_flow.add_node(self._passthru_ops[inp])
        for anc, inp in itertools.product(range(self._num_of_anchors), range(self._num_inputs)):
            layer_flow.add_node(self._slice_ops[anc * self._num_inputs + inp])
        for i in range(self._num_of_anchors):
            layer_flow.add_node(self._pre_concat_ops[i])
            layer_flow.add_node(self._conv_ops[i])
            layer_flow.add_node(self._bias_ops[i])
        layer_flow.add_node(self._concat_op)
        layer_flow.add_node(self._act_op)

        # Add edges
        for inp in range(self._num_inputs):
            layer_flow.add_edge(in_nodes[inp], self._passthru_ops[inp], data_path=DataPath.LAYER_IN)
        for anc, inp in itertools.product(range(self._num_of_anchors), range(self._num_inputs)):
            slice_op = self._slice_ops[anc * self._num_inputs + inp]
            layer_flow.add_edge(self._passthru_ops[inp], slice_op, data_path=DataPath.LAYER_IN)
            layer_flow.add_edge(slice_op, self._pre_concat_ops[anc], input_index=inp, data_path=DataPath.LAYER_IN)
        for i in range(self._num_of_anchors):
            layer_flow.add_edge(self._pre_concat_ops[i], self._conv_ops[i], DataPath.LAYER_IN)
            layer_flow.add_edge(self._conv_ops[i], self._bias_ops[i], DataPath.ACCUMULATOR)
            layer_flow.add_edge(self._bias_ops[i], self._concat_op, DataPath.ACCUMULATOR, input_index=i)
        layer_flow.add_edge(self._concat_op, self._act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self._act_op, out1, DataPath.LAYER_OUT)

        return layer_flow

    def import_weights(self, layer_params: LayerParams):
        """
        The input to each conv is ty, tx, h_exp, w_exp
        The output is ymin, xmin, ymax, xmax
        ymin = ycenter_a_min + ty * ha - h_exp * ha/2
        xmin = xcenter_a_min + tx * wa - w_exp * wa/2
        ymax = ycenter_a_max + ty * ha + h_exp * ha/2
        xmax = xcenter_a_max + tx * wa + w_exp * wa/2

        Comment - In Yolov5, the value of the behaviour of the anchors is different. The value of anchor
        width/height will be constant 2/number_of_input_pixels_to_bbox. The rest of the parameters
        (anchors_height_div_2 ect...) will behave as expected.
        For more information, look at
        https://github.com/hailo-ai/hailo_model_zoo/blob/master/hailo_model_zoo/core/postprocessing/detection/yolo.py
        """
        self.anchors_heights = layer_params["anchors_heights"]
        self.anchors_widths = layer_params["anchors_widths"]
        self.anchors_heights_div_2 = layer_params["anchors_heights_div_2"]
        self.anchors_widths_div_2 = layer_params["anchors_widths_div_2"]
        self.anchors_heights_minus_div_2 = layer_params["anchors_heights_minus_div_2"]
        self.anchors_widths_minus_div_2 = layer_params["anchors_widths_minus_div_2"]
        self.x_centers = layer_params["x_centers"]
        self.y_centers = layer_params["y_centers"]
        y_centers = []
        x_centers = []

        for i in range(self._num_of_anchors):
            # load kernel params
            ha = self.anchors_heights[i]
            wa = self.anchors_widths[i]
            ha_div2 = self.anchors_heights_div_2[i]
            wa_div2 = self.anchors_widths_div_2[i]
            ha_minus_div2 = self.anchors_heights_minus_div_2[i]
            wa_minus_div2 = self.anchors_widths_minus_div_2[i]

            kernel_np = np.zeros([POINTS_PER_BOX, POINTS_PER_BOX], dtype=np.float32)
            kernel_np[0, :] = [ha, 0, ha_minus_div2, 0]  # ymin
            kernel_np[1, :] = [0, wa, 0, wa_minus_div2]  # xmin
            kernel_np[2, :] = [ha, 0, ha_div2, 0]  # ymax
            kernel_np[3, :] = [0, wa, 0, wa_div2]  # xmax

            kernel_np = np.transpose(kernel_np)
            # add the 1X1 special dims
            kernel_np = np.reshape(kernel_np, [1, 1, POINTS_PER_BOX, POINTS_PER_BOX])
            self._conv_ops[i].import_weights(kernel_np, layer_params)
            # load bias params
            # The centers are written twice in the NPZ
            y_centers.append(layer_params["y_centers"][:, i * 2])
            # The centers are written twice in the NPZ
            x_centers.append(layer_params["x_centers"][:, i * 2])
            x = x_centers[i]
            y = y_centers[i]
            x = np.reshape(x, [1, len(x)])  # TODO: verify the transpose axis
            y = np.reshape(y, [len(y), 1])
            x_grid = np.tile(x, y.shape)
            y_grid = np.tile(y, x.shape)
            bias_np = np.zeros(list(x_grid.shape) + [POINTS_PER_BOX])
            bias_np[:, :, 0] = y_grid  # ymin
            bias_np[:, :, 1] = x_grid  # xmin
            bias_np[:, :, 2] = y_grid  # ymax
            bias_np[:, :, 3] = x_grid  # xmax
            self._bias_ops[i].import_weights(bias_np)
        self._act_op.import_weights(layer_params)

    def _export_weights(self):
        weights_dict = dict()
        weights_dict["anchors_heights"] = self.anchors_heights
        weights_dict["anchors_widths"] = self.anchors_widths
        weights_dict["anchors_heights_div_2"] = self.anchors_heights_div_2
        weights_dict["anchors_widths_div_2"] = self.anchors_widths_div_2
        weights_dict["anchors_heights_minus_div_2"] = self.anchors_heights_minus_div_2
        weights_dict["anchors_widths_minus_div_2"] = self.anchors_widths_minus_div_2
        weights_dict["x_centers"] = self.x_centers
        weights_dict["y_centers"] = self.y_centers
        return weights_dict

    @property
    def output_op(self):
        return self._act_op

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        output_channels = hn_element["output_shapes"][0][-1]
        num_inputs = len(hn_element["input_shapes"])
        layer = cls(
            name=lname,
            num_of_anchors=output_channels // POINTS_PER_BOX,
            num_inputs=num_inputs,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups

        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode, force_signed_kernel=True)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        # TODO: slice and pre concat ops

        for conv_op, bias_op in zip(self._conv_ops, self._bias_ops):
            conv_op.create_weight_quant_element(kernel_bits, signed)
            bias_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self._act_op.create_weight_quant_element(optimization_target)

        # set quantization groups - we now dont support it but maybe in the future?
        self._act_op.set_quantization_groups(quant_groups)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        # Encoding conv scales - need to sync all convs to the same hw parameters
        self.enforce_internal_encoding()

        # Get the maximum value of the kernels
        # Not entirely sure why we factor the max by 1/2,
        # it was here before to create some buffer in the kernel range. removing it might be ok.
        max_kernels_value = np.max([np.max(np.abs(conv.kernel)) for conv in self._conv_ops])
        k_limvals = (-max_kernels_value, max_kernels_value)

        # Get the maximim value of all kernels accumulator
        act_stats = self.get_preact_stats()
        all_kernels_max_native_accumulator = np.array([act_stats[0].max.max()] * POINTS_PER_BOX)
        for i in range(self._num_of_anchors):
            kernel_lossy = self._conv_ops[i].weight_lossy_elements.kernel
            kernel_bits = kernel_lossy.bits
            zp, kernel_scale_candidate_scalar, _ = limvals_to_zp_scale(
                k_limvals, kernel_lossy, self.full_name, self._logger
            )
            assert zp == 0.0
            # in case of 16 bit kernel, we change the scale such that only 12 bit for nms requirement in the PPU
            if kernel_bits == 16:
                kernel_scale_candidate_scalar *= 16
            kernel_scale_matrix = kernel_scale_candidate_scalar * np.ones(shape=(POINTS_PER_BOX, POINTS_PER_BOX))

            # We will use the scale of the second input (rows 2,3) to calc the acc scale since the second input has larger scale
            exp_inp_scale = self._conv_ops[i].input_scales[0][2]
            acc_before = exp_inp_scale * kernel_scale_matrix[0]

            acc_scale_before_shift = (
                acc_before / 2 ** self._conv_ops[i].weight_placement_shift
            )  # TODO - not dure if we need this
            expected_max_accumulator = np.max(all_kernels_max_native_accumulator / acc_scale_before_shift)
            # get accumulator
            accumultor_size = self._conv_ops[i].output_lossy_element.bits

            if kernel_bits == 16:
                pre_acc_shift = 0
                shift_delta = 0
            else:
                pre_acc_shift, shift_delta = calculate_shifts(
                    expected_max_accumulator,
                    accumultor_size,
                    SHIFT_CALCULATE_BUFFER,
                    hw_shifts=hw_shifts,
                )

            if shift_delta > 0:
                # HW can't provide a shift large enough to avoid final accumulator overflow,
                #  we need smaller numeric values by making kernel range wider
                self._conv_ops[i]._logger.info(
                    f"No shifts available for layer {self._conv_ops[i].full_name}, using max shift instead. "
                    f"delta={shift_delta:.04f}"
                )
                acc_scale_before_shift *= 2**shift_delta

            # no more use but will be exported to qnpz debug info (TODO)
            self._conv_ops[i].shift_delta = shift_delta
            self._conv_ops[i].pre_acc_shift = pre_acc_shift
            self._bias_ops[i].pre_acc_shift = pre_acc_shift

            # Update the accumulator scale candidate after final shifts
            self._conv_ops[i].accumulator_scale_candidate = np.squeeze(
                acc_scale_before_shift * 2 ** self._conv_ops[i].pre_acc_shift,
            )
            self._conv_ops[i].kernel_scale_candidate = kernel_scale_candidate_scalar

        acc_scale_candidate = self._conv_ops[0].accumulator_scale_candidate[0]
        self._act_op.create_hw_params(acc_scale_candidate, optimization_target)

        self.enforce_internal_encoding()
        for i in range(self._num_of_anchors):
            self._bias_ops[i].create_hw_params()

    def enforce_internal_encoding(self, training=False, **kwargs):
        acc_scale = self._act_op.get_accumulator_scale()
        for inp in range(self._num_inputs):
            self._passthru_ops[inp].forward_encoding()
        for anc in range(self._num_of_anchors):
            for inp in range(self._num_inputs):
                self._slice_ops[anc * self._num_inputs + inp].input_scales = [self._passthru_ops[inp].output_scale]
                self._slice_ops[anc * self._num_inputs + inp].input_zero_points = [
                    self._passthru_ops[inp].output_zero_point,
                ]
                # Slices
                self._slice_ops[anc * self._num_inputs + inp].enforce_encoding()

                # Pre concat
                self._pre_concat_ops[anc].input_scales[inp] = self._slice_ops[anc * self._num_inputs + inp].output_scale
                self._pre_concat_ops[anc].input_zero_points[inp] = self._slice_ops[
                    anc * self._num_inputs + inp
                ].output_zero_point

            self._pre_concat_ops[anc].enforce_encoding()

            # Conv
            self._conv_ops[anc].input_scales = [self._pre_concat_ops[anc].output_scale]
            self._conv_ops[anc].input_zero_points = [self._pre_concat_ops[anc].output_zero_point]
            self._conv_ops[anc].output_scale = acc_scale[anc * POINTS_PER_BOX : (anc + 1) * POINTS_PER_BOX]
            self._conv_ops[anc].enforce_encoding(training=training)

            # Bias
            self._bias_ops[anc].input_scales = [acc_scale[anc * POINTS_PER_BOX : (anc + 1) * POINTS_PER_BOX]]
            self._bias_ops[anc].output_scale = acc_scale[anc * POINTS_PER_BOX : (anc + 1) * POINTS_PER_BOX]
            self._bias_ops[anc].input_zero_points = [self._conv_ops[anc].output_zero_point]
            self._bias_ops[anc].enforce_encoding()

        # Concat op
        self._concat_op.input_scales = [self._bias_ops[0].output_scale]
        for i in range(1, self._num_of_anchors):
            self._concat_op.input_scales.append(self._bias_ops[i].output_scale)

        self._concat_op.enforce_encoding()

        # Activation op
        self._act_op.input_zero_points = [self._concat_op.output_zero_point]
        self._act_op.enforce_encoding(training=training)

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def kernels_to_anchors(self, params: dict):
        anchors = dict()
        anchors_heights = []
        anchors_widths = []
        anchors_heights_div_2 = []
        anchors_widths_div_2 = []
        anchors_heights_minus_div_2 = []
        anchors_widths_minus_div_2 = []
        for index in range(self._num_of_anchors):
            kernel = params[f"kernel_anchor_{index}"].reshape(-1)
            anchors_heights.append(kernel[0])
            anchors_widths.append(kernel[5])
            anchors_heights_div_2.append(kernel[10])
            anchors_widths_div_2.append(kernel[15])
            anchors_heights_minus_div_2.append(kernel[8])
            anchors_widths_minus_div_2.append(kernel[13])

        anchors["anchors_heights"] = np.array(anchors_heights, dtype=np.int32)
        anchors["anchors_widths"] = np.array(anchors_widths, dtype=np.int32)
        anchors["anchors_heights_div_2"] = np.array(anchors_heights_div_2, dtype=np.int32)
        anchors["anchors_widths_div_2"] = np.array(anchors_widths_div_2, dtype=np.int32)
        anchors["anchors_heights_minus_div_2"] = np.array(anchors_heights_minus_div_2, dtype=np.int32)
        anchors["anchors_widths_minus_div_2"] = np.array(anchors_widths_minus_div_2, dtype=np.int32)
        return anchors

    def bias_to_centers(self, params: dict) -> dict:
        centers = dict()
        x_centers = []
        y_centers = []
        for index in range(self._num_of_anchors):
            bias_params = params[f"bias_anchor_{index}"]
            x_min_centers = bias_params[0, :, 1]
            x_max_centers = bias_params[0, :, 3]
            y_min_centers = bias_params[:, 0, 0]
            y_max_centers = bias_params[:, 0, 2]

            x_centers.extend([x_min_centers, x_max_centers])
            y_centers.extend([y_min_centers, y_max_centers])

        centers["x_centers"] = np.stack(x_centers, axis=1).astype(np.int32)
        centers["y_centers"] = np.stack(y_centers, axis=1).astype(np.int32)

        return centers

    def enforce_io_encoding(self, training=False, **kwargs):
        pass

    def _get_kernel_bits(self):
        return self._conv_ops[0].weight_lossy_elements.kernel.bits

    def _get_bias_mode(self):
        num_decomposition = self._bias_ops[0].weight_lossy_elements.bias_decompose.num_decomposition
        if num_decomposition == 0:
            return BiasMode.double_scale_initialization
        elif num_decomposition == 1:
            return BiasMode.single_scale_decomposition
        elif num_decomposition == 2:
            return BiasMode.double_scale_decomposition
        return None

    def is_differentiable(self) -> bool:
        """
        BBox layers is added as a postprocess block, and we don't want to train it.
        """
        return False

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            supported_precision_mode = {
                PrecisionMode.a8_w8,
                PrecisionMode.a16_w16,
                PrecisionMode.a8_w8_a8,
                PrecisionMode.a8_w8_a16,
                PrecisionMode.a16_w16_a16,
            }
        else:
            supported_precision_mode = super()._get_precision_mode_supported_in_hw(arch)
        return supported_precision_mode
