#!/usr/bin/env python

import os
import socket
import subprocess
import sys
from collections import OrderedDict
from enum import Enum
from traceback import format_exception

import numpy as np
import py

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PostprocessTarget
from hailo_sdk_client.allocator import integrated_hw_graph_pb2
from hailo_sdk_client.allocator.allocator_params import (
    AllocatorAgent,
    AllocatorParams,
    AllocatorStrategy,
    BuilderExitPoint,
)
from hailo_sdk_client.allocator.estimator import Estimator
from hailo_sdk_client.allocator.pb_wrapper import PbWrapper
from hailo_sdk_client.allocator.performance_params import get_max_compiler_optimization_level
from hailo_sdk_client.runner.exceptions import InvalidParserInputException
from hailo_sdk_client.sdk_backend.script_parser.commands import FromTFCommand, PrintBuffersCommand
from hailo_sdk_client.sdk_backend.script_parser.model_modifications_commands import ModelModificationsCommand
from hailo_sdk_client.sdk_backend.script_parser.model_optimization_commands import QuantizationParamCommand
from hailo_sdk_client.sdk_backend.script_parser.model_script_parser import ModelScriptParser
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import (
    BackendAllocatorException,
    BackendInternalException,
    HailoToolsException,
    ProfilingException,
)
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.tools_params import AutoInt
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_script_parser.model_script_modes import ModelScriptModes
from hailo_sdk_common.paths_manager.config import get_parsed_config_from_path
from hailo_sdk_common.paths_manager.paths import SDKPaths
from hailo_sdk_common.profiler.profiler_common import ProfilerModes
from hailo_sdk_common.serialization import client_server_api_pb2
from hailo_sdk_common.serialization.numpy_serialization import hailo_np_savez


class MappingException(BackendAllocatorException):
    pass


class HailoToolsTargets(Enum):
    """Enum-like class for different hailo_tools targets."""

    #: Perform a build of a network.
    BUILDER = "builder"

    #: Estimate resources of a network
    ESTIMATOR = "estimator"


def create_env():
    env = dict(os.environ)
    # Add the libraries for the or tools shared objects
    ld_library_path = env.get("LD_LIBRARY_PATH", "")

    sdk_paths = SDKPaths()
    env["LD_LIBRARY_PATH"] = "{}:{}:{}".format(
        ld_library_path,
        sdk_paths.join_hailo_tools_path("or-tools/dependencies/install/lib"),
        sdk_paths.join_hailo_tools_path("or-tools/lib"),
    )
    return env


def establish_server_client_channels():
    try:
        LOCALHOST = "127.0.0.1"
        host = LOCALHOST
        port = 0  # any available

        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind((host, port))
        port = s.getsockname()[1]
        s.listen(1)

        client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        client.connect((host, port))

        server, addr = s.accept()
        s.close()
    except Exception:
        raise HailoToolsException("Couldn't launch Hailo tool securely")

    return server, client, addr, host


def validate_communication_steup(client, server, integrated_pb_string):
    valid_client = isinstance(client.fileno(), int) and client.fileno() > -1
    valid_server = isinstance(server.fileno(), int) and server.fileno() > -1
    valid_pb = integrated_pb_string is not None
    if not (valid_client and valid_server and valid_pb):
        raise HailoToolsException("Hailo tool couldn't be launched successfully")


def recvall(sock, n):
    pb = bytearray()
    while len(pb) < n:
        temp = sock.recv(n - len(pb))
        if not temp:
            raise HailoToolsException("Hailo tool failed to transfer full pb")
        pb.extend(temp)
    return pb


def send_pb(socket, msg):
    # header
    size = str(int(len(msg))).zfill(10).encode("ascii")
    socket.sendall(size)
    # payload
    socket.sendall(msg)


def send_ack(socket):
    term_msg = bytearray(1)
    socket.sendall(term_msg)


def recv_pb(socket):
    # header
    size = int(recvall(socket, 10))
    # payload
    return recvall(socket, size)


