import os
import random
from abc import ABC
from typing import ClassVar

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import IOType, PostprocessTarget
from hailo_sdk_client.sdk_backend.modification_config import (
    ChangeOutputActivationConfig,
    LogitsLayerConfig,
    NormalizationConfig,
    ResizeConfig,
    SetKVCachePairsConfig,
    TransposeConfig,
)
from hailo_sdk_client.sdk_backend.script_parser.commands import (
    CommandsGroups,
    ModelScriptCommand,
    ModelScriptCommandWithScope,
    SupportedCommands,
)
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import AllocatorScriptParserException
from hailo_sdk_client.tools.logits_layer_addition import LogitsLayersAdder
from hailo_sdk_client.tools.normalization_layers_addition import NormalizationLayersAdder
from hailo_sdk_client.tools.output_activation_modification import OutputActivationModifier
from hailo_sdk_client.tools.resize_layers_addition import ResizeLayersAdder
from hailo_sdk_client.tools.transpose_model import transpose_weights
from hailo_sdk_common.hailo_nn.hailo_nn import hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    HWLayerType,
    LayerType,
    ResizeBilinearPixelsMode,
    ResizeMethod,
)
from hailo_sdk_common.hailo_nn.hn_layers import ActivationLayer
from hailo_sdk_common.logger.logger import default_logger

logger = default_logger()


class ModelModificationsCommand(ModelScriptCommand, ABC):
    def __init__(self, function_name, function_args=None, function_return_vals=None, sort_key_func=None):
        super().__init__(function_name, function_args, function_return_vals, sort_key_func)
        self._meta_data = {}

    @property
    def group(self):
        return CommandsGroups.MODEL_MODIFICATIONS

    def apply(self, hailo_nn, params):
        raise NotImplementedError

    @property
    def meta_data(self):
        return self._meta_data


class ModelModificationsOnInputLayerCommand(ModelModificationsCommand, ModelScriptCommandWithScope, ABC):
    def has_layers(self):
        return bool(self._input_layer)


