"""
This module provides translators for converting HN (HailoNet) layer definitions into Saitama's
`BaseLayer` objects. Each translator is responsible for:

1. Constructing a new Saitama `BaseLayer` (or subclass) from the provided HN definition.
2. Building and importing weights and encoding information from the given Acceleras parameters.
3. Finalizing the translated layer by enabling quantization and placing it on the correct device.

The module contains specific translators for different layer types (e.g., convolution, normalization,
depth-to-space). Each translator is registered in the HN translation registry via the decorator
`@HN_TRANSLATION_REGISTRY(LayerType.<type>)`, which ties the layer type to its corresponding
translator.

How to Add a New Translator
---------------------------
1. Create a new translator class that inherits from ``HnTRanslator``.
2. Implement the ``_build_quant_layer_from_hn(hn_elemnt)`` method to build a Saitama `BaseLayer` (or subclass)
   from the HN dictionary (e.g., layer parameters, shapes, etc.).
3. Implement the ``_build_native_layer_from_hn(hn_elemnt)`` method to build a Saitama `BaseLayer` (or subclass)
   from the HN dictionary (e.g., layer parameters, shapes, etc.).
4. Fill in the `state_dict_translation` dictionary with the appropriate mappings from HN keys to
5. If necessary, implement the ``_build_weights_and_encodings(acceleras_params)`` method to parse
   weights and encoding info. Return them as a tuple of two `TensorDict`s.
6. Add the translator to the registry by decorating the class with:
   ``@HN_TRANSLATION_REGISTRY(LayerType.<YOUR_LAYER_TYPE>)``.
"""

from typing import ClassVar, Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    ConcatAxis,
    EWMultType,
    FeatureMultiplierType,
    LayerType,
    PaddingType,
    PrecisionMode,
    SplittedPrecisionMode,
)
from hailo_model_optimization.saitama.framework.apu_modules.activation_factory import activation_factory
from hailo_model_optimization.saitama.framework.apu_modules.apu_activation import APUActivation
from hailo_model_optimization.saitama.framework.apu_modules.apu_maxpool2d import APUMaxpool2d
from hailo_model_optimization.saitama.framework.apu_modules.apu_reduce_max import APUReduceMax
from hailo_model_optimization.saitama.framework.common.saitama_definitions import DimsInfo, QType
from hailo_model_optimization.saitama.framework.common.utils import parse_explicit_padding
from hailo_model_optimization.saitama.framework.forwarder_modules.forwarder_concat import ForwarderConcat
from hailo_model_optimization.saitama.framework.forwarder_modules.forwarder_slice import ForwarderSlice
from hailo_model_optimization.saitama.framework.forwarder_modules.forwarder_splitter import ForwarderSplitter
from hailo_model_optimization.saitama.framework.fused_modules.fused_base import SubClusterModule
from hailo_model_optimization.saitama.framework.mac_modules.mac_avgpool2d import MACAvgpool2d
from hailo_model_optimization.saitama.framework.mac_modules.mac_conv2d import MACConv2d
from hailo_model_optimization.saitama.framework.mac_modules.mac_dense import MACDense
from hailo_model_optimization.saitama.framework.mac_modules.mac_ew_add import MACEWAdd, MACEWSub
from hailo_model_optimization.saitama.framework.mac_modules.mac_ew_mult import MACEWMult, MACFeatureMultiplier
from hailo_model_optimization.saitama.framework.mac_modules.mac_ew_mult_on_apu import MACEWMultOnAPU
from hailo_model_optimization.saitama.framework.mac_modules.mac_matmul import MACMatmul
from hailo_model_optimization.saitama.framework.mac_modules.mac_norm import MACNorm
from hailo_model_optimization.saitama.framework.mac_modules.mac_reduce import MACReduceSum
from hailo_model_optimization.saitama.framework.other_modules.const_input import SaitamaConstInput
from hailo_model_optimization.saitama.framework.other_modules.io_modules import (
    IOPrecisionConfig,
    SaitamaInput,
    SaitamaOutput,
)
from hailo_model_optimization.saitama.translators.hailo_translator.base_hailo_translator import HNTranslator, LayerMode
from hailo_model_optimization.saitama.translators.hailo_translator.mappings.common_unit_mapping import (
    CommonMappings,
    handle_const_input_equalization,
    handle_const_input_scale,
    reorder_kernel,
)
from hailo_model_optimization.saitama.translators.hailo_translator.mappings.custom_mapping import (
    get_ew_mult_on_apu_mapping,
    get_ew_mult_on_mac_mapping,
)
from hailo_model_optimization.saitama.translators.hailo_translator.native_modules import (
    ConstInput,
    EWAdd,
    EWMult,
    EWSub,
    FeatureMultiplier,
    FusedConvAndAdd,
    GroupedSoftmax,
    MatMul,
    ReduceMax,
    ReduceMean,
    ReduceSum,
    Reshape,
    ResizeNN,
)
from hailo_model_optimization.saitama.translators.translator_registry import Register
from hailo_model_optimization.saitama.translators.translator_utils import KeyHandler, Tags

HN_TRANSLATION_REGISTRY = Register(HNTranslator)


def convert_axes_to_nchw(axes):
    axis_mapping = {3: 1, 2: 3, 1: 2}
    return [axis_mapping[axis] for axis in axes]


