import argparse
import os
from contextlib import suppress
from tarfile import ReadError

from hailo_sdk_client.exposed_definitions import JoinAction
from hailo_sdk_client.runner.client_runner import ClientRunner
from hailo_sdk_client.tools.cmd_utils.base_utils import CmdUtilsBaseUtil, CmdUtilsBaseUtilError
from hailo_sdk_client.tools.cmd_utils.cmd_definitions import ClientCommandGroups


class JoinException(CmdUtilsBaseUtilError):
    pass


class JoinCLI(CmdUtilsBaseUtil):
    GROUP = ClientCommandGroups.MODEL_CONVERSION_FLOW
    HELP = "Join two Hailo models to a single model"

    def __init__(self, parser):
        super().__init__(parser)
        parser.formatter_class = argparse.RawTextHelpFormatter

        parser.add_argument("har_path1", type=str, help="Path to HAR for the first model to join")
        parser.add_argument("har_path2", type=str, help="Path to HAR for the second model to join")
        parser.add_argument(
            "--join-action",
            type=JoinAction,
            default=JoinAction.NONE,
            choices=[JoinAction.NONE, JoinAction.AUTO_JOIN_INPUTS, JoinAction.AUTO_CHAIN_NETWORKS],
            help=f"""Type of action to run in addition to joining the models:
* {JoinAction.NONE}: Join the graphs without any connection
  between them.
* {JoinAction.AUTO_JOIN_INPUTS}: Automatically detects inputs
  for both graphs and combines them into one. This only works when
  both networks have a single input of the same shape.
* {JoinAction.AUTO_CHAIN_NETWORKS}: Automatically detects
  the output of this model and the input of the other model, and
  connects them. This only works when the model has a single
  output, and the other model has a single input, of the
  same shape.""",
        )
        parser.add_argument("--scope-name1", type=str, help="Scope name to use for all layers of the first model")
        parser.add_argument("--scope-name2", type=str, help="Scope name to use for all layers of the second model")
        parser.add_argument("--output-path", type=str, help="Path to HAR of the joined model")
        parser.set_defaults(func=self.run)

    def run(self, args):
        runner1 = self._initialize_runner(args.har_path1)
        runner2 = self._initialize_runner(args.har_path2)
        runner1.join(runner2, args.scope_name1, args.scope_name2, args.join_action)

        output_path = args.output_path if args.output_path else f"{runner1.model_name}.har"
        self._logger.info(f"Saved HAR to: {os.path.abspath(output_path)}")
        runner1.save_har(output_path)

    @staticmethod
    def _initialize_runner(model_path):
        with suppress(ReadError):
            return ClientRunner(har=model_path)

        raise JoinException(f"The given model {model_path} must be a valid HAR file")


def main():
    parser = argparse.ArgumentParser()
    join_runner = JoinCLI(parser)
    parser_args = parser.parse_args()
    join_runner.run(parser_args)


if __name__ == "__main__":
    main()