class ResizeCommand(ModelModificationsCommand):
    def __init__(
        self,
        resize_layers_names,
        layer,
        resize_shapes,
        resize_method=None,
        pixels_mode=None,
        hw_layer_type=None,
        engine=None,
    ):
        super().__init__(SupportedCommands.RESIZE)
        self._resize_layers_names = resize_layers_names
        self._layer = layer
        self._resize_shapes = resize_shapes
        self._resize_method = resize_method
        self._pixels_mode = pixels_mode
        self._hw_layer_type = hw_layer_type
        self._engine = engine

    def __str__(self):
        args = ""
        if self._layer:
            args += f"{self._layer}, "
        if self._resize_shapes:
            args += f"resize_shapes={self._resize_shapes}"
        if self._resize_method:
            args += f", resize_method={self._resize_method}"
        if self._pixels_mode:
            args += f", pixels_mode={self._pixels_mode}"
        if self._hw_layer_type:
            args += f", hw_layer_type={self._hw_layer_type}"
        if self._engine:
            args += f", engine={self._engine}"

        return f'{", ".join(self._resize_layers_names)} = {self.function_name.value}({args})'

    @classmethod
    def from_tokens(cls, tokens):
        resize_layers_names = tokens.multiple_return_vals.asList()
        layer, resize_shapes, resize_method, pixels_mode, hw_layer_type, engine = [None] * 6
        function_arg = tokens.function_args

        named_arg_seen = False
        for arg in function_arg:
            if isinstance(arg, str):  # positional argument
                if named_arg_seen:
                    raise AllocatorScriptParserException(
                        f"Must supply positional argument '{arg}' before named arguments.",
                    )
                layer = arg
            elif isinstance(arg, dict):  # named argument
                named_arg_seen = True
                arg_name = next(iter(arg.keys()))
                if arg_name in ["input_shape", "resize_shapes"]:
                    resize_shapes = [int(dim) for dim in arg[arg_name]]
                elif arg_name == "pixels_mode":
                    pixels_mode = arg[arg_name]
                elif arg_name == "hw_layer_type":
                    hw_layer_type = arg[arg_name]
                elif arg_name == "resize_method":
                    resize_method = arg[arg_name]
                elif arg_name == "engine":
                    engine = arg[arg_name]
                else:
                    raise AllocatorScriptParserException(
                        f"No argument named {arg_name}. Please make sure to use the "
                        f"argument name as it appears in the command description.",
                    )
        return cls(
            resize_layers_names=resize_layers_names,
            layer=layer,
            resize_shapes=resize_shapes,
            resize_method=resize_method,
            pixels_mode=pixels_mode,
            hw_layer_type=hw_layer_type,
            engine=engine,
        )

    @property
    def resize_layers_names(self):
        return self._resize_layers_names

    @property
    def input_shapes(self):
        return self._resize_shapes

    @property
    def pixels_mode(self):
        return self._pixels_mode

    def get_layers(self):
        return [self._layer]

    def validate_command(self, layer_scope_from_hn):
        if self._pixels_mode:
            valid_options = [item.value for item in ResizeBilinearPixelsMode]
            if self._pixels_mode not in valid_options:
                raise AllocatorScriptParserException(
                    f'Invalid pixels_mode {self._pixels_mode}. Must be one either one of: {", ".join(valid_options)}',
                )
        if self._hw_layer_type:
            valid_options = [item.value for item in HWLayerType]
            if self._hw_layer_type not in valid_options:
                raise AllocatorScriptParserException(
                    f'Invalid hw_layer_type {self._hw_layer_type}. '
                    f'Must be one either one of: {", ".join(valid_options)}',
                )
        if self._resize_method:
            valid_options = [item.value for item in ResizeMethod]
            if self._resize_method not in valid_options:
                raise AllocatorScriptParserException(
                    f'Invalid resize_method {self._resize_method}. Must be one either one of: '
                    f'{", ".join(valid_options)}',
                )
        if self._layer:
            if self._layer not in layer_scope_from_hn:
                raise AllocatorScriptParserException(f"Given layer {self._layer} not exist in the HN")

            if len(self._resize_layers_names) != 1:
                raise AllocatorScriptParserException(
                    f"Given {len(self._resize_layers_names)} names for the new layer when one name is required",
                )
        elif len(self._resize_shapes) != 2:
            raise AllocatorScriptParserException("The input shape must be a list of h, w")

        invalid_layer_names = [
            layer_name for layer_name in self._resize_layers_names if layer_name in layer_scope_from_hn
        ]
        if invalid_layer_names:
            raise AllocatorScriptParserException(
                f"Given layer names {invalid_layer_names} exist in the model. Please use different names",
            )

    def add_scope(self, scope_name, force=False):
        if self._layer:
            self._layer = self.add_scope_to_layer(scope_name, self._layer, force)
        self._resize_layers_names = [
            self.add_scope_to_layer(scope_name, resize_layer, force) for resize_layer in self._resize_layers_names
        ]

    def _all_layers(self):
        layers = [*self._resize_layers_names]
        if self._layer:
            layers.append(self._layer)
        return layers

    def _replace_all_layers(self, new_values):
        count = len(self._resize_layers_names)
        self._resize_layers_names = new_values[:count]
        if self._layer:
            self._layer = new_values[count]

    def remove_scope(self):
        self._resize_layers_names = self._remove_scope(self._resize_layers_names)
        if self._layer is not None:
            self._layer = self._remove_scope(self._layer)

    def apply(self, hailo_nn, params, **kwargs):
        resize_shapes = {self._layer: self._resize_shapes} if self._layer else self._resize_shapes
        resize_adder = ResizeLayersAdder(
            hailo_nn=hailo_nn,
            resize_layers=self._resize_layers_names,
            resize_shapes=resize_shapes,
            resize_method=self._resize_method,
            pixels_mode=self._pixels_mode,
            hw_layer_type=self._hw_layer_type,
            engine=self._engine,
        )

        for idx, input_layer in enumerate(resize_adder.layers_to_shapes):
            config = ResizeConfig(
                cmd_type=self._function_name,
                input_shape=input_layer.output_shape,
                output_shape=resize_adder.layers_to_shapes[input_layer],
                pixels_mode=resize_adder.pixels_mode,
                hw_layer_type=resize_adder.hw_layer_type,
                resize_layer_name=resize_adder.resize_layers[idx],
                interpolation_method=resize_adder.resize_method,
            )
            self.meta_data[input_layer.name] = config
        hailo_nn = resize_adder.add_resize_layers()

        return hailo_nn, params

    def has_layers(self):
        return bool(self._layer)