@HN_TRANSLATION_REGISTRY(LayerType.ACTIVATION)
class HNActivationTranslator(HNTranslator):
    state_dict_translation: ClassVar[Dict[LayerMode, Dict[str, KeyHandler]]] = {
        LayerMode.QUANT: [
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            *CommonMappings.get_kernel_encoding_mapping("mac.kernel", op_name="mock_op"),  # adds kernel encoding
            *CommonMappings.get_bias_mapping("mac.bias"),  # adds bias keys
            KeyHandler(
                "mac.kernel.weight", ("kernel:0", "mock_op/output_scale:0:0"), lambda x, y, **kw: x + np.zeros_like(y)
            ),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("apu"),
        ],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags, *, dtype=None, device=None):
        channels = hn_layer["output_shapes"][0][-1]
        activation = ActivationType(hn_layer["params"]["activation"])
        out_channels = hn_layer["output_shapes"][0][-1]

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        norm = MACNorm(channels, precision_config=mac_cfg, device=device, dtype=dtype)
        act = APUActivation(out_channels, activation, apu_cfg, dtype=dtype, device=device)
        return SubClusterModule(norm, act, is_activation_only=True), tags

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        activation = ActivationType(hn_elemnt["params"]["activation"])
        tags.add(activation)
        return SubClusterModule(nn.Identity(), activation_factory(activation), is_activation_only=True), tags


@HN_TRANSLATION_REGISTRY(LayerType.DW)
class HNDWTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_bias_mapping("mac.bias"),
            *CommonMappings.get_activation_mapping("apu"),  # adds apu keys
            *CommonMappings.get_kernel_encoding_mapping("mac.kernel", op_name="conv_op"),
            KeyHandler("mac.kernel.weight", "kernel:0", lambda x, **kw: x.transpose(2, 3, 0, 1)),
            KeyHandler("mac.padding_value.zero_point", "conv_op/input_zero_point:0:0"),
            KeyHandler("mac.padding_value.scale", "conv_op/input_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [
            KeyHandler("mac.bias", "bias:0"),
            KeyHandler("mac.weight", "kernel:0", lambda x, **kw: x.transpose(2, 3, 0, 1)),
            *CommonMappings.get_native_activation_mapping("apu"),
        ],
    }

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        if params["elementwise_add"]:
            raise NotImplementedError("Conv&Add is not supported")
        in_channels = hn_elemnt["input_shapes"][0][-1]
        out_channels = hn_elemnt["output_shapes"][0][-1]
        kernel_size = tuple(params["kernel_shape"][:2])
        spatial_shape = DimsInfo(*hn_elemnt["input_shapes"][0][1:3])
        stride = tuple(params["strides"][1:3])
        groups = params["groups"]
        if groups != 1:
            raise ValueError("Depthwise doesn't support group param")
        if out_channels % in_channels != 0:
            raise ValueError("Depthwise output channels must be a multiple of input channels")
        dilation = tuple(params["dilations"][1:3])
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        padding = PaddingType(params["padding"]).value.lower()
        tags.add(activation)

        mac = torch.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            "valid",
            dilation,
            in_channels,
        )
        new_padding, padding_mode, explicit_padding = parse_explicit_padding(
            padding,
            kernel_size,
            stride,
            dilation,
            spatial_shape.as_hw_tuple(),
        )
        # Adhoc solution to support out padding in nn.Conv2d
        mac._reversed_padding_repeated_twice = explicit_padding
        mac.padding = new_padding
        mac.padding_mode = padding_mode

        apu = activation_factory(activation)
        return SubClusterModule(mac, apu), tags

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        if params["elementwise_add"]:
            raise NotImplementedError("Conv&Add is not supported")
        in_channels = hn_layer["input_shapes"][0][-1]
        out_channels = hn_layer["output_shapes"][0][-1]
        kernel_size = tuple(params["kernel_shape"][:2])
        spatial_shape = DimsInfo(*hn_layer["input_shapes"][0][1:3])
        stride = tuple(params["strides"][1:3])
        groups = params["groups"]
        dilation = tuple(params["dilations"][1:3])
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        padding = PaddingType(params["padding"])
        if padding in [PaddingType.SAME, PaddingType.VALID, PaddingType.SAME_TENSORFLOW]:
            padding = padding.value.lower()
        else:
            raise NotImplementedError(f"Padding type {padding} is not supported")
        if groups != 1:
            raise ValueError("Depthwise doesn't support group param")
        if out_channels % in_channels != 0:
            raise ValueError("Depthwise output channels must be a multiple of input channels")

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        mac = MACConv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=in_channels,
            spatial_shape=spatial_shape.as_hw_tuple(),
            precision_config=mac_cfg,
            device=device,
            dtype=dtype,
        )

        apu = APUActivation(
            out_channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )

        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.CONV)
class HNConvTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_bias_mapping("mac.bias"),
            *CommonMappings.get_activation_mapping("apu"),  # adds apu keys
            *CommonMappings.get_kernel_encoding_mapping("mac.kernel", op_name="conv_op"),
            KeyHandler("mac.kernel.weight", "kernel:0", reorder_kernel),
            KeyHandler("mac.padding_value.zero_point", "conv_op/input_zero_point:0:0"),
            KeyHandler("mac.padding_value.scale", "conv_op/input_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [
            KeyHandler("mac.bias", "bias:0"),
            KeyHandler("mac.weight", "kernel:0", reorder_kernel),
            *CommonMappings.get_native_activation_mapping("apu"),
        ],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        if params["elementwise_add"]:
            raise NotImplementedError("Conv&Add is not supported")
        in_channels = hn_layer["input_shapes"][0][-1]
        out_channels = hn_layer["output_shapes"][0][-1]
        kernel_size = tuple(params["kernel_shape"][:2])
        spatial_shape = tuple(hn_layer["input_shapes"][0][1:3])
        stride = tuple(params["strides"][1:3])
        groups = params["groups"]
        dilation = tuple(params["dilations"][1:3])
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        padding = PaddingType(params["padding"])
        if padding in [PaddingType.SAME, PaddingType.VALID, PaddingType.SAME_TENSORFLOW]:
            padding = padding.value.lower()
        else:
            raise NotImplementedError(f"Padding type {padding} is not supported")

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        mac = MACConv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            spatial_shape=spatial_shape,
            precision_config=mac_cfg,
            device=device,
            dtype=dtype,
        )

        apu = APUActivation(
            out_channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        ew_add = params.get("elementwise_add", False)
        in_channels = hn_elemnt["input_shapes"][0][-1]
        out_channels = hn_elemnt["output_shapes"][0][-1]
        kernel_size = tuple(params["kernel_shape"][:2])
        spatial_shape = tuple(hn_elemnt["input_shapes"][0][1:3])
        stride = tuple(params["strides"][1:3])
        groups = params["groups"]
        dilation = tuple(params["dilations"][1:3])
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        padding = PaddingType(params["padding"]).value.lower()
        tags.add(activation)

        params_dict = {
            "in_channels": in_channels,
            "out_channels": out_channels,
            "kernel_size": kernel_size,
            "stride": stride,
            "padding": "valid",
            "dilation": dilation,
            "groups": groups,
        }
        mac = torch.nn.Conv2d(**params_dict) if not ew_add else FusedConvAndAdd(**params_dict)
        new_padding, padding_mode, explicit_padding = parse_explicit_padding(
            padding,
            kernel_size,
            stride,
            dilation,
            spatial_shape,
        )
        # Adhoc solution to support out padding in nn.Conv2d
        mac._reversed_padding_repeated_twice = explicit_padding
        mac.padding = new_padding
        mac.padding_mode = padding_mode

        apu = activation_factory(activation)
        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.NORMALIZATION)
class HNNormalizationTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_activation_mapping("apu", output_op="output_op"),  # adds apu keys
            *CommonMappings.get_kernel_encoding_mapping("mac.kernel", op_name="conv_op"),  # adds kernel encoding
            *CommonMappings.get_bias_mapping("mac.bias"),  # adds bias keys
            KeyHandler("mac.kernel.weight", "kernel:0", lambda x, **kw: x.reshape(-1)),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [
            KeyHandler("0.bias", "bias"),
            KeyHandler("0.weight", "kernel:0", lambda x, **kw: x.transpose(2, 3, 0, 1)),
        ],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        if params["elementwise_add"]:
            raise NotImplementedError("Norm&Add is not supported")
        activation = ActivationType(params["activation"])
        num_features = hn_layer["input_shapes"][0][-1]

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        mac = MACNorm(
            num_features,
            precision_config=mac_cfg,
            dtype=dtype,
            device=device,
        )
        apu = APUActivation(
            num_features,
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        if params["elementwise_add"]:
            raise NotImplementedError("Norm&Add is not supported")
        activation = ActivationType(params["activation"])
        num_features = hn_layer["input_shapes"][0][-1]
        layer = nn.Conv2d(
            num_features,
            num_features,
            1,
            groups=num_features,
            bias=True,
            dtype=dtype,
            device=device,
        )
        activation_module = activation_factory(activation)

        return nn.Sequential(layer, activation_module), tags


@HN_TRANSLATION_REGISTRY(LayerType.DEPTH_TO_SPACE)
class HNDepthToSpaceTranslator(HNTranslator):
    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        block_size = hn_layer["params"]["block_sizes"][0]
        layer = nn.PixelShuffle(block_size)
        return layer, tags


@HN_TRANSLATION_REGISTRY(LayerType.INPUT_LAYER)
class HNInputLayer(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            KeyHandler("output_quantizer.scale", "output_op/output_scale:0:0"),
            KeyHandler("output_quantizer.zero_point", "output_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        quantization_params = hn_layer["quantization_params"]
        precision_mode = PrecisionMode(quantization_params["precision_mode"])
        b_pm = SplittedPrecisionMode.from_precision_mode(precision_mode)
        output_qtype = QType(min(b_pm.output, 15), quantization_params["signed_output"])
        precision_config = IOPrecisionConfig(output_qtype)

        layer = SaitamaInput(
            out_channels=hn_layer["output_shapes"][0][-1],
            precision_config=precision_config,
            dtype=dtype,
            device=device,
        )
        return layer, tags

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        layer = nn.Identity()
        return layer, tags


@HN_TRANSLATION_REGISTRY(LayerType.OUTPUT_LAYER)
class HNOutputLayer(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            KeyHandler("output_quantizer.scale", "input_op/output_scale:0:0"),
            KeyHandler("output_quantizer.zero_point", "input_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        precision_config = None

        if quantization_params := hn_layer.get("quantization_params", None):
            precision_mode = PrecisionMode(quantization_params["precision_mode"])

            b_pm = SplittedPrecisionMode.from_precision_mode(precision_mode)

            output_qtype = QType(min(b_pm.output, 15), quantization_params["signed_output"])
            precision_config = IOPrecisionConfig(output_qtype)
        layer = SaitamaOutput(
            out_channels=hn_layer["output_shapes"][0][-1],
            precision_config=precision_config,
            dtype=dtype,
            device=device,
        )
        return layer, tags

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        layer = nn.Identity()
        return layer, tags


@HN_TRANSLATION_REGISTRY(LayerType.ELEMENTWISE_ADD)
class HNElementwiseAddLayer(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_bias_mapping("mac.bias"),
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            *CommonMappings.get_kernel_encoding_mapping("mac.kernel.0", op_name="elementwise_add_op", in_idx=0),
            *CommonMappings.get_kernel_encoding_mapping("mac.kernel.1", op_name="elementwise_add_op", in_idx=1),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0", lambda x, **kw: x.squeeze()),
            KeyHandler(
                "mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0", lambda x, **kw: x.squeeze()
            ),
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("apu"),
        ],  # adds apu keys,
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        input_repeats = params.get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        input_repeats = tuple(DimsInfo(*repeats) for repeats in input_repeats)
        channels = hn_layer["output_shapes"][0][-1]
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        mac = MACEWAdd(
            channels,
            input_repeats=input_repeats,
            precision_config=mac_cfg,
            dtype=dtype,
            device=device,
        )
        apu = APUActivation(
            channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        activation = ActivationType(params["activation"])
        input_repeats = params.get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        input_repeats = tuple(DimsInfo(*repeats) for repeats in input_repeats)
        tags.add(activation)
        mac = EWAdd(input_repeats)
        apu = activation_factory(activation)
        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.MAXPOOL)
class HNMaxpoolLayer(HNTranslator):
    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        activation = ActivationType(params["activation"])
        kernel_size = params["kernel_shape"][1:3]
        stride = params["strides"][1:3]
        padding = PaddingType(params["padding"]).value.lower()
        spatial_shape = tuple(hn_layer["input_shapes"][0][1:3])
        activation = ActivationType(params.get("activation", "linear"))
        if activation != ActivationType.LINEAR:
            raise ValueError("Only linear activation is supported for maxpool layers")

        layer = APUMaxpool2d(
            kernel_size,
            stride=stride,
            padding=padding,
            spatial_shape=spatial_shape,
            dtype=dtype,
            device=device,
        )
        return layer, tags

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        activation = ActivationType(params["activation"])
        kernel_size = params["kernel_shape"][1:3]
        stride = params["strides"][1:3]
        padding = PaddingType(params["padding"]).value.lower()
        spatial_shape = tuple(hn_elemnt["input_shapes"][0][1:3])
        activation = ActivationType(params.get("activation", "linear"))
        if activation != ActivationType.LINEAR:
            raise ValueError("Only linear activation is supported for maxpool layers")

        spatial_shape = tuple(hn_elemnt["input_shapes"][0][1:3])
        _, _, explicit_padding = parse_explicit_padding(padding, kernel_size, stride, (1, 1), spatial_shape)
        pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = explicit_padding
        if pad_beg_h == pad_end_h and pad_beg_w == pad_end_w:
            padding = (pad_beg_h, pad_beg_w)
        else:
            raise NotImplementedError("Padding mode is not supported")

        maxpool = nn.MaxPool2d(kernel_size, stride, padding)
        return maxpool, tags


@HN_TRANSLATION_REGISTRY(LayerType.AVGPOOL)
class HNAvgpoolLayer(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_bias_mapping("mac.bias"),
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            KeyHandler("mac.kernel.scale", "avgpool_op/kernel_scale:0"),
            KeyHandler("mac.kernel.zero_point", "avgpool_op/kernel_zero_point:0"),
            KeyHandler("mac.kernel.mac_shift", "avgpool_op/mac_shift:0"),
            KeyHandler("mac.padding_value.weight", "padding_const_value:0"),
            KeyHandler("mac.padding_value.zero_point", "avgpool_op/input_zero_point:0:0"),
            KeyHandler("mac.padding_value.scale", "avgpool_op/input_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("2"),
            KeyHandler("0.value", "padding_const_value:0"),
        ],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        channels = hn_layer["output_shapes"][0][-1]
        kernel_size = tuple(params["kernel_shape"][1:3])
        stride = tuple(params["strides"][1:3])
        padding = PaddingType(params["padding"]).value.lower()
        count_include_pad = params["count_include_pad"]
        spatial_shape = tuple(hn_layer["input_shapes"][0][1:3])
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        mac = MACAvgpool2d(
            channels,
            kernel_size,
            stride,
            padding,
            spatial_shape=spatial_shape,
            count_include_pad=count_include_pad,
            precision_config=mac_cfg,
            dtype=dtype,
            device=device,
        )
        apu = APUActivation(
            channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        kernel_size = tuple(params["kernel_shape"][1:3])
        stride = tuple(params["strides"][1:3])
        padding = PaddingType(params["padding"]).value.lower()
        count_include_pad = params["count_include_pad"]
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        tags.add(activation)

        spatial_shape = tuple(hn_elemnt["input_shapes"][0][1:3])
        _, _, explicit_padding = parse_explicit_padding(padding, kernel_size, stride, (1, 1), spatial_shape)
        pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = explicit_padding
        pad = nn.ConstantPad2d((pad_beg_w, pad_end_w, pad_beg_h, pad_end_h), 0)

        def load_pad_value(state_dict, *args, **kwargs):
            pad.value = float(state_dict.pop("0.value"))

        pad._register_load_state_dict_pre_hook(load_pad_value)
        pool = nn.AvgPool2d(kernel_size, stride, count_include_pad=count_include_pad)
        act = activation_factory(activation)
        return nn.Sequential(pad, pool, act), tags


@HN_TRANSLATION_REGISTRY(LayerType.DENSE)
class HNDenseTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_bias_mapping("mac.bias"),
            *CommonMappings.get_kernel_encoding_mapping("mac.kernel", op_name="conv_op"),
            *CommonMappings.get_activation_mapping("apu"),
            KeyHandler("mac.kernel.weight", "kernel:0", lambda x, **kw: x.transpose(1, 0)),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
        ],
        LayerMode.NATIVE: [
            KeyHandler("mac.bias", "bias:0"),
            KeyHandler(
                "mac.weight", "kernel:0", lambda x, **kw: np.expand_dims(np.expand_dims(x.transpose(1, 0), 2), 3)
            ),
            *CommonMappings.get_native_activation_mapping("apu"),
        ],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        in_channels = hn_layer["input_shapes"][0][-1]
        out_channels = hn_layer["output_shapes"][0][-1]
        activation = ActivationType(params["activation"])

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        mac = MACDense(
            in_channels,
            out_channels,
            bias=True,
            precision_config=mac_cfg,
            device=device,
            dtype=dtype,
        )

        apu = APUActivation(
            out_channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        in_channels = hn_elemnt["input_shapes"][0][-1]
        out_channels = hn_elemnt["output_shapes"][0][-1]
        activation = ActivationType(params["activation"])

        dense = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
        activation_module = activation_factory(activation)
        return SubClusterModule(dense, activation_module), tags


@HN_TRANSLATION_REGISTRY(LayerType.CONCAT)
class HNConcatTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [],
        LayerMode.NATIVE: [],
    }

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        return self._build_forwarder_layer(hn_elemnt, tags, dtype=dtype, device=device)

    def _build_quant_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        return self._build_forwarder_layer(hn_elemnt, tags, dtype=dtype, device=device)

    def _build_forwarder_layer(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        axis = ConcatAxis(params.get("concat_axis", "features"))
        if axis == ConcatAxis.features:
            axis = 1
        elif axis == ConcatAxis.spatial_h:
            axis = 2
        elif axis == ConcatAxis.spatial_w:
            axis = 3
        else:
            raise ValueError(f"Unsupported axis: {axis}")
        group_sizes = params.get("group_sizes", (1,))

        return ForwarderConcat(axis, group_sizes), tags


@HN_TRANSLATION_REGISTRY(LayerType.SOFTMAX)
class HNSoftmaxTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [],
    }

    def _build_native_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        if params.get("logits_axis", 3) != 3:
            raise NotImplementedError("Softmax axis is not supported")
        input_shape = list(hn_elemnt["input_shapes"][0])
        input_shape = [input_shape[0], input_shape[3], input_shape[1], input_shape[2]]
        groups = params.get("groups", 1)

        grouped_softmax = GroupedSoftmax(1, groups)

        return grouped_softmax, tags


@HN_TRANSLATION_REGISTRY(LayerType.ELEMENTWISE_MULT)
class HNEWMultTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            *get_ew_mult_on_apu_mapping(tags=EWMultType.on_apu),
            *get_ew_mult_on_mac_mapping(tags=EWMultType.on_mac),
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("apu"),
        ],
    }

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        in_channels = hn_layer["input_shapes"][0][-1]
        out_channels = hn_layer["output_shapes"][0][-1]
        params = hn_layer["params"]
        input_repeats = params.get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        input_repeats = tuple(DimsInfo(*repeats) for repeats in input_repeats)
        reduce_sum_groups = params.get("reduce_sum_groups", 1)
        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        ew_mult_type = EWMultType(params.get("ew_mult_type", EWMultType.on_apu))
        tags.add(ew_mult_type)
        if ew_mult_type == EWMultType.on_apu:
            mac = MACEWMultOnAPU(out_channels, input_repeats, mac_cfg, device, dtype)
        elif ew_mult_type == EWMultType.on_mac:
            mac = MACEWMult(in_channels, out_channels, reduce_sum_groups, input_repeats, mac_cfg, device, dtype)
        else:
            raise ValueError(f"Unsupported EWMultType: {ew_mult_type}")

        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        tags.add(activation)
        apu = APUActivation(
            out_channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )

        return SubClusterModule(mac, apu), tags

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        assert activation == ActivationType.LINEAR, "Only linear activation is supported for elementwise mult layers"
        input_repeats = params.get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        input_repeats = tuple(DimsInfo(*repeats) for repeats in input_repeats)
        return EWMult(input_repeats), tags


@HN_TRANSLATION_REGISTRY(LayerType.ELEMENTWISE_SUB)
class HNEWSubTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_bias_mapping("mac.bias"),
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            *CommonMappings.get_kernel_encoding_mapping(
                "mac.kernel.0", op_name="elementwise_sub_op", in_idx=0, is_channelwise=True
            ),
            *CommonMappings.get_kernel_encoding_mapping(
                "mac.kernel.1", op_name="elementwise_sub_op", in_idx=1, is_channelwise=True
            ),
            KeyHandler("mac.accumulator_quantizer.scale", "bias_add_op/output_scale:0:0", lambda x: x.squeeze()),
            KeyHandler(
                "mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0", lambda x: x.squeeze()
            ),
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("apu"),
        ],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        input_repeats = params.get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        input_repeats = tuple(DimsInfo(*repeats) for repeats in input_repeats)
        tags.add(activation)

        mac = EWSub(input_repeats)
        apu = activation_factory(activation)
        return SubClusterModule(mac, apu), tags

    def _build_quant_layer_from_hn(self, hn_layer, tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        input_repeats = params.get("input_repeats", [[1, 1, 1], [1, 1, 1]])
        input_repeats = tuple(DimsInfo(*repeats) for repeats in input_repeats)
        channels = hn_layer["output_shapes"][0][-1]
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)

        mac = MACEWSub(
            channels,
            input_repeats=input_repeats,
            precision_config=mac_cfg,
            dtype=dtype,
            device=device,
        )
        apu = APUActivation(
            channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.REDUCE_MAX)
class HNReduceMaxTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [],
        LayerMode.NATIVE: [],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        axes = params.get("reduce_axes", [1])
        axes = convert_axes_to_nchw(axes)
        groups = [params.get("groups"), 1, 1]

        apu = ReduceMax(axes, groups)
        return SubClusterModule(nn.Identity(), apu), tags

    def _build_quant_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        axes = params.get("reduce_axes", [1])
        axes = convert_axes_to_nchw(axes)
        groups = [params.get("groups"), 1, 1]

        apu = APUReduceMax(axes, groups)
        return SubClusterModule(nn.Identity(), apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.REDUCE_SUM)
class HNReduceSumTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            KeyHandler("mac.kernel.weight", "kernel:0"),
            KeyHandler("mac.kernel.scale", "reduce_sum_op/kernel_scale:0"),
            KeyHandler("mac.kernel.zero_point", "reduce_sum_op/kernel_zero_point:0"),
            KeyHandler("mac.kernel.mac_shift", "reduce_sum_op/mac_shift:0"),
            KeyHandler(
                "mac.accumulator_quantizer.scale",
                ["bias_add_op/output_scale:0:0", "bias:0"],
                lambda x, y, **kw: x + np.zeros_like(y),
            ),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
            *CommonMappings.get_bias_mapping("mac.bias"),  # adds bias keys
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("apu"),
            KeyHandler("mac.kernel", "kernel:0", default_factory=lambda: 1.0),
        ],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        axes = params.get("reduce_axes", [1])
        axes = convert_axes_to_nchw(axes)
        groups = [params.get("groups", 1), params.get("height_groups", 1), 1]
        mac = ReduceSum(axes, groups)

        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        tags.add(activation)

        apu = activation_factory(activation)
        return SubClusterModule(mac, apu), tags

    def _build_quant_layer_from_hn(self, hn_layer, tags, *, dtype=None, device=None):
        output_shape = hn_layer["output_shapes"][0]
        params = hn_layer["params"]
        groups = [params.get("groups", 1), params.get("height_groups", 1), 1]
        axes = params.get("reduce_axes", [1])
        axes = convert_axes_to_nchw(axes)
        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])

        mac = MACReduceSum(output_shape[-1], groups, axes, precision_config=mac_cfg, device=device, dtype=dtype)

        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        tags.add(activation)

        apu = APUActivation(
            output_shape[-1],
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.REDUCE_MEAN)
class HNReduceMeanTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            KeyHandler("mac.kernel.weight", "kernel:0"),
            KeyHandler("mac.kernel.scale", "reduce_sum_op/kernel_scale:0"),
            KeyHandler("mac.kernel.zero_point", "reduce_sum_op/kernel_zero_point:0"),
            KeyHandler("mac.kernel.mac_shift", "reduce_sum_op/mac_shift:0"),
            KeyHandler(
                "mac.accumulator_quantizer.scale",
                ["bias_add_op/output_scale:0:0", "bias:0"],
                lambda x, y: x + np.zeros_like(y),
            ),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
            *CommonMappings.get_bias_mapping("mac.bias"),  # adds bias keys
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("apu"),
            KeyHandler("mac.kernel", "kernel:0", default_factory=lambda: 1.0),
        ],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        axes = params.get("reduce_axes", [1])
        axes = convert_axes_to_nchw(axes)
        groups = [params.get("groups", 1), params.get("height_groups", 1), 1]
        mac = ReduceMean(axes, groups)

        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        tags.add(activation)

        apu = activation_factory(activation)
        return SubClusterModule(mac, apu), tags

    def _build_quant_layer_from_hn(self, hn_layer, tags, *, dtype=None, device=None):
        output_shape = hn_layer["output_shapes"][0]
        params = hn_layer["params"]
        groups = [params.get("groups", 1), params.get("height_groups", 1), 1]

        axes = params.get("reduce_axes", [1])
        axes = convert_axes_to_nchw(axes)
        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])
        mac = MACReduceSum(
            output_shape[-1],
            groups,
            axes,
            precision_config=mac_cfg,
            device=device,
            dtype=dtype,
        )

        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        tags.add(activation)
        apu = APUActivation(
            output_shape[-1],
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.FEATURE_MULTIPLIER)
class HNFeatureMultiplierTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            KeyHandler(
                "mac.accumulator_quantizer.scale",
                ["bias_add_op/output_scale:0:0", "bias:0"],
                lambda x, y: x + np.zeros_like(y),
            ),
            KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0"),
            *CommonMappings.get_bias_mapping("mac.bias"),
            KeyHandler(
                "mac.zero_point",
                [
                    "elementwise_mult_op/input_zero_point:0:0",
                    "elementwise_mult_op/input_zero_point:1:0",
                    "elementwise_mult_op/output_scale:0:0",
                ],
                lambda x, y, z: np.stack([np.ones_like(z) * x, np.ones_like(z) + y], axis=0),
            ),
            KeyHandler("mac.mac_shift", "elementwise_mult_op/mult_shift:0"),
        ],
        LayerMode.NATIVE: [
            *CommonMappings.get_native_activation_mapping("apu"),
        ],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        feature_multiplier_type = FeatureMultiplierType(params.get("feature_multiplier_type"))
        if feature_multiplier_type != FeatureMultiplierType.square:
            raise ValueError(f"Unsupported feature multiplier type: {feature_multiplier_type}")
        reduce_sum_groups = params.get("reduce_sum_groups", 1)
        tags.add(activation)
        mac = FeatureMultiplier(feature_multiplier_type, reduce_sum_groups)
        apu = activation_factory(activation)
        return SubClusterModule(mac, apu), tags

    def _build_quant_layer_from_hn(self, hn_layer, tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        feature_multiplier_type = FeatureMultiplierType(params.get("feature_multiplier_type"))
        if feature_multiplier_type != FeatureMultiplierType.square:
            raise ValueError(f"Unsupported feature multiplier type: {feature_multiplier_type}")
        activation = ActivationType(params.get("activation", ActivationType.LINEAR))
        reduce_sum_groups = params.get("reduce_sum_groups")
        in_channels = hn_layer["input_shapes"][0][-1]
        out_channels = hn_layer["output_shapes"][0][-1]
        tags.add(activation)

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_layer["quantization_params"])

        mac = MACFeatureMultiplier(
            feature_multiplier_type,
            reduce_sum_groups,
            in_channels,
            out_channels,
            None,
            True,
            precision_config=mac_cfg,
            device=device,
            dtype=dtype,
        )

        apu = APUActivation(
            out_channels,
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.SHORTCUT)
class HNShortcutTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        return nn.Identity(), tags


