import argparse
import json
import os
import sys
from pathlib import Path

from hailo_model_optimization.acceleras.utils.acceleras_exceptions import SubprocessTracebackFailure
from hailo_sdk_client.exposed_definitions import SUPPORTED_HW_ARCHS
from hailo_sdk_client.model_translator.exceptions import (
    ParsingWithRecommendationException,
    UnsupportedInputFormatError,
    UnsupportedInputShapesError,
)
from hailo_sdk_client.runner.client_runner import ClientRunner
from hailo_sdk_client.tools.cmd_utils.base_utils import CmdUtilsBaseUtil, CmdUtilsBaseUtilError
from hailo_sdk_client.tools.cmd_utils.cmd_definitions import ClientCommandGroups
from hailo_sdk_client.tools.cmd_utils.utils import parse_dict_or_list_arg
from hailo_sdk_client.tools.frameworks_inference.dfc_frameworks_inference_tool import HailoParserInferenceTool
from hailo_sdk_client.tools.tf_proto_helper import TF_OPTIONAL_EXTENSIONS
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.logger.logger import default_logger

logger = default_logger()


class SavedModelError(Exception):
    pass


class ParserCLIException(CmdUtilsBaseUtilError):
    pass


class ParsingWithRecommendationCLIException(ParsingWithRecommendationException):
    def __init__(
        self,
        message,
        client_message=None,
        recommended_start_node_names=None,
        recommended_end_node_names=None,
        command=None,
    ):
        super().__init__(message, client_message, recommended_start_node_names, recommended_end_node_names)
        self._command = command

    def __str__(self):
        def _format_recommended_node_names(flag, recommended_node_names):
            # Removes the flag and the following args in case it exists in the original command
            if flag in self._command:
                node_names_start_index = self._command.index(flag)
                node_names_end_index = len(self._command)
                for i in range(node_names_start_index + 1, node_names_end_index):
                    if self._command[i].startswith("--"):
                        node_names_end_index = i
                        break
                del self._command[node_names_start_index:node_names_end_index]

            # Adds the flag with the recommended node names
            self._command.append(flag)
            self._command.extend([f'"{name}"' for name in recommended_node_names])

        errors_str = self.client_message
        if self.recommended_start_node_names or self.recommended_end_node_names:
            errors_str += "\nPlease try to parse the model again, using:\n"
            if self.recommended_start_node_names:
                _format_recommended_node_names("--start-node-names", self.recommended_start_node_names)
            if self.recommended_end_node_names:
                _format_recommended_node_names("--end-node-names", self.recommended_end_node_names)
            errors_str += " ".join(self._command)
        return errors_str