class NormalizationCommand(ModelModificationsOnInputLayerCommand):
    def __init__(self, mean, std, normalization_layers, input_layer):
        super().__init__(SupportedCommands.NORMALIZATION)
        self._mean = mean
        self._std = std
        self._normalization_layers = normalization_layers
        self._input_layer = input_layer

    def __str__(self):
        args = f"{self._mean}, {self._std}"
        if self._input_layer:
            args += f", {self._input_layer}"
        return f'{", ".join(self._normalization_layers)} = {self.function_name.value}({args})'

    @property
    def normalization_layers(self):
        return self._normalization_layers

    @classmethod
    def from_tokens(cls, tokens):
        mean = tokens.function_args[0]
        std = tokens.function_args[1]
        input_layer = tokens.function_args[2] if len(tokens.function_args) == 3 else None
        normalization_layers = tokens.multiple_return_vals.asList()
        return cls(mean, std, normalization_layers, input_layer)

    def get_layers(self):
        return self._normalization_layers

    def validate_command(self, layer_scope_from_hn):
        if self._input_layer:
            if self._input_layer not in layer_scope_from_hn:
                raise AllocatorScriptParserException(f"Given layer {self._input_layer} not exist in the HN")
            if len(self._normalization_layers) != 1:
                raise AllocatorScriptParserException("When normalizing one layer, one normalization name is required")

        invalid_layer_names = [
            layer_name for layer_name in self._normalization_layers if layer_name in layer_scope_from_hn
        ]
        if invalid_layer_names:
            raise AllocatorScriptParserException(
                f"Given layer names {invalid_layer_names} exist in the model. Please use different names",
            )

    def add_scope(self, scope_name, force=False):
        if self._input_layer:
            self._input_layer = self.add_scope_to_layer(scope_name, self._input_layer, force)

        self._normalization_layers = [
            self.add_scope_to_layer(scope_name, normalization_layer, force)
            for normalization_layer in self._normalization_layers
        ]

    def apply(self, hailo_nn, params, **kwargs):
        norm_adder = NormalizationLayersAdder(
            hailo_nn,
            params,
            self._mean,
            self._std,
            self._normalization_layers,
            self._input_layer,
        )
        for idx, inp_layer in enumerate(norm_adder.input_layers):
            config = NormalizationConfig(
                cmd_type=SupportedCommands.NORMALIZATION,
                mean=norm_adder.mean,
                std=norm_adder.std,
                normalization_layer=norm_adder.normalization_names[idx],
            )
            self.meta_data[inp_layer.name] = config
        return norm_adder.add_normalization_layers()

    def _all_layers(self):
        layers = [*self._normalization_layers]
        if self._input_layer:
            layers.append(self._input_layer)
        return layers

    def _replace_all_layers(self, new_values):
        count = len(self._normalization_layers)
        self._normalization_layers = new_values[:count]
        if self._input_layer:
            self._input_layer = new_values[count]

    def remove_scope(self):
        self._normalization_layers = self._remove_scope(self._normalization_layers)
        if self._input_layer is not None:
            self._input_layer = self._remove_scope(self._input_layer)