@HN_TRANSLATION_REGISTRY(LayerType.SLICE)
class HNSliceTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [],
        LayerMode.QUANT: [],
    }

    def _build_forwarder_layer(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        slice_h = params["height_slice"]
        slice_w = params["width_slice"]
        slice_f = params["features_slice"]
        groups = params["groups"]
        assert groups == 1
        assert slice_h[-1] == 1, "stride not supported"
        assert slice_w[-1] == 1, "stride not supported"
        assert slice_f[-1] == 1, "stride not supported"
        input_shape = hn_layer["input_shapes"][0]
        axes = []
        starts = []
        ends = []
        if slice_h[0] != 0 or slice_h[1] != input_shape[1]:
            axes.append(2)
            starts.append(slice_h[0])
            ends.append(slice_h[1])
        if slice_w[0] != 0 or slice_w[1] != input_shape[2]:
            axes.append(3)
            starts.append(slice_w[0])
            ends.append(slice_w[1])
        if slice_f[0] != 0 or slice_f[1] != input_shape[3]:
            axes.append(1)
            starts.append(slice_f[0])
            ends.append(slice_f[1])
        slice_ = ForwarderSlice(axes, starts, ends)
        return slice_, tags

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        return self._build_forwarder_layer(hn_layer, tags, dtype=dtype, device=device)

    def _build_quant_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        return self._build_forwarder_layer(hn_elemnt, tags, dtype=dtype, device=device)


@HN_TRANSLATION_REGISTRY(LayerType.FORMAT_CONVERSION)
class HNFormatConversionTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        assert params["conversion_type"] == "spatial_reshape"
        assert tuple(params["input_windows"]) == (1, 1, 1)
        assert tuple(params["output_windows"]) == (1, 1, 1)
        h, w, f = tuple(params["spatial_reshape_sizes"])

        return Reshape((-1, f, h, w)), tags


@HN_TRANSLATION_REGISTRY(LayerType.RESIZE)
class HNResizeTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.QUANT: [],
        LayerMode.NATIVE: [],
    }

    def _build_mac_from_hn(self, params, hw_layer_type_list, *, dtype=None, device=None):
        assert params["method"] == "nearest_neighbor"
        assert params["resize_bilinear_pixels_mode"] == "disabled"
        assert tuple(hw_layer_type_list) == ("lcu",)
        resize_h = params["resize_h_ratio_list"]
        resize_w = params["resize_w_ratio_list"]
        resize_f = params["resize_f_ratio_list"]
        assert all(int(h) == h for h in resize_h)
        assert all(int(w) == w for w in resize_w)
        assert all(int(f) == f for f in resize_f)

        return ResizeNN(resize_h, resize_w, resize_f)

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        hw_layer_type_list = hn_layer["compilation_params"]["hw_layer_type_list"]
        mac = self._build_mac_from_hn(params, hw_layer_type_list)

        return SubClusterModule(mac, nn.Identity()), tags

    def _build_quant_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        params = hn_elemnt["params"]
        hw_layer_type_list = hn_elemnt["compilation_params"]["hw_layer_type_list"]
        mac = self._build_mac_from_hn(params, hw_layer_type_list)

        return SubClusterModule(mac, nn.Identity()), tags


