#!/usr/bin/env python
import copy
import os
import os.path

import pyparsing as pp

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import (
    ModelOptimizationConfig,
    update_nested,
)
from hailo_sdk_client.allocator.pb_wrapper import PbWrapper
from hailo_sdk_client.sdk_backend.script_parser.commands import (
    DICT_RETURN_COMMANDS,
    GLOB_COMMANDS,
    MULTIPLE_RETURN_COMMANDS,
    SINGLE_RETURN_COMMANDS,
    VOID_COMMANDS,
    AddResourceCommand,
    AllocatorParamCommand,
    BucketCommand,
    BufferCalcCommand,
    BuffersCommand,
    CascadeCommand,
    CollapseCommand,
    CommandsGroups,
    CompilationParamCommand,
    ConcatCommand,
    ContextCommand,
    ContextCompilationParamCommand,
    ContextPerformanceParamCommand,
    ContextPlaceCommand,
    ContextResourcesParamCommand,
    ContextSwitchParamCommand,
    ConvertToDenseCommand,
    DefuseBlockCommand,
    DefuseCommand,
    ForceMappingCommand,
    ForceRouteCommand,
    FormatConversionCommand,
    FromTFCommand,
    HefParamCommand,
    InternalAllocatorParamCommand,
    InternalContextSwitchParamCommand,
    LoggerParamCommand,
    MergeCommand,
    MirrorCommand,
    MuxDemuxCommand,
    NetworkGroupCommand,
    OptimizeBuffersCommand,
    OutputLayerCommand,
    OutputMuxCommand,
    PerformanceParamCommand,
    PlaceCommand,
    PlatformParamCommand,
    PrintBuffersCommand,
    RemoveNodeCommand,
    ResourcesParamCommand,
    ShapeSplitterCommand,
    ShareConfigCommand,
    ShortcutCommand,
    StrategyCommand,
    SupportedCommands,
    TransposeConcatCommand,
    decode_prefix,
    get_scopes_set_from_layers,
)
from hailo_sdk_client.sdk_backend.script_parser.input_conversion_commands import InputConversionCommand
from hailo_sdk_client.sdk_backend.script_parser.mo_config_commands import ModelOptimizationConfigCommand
from hailo_sdk_client.sdk_backend.script_parser.model_modifications_commands import (
    ChangeOutputActivationCommand,
    LogitsLayerCommand,
    ModelModificationsOnInputLayerCommand,
    NormalizationCommand,
    ResizeCommand,
    SetKVCacheGlobalParamsCommand,
    SetKVCachePairCommand,
    SetSeedCommand,
    TransposeCommand,
)
from hailo_sdk_client.sdk_backend.script_parser.model_optimization_commands import (
    CompressionParamsCommands,
    ModelOptimizationFlavorCommand,
    QuantizationParamCommand,
)
from hailo_sdk_client.sdk_backend.script_parser.nms_postprocess_command import NMSPostprocessCommand
from hailo_sdk_client.sdk_backend.script_parser.post_quantization_commands import PostQuantizationOptimizationCommand
from hailo_sdk_client.sdk_backend.script_parser.pre_quantization_commands import PreQuantizationOptimizationCommand
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendScriptParserException
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_script_parser.model_script_modes import ModelScriptModes
from hailo_sdk_common.paths_manager.paths import SDKPaths


def get_auto_alls_path(name):
    return SDKPaths().join_build_sdk(name + ".auto.alls")


def get_funcs_list(commands_list):
    return pp.oneOf(" ".join(commands_list))("function_name")


def sort_by_key(cmd):
    return cmd.sort_key()


def sort_by_group(cmd):
    return int(cmd.group)