class TransposeCommand(ModelModificationsOnInputLayerCommand):
    def __init__(self, input_layer):
        super().__init__(SupportedCommands.TRANSPOSE)
        self._input_layer = input_layer

    def __str__(self):
        input_layer = self._input_layer if self._input_layer else ""
        return f"{self.function_name.value}({input_layer})"

    @classmethod
    def from_tokens(cls, tokens):
        input_layer = tokens.function_args[0] if tokens.function_args else None
        return cls(input_layer)

    def add_scope(self, scope_name, force=False):
        if self._input_layer:
            self._input_layer = self.add_scope_to_layer(scope_name, self._input_layer, force)

    def get_layers(self):
        return [self._input_layer]

    def validate_command(self, layer_scope_from_hn):
        if self._input_layer and self.has_unfound_layers(layer_scope_from_hn):
            raise AllocatorScriptParserException(f"Given layer {self._input_layer} not exist in the HN")

    def apply(self, hailo_nn, params, **kwargs):
        inputs = [self._input_layer] if self._input_layer else None
        hailo_nn.transpose_layers_height_width(inputs)
        meta_data_layers = inputs if inputs else [layer.name for layer in hailo_nn.get_input_layers()]
        for layer in meta_data_layers:
            self.meta_data[layer] = TransposeConfig(cmd_type=SupportedCommands.TRANSPOSE)
        return transpose_weights(hailo_nn, params)

    def _all_layers(self):
        if self._input_layer:
            return [self._input_layer]
        return []

    def _replace_all_layers(self, new_values):
        if self._input_layer:
            self._input_layer = new_values[0]

    def remove_scope(self):
        if self._input_layer is not None:
            self._input_layer = self._remove_scope(self._input_layer)


class ChangeActivationCommand(ModelModificationsCommand):
    def __init__(self, function_name, activation_type, layer):
        super().__init__(function_name)
        self._activation_type = activation_type
        self._layer = layer

    def add_scope(self, scope_name, force=False):
        if self._layer is not None:
            self._layer = self.add_scope_to_layer(scope_name, self._layer, force)

    def _alls(self):
        if self._layer is not None:
            return [self._layer]
        return []

    def _replace_all_layers(self, new_values):
        if self._layer is not None:
            self._layer = new_values[0]

    def remove_scope(self):
        if self._layer is not None:
            self._layer = self._remove_scope(self._layer)


class ChangeOutputActivationCommand(ChangeActivationCommand):
    def __init__(self, output_layer, activation_type):
        super().__init__(SupportedCommands.CHANGE_OUTPUT_ACTIVATION, activation_type, output_layer)

    def __str__(self):
        args = f"{self._activation_type.value}"
        if self._layer is not None:
            args = f"{self._layer}," + args
        return f"{self.function_name.value}({args})"

    @classmethod
    def from_tokens(cls, tokens):
        if len(tokens.function_args) == 1:
            # only activation type was specified
            output_layer = None
            activation_type = tokens.function_args[0]

        elif len(tokens.function_args) > 1:
            # output layer and activation function were specified
            output_layer = tokens.function_args[0]
            activation_type = tokens.function_args[1]

        else:
            raise AllocatorScriptParserException(f"{tokens.function_name} model script command must get arguments")

        return cls(output_layer, ActivationType(activation_type))

    @property
    def activation_type(self):
        return self._activation_type

    def get_layers(self):
        return [self._layer]

    def validate_command(self, layers_scope_from_hn):
        if self._layer is not None and self.has_unfound_layers(layers_scope_from_hn):
            raise AllocatorScriptParserException(f"Given layers {self._layer} not exist in the HN")

        if self._activation_type not in ActivationType:
            raise AllocatorScriptParserException(f"The activation {self._activation_type.value} is not supported")

        if ActivationLayer.requires_weights(self._activation_type):
            raise AllocatorScriptParserException(
                f"The activation {self._activation_type.value} requires native weights, which is not supported",
            )

    def apply(self, hailo_nn, params, **kwargs):
        output_layers_to_activate = (
            hailo_nn.get_real_output_layers() if self._layer is None else [hailo_nn.get_layer_by_name(self._layer)]
        )
        output_activation_modifier = OutputActivationModifier(
            hailo_nn,
            params,
            self._activation_type,
            output_layers_to_activate,
        )

        hailo_nn, original_activations = output_activation_modifier.apply_activation_modification()

        for idx, output_layer in enumerate(output_activation_modifier.output_layers):
            config = ChangeOutputActivationConfig(
                cmd_type=SupportedCommands.CHANGE_OUTPUT_ACTIVATION,
                original_activation=original_activations[idx],
                new_activation=output_activation_modifier.activation_type,
                hn_layer_name=output_activation_modifier.activation_layer_names[idx],
            )
            self.meta_data[output_layer.name] = config

        return hailo_nn, params