def recv_request(socket):
    return recvall(socket, 100).decode()


def server_loop(server, builder_pb_input=None):
    hef_data = None
    output_proto_graph_bytes = None

    while True:
        try:
            req = recv_request(server)
        except HailoToolsException:
            server.close()
            break
        try:
            if req.startswith("download builder graph"):
                send_pb(server, builder_pb_input)
            elif req.startswith("upload graph"):
                output_proto_graph_bytes = recv_pb(server)
                send_ack(server)
            elif req.startswith("upload map"):
                hef_data = recv_pb(server)
                send_ack(server)
            else:
                raise AssertionError

        except HailoToolsException:
            # if a transaction has started we expect it to finish
            raise HailoToolsException("Unexpected proto-transfer transaction failure")

    return hef_data, output_proto_graph_bytes


def run_tool_from_binary(binary_path, tool_args, client=None, server=None, builder_pb_input=None):
    """
    run tool from binary in another process.
    binary_path - The path of the binary which would be executed
    tool_args - Additional arguments for executing the tool
    """
    cmd_args = [binary_path, *tool_args]

    # Add the libraries for the or tools shared objects
    env = create_env()
    process = subprocess.Popen(cmd_args, env=env, pass_fds=(client.fileno(),))
    client.close()

    hef_data, output_proto_graph_bytes = server_loop(server, builder_pb_input=builder_pb_input)
    try:
        _, err = process.communicate()
    except KeyboardInterrupt:
        process.wait(1)
        if not process.poll():
            process.terminate()
        raise HailoToolsException("Hailo tool has been interrupted by the user")

    if process.returncode != 0:
        _handle_hailo_tools_error(err, process)

    return hef_data, output_proto_graph_bytes


def _handle_hailo_tools_error(err, process):
    if os.path.exists("errors.pb"):
        hailo_tools_error = client_server_api_pb2.HailoToolsError()
        with open("errors.pb", "rb") as f:
            hailo_tools_error.ParseFromString(f.read())
        os.remove("errors.pb")
        raise HailoToolsException(
            hailo_tools_error.error_message,
            returncode=process.returncode,
            hailo_tools_error=hailo_tools_error,
        )

    msg = f"Hailo tools failed with error message: {process.returncode}"
    raise HailoToolsException(msg, returncode=process.returncode)


def run_hailo_tools(tool_args, exe_name, client=None, server=None, builder_pb_input=None):
    hailo_tools_path = SDKPaths().join_hailo_tools_path("build/" + exe_name)
    return run_tool_from_binary(
        hailo_tools_path,
        tool_args,
        client=client,
        server=server,
        builder_pb_input=builder_pb_input,
    )


class ContextKernelConnections:
    def __init__(self, context, kernel_connections):
        self.context = context
        self.kernel_connections = kernel_connections


def _hailo_tools_exception_hook(exception_type, exception, traceback):
    msg = f"{exception_type.__name__}: {exception}"
    default_logger().error(msg)
    default_logger().command("".join(format_exception(exception_type, exception, traceback)))


