from collections import Counter
from pprint import pprint

import torch
import torch.nn as nn

from hailo_model_optimization.saitama.translators.base_translator import BaseModelTranslator
from hailo_model_optimization.saitama.translators.torch_translator.torch_modules_translator import (
    TORCH_MODULES_REGISTRY,
)


class TorchTranslator(BaseModelTranslator):
    module_translator_registry = TORCH_MODULES_REGISTRY

    def __init__(self, logger=None):
        super().__init__(logger)
        self.skipped_modules = []
        self.optimized_modules = []
        self.parent_modules = []

    def hook_stats(self, model: nn.Module):
        for child in model.children():
            if type(child) in self.module_translator_registry.keys():
                self.module_translator_registry[type(child)].enable_stats_hook(child)
            elif isinstance(child, nn.Module):
                self.hook_stats(child)

    def translate(self, model: nn.Module, dtype=None, device=None):
        has_children = False
        for name, child in model.named_children():
            has_children = True
            if type(child) in self.module_translator_registry.keys():
                # Not sure why, but moving the layer to CPU before getting the weights reduces VRAM usage, but slows everything down
                child.to("cpu")
                torch.cuda.empty_cache()
                new_layer = self.module_translator_registry[type(child)].translate(child, dtype=dtype, device=device)
                old_layer = getattr(model, name)
                del old_layer
                torch.cuda.empty_cache()

                setattr(model, name, new_layer)
                self.optimized_modules.append(new_layer)

            elif isinstance(child, nn.Module):
                self.translate(child)
        if not has_children:
            self.skipped_modules.append(model)
        else:
            self.parent_modules.append(model)

    def _get_persent(self):  # TODO remove this ugly thing and work like a good programer
        return len(self.optimized_modules) / (len(self.optimized_modules) + len(self.skipped_modules)) * 100

    def summarize(self):
        skipped_modules_types = [type(m).__name__ for m in self.skipped_modules]
        skipped = Counter(skipped_modules_types)
        print("Skipped modules types:")
        pprint(skipped)
        print(f"Transalted modules: {len(self.optimized_modules)}, Total skipped: {len(self.skipped_modules)}")
        parent_modules_types = [type(m).__name__ for m in self.parent_modules]
        parents = Counter(parent_modules_types)
        print("Parent modules types:")
        pprint(parents)
        print(f"Total Parents: {len(self.parent_modules)}")
