"""

Module for mixed precision helper functions and clases
this module provides classes in charge ofm manageing the results
of infers and calculate snr

Raises
    ValueError: _description_

Returns
    _type_: _description_

"""

from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ComprecisionMetric, PrecisionMode
from hailo_model_optimization.algorithms.lat_utils.lat_noise_metrics import SNRNoise


class InferResults(BaseModel):
    """Per layer infer results"""

    precision: str = Field("native")
    output_layer: str = Field(description="name of the output layer")
    layer_evaluated: Optional[str] = Field("")
    values: Optional[List[List[np.ndarray]]] = Field(description="Values of the outputs")

    class Config:
        arbitrary_types_allowed = True

    def update(self, values: List[np.array]):
        if self.values:
            for index, val in enumerate(values):
                self.values[index].append(val)
        else:
            self.values = [[val] for val in values]

    def get_values(self) -> List[np.ndarray]:
        out = []
        if self.values:
            for val in self.values:
                out.append(np.array(val))
        else:
            raise ValueError("values where never been given")
        return out


class InferResultsManager:
    """Use to manage inference results, update snrs and aggredte snr"""

    native_precision = "native"
    native_layer = "native"

    def __init__(self) -> None:
        self._infer_results: Dict[Tuple[str, str, str], InferResults] = {}
        self._compresion_metric: Dict[tuple[str, PrecisionMode, ComprecisionMetric]] = {}

    def update_native(self, value: List[np.ndarray], output_name: str):
        self._infer_results.setdefault(
            (output_name, self.native_layer, self.native_precision),
            InferResults(precision=self.native_precision, output_layer=output_name, layer_evaluated=self.native_layer),
        ).update(value)

    def update_precision(self, value: List[np.ndarray], output_name: str, layer_name: str, precision: PrecisionMode):
        self._infer_results.setdefault(
            (output_name, layer_name, precision.name),
            InferResults(precision=precision.name, output_layer=output_name, layer_evaluated=self.native_layer),
        ).update(value)

    def get_native_value(self, output_name: str):
        return self.get_values(output_name, self.native_layer, self.native_precision)

    def get_values(self, output_name: str, layer_name: str, precision: PrecisionMode):
        if issubclass(PrecisionMode, precision):
            precision = precision.name
        return self._infer_results.get(output_name, layer_name, precision).get_values()

    def set_ops(self, ops, layer_name: str, precision: PrecisionMode, metric: ComprecisionMetric):
        self._compresion_metric[(layer_name, precision.name)] = (ops, metric.name)

    def clean_results(self):
        del self._infer_results
        self._infer_results = {}

    def calculate_snr(self, agg_funtion=None) -> pd.DataFrame:
        if not agg_funtion:

            def harmonic(values: np.ndarray) -> float:
                return len(values) / np.sum(1.0 / values)

            agg_funtion = harmonic

        snr = {}
        for (out_name, layer_name, precision), val in self._infer_results.items():
            if layer_name != self.native_layer:
                native = self._infer_results[(out_name, self.native_layer, self.native_precision)]

                snr_single = SNRNoise([out_name]).update(native.get_values(), val.get_values(), out_name).get()
                snr.setdefault((layer_name, precision), list()).append(snr_single[out_name])

        results = []
        for (layer_name, precision), list_snr in snr.items():
            temp = np.array(list_snr).reshape(-1)
            snr_val = agg_funtion(temp)
            ops, metric = self._compresion_metric[(layer_name, precision)]
            results.append([layer_name, precision, snr_val, ops, metric])

        res = pd.DataFrame(results, columns=["layer_name", "precision", "snr_val", "ops", "metric"])
        return res