class ModelScriptParser:
    DEFAULT_VERSION = 1
    EXCLUDE_FROM_AUTO_ALLS = [
        SupportedCommands.INTERNAL_ALLOCATOR_PARAM.value,
        SupportedCommands.HEF_PARAM.value,
        SupportedCommands.INTERNAL_CONTEXT_SWITCH_PARAM.value,
        SupportedCommands.CONTEXT_SWITCH_PARAM.value,
    ]
    MODEL_OPTIMIZATION_COMMAND_GROUPS = [
        CommandsGroups.MODEL_MODIFICATIONS,
        CommandsGroups.QUANTIZATION,
        CommandsGroups.QUANTIZATION_FLAVOR,
    ]

    def __init__(self, hn, mode=ModelScriptModes.OPTIMIZATION_MODE, alls_ignore_invalid_cmds=False):
        self._hn = None
        self._layers_scope = None
        self._layers_in_hn = None
        self.update_model(hn)
        self._mode = mode
        self._commands = []
        self._define_base_grammar()
        self._sort_auto_alls = True
        self._sorting_disabled = False
        self._path = None
        self._har = None
        self._nms_config_file = None
        self._original_script = ""
        self._alls_ignore_invalid_cmds = alls_ignore_invalid_cmds

    @property
    def original_script(self):
        return self._original_script

    @property
    def is_multi_scope(self):
        return self._hn.is_multi_scope

    @property
    def is_single_scope(self):
        return self._hn.is_single_scope

    def _build_script_str(self, is_remove_scope, is_to_auto_alls=False):
        ordered_commands = self._commands.copy()
        if not self._sorting_disabled:
            sort_key = sort_by_key if self._sort_auto_alls else sort_by_group
            ordered_commands.sort(key=sort_key)

        if is_remove_scope and self.is_single_scope:
            ordered_commands_cpy = []
            for cmd in ordered_commands:
                cmd_cpy = copy.copy(cmd)
                cmd_cpy.remove_scope()
                ordered_commands_cpy.append(cmd_cpy)
            ordered_commands = ordered_commands_cpy

        return self.to_str(ordered_commands, is_to_auto_alls)

    @staticmethod
    def to_str(ordered_commands, is_to_auto_alls=False):
        res = ""
        for i, command in enumerate(ordered_commands):
            res = "".join([res, command.str_to_alls() if is_to_auto_alls else str(command), "\n"])
            last_in_group = len(ordered_commands) > i + 1 and ordered_commands[i + 1].group != command.group
            if last_in_group:
                res = f"{res}\n"
        return res

    def __str__(self):
        return self._build_script_str(False)

    def export_auto_alls(self):
        return self._build_script_str(self.is_single_scope, is_to_auto_alls=True)

    @property
    def layers_scope(self):
        return self._layers_scope

    @property
    def commands(self):
        return self._commands

    @property
    def model(self):
        return self._hn

    @staticmethod
    def get_stable_toposort_func(stable_toposort_dict):
        def func(layer):
            if isinstance(layer, str):
                return stable_toposort_dict[layer]
            raise ValueError("layer must be string to get his table_toposort_index")

        return func

    @staticmethod
    def get_stable_toposort_func_mellow(stable_toposort_dict):
        def func(layer):
            if isinstance(layer, str):
                return stable_toposort_dict.get(layer, len(stable_toposort_dict))
            raise ValueError("layer must be string to get his table_toposort_index")

        return func

    def load_pb(self, script_pb):
        # append commands from proto
        pb_wrapper = PbWrapper()
        for command in script_pb.commands:
            if command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_DEFUSE:
                self._commands.append(DefuseCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_DEFUSE_BLOCK:
                self._commands.append(DefuseBlockCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CASCADE:
                self._commands.append(CascadeCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONCAT:
                self._commands.append(ConcatCommand.from_pb(command))
            elif command.op in (
                pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_SHORTCUT,
                pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_PORTAL,
                pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_DDR,
                pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_L4_PORTAL,
            ):
                self._commands.append(ShortcutCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_OUTPUT_MUX:
                self._commands.append(OutputMuxCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_MUX_DEMUX:
                self._commands.append(MuxDemuxCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_OUTPUT_LAYER:
                self._commands.append(OutputLayerCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_COMPILATION_PARAM:
                self._commands.append(CompilationParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_BUFFERS:
                self._commands.append(BuffersCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_PLACE:
                self._commands.append(PlaceCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONTEXT_PLACE:
                self._commands.append(ContextPlaceCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_PRINT_BUFFERS:
                self._commands.append(PrintBuffersCommand())
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_ALLOCATOR_PARAM:
                self._commands.append(AllocatorParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONTEXT:
                self._commands.append(ContextCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_MERGE:
                self._commands.append(MergeCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_FORMAT_CONVERSION:
                self._commands.append(FormatConversionCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONTEXT_SWITCH_PARAM:
                self._commands.append(ContextSwitchParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONTEXT_COMPILATION_PARAM:
                self._commands.append(ContextCompilationParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONTEXT_RESOURCES_PARAM:
                self._commands.append(ContextResourcesParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_HEF_PARAM:
                self._commands.append(HefParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_COLLAPSE:
                self._commands.append(CollapseCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_TRANSPOSE_CONCAT:
                self._commands.append(TransposeConcatCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_RESOURCES_PARAM:
                self._commands.append(ResourcesParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_NETWORK_GROUP:
                self._commands.append(NetworkGroupCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_PERFORMANCE_PARAM:
                self._commands.append(PerformanceParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONTEXT_PERFORMANCE_PARAM:
                self._commands.append(ContextPerformanceParamCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_REMOVE_NODE:
                self._commands.append(RemoveNodeCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_CONVERT_TO_DENSE:
                self._commands.append(ConvertToDenseCommand.from_pb(command))
            # skip commands which are set by ALLOCATOR_PARAM command
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_SHAPE_SPLITTER:
                self._commands.append(ShapeSplitterCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_MIRROR:
                self._commands.append(MirrorCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_BUCKET:
                self._commands.append(BucketCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_SHARE_CONFIG:
                self._commands.append(ShareConfigCommand.from_pb(command))
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_ADD_RESOURCE:
                self._commands.append(AddResourceCommand.from_pb(command))
            elif command.op in (
                pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_STRATEGY,
                pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_TIMEOUT,
                pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_OPTIMIZE_BUFFERS,
            ):
                pass
            elif command.op == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_CMD_PLATFORM_PARAM:
                self._commands.append(PlatformParamCommand.from_pb(command))
            else:
                default_logger().warning("Unexpected op while loading proto script.")

        if not self.is_multi_scope:
            assert len(self._hn.net_params.net_scopes) == 1
            for command in self._commands:
                command.add_scope(self._hn.net_params.net_scopes)

    def parse_script_from_file(self, input_script_path, nms_config_file=None, append=False):
        self._path = os.path.abspath(input_script_path)
        with open(input_script_path) as f:
            return self.parse_script(f.read(), append, nms_config_file)

    def _define_base_grammar(self):
        # handle spaces, tabs, new lines
        pp.ParserElement.setDefaultWhitespaceChars(" \t\n")
        new_line = pp.LineEnd().suppress()
        self._empty_line = new_line + new_line
        self._comment = pp.LineStart() + pp.Literal("#").suppress() + pp.Optional(pp.restOfLine)

        # basic characters grammar
        self._left_par, self._right_par = pp.Literal("(").suppress(), pp.Literal(")").suppress()
        left_brck, right_brck = pp.Literal("[").suppress(), pp.Literal("]").suppress()
        self._eq = pp.Literal("=").suppress()
        any_number = pp.pyparsing_common.fnumber

        # variables grammar
        self._identifier = pp.Word(pp.alphanums, pp.alphanums + "_" + "-" + "/")
        self._identifier_val = pp.Word(pp.alphanums + "_", pp.alphanums + "_" + "-" + "/" + ".")

        basic_variable_val = any_number | pp.quotedString | self._identifier_val
        list_of_args = pp.Group(left_brck + pp.delimitedList(basic_variable_val, ",") + right_brck)
        list_of_args.setParseAction(self._parse_list_args)
        range_endpoint = pp.delimitedList(basic_variable_val, ",")
        range_endpoint.setParseAction(self._parse_range_endpoint)
        ranges = pp.Group(left_brck + pp.delimitedList(range_endpoint, ":") + right_brck)
        glob_syntax_variable_val = (
            pp.Literal("{").suppress()
            + (pp.Word(pp.alphanums + "_" + "*", pp.alphanums + "-/_*?![]") | pp.Literal("*"))
            + pp.Literal("}").suppress()
        )
        timeout = any_number + pp.Word("smhd")
        # arguments grammar
        variable_val = list_of_args | ranges | timeout | basic_variable_val
        assigned_arg = self._identifier + self._eq + variable_val
        assigned_arg.setParseAction(self._parse_assigned_arg)
        argument = assigned_arg | variable_val
        assigned_glob = self._identifier + self._eq + glob_syntax_variable_val
        assigned_glob.setParseAction(self._parse_assigned_arg)
        glob_syntax_argument = assigned_glob | argument | glob_syntax_variable_val
        self._args = pp.delimitedList(argument, ",")("function_args")
        self._glob_syntax_args = pp.delimitedList(glob_syntax_argument, ",")("function_args")

    @staticmethod
    def _parse_assigned_arg(s, loc, tokens):
        key = tokens[0]
        val = tokens[1] if len(tokens[1:]) == 1 else tokens[1:]
        return {key: val}

    @staticmethod
    def _parse_list_args(s, loc, tokens):
        return tokens.asList()

    @staticmethod
    def _parse_range_endpoint(s, loc, tokens):
        return [tokens.asList()]

    def parse_script(self, input_script, append=False, nms_config_file=None):
        self._nms_config_file = nms_config_file
        if append:
            self._original_script += "\n" + input_script
        else:
            self._original_script = input_script
        try:
            if not append:
                self._commands = []
            if isinstance(input_script, bytes):
                input_script = input_script.decode("ascii")
            script_grammar = self.get_script_grammar()
            input_script = input_script.replace(" ", "").replace("\t", "")
            script_grammar.parseString(input_script, parseAll=True)
            if len(self.commands) == 0:
                default_logger().warning("Model script is empty")
            if len([cmd for cmd in self.commands if isinstance(cmd, NetworkGroupCommand)]) > 8:
                raise BackendScriptParserException("More than 8 network group command is prohibited.")
            return self.commands

        except pp.ParseException as e:
            raise BackendScriptParserException(f"Parsing failed at:\n{e.markInputline()}")

    def parse_script_from_har(self, har):
        self._har = har
        nms_config = har.nms_config_file if har.nms_config_file else None
        if har.model_script:
            self.parse_script(har.model_script, nms_config_file=nms_config)

    def get_script_grammar(self):
        # functions grammar
        glob_syntax_void_funcs = get_funcs_list(GLOB_COMMANDS)
        void_funcs = get_funcs_list(VOID_COMMANDS)
        single_return_funcs = get_funcs_list(SINGLE_RETURN_COMMANDS)
        multiple_return_funcs = get_funcs_list(MULTIPLE_RETURN_COMMANDS)
        dict_return_funcs = get_funcs_list(DICT_RETURN_COMMANDS)

        args_exp = self._left_par + self._args + self._right_par
        glob_syntax_args_exp = self._left_par + self._glob_syntax_args + self._right_par
        empty_args_exp = self._left_par + self._right_par
        void_exp = void_funcs + args_exp
        identifier_void_exp = self._identifier("object") + pp.Literal(".") + void_funcs + args_exp
        glob_syntax_args_void_exp = glob_syntax_void_funcs + glob_syntax_args_exp
        void_exp_no_args = void_funcs + empty_args_exp
        single_return_val_exp = self._identifier("single_return_val") + self._eq + single_return_funcs + args_exp
        multiple_return_vals = (
            pp.delimitedList(self._identifier, ",")("multiple_return_vals") + self._eq + multiple_return_funcs
        )
        multiple_return_vals_exp = multiple_return_vals + args_exp
        multiple_return_vals_exp_no_args = multiple_return_vals + empty_args_exp

        LBRACE, RBRACE, LBRACK, RBRACK, COMMA, COLON = map(pp.Literal, "{}[],:")

        # Define elements for key-value pairs with variable names
        value = pp.Group(pp.delimitedList(self._identifier_val, delim=COMMA))
        key_value = pp.Group(self._identifier_val + COLON.suppress() + LBRACK.suppress() + value + RBRACK.suppress())
        key_value_pairs = pp.Optional(pp.Group(pp.delimitedList(key_value, delim=COMMA))("dict_return_vals"))

        # Define the main expression
        dict_return_vals_exp = (
            LBRACE.suppress() + key_value_pairs + RBRACE.suppress() + self._eq + dict_return_funcs + args_exp
        )

        command = pp.LineStart() + (
            glob_syntax_args_void_exp
            | void_exp
            | void_exp_no_args
            | identifier_void_exp
            | single_return_val_exp
            | multiple_return_vals_exp
            | multiple_return_vals_exp_no_args
            | dict_return_vals_exp
        )
        script_grammar = pp.ZeroOrMore(command).ignore(self._comment).ignore(self._empty_line)

        void_exp.setParseAction(self._parse_void_func)
        identifier_void_exp.setParseAction(self._parse_identifier_void_func)
        void_exp_no_args.setParseAction(self._parse_void_func)
        glob_syntax_args_void_exp.setParseAction(self._parse_void_func)
        single_return_val_exp.setParseAction(self._parse_single_return_val_func)
        multiple_return_vals_exp.setParseAction(self._parse_multiple_return_val_func)
        multiple_return_vals_exp_no_args.setParseAction(self._parse_multiple_return_val_func)
        dict_return_vals_exp.setParseAction(self._parse_dict_return_val_func)
        return script_grammar

    def to_pb(self, pb_wrapper):
        script_pb = pb_wrapper.integrated_hw_graph_base_pb2.ProtoAllocatorScript()
        script_pb.version = type(self).DEFAULT_VERSION
        commands_pb = []
        for command in self._commands:
            if command.group not in self.MODEL_OPTIMIZATION_COMMAND_GROUPS:
                if not self.is_multi_scope:
                    command.remove_scope()
                new_commands = command.to_pb(pb_wrapper)
                if isinstance(new_commands, list):
                    commands_pb.extend(new_commands)
                else:
                    commands_pb.append(new_commands)

        script_pb.commands.extend(commands_pb)
        return script_pb

    def filter_commands_by_types(self, cmd_types):
        # clean all commands except of specified types
        self._commands = list(
            filter(lambda cmd: sum(isinstance(cmd, cmd_type) for cmd_type in cmd_types), self._commands),
        )

    def remove_internal_commands(self):
        self._commands = list(
            filter(lambda cmd: cmd.function_name.value not in self.EXCLUDE_FROM_AUTO_ALLS, self._commands),
        )

    @property
    def sort_auto_alls(self):
        return self._sort_auto_alls

    @sort_auto_alls.setter
    def sort_auto_alls(self, sort_auto_alls):
        self._sort_auto_alls = sort_auto_alls

    @property
    def sorting_disabled(self):
        return self._sorting_disabled

    @sorting_disabled.setter
    def sorting_disabled(self, sorting_disabled):
        self._sorting_disabled = sorting_disabled

    def save_auto_alls(self):
        self.save(get_auto_alls_path(self._hn.name), "### Auto-generated model script ###\n", is_to_auto_alls=True)

    def save(self, path, title=None, is_to_auto_alls=False):
        with open(path, "w") as f:
            print(title, file=f)
            print(self._build_script_str(self.is_single_scope, is_to_auto_alls), file=f)

    def add_scope_to_commands(self, scope_names):
        for command in self._commands:
            command.add_scope(scope_names, force=True)

    def command_group_iterator(self, group):
        """
        Returns an iterator of commands related to quantization (CommandsGroups.QUANTIZATION).
        """
        return filter(lambda x: x.group == group, self._commands)

    def update_model_optimization_commands(self):
        """
        Finalize model optimization commands loading by iterating all the commands and for each one:
            1. expand the glob syntax
            2. validate the commands
        """
        for cmd in self.optimization_commands_iterator():
            cmd.expand_glob(self._layers_in_hn, self._hn.net_params.net_scopes)
            cmd.validate_command(self._layers_scope)

    def export_model_optimization_commands(self):
        """
        Returns the mo configuration generated from the model optimization commands in the script.
        This must be called after `update_model_optimization_commands`.
        """
        mo_commands_export = {}
        for cmd in self.optimization_commands_iterator():
            update_nested(mo_commands_export, cmd.export())
        return mo_commands_export

    def optimization_commands_iterator(self):
        cmds = self.command_group_iterator(CommandsGroups.QUANTIZATION)
        post_fuser_cmds = self.post_fuser_commands_iterator()
        return [x for x in cmds if x not in post_fuser_cmds]

    def post_fuser_commands_iterator(self):
        return filter(
            lambda x: isinstance(x, PreQuantizationOptimizationCommand) and x.POST_FUSER_COMMAND,
            self._commands,
        )

    def export_post_fuser_commands(self):
        """
        Returns the mo configuration generated from the model optimization commands in the script.
        """
        post_fuser_commands_export = {}

        for cmd in self.post_fuser_commands_iterator():
            cmd.expand_glob(self._layers_in_hn, self._hn.net_params.net_scopes)
            cmd.validate_command(self._layers_scope)
            update_nested(post_fuser_commands_export, cmd.export())

        return post_fuser_commands_export

    def export_optimization_goal_fps(self):
        """
        Return performance param FPS.
        """
        for cmd in filter(lambda x: x.function_name == SupportedCommands.PERFORMANCE_PARAM, self._commands):
            if cmd.is_fps_set and (len(cmd.components) == 0 or len(cmd.components) == 1 and cmd.components[0] == "*"):
                return cmd.fps

    def export_model_optimization_flavor(self):
        """
        Returns the mo configuration generated from the model optimization commands in the script.
        This must be called after `update_model_optimization_commands`.
        """
        mo_flavor_export = {}
        for cmd in self.command_group_iterator(CommandsGroups.QUANTIZATION_FLAVOR):
            update_nested(mo_flavor_export, cmd.export())
        return mo_flavor_export

    def reload_mo_commands(self, config: ModelOptimizationConfig, exclude_defaults=True):
        """
        Reload complete model optimization config as model script commands
        Args:
            config: mo config post verification
            exclude_defaults: whether default config should be reloaded or not

        """
        # Auto generate configuration from script as alls commands
        mo_commands_raw = config.to_commands(exclude_defaults)
        self._commands = [command for command in self.commands if command.group != CommandsGroups.QUANTIZATION]
        # Parse mo commands again, after verification and cleaning.
        script_grammar = self.get_script_grammar()
        input_script = "\n".join(mo_commands_raw)
        script_grammar.parseString(input_script, parseAll=True)
        self.update_model_optimization_commands()

    def export_model_modifications_commands(self):
        """
        Returns an iterator of commands related to model modification (CommandsGroups.MODEL_MODIFICATIONS).
        """
        return filter(lambda x: x.group == CommandsGroups.MODEL_MODIFICATIONS, self._commands)

    def export_quantization_param_commands(self):
        """
        Returns an iterator of commands related to model modification (CommandsGroups.QUANTIZATION_PARAM).
        """
        return filter(lambda x: x.function_name == SupportedCommands.QUANTIZATION_PARAM, self._commands)

    def export_performance_param_commands(self):
        """
        Returns an iterator of performance commands.
        """
        return filter(
            lambda x: isinstance(x, PerformanceParamCommand),
            self._commands,
        )

    def update_model(self, hn):
        """
        Update the reference HN for the script parser.
        The hn, layers_scope and layers_in_hn attributes are used when expanding the glob syntax and validating the
            commands.
        """
        self._hn = hn
        self._layers_scope = [layer.name for layer in self._hn]
        self._layers_in_hn = [layer.name for layer in self._hn]

    def _parse_void_func(self, s, loc, tokens):
        new_cmd = None
        command_string = s[loc:].split("\n")[0]
        line_number = len(s[:loc].split("\n"))
        if tokens.function_name == SupportedCommands.PLACE.value:
            new_cmd = PlaceCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.COMPILATION_PARAM.value:
            new_cmd = CompilationParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.PERFORMANCE_PARAM.value:
            new_cmd = PerformanceParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.ALLOCATOR_PARAM.value:
            new_cmd = AllocatorParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.BUFFER_CALC_PARAM.value:
            new_cmd = BufferCalcCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.BUFFERS.value:
            new_cmd = BuffersCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.STRATEGY.value:
            new_cmd = StrategyCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.QUANTIZATION_PARAM.value:
            new_cmd = QuantizationParamCommand.from_tokens(tokens, command_string, line_number)
        elif tokens.function_name == SupportedCommands.COMPRESSION_PARAMS.value:
            new_cmd = CompressionParamsCommands.from_tokens(tokens, command_string, line_number)
        elif tokens.function_name == SupportedCommands.PRINT_BUFFERS.value:
            new_cmd = PrintBuffersCommand()
        elif tokens.function_name == SupportedCommands.OPTIMIZE_BUFFERS.value:
            new_cmd = OptimizeBuffersCommand()
        elif tokens.function_name == SupportedCommands.FORCE_MAPPING.value:
            new_cmd = ForceMappingCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.FORCE_ROUTE.value:
            new_cmd = ForceRouteCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.CONTEXT_SWITCH_PARAM.value:
            new_cmd = ContextSwitchParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.HEF_PARAM.value:
            new_cmd = HefParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.LOGGER_PARAM.value:
            new_cmd = LoggerParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.INTERNAL_ALLOCATOR_PARAM.value:
            new_cmd = InternalAllocatorParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.COLLAPSE.value:
            new_cmd = CollapseCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.TRANSPOSE_CONCAT.value:
            new_cmd = TransposeConcatCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.INTERNAL_CONTEXT_SWITCH_PARAM.value:
            new_cmd = InternalContextSwitchParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.RESOURCES_PARAM.value:
            new_cmd = ResourcesParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.PRE_QUANTIZATION_OPTIMIZATION.value:
            new_cmd = PreQuantizationOptimizationCommand.from_tokens(tokens, command_string, line_number)
        elif tokens.function_name == SupportedCommands.POST_QUANTIZATION_OPTIMIZATION.value:
            new_cmd = PostQuantizationOptimizationCommand.from_tokens(tokens, command_string, line_number)
        elif tokens.function_name == SupportedCommands.MODEL_OPTIMIZATION_FLAVOR.value:
            new_cmd = ModelOptimizationFlavorCommand.from_tokens(tokens, command_string, line_number)
        elif tokens.function_name == SupportedCommands.MODEL_OPTIMIZATION_CONFIG.value:
            new_cmd = ModelOptimizationConfigCommand.from_tokens(tokens, command_string, line_number)
        elif tokens.function_name == SupportedCommands.PLATFORM_PARAM.value:
            new_cmd = PlatformParamCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.TRANSPOSE.value:
            new_cmd = TransposeCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.NMS_POSTPROCESS.value:
            new_cmd = NMSPostprocessCommand.from_tokens(tokens, self._path, self._nms_config_file)
        elif tokens.function_name == SupportedCommands.SET_SEED.value:
            new_cmd = SetSeedCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.CHANGE_OUTPUT_ACTIVATION.value:
            new_cmd = ChangeOutputActivationCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.REMOVE_NODE.value:
            new_cmd = RemoveNodeCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.CONVERT_TO_DENSE.value:
            new_cmd = ConvertToDenseCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.MIRROR.value:
            new_cmd = MirrorCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.SHARE_CONFIG.value:
            new_cmd = ShareConfigCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.SET_KV_CACHE_PAIR.value:
            new_cmd = SetKVCachePairCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.SET_KV_CACHE_GLOBAL_PARAMS.value:
            new_cmd = SetKVCacheGlobalParamsCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.ADD_RESOURCE.value:
            new_cmd = AddResourceCommand.from_tokens(tokens)

        self._process_void_func(new_cmd)

    def _process_void_func(self, new_cmd, index=None):
        if new_cmd is not None:
            self._add_scope(new_cmd)
            self._handle_allocation_commands(new_cmd, index=index)

    def _parse_identifier_void_func(self, s, loc, tokens):
        new_cmd = None
        if tokens.function_name == SupportedCommands.PLACE.value:
            new_cmd = ContextPlaceCommand.from_tokens(tokens)
        if tokens.function_name == SupportedCommands.CONTEXT_COMPILATION_PARAM.value:
            new_cmd = ContextCompilationParamCommand.from_tokens(tokens)
        if tokens.function_name == SupportedCommands.CONTEXT_RESOURCES_PARAM.value:
            new_cmd = ContextResourcesParamCommand.from_tokens(tokens)
        if tokens.function_name == SupportedCommands.CONTEXT_PERFORMANCE_PARAM.value:
            new_cmd = ContextPerformanceParamCommand.from_tokens(tokens)
        if tokens.function_name == SupportedCommands.REMOVE_NODE.value:
            new_cmd = RemoveNodeCommand.from_tokens(tokens)
        if tokens.function_name == SupportedCommands.CONVERT_TO_DENSE.value:
            new_cmd = ConvertToDenseCommand.from_tokens(tokens)

        self._process_void_func(new_cmd)

    def _process_cmd(self, new_cmd, index=None):
        if new_cmd.function_name in GLOB_COMMANDS:
            self._process_void_func(new_cmd, index=index)
        elif new_cmd.function_name in SINGLE_RETURN_COMMANDS:
            self._process_single_return_val_func(new_cmd, index=index)
        elif new_cmd.function_name in MULTIPLE_RETURN_COMMANDS:
            self._process_multiple_return_val_func(new_cmd, index=index)
        elif new_cmd.function_name in VOID_COMMANDS:
            self._process_void_func(new_cmd, index=index)
        elif new_cmd.function_name in DICT_RETURN_COMMANDS:
            self._process_dict_return_val_func(new_cmd, index=index)

    def _parse_single_return_val_func(self, s, loc, tokens):
        new_cmd = None
        if tokens.function_name in [
            SupportedCommands.SHORTCUT.value,
            SupportedCommands.PORTAL.value,
            SupportedCommands.DDR.value,
            SupportedCommands.L4_PORTAL.value,
        ]:
            new_cmd = ShortcutCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.OUTPUT_MUX.value:
            new_cmd = OutputMuxCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.CONCAT.value:
            new_cmd = ConcatCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.FROM_TF.value:
            new_cmd = FromTFCommand.from_tokens(tokens, self._hn)
        elif tokens.function_name == SupportedCommands.CONTEXT.value:
            new_cmd = ContextCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.OUTPUT_LAYER.value:
            new_cmd = OutputLayerCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.MERGE.value:
            new_cmd = MergeCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.FEATURE_SPLITTER.value:
            new_cmd = ShapeSplitterCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.NETWORK_GROUP.value:
            new_cmd = NetworkGroupCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.SHAPE_SPLITTER.value:
            new_cmd = ShapeSplitterCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.BUCKET.value:
            new_cmd = BucketCommand.from_tokens(tokens)

        self._process_single_return_val_func(new_cmd)

    def _process_single_return_val_func(self, new_cmd, index=None):
        if new_cmd is not None:
            self._add_scope(new_cmd)
            layers_to_append = None
            if new_cmd.function_name not in [
                SupportedCommands.CONTEXT,
                SupportedCommands.NETWORK_GROUP,
                SupportedCommands.BUCKET,
            ]:
                layers_to_append = new_cmd.function_return_vals
            self._handle_allocation_commands(new_cmd, layers_to_append=layers_to_append, index=index)

    def _parse_dict_return_val_func(self, s, loc, tokens):
        new_cmd = None
        if tokens.function_name == SupportedCommands.DEFUSE_BLOCK.value:
            new_cmd = DefuseBlockCommand.from_tokens(tokens)
            self._process_dict_return_val_func(new_cmd)

    def _parse_multiple_return_val_func(self, s, loc, tokens):
        new_cmd = None
        if tokens.function_name == SupportedCommands.DEFUSE.value:
            new_cmd = DefuseCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.MUX_DEMUX.value:
            new_cmd = MuxDemuxCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.CASCADE.value:
            new_cmd = CascadeCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.NORMALIZATION.value:
            new_cmd = NormalizationCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.TRANSPOSE_CONCAT.value:
            new_cmd = TransposeConcatCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.FORMAT_CONVERSION.value:
            new_cmd = FormatConversionCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.INPUT_CONVERSION.value:
            new_cmd = InputConversionCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.LOGITS_LAYER.value:
            new_cmd = LogitsLayerCommand.from_tokens(tokens)
        elif tokens.function_name == SupportedCommands.RESIZE.value:
            new_cmd = ResizeCommand.from_tokens(tokens)

        self._process_multiple_return_val_func(new_cmd)

    def _process_multiple_return_val_func(self, new_cmd, index=None):
        if new_cmd is not None:
            self._add_scope(new_cmd)
            layers_to_append, layers_to_remove = (None,) * 2
            if new_cmd.function_name == SupportedCommands.DEFUSE.value:
                layers_to_remove = new_cmd.layers_to_remove
                layers_to_append = new_cmd.defused_layers
            elif new_cmd.function_name == SupportedCommands.MUX_DEMUX.value:
                layers_to_append = new_cmd.output_mux_demux_layers
            elif new_cmd.function_name == SupportedCommands.CASCADE.value:
                layers_to_append = new_cmd.cascade_layers
            elif new_cmd.function_name == SupportedCommands.TRANSPOSE_CONCAT.value:
                layers_to_append = new_cmd.format_conversions
            elif new_cmd.function_name == SupportedCommands.FORMAT_CONVERSION.value:
                layers_to_append = new_cmd.function_return_vals
            self._handle_allocation_commands(
                new_cmd,
                layers_to_remove=layers_to_remove,
                layers_to_append=layers_to_append,
                index=index,
            )

    def _process_dict_return_val_func(self, new_cmd, index=None):
        if new_cmd is not None:
            self._add_scope(new_cmd)
            layers_to_append, layers_to_remove = (None,) * 2
            if new_cmd.function_name == SupportedCommands.DEFUSE_BLOCK.value:
                layers_to_remove = new_cmd.layers_to_remove
                layers_to_append = new_cmd.defused_layers

            self._handle_allocation_commands(
                new_cmd,
                layers_to_remove=layers_to_remove,
                layers_to_append=layers_to_append,
                index=index,
            )

    def _get_relevant_scope(self, new_cmd):
        if self.is_single_scope:
            return self._hn.net_params.net_scopes

        # In case of multi scope network
        scopes_in_command = new_cmd.scopes_in_command()

        # All layers have scope name (or no layers)
        if scopes_in_command is None or "" not in scopes_in_command:
            return None

        # Some layer scope names are missing
        scopes_in_command.remove("")

        if len(scopes_in_command) == 1:
            return scopes_in_command

        # Allocator commands validation is done in the command validation
        if new_cmd.group not in self.MODEL_OPTIMIZATION_COMMAND_GROUPS:
            return None

        # In case of model modification commands on input we can take the scope name from the input layers in the HN
        if isinstance(new_cmd, ModelModificationsOnInputLayerCommand) and len(scopes_in_command) == 0:
            input_layer_names = [input_layer.name for input_layer in self._hn.get_input_layers()]
            scopes_in_command = get_scopes_set_from_layers(input_layer_names)

        if len(scopes_in_command) == 0:
            raise BackendScriptParserException(f"All layer scope names are missing: {new_cmd!s}")
        elif len(scopes_in_command) > 1:
            raise BackendScriptParserException(f"Some layers scope names are missing: {new_cmd!s}")
        else:
            return scopes_in_command

    def _add_scope(self, new_cmd):
        scope_name = self._get_relevant_scope(new_cmd)
        if scope_name:
            if new_cmd.scopes_in_command():
                unfound_scopes = [scope for scope in new_cmd.scopes_in_command() if scope not in (["", *scope_name])]
                if len(unfound_scopes) > 0 and self._alls_ignore_invalid_cmds:
                    return
            new_cmd.add_scope(scope_name)

    def _gen_context_name(self, orig, old_prefix, new_prefix):
        scope, reminder = decode_prefix(new_prefix)
        old_scope, old_reminder = decode_prefix(old_prefix)
        base_name = orig.replace(old_reminder, "")
        base_name = base_name.replace(old_scope, "")
        return scope + "_" + reminder + "_" + base_name

    def _fix_mirrored_command(self, new_cmd, old_prefix, new_prefix):
        new_cmd.replace_prefix(old_prefix, new_prefix)

        if new_cmd.function_name == SupportedCommands.CONTEXT.value:
            new_cmd.set_name(self._gen_context_name(new_cmd.context_name, old_prefix, new_prefix))
        elif new_cmd.function_name == SupportedCommands.PLACE.value:
            new_cmd.set_context_name(self._gen_context_name(new_cmd.context_name, old_prefix, new_prefix))

    def _apply_commands_chunk(self, cmd_chunk, index, count):
        for prefix in cmd_chunk:
            for cmd in cmd_chunk[prefix]:
                self._process_cmd(cmd, index=(index + count))
                count += 1
        return count

    def _should_not_mirror(self, cmd):
        return_vals = cmd.function_return_vals
        return return_vals and any(
            [return_val in self._layers_scope and "placeholder" not in return_val for return_val in return_vals]
        )

    def _python_side_mirror(self, cmd):
        src_prefix = cmd.src_prefix
        dst_prefixes = cmd.dest_prefixes

        working_commands = self._commands[:]

        count = 0
        last_cmd_type = working_commands[0].function_name
        last_cmd_chunk = {}
        for index, old_cmd in enumerate(working_commands):
            is_prefix_in_command = old_cmd.prefix_in_command(src_prefix)
            if last_cmd_type != old_cmd.function_name or not is_prefix_in_command:
                count = self._apply_commands_chunk(last_cmd_chunk, index, count)
                last_cmd_chunk = {}
                last_cmd_type = old_cmd.function_name

            if is_prefix_in_command:
                for new_prefix in dst_prefixes:
                    temp = copy.deepcopy(old_cmd)
                    self._fix_mirrored_command(temp, src_prefix, new_prefix)

                    if self._should_not_mirror(temp):
                        continue

                    if new_prefix not in last_cmd_chunk:
                        last_cmd_chunk[new_prefix] = []

                    last_cmd_chunk[new_prefix].append(temp)

        self._apply_commands_chunk(last_cmd_chunk, index, count)

    def _handle_allocation_commands(self, cmd, layers_to_remove=None, layers_to_append=None, index=None):
        if index is None:
            index = len(self._commands)

        if self._mode == ModelScriptModes.OPTIMIZATION_MODE:
            self._commands.insert(index, cmd)
            return

        if self._alls_ignore_invalid_cmds and cmd.has_unfound_layers(self._layers_scope):
            if not cmd.handle_unfound_layers(self._layers_scope):
                # this command will not be appended, no need to continue
                return

        if cmd.group in self.MODEL_OPTIMIZATION_COMMAND_GROUPS:
            self._commands.insert(index, cmd)
            return

        if cmd.function_name == SupportedCommands.MIRROR.value:
            self._python_side_mirror(cmd)

        cmd.expand_glob(self._layers_scope, self._hn.net_params.net_scopes)
        cmd.validate_command(self._layers_scope)

        if layers_to_remove:
            for layer in layers_to_remove:
                self._layers_scope.remove(layer)

        if layers_to_append:
            self._layers_scope.extend(layers_to_append)

        self._commands.insert(index, cmd)
