import argparse
import sys
from contextlib import suppress
from json import JSONDecodeError
from tarfile import ReadError

from hailo_sdk_client.exposed_definitions import SUPPORTED_HW_ARCHS, States
from hailo_sdk_client.runner.client_runner import ClientRunner
from hailo_sdk_client.tools.cmd_utils.base_utils import CmdUtilsBaseUtil, CmdUtilsBaseUtilError
from hailo_sdk_client.tools.cmd_utils.cmd_definitions import ClientCommandGroups
from hailo_sdk_common.tools.weights_generator import generate_random_weights_for_model


class OptimizeCLIException(CmdUtilsBaseUtilError):
    pass


class OptimizeCLI(CmdUtilsBaseUtil):
    GROUP = ClientCommandGroups.MODEL_CONVERSION_FLOW
    HELP = "Optimize model"

    def __init__(self, parser):
        super().__init__(parser)
        self._runner = None
        parser.description = "Optimize the network model and parameters using a pre-processed calibration set."
        parser.formatter_class = argparse.RawDescriptionHelpFormatter

        parser.add_argument("har_path", type=str, help="Path to the HAR of the model to quantize")
        parser.add_argument("--hw-arch", type=str, choices=SUPPORTED_HW_ARCHS, help="Hardware architecture to be used")

        # one of calib set arguments is required when not using --full-precision-only
        is_required = "--full-precision-only" not in sys.argv
        optimization_mode_group = parser.add_mutually_exclusive_group(required=is_required)
        optimization_mode_group.add_argument(
            "--calib-set-path",
            type=str,
            dest="calibration_set_path",
            help="Path to the calibration set .npy file containing a numpy array of "
            "pre-processed images with shapes (calib_size, h, w, c)",
        )
        optimization_mode_group.add_argument(
            "--use-random-calib-set",
            action="store_true",
            default=False,
            help="The calibset-size is batch-size * calib-num-batch, it ignores the "
            "model script calibset-size (64 by default). In a case where the model "
            "has 3 input channels, real RGB images will be used instead of random "
            "data.",
        )
        parser.add_argument(
            "--full-precision-only",
            action="store_true",
            default=False,
            help="Run full-precision optimizations only. This stage is the "
            "intermediate stage between parsing (native) and optimizing the model "
            "(quantized). Running in this mode will apply model modifications and "
            "perform some pre-quantization algorithms.",
        )

        parser.add_argument(
            "--calib-random-max",
            type=int,
            default=1,
            help="Max value for the random calibration set. If the max value is greater than 127, "
            "the datatype will be int",
        )

        weights_arg_group = parser.add_mutually_exclusive_group(required=False)
        weights_arg_group.add_argument("--weights-path", type=str, help=argparse.SUPPRESS)
        weights_arg_group.add_argument(
            "--use-random-weights",
            action="store_true",
            default=False,
            help="Whether to generate random weights. Disclaimer: random weights might fail "
            "to quantize on some occasions",
        )

        parser.add_argument("--work-dir", type=str, help="If given, dump quantization debug outputs to this directory")
        model_script_group = parser.add_mutually_exclusive_group()
        model_script_group.add_argument(
            "--model-script",
            type=str,
            help="Path to model script allowing the modification of model current model, "
            "prior to quantization. If present, the script is parsed and a modified "
            "model is set, where each layer has (possibly) new quantization params.",
        )

        parser.add_argument(
            "--output-har-path",
            type=str,
            help="Write the quantized HAR to this path, the default is <model_name>_optimized.har",
        )

        parser.add_argument(
            "--compilation-only-har",
            action="store_true",
            default=False,
            help="Save a reduced size har, containing only compilation related data.",
        )
        parser.set_defaults(func=self.run)

    def run(self, args):
        self._initialize_runner(args)

        if args.use_random_weights:
            self._runner.load_params(generate_random_weights_for_model(self._runner.get_hn_model()))

        if args.model_script:
            self._runner.load_model_script(args.model_script)

        dataset = self.get_input_dataset(args)
        if self._runner.state == States.HAILO_MODEL:
            self._runner.optimize_full_precision(calib_data=dataset)
        if args.full_precision_only:
            output_har_suffix = "_fp_optimized.har"
        else:
            output_har_suffix = "_optimized_slim.har" if args.compilation_only_har else "_optimized.har"
            if args.use_random_calib_set or args.use_random_weights:
                self._runner._sdk_backend.set_default_optimization_flavor(compression_level=0, optimization_level=0)
            self._runner.optimize(dataset, work_dir=args.work_dir)

        output_har_path = (
            args.output_har_path if args.output_har_path else f"{self._runner.model_name}{output_har_suffix}"
        )

        self._runner.save_har(output_har_path, compilation_only=args.compilation_only_har)

    def _initialize_runner(self, args):
        runner = None
        if args.har_path.endswith(".har"):
            with suppress(ReadError):
                runner = ClientRunner(har=args.har_path, hw_arch=args.hw_arch)

        elif args.har_path.endswith(".hn"):
            with suppress(JSONDecodeError):
                if not args.weights_path and not args.use_random_weights:
                    raise OptimizeCLIException(
                        "either of --weights-path or --use-random-weights must be used when using quantize from hn",
                    )
                with open(args.har_path) as hn_file:
                    runner = ClientRunner(hn_file.read(), hw_arch=args.hw_arch)
                if not args.use_random_weights:
                    runner.load_params(args.weights_path)

        if runner is None:
            raise OptimizeCLIException("The given model must be a valid HAR file")

        err_msg_prefix = "Optimize"
        allowed_states = [States.HAILO_MODEL.value, States.FP_OPTIMIZED_MODEL.value]
        if args.full_precision_only:
            err_msg_prefix = "Full-precision optimize"
            allowed_states = [States.HAILO_MODEL.value]

        if runner.state.value not in allowed_states:
            raise OptimizeCLIException(
                f'{err_msg_prefix} is allowed only when the state is one of {", ".join(allowed_states)}. '
                f'Current state is {runner.state.value}',
            )

        self._runner = runner

    def get_input_dataset(self, args):
        if args.use_random_calib_set:
            # the calibration dataset will be generated on the fly
            self._runner._sdk_backend.calibration_data = None
            self._runner._sdk_backend.calibration_data_random_max = (
                args.calib_random_max if args.calib_random_max else 1
            )
        else:
            self._runner._sdk_backend.calibration_data = args.calibration_set_path

        return self._runner._sdk_backend.calibration_data


def main():
    parser = argparse.ArgumentParser()
    quantize_runner = OptimizeCLI(parser)
    parser_args = parser.parse_args()
    quantize_runner.run(parser_args)


if __name__ == "__main__":
    main()