class HailoToolsRunner:
    INTEGRATED_HW_GRAPH_MODULE = integrated_hw_graph_pb2

    def __init__(
        self,
        hw_arch,
        model=None,
        max_cluster_util=1.0,
        fps=None,
        clk_freq=62500000.0,
        seed=0,
        number_of_clusters=None,
        number_of_preposts=None,
        timeout=None,
        auto_find_macros=False,
        dump_statistics=False,
    ):
        self.hw_arch = hw_arch
        self._hn = model
        if self.hn is not None:
            self.layers_by_index = model.layers_by_index
        if number_of_clusters is None:
            number_of_clusters = hw_arch.consts["CORE_PKG::N_CLUSTERS"]
        if number_of_preposts is None:
            number_of_preposts = hw_arch.consts["CORE_PKG::N_PREPOST_CLUSTERS"]
        self.kernel_connections = {}
        self.seed = seed
        self._max_cluster_util = max_cluster_util
        self._fps = fps
        self._clk_freq = clk_freq
        self._number_of_clusters = number_of_clusters
        self._number_of_preposts = number_of_preposts
        if timeout is None:
            timeout = AutoInt("automatic")
        elif isinstance(timeout, int):
            timeout = AutoInt(timeout)
        self._timeout = timeout
        self._script_parser = None
        self._auto_find_macros = auto_find_macros
        self._params_dir = None
        self._auto_alls = None
        self._dump_statistics = self.should_dump_statistics()
        self._input_pb_string = None
        self._output_integrated_pb_graph = None
        self._output_hef_data = None
        self._server = None
        self._client = None
        self._logger = default_logger()

    @property
    def script_parser(self):
        return self._script_parser

    @property
    def hn(self):
        return self._hn

    @hn.setter
    def hn(self, hn):
        self._hn = hn
        if hn is not None:
            self.layers_by_index = hn.layers_by_index

    @script_parser.setter
    def script_parser(self, script_parser):
        self._script_parser = script_parser

    def should_dump_statistics(self):
        if not SDKPaths().is_internal:
            return False

        config = get_parsed_config_from_path()
        if "sdk_server" in config and "dump_statistics" in config["sdk_server"]:
            return config["sdk_server"]["dump_statistics"] == "True"

        return False

    def _run_hailo_tools(self, tool_args):
        self._output_hef_data, self._output_integrated_pb_graph = run_hailo_tools(
            tool_args,
            "compiler",
            client=self._client,
            server=self._server,
            builder_pb_input=self._input_pb_string,
        )

    def run_estimator(
        self,
        network_graph_filename,
        output_filename,
        profiling_mode=None,
        original_hailo_nn=None,
        should_use_logical_layers=True,
        allocator_script=None,
        allocator_script_mode=ModelScriptModes.ALLOCATION_ONLY_MODE,
        translated_params=None,
        compilation_output_proto="",
        script_parser=None,
        flavor_config=None,
        mo_flavor=None,
        params=None,
        har=None,
        stream_fps=None,
        accuracy_data=None,
    ):
        if profiling_mode is ProfilerModes.PRE_PLACEMENT:
            self.create_post_partition(
                network_graph_filename,
                output_filename,
                allocator_script=allocator_script,
                allocator_script_mode=allocator_script_mode,
                params=translated_params,
                compilation_output_proto=compilation_output_proto,
                har=har,
            )
        else:
            # NOTE: be sure to pass the required allocator params
            pb_wrapper = PbWrapper()

            # NOTE: https://hailotech.atlassian.net/browse/SDK-13704
            pb_wrapper.protobuf_exporter(self.hn).save_integrated_pb(
                network_graph_filename,
                self._fps,
                self._clk_freq,
                ppu_clk_freq=self.hw_arch.ppu_clk_freq,
            )
            paths = [self.hw_arch.name, network_graph_filename, output_filename]

            try:
                self._run_hailo_tools(paths)
            except HailoToolsException as e:
                raise ProfilingException(f"Hailo estimator tool failed - {e}")
        estimator = self._init_estimator(
            self._output_integrated_pb_graph,
            profiling_mode,
            original_hailo_nn,
            should_use_logical_layers,
            translated_params,
            script_parser,
            flavor_config,
            mo_flavor,
            params,
            stream_fps,
            accuracy_data=accuracy_data,
        )
        if not SDKPaths().is_internal:
            self._safe_remove(network_graph_filename)
            self._safe_remove(output_filename)

        return estimator, self._auto_alls

    def _init_estimator(
        self,
        mapped_graph_data,
        profiling_mode,
        original_hailo_nn,
        should_use_logical_layers,
        translated_params,
        script_parser,
        flavor_config,
        mo_flavor,
        params,
        stream_fps,
        accuracy_data,
    ):
        return Estimator(
            self.hw_arch,
            mapped_graph_data,
            self.hn,
            self._clk_freq,
            profiling_mode,
            original_hailo_nn,
            should_use_logical_layers=should_use_logical_layers,
            translated_params=translated_params,
            script_parser=script_parser,
            flavor_config=flavor_config,
            mo_flavor=mo_flavor,
            params=params,
            stream_fps=stream_fps,
            accuracy_data=accuracy_data,
        )

    def create_script_parser(
        self,
        hn,
        model_script_mode=ModelScriptModes.OPTIMIZATION_MODE,
        alls_ignore_invalid_cmds=False,
    ):
        return ModelScriptParser(hn, mode=model_script_mode, alls_ignore_invalid_cmds=alls_ignore_invalid_cmds)

    def _extract_tensors(
        self, paths, params, expected_output_tensor, expected_pre_act_tensor, network_inputs, network_outputs
    ):
        if params:
            params_dir = self._save_to_npz(params, "params")
            self._params_dir = params_dir
            if params_dir is not None:
                paths.append(params_dir)
        else:
            paths.append('""')

    def establish_server_client_channels(self):
        server, client, addr, host = establish_server_client_channels()
        self._server = server
        self._client = client
        return server, client, addr, host

    def parse_allocation_script(
        self,
        allocator_script,
        allocator_script_mode,
        pb_wrapper,
        har=None,
        alls_ignore_invalid_cmds=False,
    ):
        self.script_parser = self.create_script_parser(
            self.hn,
            model_script_mode=allocator_script_mode,
            alls_ignore_invalid_cmds=alls_ignore_invalid_cmds,
        )
        if allocator_script is not None:
            if har:
                self._script_parser.parse_script_from_har(har)
            else:
                self._script_parser.parse_script(allocator_script)
                user_dir = os.environ["USER"]
                reflections_file_path = f"/tmp/{user_dir}/reflection.alls"
                os.makedirs(os.path.dirname(reflections_file_path), exist_ok=True)
                self._script_parser.save(reflections_file_path, "### Mirror Validation")

            if allocator_script_mode == ModelScriptModes.FULL_MODE:
                self._hn = self._script_parser.model
        return self._script_parser.to_pb(pb_wrapper)

    def parse_allocator_params(self, pb_wrapper, agent, strategy, exit_point):
        allocator_params_pb = pb_wrapper.integrated_hw_graph_base_pb2.ProtoAllocatorParams()
        allocator_params = self._init_allocator_params(agent, strategy, exit_point)
        allocator_params.to_pb(allocator_params_pb, pb_wrapper)
        return allocator_params_pb

    def run_builder(
        self,
        network_graph_filename,
        output_filename,
        compilation_output_proto="",
        agent=AllocatorAgent.AUTOMATIC,
        strategy=AllocatorStrategy.POSITIVE_SEARCH,
        exit_point=BuilderExitPoint.POST_COMPILATION,
        params=None,
        expected_output_tensor=None,
        expected_pre_acts=None,
        network_inputs=None,
        network_outputs=None,
        allocator_script=None,
        allocator_script_mode=ModelScriptModes.ALLOCATION_ONLY_MODE,
        compiler_statistics_path="",
        is_debug=False,
        nms_metadata=None,
        har=None,
        alls_ignore_invalid_cmds=False,
    ):
        pb_wrapper = PbWrapper()
        allocator_script_pb = self.parse_allocation_script(
            allocator_script,
            allocator_script_mode,
            pb_wrapper,
            har,
            alls_ignore_invalid_cmds=alls_ignore_invalid_cmds,
        )

        performance_params_commands = list(self._script_parser.export_performance_param_commands())
        if not performance_params_commands or all(
            command.optimization_level.value < get_max_compiler_optimization_level()
            for command in performance_params_commands
        ):
            self._logger.info(
                'To achieve optimal performance, set the compiler_optimization_level to "max" by adding performance_param(compiler_optimization_level=max) to the model script. Note that this may increase compilation time.'
            )
        exporter = pb_wrapper.protobuf_exporter(self.hn)
        allocator_params_pb = self.parse_allocator_params(pb_wrapper, agent, strategy, exit_point)

        # saves nms_metadata to proto only if nms should be run on CPU
        nms_metadata_pb = (
            nms_metadata.to_pb(pb_wrapper)
            if (nms_metadata is not None and nms_metadata.engine == PostprocessTarget.CPU)
            else None
        )

        server, client, addr, host = self.establish_server_client_channels()
        if is_debug:
            self._input_pb_string = exporter.save_integrated_pb(
                network_graph_filename,
                self._fps,
                self._clk_freq,
                allocator_params_pb,
                allocator_script_pb,
                nms_metadata_pb,
                ppu_clk_freq=self.hw_arch.ppu_clk_freq,
                shmifo_clk_freq=self.hw_arch.shmifo_clk_freq,
            )
            mode = "debug_mode"
        else:
            self._input_pb_string = exporter.save_integrated_pb(
                network_graph_filename,
                self._fps,
                self._clk_freq,
                allocator_params_pb,
                allocator_script_pb,
                nms_metadata_pb,
                ppu_clk_freq=self.hw_arch.ppu_clk_freq,
                shmifo_clk_freq=self.hw_arch.shmifo_clk_freq,
                dumps=True,
            ).SerializeToString()
            validate_communication_steup(client, server, self._input_pb_string)
            mode = "release_mode"

        paths = [
            self.hw_arch.name,
            self._hn.name,
            network_graph_filename,
            output_filename,
            compilation_output_proto,
            compiler_statistics_path,
            str(self._client.fileno()),
            mode,
        ]

        self._extract_tensors(paths, params, expected_output_tensor, expected_pre_acts, network_inputs, network_outputs)

        try:
            self._run_hailo_tools(paths)
        except HailoToolsException as e:
            print("\033[?25h")  # Bring back the cursur if it's still hidden
            compiler_msg = e.hailo_tools_error
            if compiler_msg:
                raise e.internal_exception("Compilation failed:", hailo_tools_error=compiler_msg) from None
            else:
                raise e.internal_exception("Compilation failed with unexpected crash") from None
        finally:
            if self._output_integrated_pb_graph is None and self._output_hef_data is None:
                output_proto_graph_bytes = bytearray()
                output_proto_map_bytes = bytearray()
                try:
                    with open(output_filename, "rb") as f:
                        output_proto_graph_bytes += f.read()
                    self._output_integrated_pb_graph = output_proto_graph_bytes
                except Exception:
                    self._logger.error("Failed to produce compiled graph")
                if compilation_output_proto:
                    try:
                        with open(compilation_output_proto, "rb") as f:
                            output_proto_map_bytes += f.read()
                        self._output_hef_data = output_proto_map_bytes
                    except Exception:
                        pass  # we don't always need the map

    def _init_allocator_params(self, agent, strategy, exit_point):
        return AllocatorParams(
            self._timeout,
            agent,
            strategy,
            self._max_cluster_util,
            exit_point,
            dump_statistics=self._dump_statistics,
        )

    @staticmethod
    def _safe_remove(file_path):
        try:
            os.remove(file_path)
        except OSError:
            default_logger().warning(f"Failed removing file {file_path}")

    def _save_to_npz(self, tensors_dict, file_name, keys_prefix="", overload_file=True):
        npz_file = SDKPaths().join_build_sdk(file_name + ".npz")

        if keys_prefix:
            tensors_dict = {f"{keys_prefix}/{key}": value for key, value in tensors_dict.items()}

        tensors_dict = {key: np.array(value).astype("float64") for key, value in tensors_dict.items()}

        if not os.path.isfile(npz_file) or overload_file:
            py.path.local(npz_file).ensure()
            hailo_np_savez(npz_file, **tensors_dict)
        else:
            # File exists, load to the file.
            data = np.load(npz_file)
            data = {**data, **tensors_dict}
            np.savez(npz_file, **data)

        return npz_file

    def deserialize_kernel_connections(self, proto_contexts_kernel_connections):
        contexts_kernel_connections = {}
        for context_kernel_connections in proto_contexts_kernel_connections:
            kernel_connections = {}
            for kernel_connection in context_kernel_connections.kernel_connections.connections:
                src_layer = self.layers_by_index[kernel_connection.src_layer_index]
                dst_layer = self.layers_by_index[kernel_connection.dst_layer_index]
                kernel_connections[(src_layer, dst_layer)] = kernel_connection
            contexts_kernel_connections[context_kernel_connections.context.context_name] = ContextKernelConnections(
                context_kernel_connections.context,
                kernel_connections,
            )
        return contexts_kernel_connections

    def deserialize(self, output_filename=None, blind=False, pb=None):
        if output_filename is not None and pb is None:
            # No pb was passed, parse pb out of output_filename
            integrated_hw_graph = self.INTEGRATED_HW_GRAPH_MODULE.ProtoIntegratedHWGraph()
            with open(output_filename, "rb") as f:
                integrated_hw_graph.ParseFromString(f.read())
                updated_hn = HailoNN.from_integrated_pb_file(output_filename, PbWrapper())
        elif output_filename is None and pb is None:
            raise HailoToolsException("deserialize must get pb or a output_filename")
        else:
            # Use the given pb
            integrated_hw_graph = pb
            updated_hn = HailoNN.from_integrated_pb_data(pb.SerializePartialToString(), PbWrapper())
        updated_hn.name = self.hn.name
        self._original_hn = self._hn
        self.hn = updated_hn

        self._deserialize(integrated_hw_graph, blind)

    def _stable_toposort_dict(self):
        stable_toposort_dict = OrderedDict()
        for idx, layer in enumerate(self.hn.stable_toposort(key="name")):
            stable_toposort_dict[layer.name] = idx
            if layer.defuse_name and layer.defuse_name not in stable_toposort_dict:
                stable_toposort_dict[layer.defuse_name] = idx
        return stable_toposort_dict

    def _deserialize(self, integrated_hw_graph, blind=False):
        # Deserialize allocation script
        script_pb = integrated_hw_graph.allocator_script
        if len(script_pb.commands) > 0:
            if not self._script_parser:
                self.script_parser = self.create_script_parser(self.hn)
            self._script_parser.filter_commands_by_types(
                [QuantizationParamCommand, FromTFCommand, PrintBuffersCommand, ModelModificationsCommand],
            )
            self._script_parser.load_pb(script_pb)
            self._script_parser.remove_internal_commands()
            self._auto_alls = self._script_parser.export_auto_alls()

    def _log_post_call_error(self, msg):
        pass

    def call_builder(self, network_graph_path, output_path, blind_deserialize=False, **kwargs):
        def _post_call():
            if self._params_dir:
                self._safe_remove(self._params_dir)
            if self._output_integrated_pb_graph is None:
                raise InvalidParserInputException("No output graph, deserialization failed.")
            integrated_hw_graph = self.INTEGRATED_HW_GRAPH_MODULE.ProtoIntegratedHWGraph()
            integrated_hw_graph.ParseFromString(self._output_integrated_pb_graph)
            self.deserialize(blind=blind_deserialize, pb=integrated_hw_graph)

        sys.excepthook = _hailo_tools_exception_hook
        try:
            self.run_builder(network_graph_path, output_path, **kwargs)
        except BackendInternalException:
            try:
                _post_call()
            except Exception as deserialize_e:
                self._log_post_call_error(deserialize_e)
            raise
        sys.excepthook = sys.__excepthook__
        _post_call()

    def create_mapping_and_jlfs(
        self,
        network_graph_path,
        output_path,
        compilation_output_proto="",
        agent=AllocatorAgent.AUTOMATIC,
        strategy=AllocatorStrategy.POSITIVE_SEARCH,
        auto_mapping=True,
        params=None,
        allocator_script=None,
        allocator_script_mode=ModelScriptModes.ALLOCATION_ONLY_MODE,
        compiler_statistics_path="",
    ):
        if self.hn.net_params.clusters_placement != [[]]:
            assert (
                len(self.hn.net_params.clusters_placement) <= self._number_of_clusters
            ), "Number of clusters in layer placements is larger than allowed number of clusters"

        self.call_builder(
            network_graph_path,
            output_path,
            compilation_output_proto=compilation_output_proto,
            agent=agent,
            strategy=strategy,
            exit_point=BuilderExitPoint.POST_COMPILATION,
            params=params,
            allocator_script=allocator_script,
            allocator_script_mode=allocator_script_mode,
            compiler_statistics_path=compiler_statistics_path,
        )

        return self._auto_alls, self._output_hef_data, self._output_integrated_pb_graph

    def create_mapping_and_full_build_hef(
        self,
        network_graph_path,
        output_path,
        compilation_output_proto="",
        agent=AllocatorAgent.AUTOMATIC,
        strategy=AllocatorStrategy.POSITIVE_SEARCH,
        auto_mapping=True,
        params=None,
        expected_output_tensor=None,
        expected_pre_acts=None,
        network_inputs=None,
        network_outputs=None,
        allocator_script=None,
        allocator_script_mode=ModelScriptModes.ALLOCATION_ONLY_MODE,
        compiler_statistics_path="",
        nms_metadata=None,
        har=None,
        alls_ignore_invalid_cmds=False,
    ):
        if self.hn.net_params.clusters_placement != [[]]:
            assert (
                len(self.hn.net_params.clusters_placement) <= self._number_of_clusters
            ), "Number of clusters in layer placements is larger than allowed number of clusters"

        self.call_builder(
            network_graph_path,
            output_path,
            compilation_output_proto=compilation_output_proto,
            agent=agent,
            strategy=strategy,
            exit_point=BuilderExitPoint.POST_CAT,
            params=params,
            expected_output_tensor=expected_output_tensor,
            expected_pre_acts=expected_pre_acts,
            network_inputs=network_inputs,
            network_outputs=network_outputs,
            allocator_script=allocator_script,
            allocator_script_mode=allocator_script_mode,
            compiler_statistics_path=compiler_statistics_path,
            nms_metadata=nms_metadata,
            har=har,
            alls_ignore_invalid_cmds=alls_ignore_invalid_cmds,
        )

        return self._auto_alls, self._output_hef_data, self._output_integrated_pb_graph

    def create_post_partition(
        self,
        network_graph_path,
        output_path,
        allocator_script=None,
        allocator_script_mode=ModelScriptModes.ALLOCATION_ONLY_MODE,
        params=None,
        compilation_output_proto="",
        har=None,
    ):
        self.call_builder(
            network_graph_path,
            output_path,
            blind_deserialize=True,
            exit_point=BuilderExitPoint.POST_PARTITION,
            allocator_script=allocator_script,
            allocator_script_mode=allocator_script_mode,
            params=params,
            compilation_output_proto=compilation_output_proto,
            har=har,
        )
        return self._auto_alls, self._output_hef_data, self._output_integrated_pb_graph

    def create_expansion(
        self,
        network_graph_path,
        output_path,
        allocator_script=None,
        allocator_script_mode=ModelScriptModes.ALLOCATION_ONLY_MODE,
        params=None,
        compilation_output_proto="",
    ):
        self.call_builder(
            network_graph_path,
            output_path,
            blind_deserialize=True,
            exit_point=BuilderExitPoint.POST_EXPANSION,
            allocator_script=allocator_script,
            allocator_script_mode=allocator_script_mode,
            params=params,
            compilation_output_proto=compilation_output_proto,
        )
        return self._auto_alls, self._output_hef_data, self._output_integrated_pb_graph
