import argparse
import json
import os
import sys
from contextlib import suppress
from json import JSONDecodeError
from tarfile import ReadError

from hailo_sdk_client.exposed_definitions import SUPPORTED_HW_ARCHS, States
from hailo_sdk_client.runner.client_runner import ClientRunner
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import hailo_tools_exception_handler
from hailo_sdk_client.tools.cmd_utils.base_utils import AllsAppendAction, CmdUtilsBaseUtil, CmdUtilsBaseUtilError
from hailo_sdk_client.tools.cmd_utils.cmd_definitions import ClientCommandGroups
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.targets.inference_targets import ParamsKinds

logger = default_logger()


class CompilerCLIException(CmdUtilsBaseUtilError):
    pass


class CompilerCLI(CmdUtilsBaseUtil):
    GROUP = ClientCommandGroups.MODEL_CONVERSION_FLOW
    HELP = "Compile Hailo model to HEF binary files"

    def __init__(self, parser):
        super().__init__(parser)
        parser.register("action", "alls_append", AllsAppendAction)
        parser.add_argument("har_path", type=str, help="Path to the HAR of the model to compile")
        parser.add_argument("--hw-arch", type=str, help="Hardware architecture to be used", choices=SUPPORTED_HW_ARCHS)
        parser.add_argument("--quantized-weights-path", type=str, help=argparse.SUPPRESS, default=None)
        parser.add_argument(
            "--model-script",
            type=str,
            default=None,
            action="alls_append",
            help="Path to a model script to use",
        )
        parser.add_argument("--output-dir", type=str, help="Save HEF to specified dir, defaults to current directory")
        auto_model_script_group = parser.add_mutually_exclusive_group(required=False)
        auto_model_script_group.add_argument(
            "--auto-model-script",
            type=str,
            help="Path to save the auto-generated model script",
        )
        parser.add_argument("--output-har-path", type=str, default=None, help="Write the compiled model to this path")
        parser.set_defaults(func=self.run)

    def run(self, args):
        self._compile(
            args.har_path,
            args.hw_arch,
            args.model_script,
            args.output_dir,
            args.auto_model_script,
            args.output_har_path,
            args.quantized_weights_path,
        )

    def _compile(
        self,
        model_path,
        hw_arch,
        alls=None,
        output_dir=None,
        auto_alls_path=None,
        output_har_path=None,
        quantized_weights_path=None,
    ):
        runner = self._initialize_runner(model_path, quantized_weights_path, hw_arch)

        sys.excepthook = hailo_tools_exception_handler  # hack to hide the python API traceback here

        if alls is not None:
            if isinstance(alls, list):
                alls = "{}\n".join(alls)
            runner.load_model_script(alls)

        logger.info("Compiling network")
        if output_dir:
            os.chdir(output_dir)

        hw_representation = runner.compile()
        logger.info("Compilation complete")

        hef_path = f"{runner.model_name}.hef"
        with open(hef_path, "wb") as f:
            f.write(hw_representation)
        logger.info(f"Saved HEF to: {os.path.abspath(hef_path)}")

        if auto_alls_path is not None:
            runner.save_autogen_allocation_script(auto_alls_path)
            logger.info(f"Saved auto alls to: {os.path.abspath(auto_alls_path)}")

        slim_mode = runner.state == States.COMPILED_SLIM_MODEL
        if not output_har_path:
            slim_str = "_slim" if slim_mode else ""
            output_har_path = f"{runner.model_name}_compiled{slim_str}.har"
        runner.save_har(output_har_path, compilation_only=slim_mode)

    @staticmethod
    def _initialize_runner(model_path, quantized_weights_path, hw_arch):
        runner = None
        if model_path.endswith(".har"):
            with suppress(ReadError):
                runner = ClientRunner(har=model_path, hw_arch=hw_arch)

        elif model_path.endswith(".hn"):
            with suppress(JSONDecodeError):
                with open(model_path) as model_file:
                    hn_json = json.load(model_file)
                runner = ClientRunner(hn=hn_json, hw_arch=hw_arch)
                if quantized_weights_path is None:
                    raise CompilerCLIException("--quantized-weights-path is required when compiling from HN")
                runner.load_params(quantized_weights_path, params_kind=ParamsKinds.TRANSLATED)

        if runner is None:
            raise CompilerCLIException("The given model must be a valid HAR file")

        return runner


def main():
    parser = argparse.ArgumentParser()
    compiler_cli = CompilerCLI(parser)
    parser_args = parser.parse_args()
    compiler_cli.run(parser_args)


if __name__ == "__main__":
    main()