class LogitsLayerCommand(ChangeActivationCommand):
    def __init__(self, output_layer, logits_layers, activation_type, axis, engine=None):
        super().__init__(SupportedCommands.LOGITS_LAYER, activation_type, output_layer)
        self._axis = int(axis) if axis else axis
        self._logits_layers = logits_layers
        self._engine = PostprocessTarget(engine) if engine else engine

    @classmethod
    def from_tokens(cls, tokens):
        # full tokens - (layer, logits_layer_type, axis, engine)
        # partial tokens 3 parameters - (layer, logits_layer_type, axis)
        # partial tokens 2 parameters - (logits_layer_type, axis)
        # partial tokens 1 parameters - (logits_layer_type)
        output_layer = activation_type = axis = engine = None
        num_of_args = len(tokens.function_args)
        logits_layers = tokens.multiple_return_vals.asList()
        if num_of_args == 1:
            activation_type = tokens.function_args[0]
        elif num_of_args == 2:
            activation_type = tokens.function_args[0]
            axis = tokens.function_args[1]
        elif num_of_args == 3:
            output_layer = tokens.function_args[0]
            activation_type = tokens.function_args[1]
            axis = tokens.function_args[2]
        elif num_of_args == 4:
            output_layer = tokens.function_args[0]
            activation_type = tokens.function_args[1]
            axis = tokens.function_args[2]
            engine = tokens.function_args[3]
        else:
            raise AllocatorScriptParserException(
                f"{tokens.function_name} model script command must get at least one "
                "argument and up to three arguments.",
            )
        return cls(output_layer, logits_layers, LayerType(activation_type), axis, engine)

    @property
    def activation_type(self):
        return self._activation_type

    def get_layers(self):
        return [self._layer]

    def validate_command(self, layers_scope_from_hn):
        if self._layer is not None and self._layer not in layers_scope_from_hn:
            raise AllocatorScriptParserException(f"Given layers {self._layer} not exist in the HN")

        if self._activation_type not in [LayerType.softmax, LayerType.argmax]:
            raise AllocatorScriptParserException(f"The logit function {self._activation_type.value} is not supported.")

    def __str__(self):
        self.remove_scope()
        args = f"{self._activation_type.value}"
        if self._axis is not None:
            args += f", {self._axis!s}"

        if self._engine is not None:
            args += f", {self._engine.value!s}"

        if self._layer is not None:
            args = f"{self._layer}, " + args
        return f'{", ".join(self._logits_layers)} = {self.function_name.value}({args})'

    def apply(self, hailo_nn, params, **kwargs):
        axis = self._axis if self._axis else (-1 if self._activation_type == LayerType.softmax else 0)
        layers = (
            [hailo_nn.get_layer_by_name(self._layer)] if self._layer is not None else hailo_nn.get_real_output_layers()
        )
        engine = self._engine if self._engine else self.get_default_engine(layers)
        logits_layers_adder = LogitsLayersAdder(
            hailo_nn,
            layers,
            self._logits_layers,
            self._activation_type,
            axis,
            engine,
        )
        hailo_nn = logits_layers_adder.add_logits_layers()

        for logit_name in logits_layers_adder.logits_layers_name:
            logit_layer = hailo_nn.get_layer_by_name(logit_name)
            config = LogitsLayerConfig(
                cmd_type=SupportedCommands.LOGITS_LAYER,
                logit_layer_name=logit_name,
                logits_type=logits_layers_adder.activation_type,
                axis=logits_layers_adder.axis,
                engine=logits_layers_adder.engine,
            )
            self.meta_data[logit_layer.outputs[0]] = config
        return hailo_nn, params

    def get_default_engine(self, layers):
        # if the predecessor of the logits layer is a postprocess layer the default value is `cpu` otherwise `nn_core`
        return (
            PostprocessTarget.CPU
            if any(layer.op == LayerType.postprocess for layer in layers)
            else PostprocessTarget.NN_CORE
        )

    def add_scope(self, scope_name, force=False):
        super().add_scope(scope_name)
        self._logits_layers = [
            self.add_scope_to_layer(scope_name, logits_layer, force) for logits_layer in self._logits_layers
        ]

    def remove_scope(self):
        super().remove_scope()
        self._logits_layers = [self._remove_scope(logits_layer) for logits_layer in self._logits_layers]