class NetParser(CmdUtilsBaseUtil):
    GROUP = ClientCommandGroups.MODEL_CONVERSION_FLOW
    HELP = "Translate network to Hailo network"

    def __init__(self, parser):
        super().__init__(parser)
        subparsers = parser.add_subparsers(title="Source format", dest="input_framework")
        subparsers.required = True

        tf_arg_parser = subparsers.add_parser("tf", help="Tensorflow parser")
        tf_help_messages = {
            "model_path": "Tensorflow Lite model file path, *.tflite",
            "start_node_names": "List of names for the first Tensorflow nodes to parse",
            "end_node_names": "List of TF nodes, where parsing stops after all of them are parsed",
        }
        self._add_common_args(tf_arg_parser, tf_help_messages)

        onnx_arg_parser = subparsers.add_parser("onnx", help="ONNX parser")
        onnx_help_msgs = {
            "model_path": "Path to the .onnx model file",
            "start_node_names": "List of names of the first ONNX nodes to parse.",
            "end_node_names": "List of ONNX nodes, where parsing stops after all of them are parsed. "
            "The order determines the order of the outputs.",
            "tensor_shapes": "List of start nodes and their shapes(<node_name>=<shape>), or a single "
            "shape for single input networks.",
            "input_format": "List of start nodes and their format(<node_name>=<format>), or a single "
            "format for single input networks. \nFormats should be written by their initials, for example NCHW "
            "(batch, channels, height, width), NGC (batch, groups, channels).",
        }
        self._add_common_args(onnx_arg_parser, onnx_help_msgs)
        onnx_arg_parser.add_argument("--augmented-path", type=str, help="Path to save the augmented.onnx file")
        onnx_arg_parser.add_argument(
            "--disable-rt-metadata-extraction",
            action="store_true",
            default=False,
            help="Disable ONNX metadata extraction. When set, the har-onnx-rt cannot be used",
        )
        onnx_arg_parser.add_argument(
            "--input-format", type=str, nargs="+", default=None, help=onnx_help_msgs["input_format"]
        )
        onnx_arg_parser.add_argument(
            "--tensor-shapes", type=str, nargs="+", default=None, help=onnx_help_msgs["tensor_shapes"]
        )

        parser.set_defaults(func=self.run)
        self._tf_arg_parser = tf_arg_parser
        self._onnx_arg_parser = onnx_arg_parser
        self._nodes_recommendation = None
        self._end_node_names = None
        self._start_node_names = None
        self._net_input_format = None
        self._runner = None

    @property
    def runner(self):
        return self._runner

    @runner.setter
    def runner(self, runner):
        self._runner = runner

    @property
    def start_node_names(self):
        return self._start_node_names

    @property
    def end_node_names(self):
        return self._end_node_names

    @staticmethod
    def _add_common_args(parser, help_messages):
        parser.add_argument("model_path", type=str, help=help_messages["model_path"])
        parser.add_argument("--net-name", type=str, default=None, help="Name of the new Hailo network to generate")
        parser.add_argument("--har-path", type=str, default=None, help="Path to the new HAR (defaults to net-name)")
        parser.add_argument(
            "--start-node-names",
            type=str,
            nargs="+",
            default=None,
            help=help_messages["start_node_names"] + "\nExample: --start-node-names <start_node_name1> "
            "<start_node_name2> ...",
        )
        parser.add_argument(
            "--end-node-names",
            type=str,
            nargs="+",
            default=None,
            help=help_messages["end_node_names"] + "\nExample: --end-node-names <end_node_name1> <end_node_name2> ...",
        )
        parser.add_argument("--hw-arch", type=str, choices=SUPPORTED_HW_ARCHS, help="Hardware architecture to be used")
        parser.add_argument("--compare", action="store_true", default=False, help=argparse.SUPPRESS)
        parser.add_argument(
            "-y",
            action="store_true",
            default=False,
            help='Automatic yes to prompts. Assume "yes" as '
            "an answer to all parser recommendations and run non-interactively.",
        )
        parser.add_argument(
            "--parsing-report-path",
            type=str,
            help="File path to save the parsing report (saved only when provided)",
        )

    @staticmethod
    def _get_network_name(args):
        if args.net_name:
            net_name = args.net_name
        else:
            if ".tflite" in args.model_path:
                token_splitter = ".tflite"
            elif args.model_path.endswith(".onnx"):
                token_splitter = ".onnx"
            else:
                raise SavedModelError("Cant recognize net name from model path")

            net_name = args.model_path.split(token_splitter)[0].split("/")[-1]

        return HailoNN.get_valid_input_identifier(net_name, "net_name")

    def run(self, args, save_model=True):
        net_name = self._get_network_name(args)
        argv = sys.argv
        command = [os.path.basename(argv[0]), *argv[1:]]
        tensor_shapes = None
        try:
            if hasattr(args, "tensor_shapes"):
                tensor_shapes = parse_dict_or_list_arg(args.tensor_shapes)
        except CmdUtilsBaseUtilError as err:
            raise ParserCLIException(err.args[0]) from err

        try:
            self._start_node_names = args.start_node_names
            self._end_node_names = args.end_node_names
            if args.input_framework == "onnx":
                self._net_input_format = parse_dict_or_list_arg(args.input_format, arg_name="input_format")

            try:
                self._parse(net_name, args, tensor_shapes)
                self._finalize_parsing_cli(net_name, args, self.runner, tensor_shapes, save_model)

            except SubprocessTracebackFailure as subprocess_err:
                raise subprocess_err.inner_error from subprocess_err

        except CmdUtilsBaseUtilError as err:
            raise ParserCLIException(err.args[0]) from err

        except UnsupportedInputFormatError as err:
            msg = str(err)
            if err.recommendation:
                example_command = command.copy()
                idx = command.index("--input-format") + 1
                example_command[idx] = "".join([dim[0].upper() for dim in err.recommendation])
                msg += f" Please try parsing the model again, for example: {' '.join(example_command)}"
            raise ParserCLIException(msg) from err

        except UnsupportedInputShapesError as err:
            example_command = " ".join(command) + " --tensor-shapes " + err.recommendation.replace(" ", "")
            msg = f"Please try to parse the model again, using: --tensor-shapes,\n e.g. {example_command}"
            raise ParserCLIException(f"Could not parse the model due to dynamic shapes. {msg}")

        except ParsingWithRecommendationException as err:
            self._handle_recommendation_exception(err, args, net_name, tensor_shapes, command, save_model)

        except UnsupportedModelError as err:
            raise ParserCLIException(str(err).replace("net_input_format", "input_format")) from err

    def _handle_recommendation_exception(self, err, args, net_name, tensor_shapes, command, save_model):
        final_err = None
        self._set_interactive_msg(err)
        if (err.recommended_start_node_names or err.recommended_end_node_names) and (
            args.y or self._interactive_handling()
        ):
            logger.info(f"According to recommendations, retrying parsing with {self._nodes_recommendation}")
            if err.recommended_start_node_names:
                self._start_node_names = err.recommended_start_node_names
            if err.recommended_end_node_names:
                self._end_node_names = err.recommended_end_node_names
            try:
                self._parse(net_name, args, tensor_shapes)
                self._finalize_parsing_cli(net_name, args, self.runner, tensor_shapes, save_model)
            except ParsingWithRecommendationException as err2:
                final_err = err2
        else:
            final_err = err
        if final_err is not None:
            if args.parsing_report_path and final_err.parsing_report is not None:
                with open(args.parsing_report_path, "w") as f:
                    f.write(json.dumps(final_err.parsing_report.dict(), indent=4))
            raise ParsingWithRecommendationCLIException(
                final_err.client_message,
                recommended_start_node_names=final_err.recommended_start_node_names,
                recommended_end_node_names=final_err.recommended_end_node_names,
                command=command,
            ) from None

    def _set_interactive_msg(self, err):
        msg = ""
        starts = False
        if err.recommended_start_node_names:
            starts = True
            msg += f"start node names: {err.recommended_start_node_names} "
        if err.recommended_end_node_names:
            msg += "and end node names: " if starts else "end node names: "
            msg += f"{err.recommended_end_node_names}."
        self._nodes_recommendation = msg

    def _interactive_handling(self):
        msg = f"Parsing failed with recommendations for {self._nodes_recommendation}\n"
        msg += "Would you like to parse again with the recommendation? (y/n) \n"
        return input(msg).lower().replace(" ", "") == "y"

    def _parse(self, net_name, args, tensor_shapes):
        if not self.runner:
            self.runner = ClientRunner(hw_arch=args.hw_arch)

        if args.model_path.endswith(TF_OPTIONAL_EXTENSIONS):
            self.runner.translate_tf_model(
                args.model_path,
                net_name,
                start_node_names=self._start_node_names,
                end_node_names=self._end_node_names,
            )
        elif args.model_path.endswith(".onnx"):
            self.runner.translate_onnx_model(
                args.model_path,
                net_name,
                self._start_node_names,
                self._end_node_names,
                tensor_shapes,
                args.augmented_path,
                disable_rt_metadata_extraction=args.disable_rt_metadata_extraction,
                net_input_format=self._net_input_format,
            )
        else:
            raise UnsupportedModelError(
                "Failed to analyze the model, it appears the model provided is in an unsupported format. "
                "If you are using a TF1.x model (such as .ckpt or .pb), or a TF2.x model "
                "(such as .h5 or saved_model.pb), please refer to the user guide for details on how to "
                "convert to TensorFlow Lite format.",
            )

    def _finalize_parsing_cli(self, net_name, args, runner: ClientRunner, tensor_shapes, save_model=True):
        if runner.nms_config_file:
            msg = "Would you like to "
            detected_end_nodes = runner._original_model_meta.get("detected_anchors", {}).get("end_nodes")
            should_reparse = False
            parsed_end_nodes = {x.original_names[-1] for x in runner.get_hn_model().get_real_output_layers()}
            if detected_end_nodes and set(detected_end_nodes) != parsed_end_nodes:
                msg += "parse the model again with the mentioned end nodes and "
                should_reparse = True
            msg += "add nms postprocess command to the model script? (y/n) \n"
            if args.y or input(msg).lower().replace(" ", "") == "y":
                if should_reparse:  # parse again with the relevant end nodes
                    self._end_node_names = detected_end_nodes
                    self._parse(net_name, args, tensor_shapes)
                nms_cmd = f"nms_postprocess(meta_arch={runner.nms_meta_arch.value})"
                runner.load_model_script(nms_cmd, append=True)
                logger.info("Added nms postprocess command to model script.")

        if save_model:
            self._save_model(args, net_name, runner)

        if args.compare:
            self._compare_results(args)

    def _save_model(self, args, net_name, runner, suffix=""):
        base_path = self._get_base_path(args, net_name, suffix)
        self._save_har(base_path, runner)
        if args.parsing_report_path:
            runner.save_parsing_report(args.parsing_report_path)

    @staticmethod
    def _get_base_path(args, net_name, suffix=""):
        if args.har_path:
            return args.har_path if not args.har_path.endswith(".har") else args.har_path[:-4]

        return net_name + suffix

    @staticmethod
    def _save_har(base_path, runner):
        har_path = base_path + ".har"
        runner.save_har(har_path)

    def _compare_results(self, args):
        supported_frameworks = ["onnx", "tflite"]
        framework = Path(args.model_path).suffix.strip(".")
        if framework not in supported_frameworks:
            logger.info("Compare tool currently supports only onnx and tflite models. Skipping.")
            return

        hailo_parser_inference_tool = HailoParserInferenceTool(
            self,
            model_path=args.model_path,
            start_node_names=self._start_node_names,
            end_node_names=self._end_node_names,
        )
        hailo_parser_inference_tool.dfc_validate_test_results(framework, args.model_path)
