#!/usr/bin/env python
import argparse
import json
import os
import pathlib
from typing import Dict, List, Optional, OrderedDict, Tuple

import numpy as np
import onnx
from prettytable import PrettyTable
from pydantic.v1 import BaseModel, Field

from hailo_sdk_client.exposed_definitions import States
from hailo_sdk_client.hailo_archive.hailo_archive import HailoArchive, HailoArchiveLoader
from hailo_sdk_client.tools.cmd_utils.base_utils import CmdUtilsBaseUtil
from hailo_sdk_client.tools.cmd_utils.cmd_definitions import ClientCommandGroups
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.serialization.numpy_serialization import hailo_np_savez
from hailo_sdk_common.targets.inference_targets import ParamsKinds

logger = default_logger()


class DiffStateConfig(BaseModel):
    start_names: List[List[str]] = Field([])
    end_names: List[List[str]] = Field([])
    topo_struct: List[Tuple[str, List[List[int]]]] = Field([])  # name and output shapes
    hn_params: List[OrderedDict] = Field([])
    quant_params: List[OrderedDict] = Field({})
    model_modifications: Optional[Dict]

    class Config:
        extra = "forbid"

    def is_real(self):
        return any(x[1] for x in self)


class HailoArchiveCLI(CmdUtilsBaseUtil):
    GROUP = ClientCommandGroups.ANALYSIS_AND_VISUALIZATION
    HELP = "Query and extract information from Hailo Archive file"
    STATE_TO_PARAMS_KIND = {
        States.HAILO_MODEL: ParamsKinds.NATIVE,
        States.FP_OPTIMIZED_MODEL: ParamsKinds.FP_OPTIMIZED,
        States.QUANTIZED_MODEL: ParamsKinds.TRANSLATED,
        States.COMPILED_MODEL: ParamsKinds.TRANSLATED,
    }
    STATE_TO_STR = {
        States.HAILO_MODEL: "Native model",
        States.FP_OPTIMIZED_MODEL: "FP optimized model",
        States.QUANTIZED_MODEL: "Quantized model",
        States.COMPILED_MODEL: "Compiled model",
    }

    def __init__(self, parser):
        super().__init__(parser)
        subparsers = parser.add_subparsers(dest="action")
        subparsers.required = True

        extract_parser = subparsers.add_parser(
            "extract",
            help="Extract the files from the given archive file. "
            "When using flags extract the specific files, "
            "otherwise extract all files.",
        )
        extract_parser.add_argument("har_path", type=str, help="Path for the Hailo archive")
        extract_parser.add_argument(
            "--original-model-dir",
            type=str,
            help="Directory path to extract the original model (TF/ONNX)",
        )
        extract_parser.add_argument("--hn-path", type=str, help="Path for the extracted HN")
        extract_parser.add_argument("--native-hn-path", type=str, help="Path for the extracted native HN")
        extract_parser.add_argument("--fp-hn-path", type=str, help="Path for the extracted full-precision HN")
        extract_parser.add_argument("--params-path", type=str, help="Path for the extracted native params")
        extract_parser.add_argument("--params-after-bn-path", type=str, help=argparse.SUPPRESS)
        extract_parser.add_argument("--params-fp-path", type=str, help="Path for the extracted full-precision params")
        extract_parser.add_argument(
            "--quantized-params-path",
            type=str,
            help="Path for the extracted translated params",
        )
        extract_parser.add_argument(
            "--statistics-params-path",
            type=str,
            help="Path for the extracted statistics params",
        )
        extract_parser.add_argument("--model-script-path", type=str, help="Path for the extracted model script,")
        extract_parser.add_argument(
            "--auto-model-script-path",
            type=str,
            help="Path for the extracted auto model script",
        )
        extract_parser.add_argument("--hef-path", type=str, help="Path for the extracted HEF")
        extract_parser.add_argument(
            "--preprocess-model-path",
            type=str,
            help="Path for the extracted pre-process model",
        )
        extract_parser.add_argument(
            "--postprocess-model-path",
            type=str,
            help="Path for the extracted post-process model",
        )
        extract_parser.add_argument("--nms-config-path", type=str, help="Path for the extracted NMS configuration file")
        extract_parser.add_argument(
            "--modifications-meta-data-path",
            type=str,
            help="Path for the extracted modifications metadata file",
        )

        info_parser = subparsers.add_parser("info", help="Print information for the given Hailo archive.")
        info_parser.add_argument("har_path", type=str, help="Path for the Hailo archive")
        info_parser.add_argument(
            "--verbose",
            action="store_true",
            default=False,
            help="Print extra stats about the HAR file",
        )

        diff_parser = subparsers.add_parser("diff", help="Compares two given Hailo archive files.")
        diff_parser.add_argument("har1_path", type=str, help="Path for the first Hailo archive.")
        diff_parser.add_argument("har2_path", type=str, help="Path for the second Hailo archive.")

        parser.set_defaults(func=self.run)

    def run(self, args):
        if args.action == "extract":
            self.extract_data(
                args.har_path,
                args.original_model_dir,
                args.hn_path,
                args.params_path,
                args.params_after_bn_path,
                args.params_fp_path,
                args.quantized_params_path,
                args.statistics_params_path,
                args.model_script_path,
                args.auto_model_script_path,
                args.hef_path,
                args.native_hn_path,
                args.fp_hn_path,
                args.preprocess_model_path,
                args.postprocess_model_path,
                args.nms_config_path,
                args.modifications_meta_data_path,
            )
        elif args.action == "info":
            self.print_info(args.har_path, args.verbose)
        elif args.action == "diff":
            self.diff_har([args.har1_path, args.har2_path])

    def extract_data(
        self,
        har_path,
        original_model_dir=None,
        hn_path=None,
        params_path=None,
        params_after_bn_path=None,
        params_fp_path=None,
        quantized_params_path=None,
        statistics_params_path=None,
        model_script_path=None,
        auto_model_script_path=None,
        hef_path=None,
        native_hn_path=None,
        fp_hn_path=None,
        preprocess_model_path=None,
        postprocess_model_path=None,
        nms_config_path=None,
        modifications_meta_data_path=None,
    ):
        with HailoArchiveLoader(har_path) as har_loader:
            model_name = har_loader.get_model_name()

            # Set default paths in case no file path flag used
            files = (
                original_model_dir,
                hn_path,
                params_path,
                params_after_bn_path,
                params_fp_path,
                quantized_params_path,
                statistics_params_path,
                model_script_path,
                auto_model_script_path,
                hef_path,
                native_hn_path,
                fp_hn_path,
                preprocess_model_path,
                postprocess_model_path,
                nms_config_path,
                modifications_meta_data_path,
            )
            extract_all = all(file is None for file in files)

            if extract_all:
                original_model_dir = original_model_dir if original_model_dir else os.getcwd()
                hn_path = hn_path if hn_path else f"{model_name}.hn"
                native_hn_path = native_hn_path if native_hn_path else f"{model_name}.native.hn"
                fp_hn_path = fp_hn_path if fp_hn_path else f"{model_name}.fp.hn"
                params_path = params_path if params_path else f"{model_name}.npz"
                params_after_bn_path = params_after_bn_path if params_after_bn_path else f"{model_name}.bn.npz"
                params_fp_path = params_fp_path if params_fp_path else f"{model_name}.fpo.npz"
                quantized_params_path = quantized_params_path if quantized_params_path else f"{model_name}.q.npz"
                statistics_params_path = statistics_params_path if statistics_params_path else f"{model_name}.stats.npz"
                model_script_path = model_script_path if model_script_path else f"{model_name}.alls"
                auto_model_script_path = auto_model_script_path if auto_model_script_path else f"{model_name}.auto.alls"
                hef_path = hef_path if hef_path else f"{model_name}.hef"
                preprocess_model_path = (
                    preprocess_model_path if preprocess_model_path else f"{model_name}_preprocess.onnx"
                )
                postprocess_model_path = (
                    postprocess_model_path if postprocess_model_path else f"{model_name}_postprocess.onnx"
                )
                nms_config_path = nms_config_path if nms_config_path else f"{model_name}_nms_config.json"
                modifications_meta_data_path = (
                    modifications_meta_data_path
                    if modifications_meta_data_path
                    else f"{model_name}_modifications_meta_data.json"
                )

            if original_model_dir:
                original_model_path = har_loader.extract_original_model(original_model_dir)
                if original_model_path:
                    logger.info(f"Original model (TF/ONNX) extracted to: {original_model_path}")

            if hn_path:
                hn = har_loader.get_hn()
                if hn:
                    with open(hn_path, "w") as hn_file:
                        hn_file.write(hn)
                        logger.info(f"HN extracted to: {os.path.abspath(hn_path)}")

            if native_hn_path:
                native_hn = har_loader.get_native_hn()
                if native_hn:
                    with open(native_hn_path, "w") as native_hn_file:
                        native_hn_file.write(native_hn)
                        logger.info(f"Native HN extracted to: {os.path.abspath(native_hn_path)}")

            if fp_hn_path:
                fp_hn = har_loader.get_fp_hn()
                if fp_hn:
                    with open(fp_hn_path, "w") as fp_hn_file:
                        fp_hn_file.write(fp_hn)
                        logger.info(f"Full-precision HN extracted to: {os.path.abspath(fp_hn_path)}")

            if params_path:
                params = har_loader.get_params()
                if params:
                    self._save_params(params_path, params)
                    logger.info(f"Params extracted to: {os.path.abspath(params_path)}")

            if params_after_bn_path:
                params_after_bn = har_loader.get_params(ParamsKinds.NATIVE_FUSED_BN)
                if params_after_bn:
                    self._save_params(params_after_bn_path, params_after_bn)
                    logger.info(f"Params after BN folding extracted to: {os.path.abspath(params_after_bn_path)}")

            if params_fp_path:
                params_fp = har_loader.get_params(ParamsKinds.FP_OPTIMIZED)
                if params_fp:
                    self._save_params(params_fp_path, params_fp)
                    logger.info(f"Full Precision params extracted to: {os.path.abspath(params_fp_path)}")

            if quantized_params_path:
                quantized_params = har_loader.get_params(ParamsKinds.TRANSLATED)
                if quantized_params:
                    self._save_params(quantized_params_path, quantized_params)
                    logger.info(f"Quantized params extracted to: {os.path.abspath(quantized_params_path)}")

            if statistics_params_path:
                statistics_params = har_loader.get_params(ParamsKinds.STATISTICS)
                if statistics_params:
                    self._save_params(statistics_params_path, statistics_params)
                    logger.info(f"Statistics params extracted to: {os.path.abspath(statistics_params_path)}")

            if model_script_path:
                model_script = har_loader.get_model_script()
                if model_script:
                    with open(model_script_path, "w") as model_script_file:
                        model_script_file.write(model_script)
                        logger.info(f"Model script extracted to: {os.path.abspath(model_script_path)}")

            if auto_model_script_path:
                auto_model_script = har_loader.get_auto_model_script()
                if auto_model_script:
                    with open(auto_model_script_path, "w") as auto_model_script_file:
                        auto_model_script_file.write(auto_model_script)
                        logger.info(f"Auto model script extracted to: {os.path.abspath(auto_model_script_path)}")

            if hef_path:
                hef = har_loader.get_hef()
                if hef:
                    with open(hef_path, "wb") as hef_file:
                        hef_file.write(hef)
                        logger.info(f"HEF extracted to: {os.path.abspath(hef_path)}")

            if preprocess_model_path:
                preprocess_model = har_loader.get_preprocess_model()
                if preprocess_model:
                    onnx.save_model(preprocess_model, preprocess_model_path)
                    logger.info(f"pre-process model extracted to: {os.path.abspath(preprocess_model_path)}")

            if postprocess_model_path:
                postprocess_model = har_loader.get_postprocess_model()
                if postprocess_model:
                    onnx.save_model(postprocess_model, postprocess_model_path)
                    logger.info(f"post-process model extracted to: {os.path.abspath(postprocess_model_path)}")

            if nms_config_path:
                nms_config = har_loader.get_nms_config_file()
                if nms_config:
                    with open(nms_config_path, "w") as nms_config_file:
                        json.dump(nms_config, nms_config_file, indent=4)
                    logger.info(f"NMS config file extracted to: {os.path.abspath(nms_config_path)}")

            if modifications_meta_data_path:
                meta_data = har_loader.get_modifications_meta_data_file()
                if meta_data:
                    with open(modifications_meta_data_path, "w") as meta_data_file:
                        json.dump(meta_data, meta_data_file, indent=4)
                    logger.info(
                        f"Modifications metadata file extracted to: {os.path.abspath(modifications_meta_data_path)}"
                    )

    @staticmethod
    def _save_params(params_path, params):
        with open(params_path, "wb") as save:
            hailo_np_savez(save, **dict(iter(params.items())))

    def print_info(self, har_path, verbose=False):
        with HailoArchiveLoader(har_path) as har_loader:
            model_name = har_loader.get_model_name()
            logger.info(f"Model Name: {model_name}")

            state = har_loader.get_state()
            logger.info(f'State: {state.value.replace("_", " ").title()}')

            nms_meta_arch = har_loader.get_nms_meta_arch()
            if nms_meta_arch:
                logger.info(f"NMS Meta Architecture: {nms_meta_arch.value.title()}")

            nms_engine = har_loader.get_nms_engine()
            if nms_engine:
                logger.info(f"NMS Target Device: {nms_engine.value.title()}")
            sdk_version = har_loader.get_sdk_version()
            if sdk_version:
                logger.info(f"SDK Version (that created the HAR): {sdk_version}")

            hw_arch = har_loader.get_hw_arch()
            if not hw_arch:
                hw_arch = f"missing, {HailoArchive.LEGACY_DEFAULT_HW_ARCH} will be used by default"
            logger.info(f"Hardware Architecture: {hw_arch}")

            logger.info("Files in HAR:")
            har_loader.list(verbose)

            if verbose:
                logger.info(f"HAR Size: {pathlib.Path(har_path).stat().st_size} Bytes")

                hn = har_loader.get_hn()
                if hn:
                    hailo_nn = HailoNN.from_hn(hn)
                    logger.info(f'Output Layers Order: {", ".join(hailo_nn.net_params.output_layers_order)}')
                    logger.info(f'Net Scopes: {", ".join(hailo_nn.net_params.net_scopes)}')
                    logger.info("Model Summary:")
                    hailo_nn.summary()

                original_model_meta = har_loader.get_original_model_meta()
                if original_model_meta:
                    logger.info(f"Original Model Info:{self._format_original_model_meta(original_model_meta)}")

    @staticmethod
    def _format_original_model_meta(original_model_meta):
        lines = []
        for k, v in original_model_meta.items():
            if k == "parsing_report":
                continue
            k = k.replace("_", " ").title()
            if isinstance(v, (list, dict)) and len(v) == 0:
                v = "None"
            elif isinstance(v, list):
                v = ", ".join(str(x) for x in v)
            lines.append(f"{k}: {v}")
        return "\n\t" + "\n\t".join(lines)

    def diff_har(self, paths):
        info1 = self._get_diff_tool_basic_info(paths[0])
        info2 = self._get_diff_tool_basic_info(paths[1])
        table = PrettyTable(["", "har1", "har2"], align="l")
        table.add_rows([[key, info1.get(key), info2.get(key)] for key in info1])

        success_msg = "HARs structural comparison was completed successfully."
        states = [States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL, States.QUANTIZED_MODEL]
        for state in states:
            break_next_state, found_diff = self._add_state_block(table, paths, info1, info2, state)
            if break_next_state:
                print(table)
                if found_diff:
                    logger.warning("Found differences between the HARs, see details above.")
                else:
                    logger.info(success_msg)
                return
        print(table)
        if info1["State"] != info2["State"]:
            logger.warning("HARs are structurally equal but in different states.")
        else:
            logger.info(success_msg)

    def _get_diff_tool_basic_info(self, har_path):
        info = {}
        with HailoArchiveLoader(har_path) as har_loader:
            info["Name"] = har_loader.get_model_name()
            info["State"] = har_loader.get_state()
            info["Version"] = har_loader.get_sdk_version()
        return info

    def _add_state_block(self, table, paths, info1, info2, state):
        state_info1 = self._get_diff_tool_info_by_state(info1, paths[0], state)
        state_info2 = self._get_diff_tool_info_by_state(info2, paths[1], state)
        if state_info1.is_real() and state_info2.is_real():
            break_next_state, found_diff = self._same_state_comparison(state_info1, state_info2, table, state, paths)
        elif state_info1.is_real() or state_info2.is_real():
            msg = f"Only one of the files is a {state}."
            self._add_state_rows_to_table(table, msg, state, state_info1, state_info2)
            break_next_state = True
            found_diff = True
        else:
            break_next_state = True
            found_diff = False
        return break_next_state, found_diff

    def _get_diff_tool_info_by_state(self, basic_info, har_path, state):
        with HailoArchiveLoader(har_path) as har_loader:
            hn = har_loader.get_hn()
            if state != basic_info["State"]:
                if state == States.HAILO_MODEL:
                    hn = har_loader.get_native_hn()
                elif state == States.FP_OPTIMIZED_MODEL:
                    if basic_info["State"] == States.HAILO_MODEL:
                        return DiffStateConfig()
                    hn = har_loader.get_fp_hn()
                elif basic_info["State"] not in [States.QUANTIZED_MODEL, States.COMPILED_MODEL]:
                    return DiffStateConfig()

            model = HailoNN.from_hn(hn)
            start_names = [x.original_names for x in model.get_real_input_layers()]
            end_names = [x.original_names for x in model.get_real_output_layers()]
            topo_struct = [(x.name_without_scope, x.output_shapes) for x in model.stable_toposort()]
            hn_params = [layer.to_hn().get("params", {}) for layer in model.stable_toposort()]
            quant_params = [layer.to_hn().get("quantization_params", {}) for layer in model.stable_toposort()]
            state_info = DiffStateConfig(
                start_names=start_names,
                end_names=end_names,
                topo_struct=topo_struct,
                hn_params=hn_params,
                quant_params=quant_params,
            )

            if state == States.FP_OPTIMIZED_MODEL:
                modif_metadata = har_loader.get_modifications_meta_data_file()
                if modif_metadata:
                    info_metadata = {}
                    for lname in modif_metadata["inputs"]:
                        info_metadata[lname] = [cmd["cmd_type"] for cmd in modif_metadata["inputs"][lname]]
                    for lname in modif_metadata["outputs"]:
                        info_metadata[lname] = [cmd["cmd_type"] for cmd in modif_metadata["outputs"][lname]]
                    state_info.model_modifications = info_metadata
        return state_info

    def _same_state_comparison(self, state_info1: DiffStateConfig, state_info2: DiffStateConfig, table, state, paths):
        msg = f"{self.STATE_TO_STR[state]} comparison"
        self._add_state_rows_to_table(table, msg, state, state_info1, state_info2)
        struct_diff = self._find_structure_diff(state_info1.topo_struct, state_info2.topo_struct)
        if struct_diff:
            table.add_row([f"Structural diff starts at index {struct_diff[0]}", struct_diff[1], struct_diff[2]])
            break_next_state = True
            found_diff = True
        else:
            # toposort params are equal for both models
            table.add_row(["Structures are identical.", "", ""])
            params_diff = self._find_params_diff(state_info1, state_info2)
            if params_diff:
                idx = params_diff[1]
                msg = f"{params_diff[0]} diff starts at layer {state_info1.topo_struct[idx]} (idx={idx})"
                table.add_row([msg, "", ""])
                break_next_state = True
                found_diff = True
            else:
                table.add_row(["HN params are identical.", "", ""])
                weight_diff = self._find_weights_diff(paths, state, state_info1.topo_struct)
                if weight_diff:
                    msg = f'Weights diff starts at layer {weight_diff[0]} at {weight_diff[1].split(":")[0]}'
                    table.add_row([msg, "", ""])
                    break_next_state = True
                    found_diff = True
                else:
                    table.add_row(["Weights are identical.", "", ""])
                    break_next_state = False
                    found_diff = False
        return break_next_state, found_diff

    def _add_state_rows_to_table(self, table, msg, state, info1: DiffStateConfig, info2: DiffStateConfig):
        table.add_row(["", "", ""])
        table.add_row([msg, "", ""])
        table.add_row(["Start Node Names", info1.start_names, info2.start_names])
        table.add_row(["End Node Names", info1.end_names, info2.end_names])
        if state == States.FP_OPTIMIZED_MODEL:
            table.add_row(["Model Modifications", info1.model_modifications, info2.model_modifications])

    def _find_structure_diff(self, topo_struct1, topo_struct2):
        for structure1, structure2 in zip(topo_struct1, topo_struct2):
            if structure1 != structure2:
                return (topo_struct1.index(structure1), structure1, structure2)
        if len(topo_struct1) < len(topo_struct2):
            return (len(topo_struct1), None, topo_struct2[len(topo_struct1)])
        if len(topo_struct1) > len(topo_struct2):
            return (len(topo_struct2), topo_struct1[len(topo_struct2)], None)

    def _find_params_diff(self, state_info1: DiffStateConfig, state_info2: DiffStateConfig):
        for hn_params1, hn_params2 in zip(state_info1.hn_params, state_info2.hn_params):
            if hn_params1 != hn_params2:
                return ("Params", state_info1.hn_params.index(hn_params1))
        for quant_params1, quant_params2 in zip(state_info1.quant_params, state_info2.quant_params):
            if quant_params1 != quant_params2:
                return ("Quantization params", state_info1.quant_params.index(quant_params1))

    def _find_weights_diff(self, paths, state, toposort_hn_struct):
        with HailoArchiveLoader(paths[0]) as har_loader1:
            params1 = har_loader1.get_params(self.STATE_TO_PARAMS_KIND[state])
        with HailoArchiveLoader(paths[1]) as har_loader2:
            params2 = har_loader2.get_params(self.STATE_TO_PARAMS_KIND[state])
        if len(params1.network_names) == 1 and len(params2.network_names) == 1:
            scope1 = next(iter(params1.network_names))
            scope2 = next(iter(params2.network_names))
        for hn_layer_params in toposort_hn_struct:
            lname = hn_layer_params[0]
            layer_params1 = params1.get(f"{scope1}/{lname}")
            layer_params2 = params2.get(f"{scope2}/{lname}")
            if layer_params1 and layer_params2:
                for key in layer_params1:
                    if not np.allclose(layer_params1[key], layer_params2[key], equal_nan=True):
                        return (lname, key)
            elif layer_params1 or layer_params2:
                return (lname, "all")