class SetSeedCommand(ModelModificationsCommand):
    def __init__(self, seed):
        super().__init__(SupportedCommands.SET_SEED)
        self._seed = int(seed)
        self._logger = default_logger()

    @classmethod
    def from_tokens(cls, tokens):
        num_of_args = len(tokens.function_args)
        seed = None
        if num_of_args == 1:
            seed = tokens.function_args[0]
            if isinstance(seed, dict):
                seed = next(iter(seed.values()))
        else:
            raise AllocatorScriptParserException(
                f"{tokens.function_name} model script command must get up to one argument.",
            )
        return cls(seed)

    def get_layers(self):
        return []

    def validate_command(self, layers_scope_from_hn):
        if not (isinstance(self._seed, int) or (isinstance(self._seed, float) and self._seed.is_integer())):
            raise AllocatorScriptParserException("The value of `seed` must be an integer")

    def __str__(self):
        return f"{self.function_name.value}({self._seed})"

    def apply(self, hailo_nn, params, **kwargs):
        os.environ["PYTHONHASHSEED"] = str(self._seed)
        np.random.seed(self._seed)
        random.seed(self._seed)
        tf.random.set_seed(self._seed)
        tf.experimental.numpy.random.seed(self._seed)
        tf.config.experimental.enable_op_determinism()
        self._logger.info(f"The seeds of python, python.random, Tensorflow and numpy are set to {self._seed}")
        return hailo_nn, params

    def add_scope(self, scope_name, force=False):
        # the command has no layers as arguments
        pass

    def remove_scope(self):
        # the command has no layers as arguments
        pass