@HN_TRANSLATION_REGISTRY(LayerType.FEATURE_SPLITTER)
class HNFeatureSplitterTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [],
        LayerMode.QUANT: [],
    }

    def _build_forwarder_layer(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        groups = params["groups"]
        splits = [s[-1] for s in hn_layer["output_shapes"]]
        splitter = ForwarderSplitter(1, splits, groups)
        return splitter, tags

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        return self._build_forwarder_layer(hn_layer, tags, dtype=dtype, device=device)

    def _build_quant_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        return self._build_forwarder_layer(hn_layer, tags, dtype=dtype, device=device)


@HN_TRANSLATION_REGISTRY(LayerType.LAYER_NORM)
class HNLayerNormTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [],
    }

    def _build_native_layer_from_hn(self, hn_layer: dict, tags: Tags, *, dtype=None, device=None):
        params = hn_layer["params"]
        output_shape = hn_layer["output_shapes"][0]
        output_shape = [output_shape[0], output_shape[3], output_shape[1], output_shape[2]]
        if params["rms_norm"]:
            norm = nn.RMSNorm(output_shape[1:])
        elif params["groups"] != 1:
            norm = nn.GroupNorm(params["groups"], output_shape[1])
        else:
            norm = nn.LayerNorm(output_shape[1:])
        return norm, tags


