from collections import OrderedDict
from typing import List

import numpy as np


class SNRNoise:
    """Calculates and aggregates the SNR at each layer over batches"""

    def __init__(self, layers: List[str]):
        self.layers = layers
        self.noises = OrderedDict()
        self.noise2 = OrderedDict()
        self.signal2 = OrderedDict()
        for layer in layers:
            self.noise2[layer] = None
            self.signal2[layer] = None

    def update(self, native, numeric, layer):
        """Calculate SNR for each tensor in the batch and update the class"""
        if self.noise2[layer] is None or self.signal2[layer] is None:
            self.noise2[layer] = np.zeros(len(native))
            self.signal2[layer] = np.zeros(len(native))
        noise = [x - y for x, y in zip(native, numeric)]
        noise2 = [np.linalg.norm(x) ** 2 for x in noise]
        signal2 = [np.linalg.norm(x) ** 2 for x in native]
        self.noise2[layer] += noise2
        self.signal2[layer] += signal2
        return self

    def get(self, epsilon=1e-10):
        for layer in self.layers:
            if self.noise2[layer] is None or self.signal2[layer] is None:
                continue
            self.noises[layer] = np.array(
                [
                    10 * np.log10((signal2 + epsilon) / (noise2 + epsilon))
                    for signal2, noise2 in zip(self.signal2[layer], self.noise2[layer])
                ],
                dtype=np.float32,
            )
        return self.noises