class SetKVCachePairCommand(ModelModificationsCommand):
    """
    Correlates between key and value cache IDs and their values.

    Args:
        kv_pairs: kv pair which contains an input and output layer name, if not given, the pair will be auto-detected

    Note:
    Basic assumptions:
    * Each pair is ordered as follows: [output_layer_name, input_layer_name]
    """

    def __init__(self, output_layer="", input_layer=""):
        super().__init__(SupportedCommands.SET_KV_CACHE_PAIR)
        self._output_layer = output_layer
        self._input_layer = input_layer

    def __str__(self):
        """
        Print set_kv_cache_pair command args by set_kv_cache_pair(output_layer, input_layer)
        """
        args = ""
        if self._output_layer:
            args = f"{self._output_layer}"
        if self._input_layer:
            args = f"{args}, {self._input_layer}"

        return f"{self.function_name.value}({args})"

    @classmethod
    def from_tokens(cls, tokens):
        in_args = tokens.function_args
        if len(in_args) == 2:
            input_layer, output_layer = in_args
        elif len(in_args) == 0:
            # auto-detect the pairs
            input_layer = output_layer = ""
        else:
            msg = f"{tokens.function_name} model script command must have two arguments or no arguments at all"
            raise AllocatorScriptParserException(msg)

        return cls(input_layer, output_layer)

    def validate_command(self, layers_scope_from_hn):
        if not self._input_layer or not self._output_layer:
            return
        layers_names = [name.split("/")[-1] for name in layers_scope_from_hn]
        for layer in [self._output_layer, self._input_layer]:
            if layer not in layers_names:
                msg = f"Given layer name {layer} does not exist in the HN"
                raise AllocatorScriptParserException(msg)

    def apply(self, hailo_nn, params, **kwargs):
        self._validate_pair(hailo_nn)
        self._set_pair(hailo_nn)

        return hailo_nn, params

    def _validate_pair(self, hailo_nn):
        """
        per pair, validate whether each layer is of I/O type
        Lastly, update the pair to contain the real output layer
        """
        if self._input_layer and self._output_layer:
            input_layer = hailo_nn.get_layer_by_name(self._input_layer)
            if input_layer.op != LayerType.input_layer:
                msg = f"The given input layer {input_layer.name_without_scope} is not an input layer"
                raise AllocatorScriptParserException(msg)

            output_layer = hailo_nn.get_layer_by_name(self._output_layer)
            if output_layer.op != LayerType.output_layer:
                msg = f"The given output layer {output_layer.name_without_scope} is not an output layer"
                raise AllocatorScriptParserException(msg)

    def _set_pair(self, hailo_nn):
        """
        sets the key-value cache pair/s
        if a pair is not set, detects the following structure and assigns cache input and output accordingly:
                      +---------------------+
                      |  cache input layer  |
                      +-----+---------------+           +---------------------+
                            |          +----------------|  real output layer  |
                            |          |                +----------+----------+
                      +-----+----------+----+                      |
                      |  real input layer   |                      |
                      +---------------------+            +---------+----------+
                                                         | cache output layer |
                                                         +--------------------+
        meaning, in kv cache pair, the real output layer has a successor which is the real input layer
        """

        def assign_kv_cache_pair(input_layer, output_layer):
            next_cache_id = len(
                [
                    input_layer
                    for input_layer in hailo_nn.get_layers_by_type(LayerType.input_layer)
                    if input_layer.io_type == IOType.CACHE
                ],
            )

            input_layer.cache_id = next_cache_id
            output_layer.cache_id = next_cache_id
            input_layer.io_type = IOType.CACHE
            output_layer.io_type = IOType.CACHE
            logger.info(f"Setting key-value cache pair: {output_layer.name} -> {input_layer.name}")
            logger.info("Key-value cache pair set successfully")

            config = SetKVCachePairsConfig(
                cmd_type=SupportedCommands.SET_KV_CACHE_PAIR,
                pair_names=(output_layer.name, input_layer.name),
                cache_id=input_layer.cache_id,
                cache_type=input_layer.io_type,
            )
            self.meta_data[input_layer.name] = self.meta_data[output_layer.name] = config

        if self._input_layer and self._output_layer:
            input_layer = hailo_nn.get_layer_by_name(self._input_layer)
            output_layer = hailo_nn.get_layer_by_name(self._output_layer)
            assign_kv_cache_pair(input_layer, output_layer)
        else:
            # auto-detect the pairs
            real_input_layers = hailo_nn.get_real_input_layers()
            for output_layer in hailo_nn.get_output_layers():
                # the real output layer has a successor which is the real input layer
                real_output_layer = next(iter(hailo_nn.predecessors(output_layer)))
                succs = list(hailo_nn.successors(real_output_layer))
                cache_input_layer = [succ for succ in succs if len(succs) == 2 and succ in real_input_layers]
                if cache_input_layer:
                    # found a kv cache pair
                    input_layer = next(
                        pred for pred in hailo_nn.predecessors(cache_input_layer[0]) if pred.op == LayerType.input_layer
                    )
                    assign_kv_cache_pair(input_layer, output_layer)


