#!/usr/bin/env python

import argparse
import csv
from contextlib import suppress
from tarfile import ReadError

import numpy as np

from hailo_sdk_client.exposed_definitions import States
from hailo_sdk_client.hailo_archive.hailo_archive import HailoArchiveLoader
from hailo_sdk_client.tools.cmd_utils.base_utils import CmdUtilsBaseUtil
from hailo_sdk_client.tools.cmd_utils.cmd_definitions import ClientCommandGroups
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.logger.logger import DeprecatedAPI, DeprecationVersion, default_logger
from hailo_sdk_common.targets.inference_targets import ParamsKinds

CSV_DEFAULT_PATH = "translated_params.csv"
ALLOWED_HAR_STATES = [States.QUANTIZED_MODEL, States.COMPILED_MODEL]


class FIELDS:
    NAME = "layer_name"
    ORIGINAL_LAYER_NAME = "original_layer_name"
    LIMVALS_IN = "limvals_in"
    LIMVALS_OUT = "limvals_out"
    LIMVALS_PRE_ACT = "limvals_pre_act"
    LIMVALS_IN_PRE_SCALE_MATCHING = "limvals_in_pre_scale_matching"
    LIMVALS_OUT_PRE_SCALE_MATCHING = "limvals_output_pre_scale_matching"
    QP_IN = "qp_in"
    QP_OUT = "qp_out"
    LIMVALS_KERNEL = "limvals_kernel"
    SCALE_KERNEL = "scale_kernel"
    SCALE_BIAS = "scale_bias"
    BIAS_FACTOR = "bias_factor"
    BIAS_FEED_REPEAT = "bias_feed_repeat"
    MULT_SHIFT = "output_stage/mult_shift"
    ZP_APU_COMPENSATION = "output_stage/zp_apu_compensation"
    OUTPUT_FACTOR = "output_stage/output_factor"
    PA_SLOPES = "output_stage/piecewise/slopes"
    PA_OFFSETS = "output_stage/piecewise/offsets"
    PA_X_POINTS = "output_stage/piecewise/x_points"
    NEGATIVE_SLOPES_CORRECTION_FACTOR = "negative_slopes_correction_factor"

    ELEMENTWISE_INPUT_FACTOR = "elementwise_addition/input_factor"
    ELEMENTWISE_FEED_REPEAT = "elementwise_addition/feed_repeat"
    ELEMENTWISE_QP = "elementwise_addition/qp_elwa"
    ELEMENTWISE_LIMVALS = "elementwise_addition/limvals_elwa"
    ELEMENTWISE_PRE_SCALE_MATCHING = "elementwise_addition/limvals_elwa_pre_scale_matching"

    SHIFT_DELTA = "output_stage/shift_delta"

    FIELDS_LIST = [
        NAME,
        ORIGINAL_LAYER_NAME,
        LIMVALS_IN,
        LIMVALS_OUT,
        LIMVALS_IN_PRE_SCALE_MATCHING,
        LIMVALS_OUT_PRE_SCALE_MATCHING,
        QP_IN,
        QP_OUT,
        LIMVALS_KERNEL,
        SCALE_KERNEL,
        SCALE_BIAS,
        BIAS_FACTOR,
        BIAS_FEED_REPEAT,
        MULT_SHIFT,
        SHIFT_DELTA,
        ZP_APU_COMPENSATION,
        OUTPUT_FACTOR,
        PA_SLOPES,
        PA_OFFSETS,
        PA_X_POINTS,
        ELEMENTWISE_INPUT_FACTOR,
        ELEMENTWISE_FEED_REPEAT,
        ELEMENTWISE_QP,
        ELEMENTWISE_LIMVALS,
        ELEMENTWISE_PRE_SCALE_MATCHING,
        NEGATIVE_SLOPES_CORRECTION_FACTOR,
    ]

    FIELDS_LIST_HEADER = [
        NAME,
        ORIGINAL_LAYER_NAME,
        LIMVALS_IN,
        LIMVALS_OUT,
        LIMVALS_IN_PRE_SCALE_MATCHING,
        LIMVALS_OUT_PRE_SCALE_MATCHING,
        "qp_in (zp, scale)",
        "qp_out (zp, scale)",
        LIMVALS_KERNEL,
        SCALE_KERNEL,
        SCALE_BIAS,
        BIAS_FACTOR,
        BIAS_FEED_REPEAT,
        "mult_shift",
        "shift_delta",
        "zp_apu_compensation",
        "output_factor",
        "piecewise/slopes",
        "piecewise/offsets",
        "x_points",
        "input_factor_elwa",
        "feed_repeat_elwa",
        "qp_elwa (zp, scale)",
        "limvals_elwa",
        "limvals_elwa_pre_scale_matching",
        NEGATIVE_SLOPES_CORRECTION_FACTOR,
    ]


