import onnx

from hailo_sdk_client import ClientRunner
from hailo_sdk_client.exposed_definitions import States
from hailo_sdk_client.tools.cmd_utils.base_utils import CmdUtilsBaseUtil, CmdUtilsBaseUtilError
from hailo_sdk_client.tools.cmd_utils.cmd_definitions import ClientCommandGroups
from hailo_sdk_common.logger.logger import default_logger


class HarONNXException(CmdUtilsBaseUtilError):
    pass


class HarONNXCLI(CmdUtilsBaseUtil):
    GROUP = ClientCommandGroups.MODEL_CONVERSION_FLOW
    HELP = "Generates ONNX-Runtime model including pre/post processing"

    def __init__(self, parser):
        super().__init__(parser)
        self._logger = default_logger()
        parser.add_argument("har_path", type=str, help="Path to compiled HAR")
        parser.add_argument("--out-onnx-path", type=str, default=None, help="Path to save Hailo ONNX runtime model")
        parser.set_defaults(func=self.run)

    def run(self, args):
        runner = ClientRunner(har=args.har_path)
        if runner.state not in [States.COMPILED_MODEL, States.COMPILED_SLIM_MODEL]:
            err_msg = 'The given model must be a compiled HAR file. Please use "hailo compiler --output-har-path"'
            raise HarONNXException(err_msg)
        runtime_model = runner.get_hailo_runtime_model()
        out_path = args.out_onnx_path if args.out_onnx_path else args.har_path.replace(".har", "_hailo.onnx")
        onnx.save_model(runtime_model, out_path)

        self._logger.info(f"Generated Hailo ONNX runtime model saved in: {out_path}")