@HN_TRANSLATION_REGISTRY(LayerType.MATMUL)
class HNMatmulTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [],
        LayerMode.QUANT: [
            *CommonMappings.get_activation_mapping("apu", output_op="passthru_op"),  # adds apu keys
            KeyHandler("mac.mac_shift", "matmul_op/mac_shift:0"),
            KeyHandler("mac.feed_repeat", "matmul_op/zero_point_feed_repeat:0"),
            KeyHandler(
                "mac.zp.weight",
                ("matmul_op/input_zero_point:0:0", "matmul_op/input_scale:0:0"),
                lambda x, y, **kw: x * y,
            ),
            KeyHandler("mac.zp.scale", "matmul_op/input_scale:0:0"),
            KeyHandler("mac.zp.zero_point", "matmul_op/input_scale:0:0", lambda x, **kw: 0),
            KeyHandler("mac.accumulator_quantizer.scale", "matmul_op/output_scale:0:0"),
            KeyHandler("mac.accumulator_quantizer.zero_point", "matmul_op/output_zero_point:0:0"),
        ],
    }

    @staticmethod
    def _get_basic_vals(hn_elemnt) -> Tuple[int, DimsInfo, List[DimsInfo], bool]:
        params = hn_elemnt["params"]
        wh, ww, wc = params.get("input_windows", (1, 1, 1))
        window = DimsInfo(height=wh, width=ww, channels=wc)
        input_tiles = [DimsInfo(*t) for t in params.get("input_tiles", [(1, 1, 1), (1, 1, 1)])]
        groups = params["groups"]
        transposed = params["transpose_matmul_input"]
        return groups, window, input_tiles, transposed

    def _build_native_layer_from_hn(self, hn_elemnt: dict, tags: Tags, *, dtype=None, device=None):
        groups, window, input_tiles, transposed = self._get_basic_vals(hn_elemnt)

        mm = MatMul(groups, window, input_tiles, transposed=transposed)

        return mm, tags

    def _build_quant_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        groups, window, input_tiles, transposed = self._get_basic_vals(hn_elemnt)
        params = hn_elemnt["params"]
        activation = ActivationType(params["activation"])

        mac_cfg, apu_cfg = self._quantization_params_to_precision_config(hn_elemnt["quantization_params"])
        if mac_cfg.bias_mode == BiasMode.single_scale_decomposition:
            tags.add(BiasMode.single_scale_decomposition)
        tags.add(activation)
        mac = MACMatmul(
            chanels_in_0=hn_elemnt["input_shapes"][0][-1],
            channels_out=hn_elemnt["output_shapes"][0][-1],
            transpose_input=transposed,
            groups=groups,
            zp_comp_rank=params["zp_comp_rank"],
            window=window,
            input_tiles=input_tiles,
            bias=False,
            precision_config=mac_cfg,
            device=device,
            dtype=dtype,
        )
        apu = APUActivation(
            hn_elemnt["output_shapes"][0][-1],
            activation,
            apu_cfg,
            dtype,
            device,
        )
        return SubClusterModule(mac, apu), tags