def prepare_row(layer, layer_name, params):
    def _remove_brackets(val):
        return str(val).replace("[", "").replace("]", "").replace("'", "")

    def _prettify_scale_matching_group(val):
        if field in (
            FIELDS.LIMVALS_IN_PRE_SCALE_MATCHING,
            FIELDS.LIMVALS_OUT_PRE_SCALE_MATCHING,
            FIELDS.ELEMENTWISE_PRE_SCALE_MATCHING,
        ):
            val = list(val)
            if val[0] != 0:
                val[0] = f"{val[0]:.5f}"
            val[1] = f"{val[1]:.5f}"
            if len(val) == 2:
                return val
            if val[2] is not None:
                val[2] = f"({int(val[2])})"
            else:
                del val[2]
        return val

    def _prettify_qp(val):
        if field in (FIELDS.QP_IN, FIELDS.QP_OUT, FIELDS.ELEMENTWISE_QP):
            val = list(val)
            val[0] = int(val[0])
            val[1] = f"{val[1]:.5f}"
        return val

    def _prettify_limvals(val):
        def _prettify_one_dim_array(val):
            if val.size > 1:
                val = val.tolist()
                for i, v in enumerate(val):
                    if v != 0:
                        val[i] = f"{v:.5f}"
            return val

        if field in (FIELDS.LIMVALS_IN, FIELDS.LIMVALS_OUT, FIELDS.LIMVALS_KERNEL, FIELDS.ELEMENTWISE_LIMVALS):
            return _prettify_one_dim_array(val)

        # pa slopes is a multi dimension numpy array
        elif field == FIELDS.PA_SLOPES:
            if val.ndim < 2:
                return _prettify_one_dim_array(val)
            else:
                return [_prettify_one_dim_array(arr) for arr in val]

        return val

    def _prettify_scale(val):
        if (
            field == FIELDS.SCALE_BIAS
            and (val.shape == () or np.all(val[0] == val))
            or field == FIELDS.OUTPUT_FACTOR
            or field == FIELDS.SCALE_KERNEL
            and (val.shape == ())
        ):
            val = f"{val:.6f}"
        elif field in {FIELDS.SCALE_KERNEL, FIELDS.OUTPUT_FACTOR, FIELDS.SCALE_BIAS}:
            val = "Err - vector"
        return val

    row = {FIELDS.NAME: layer_name, FIELDS.ORIGINAL_LAYER_NAME: _remove_brackets(layer.original_names)}

    for field in FIELDS.FIELDS_LIST:
        field_str = f"{layer_name}/{field}:0"
        if field_str in params:
            val = _prettify_scale_matching_group(params[field_str])
            val = _prettify_qp(val)
            val = _prettify_limvals(val)
            val = _prettify_scale(val)
            row[field] = _remove_brackets(val)
    return row


def prepare_header():
    header_row = {}
    for i, field in enumerate(FIELDS.FIELDS_LIST):
        header_row[field] = FIELDS.FIELDS_LIST_HEADER[i]
    return header_row


class NpzCsvRunnerError(Exception):
    pass


class NpzCsvRunner(CmdUtilsBaseUtil):
    GROUP = ClientCommandGroups.ANALYSIS_AND_VISUALIZATION
    HELP = "Convert translated params to csv"

    def __init__(self, parser):
        super().__init__(parser)
        self._logger = default_logger()
        parser.add_argument("har_path", type=str, help="Path to the HAR file")
        parser.add_argument("--csv-path", type=str, default=CSV_DEFAULT_PATH, help="Path to the CSV file")
        parser.set_defaults(func=self.run)

    def convert_translated_params_to_csv(self, model_path, csv_path):
        model, params = self._get_model_and_params(model_path)

        with open(csv_path, "w") as csv_file:
            csv_writer = csv.DictWriter(csv_file, FIELDS.FIELDS_LIST)
            header_row = prepare_header()
            csv_writer.writerow(header_row)

            for layer in model.layers_by_index.values():
                row = prepare_row(layer, layer.name, params)
                if row is not None:
                    csv_writer.writerow(row)
        self._logger.info(csv_path + " created")

    @staticmethod
    def _get_model_and_params(model_path):
        with suppress(ReadError), HailoArchiveLoader(model_path) as har_loader:
            state = har_loader.get_state()
            if state not in ALLOWED_HAR_STATES:
                raise NpzCsvRunnerError(
                    f'The HAR state must be one of {", ".join([state.value for state in ALLOWED_HAR_STATES])}',
                )
            model = HailoNN.from_hn(har_loader.get_hn())
            params = har_loader.get_params(ParamsKinds.TRANSLATED)
            return model, params

        raise NpzCsvRunnerError("The given model must be a valid HAR file")

    def run(self, args):
        raise DeprecatedAPI("`params-csv` command has been deprecated.", DeprecationVersion.OCT2024)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    npz_csv = NpzCsvRunner(parser)
    parser_args = parser.parse_args()
    npz_csv.run(parser_args)