class SetKVCacheGlobalParamsCommand(ModelModificationsCommand):
    """
    Configures global parameters for KV cache usage.

    Args:
        prefill_size (int): The size of the prefill.
        cache_size (int): The size of the cache.
    """

    DEFAULT_ARGS: ClassVar[dict] = {
        "prefill_size": 64,
        "cache_size": 1024,
    }

    def __init__(self, prefill_size, cache_size):
        super().__init__(SupportedCommands.SET_KV_CACHE_GLOBAL_PARAMS)
        self._prefill_size = int(prefill_size)
        self._cache_size = int(cache_size)

    def __str__(self):
        return f"{self.function_name.value}({self._prefill_size}, {self._cache_size})"

    @classmethod
    def from_tokens(cls, tokens):
        args_dict = cls.DEFAULT_ARGS
        args_list = tokens.function_args.asList()

        arg_values_dict = (
            {arg_name: arg_value for arg_dict in args_list for arg_name, arg_value in arg_dict.items()}
            if all(isinstance(arg, dict) for arg in args_list)  # list of dicts each arg is a dict
            else dict(zip(cls.DEFAULT_ARGS.keys(), args_list))  # list of values each arg is a value
        )

        for arg_name in arg_values_dict:
            if arg_name not in args_dict:
                raise AllocatorScriptParserException(
                    f"Invalid argument {arg_name}. Must be one of: {', '.join(args_dict.keys())}",
                )
        args_dict.update(arg_values_dict)
        return cls(**args_dict)

    def validate_command(self, layers_scope_from_hn):
        # the required validation is done in `from_tokens`
        pass

    def apply(self, hailo_nn, params, **kwargs):
        """
        changes the the prefill size and propagate the changes to the model and stores the cache size and num of blocks
        """
        current_prefill_size = hailo_nn.get_input_shapes()[0][-2]
        hailo_nn.net_params.prefill_size = self._prefill_size
        hailo_nn.net_params.cache_size = self._cache_size
        update_hn = False
        if not update_hn:
            return hailo_nn, params

        # updates input shapes
        for layer in hailo_nn:
            if layer.op == LayerType.input_layer:
                # updates the prefill size in inputs
                if layer.input_shapes[0][2] == layer.input_shapes[0][-1]:  # attention mask is square
                    layer.output_shapes[0][-1] = layer.input_shapes[0][-1] = self._prefill_size
                layer.output_shapes[0][2] = layer.input_shapes[0][2] = self._prefill_size

            elif layer.op == LayerType.const_input:
                # update the positional embedding const values
                layer.output_shapes[0][2] = layer.input_shapes[0][2] = self._prefill_size
                # updates the const data
                const_key = hn_to_npz_key(layer.name, "const_data")
                const_data_to_reshape = params[const_key]
                params.update({const_key: const_data_to_reshape[:, :, -self._prefill_size :]})

            elif (
                layer.op in [LayerType.ew_sub, LayerType.ew_mult] and layer.input_repeats[1][-1] == current_prefill_size
            ):
                # updated thr input repeats of the elementwise layers
                layer.input_repeats[1][-1] = self._prefill_size

            elif layer.op == LayerType.matmul:
                # updates the kernel shape since it affects the output size
                index_to_update = layer.kernel_shape.index(current_prefill_size * layer.groups)
                layer.kernel_shape[index_to_update] = int(
                    layer.kernel_shape[index_to_update] / (current_prefill_size / self._prefill_size),
                )
        # propagates the changes to the model
        hailo_nn.calculate_shapes()
        return hailo_nn, params