@HN_TRANSLATION_REGISTRY(LayerType.CONST_INPUT)
class HNConstInputTranslator(HNTranslator):
    state_dict_translation = {
        LayerMode.NATIVE: [
            KeyHandler("value", "const_data:0", lambda x, **kw: x.transpose(2, 0, 1)),
        ],
        LayerMode.QUANT: [
            KeyHandler("value.weight", "const_data:0", lambda x, **kw: x.transpose(2, 0, 1)),
            KeyHandler("value.scale", ("const_op/output_scale:0", "const_data:0"), handle_const_input_scale),
            KeyHandler("value.zero_point", "const_op/output_zero_point:0"),
            KeyHandler(
                "value.equalization_vector_out",
                ("const_op/output_scale:0", "const_data:0"),
                handle_const_input_equalization,
            ),
        ],
    }

    @staticmethod
    def get_basics(hn_elemnt):
        shape = list(hn_elemnt["input_shapes"][0])
        shape = [shape[3], shape[1], shape[2]]
        tile = hn_elemnt["params"]["input_tiles"]
        assert len(tile) == 1
        tile = DimsInfo(*tile[0])
        return shape, tile

    def _build_native_layer_from_hn(self, hn_elemnt: dict, tags: Tags, *, dtype=None, device=None):
        shape, tile = self.get_basics(hn_elemnt)
        layer = ConstInput(shape, tile, dtype=dtype, device=device)

        return layer, tags

    def _build_quant_layer_from_hn(self, hn_elemnt, tags, *, dtype=None, device=None):
        shape, tile = self.get_basics(hn_elemnt)

        quantization_params = hn_elemnt["quantization_params"]
        precision_mode = PrecisionMode(quantization_params["precision_mode"])
        b_pm = SplittedPrecisionMode.from_precision_mode(precision_mode)
        output_qtype = QType(min(b_pm.output, 15), quantization_params["signed_output"])
        precision_config = IOPrecisionConfig(output_qtype)

        layer = SaitamaConstInput(shape, tile, precision_config, dtype=dtype, device=device)
        return layer, tags
