import json
import os

_CLK_FREQS = {  # In MHz
    "zaatar": 62.5,
    "nutmeg": 62.5,
    "mustard": 62.5,
    "mustard_b0": 62.5,
    "paprika_b0": 100.0,
    "jasmine": 15.0,
    "sage_b0": 400.0,
    "hailo8": 400.0,
    "hailo8p": 400.0,
    "hailo8v": 1.0,
    "hailo8r": 200.0,
    "mercury": 600.0,
    "mercury_ppu": 400.0,
    "hailo15h_ppu": 400.0,
    "pluto_ppu": 600.0,
    "hailo15l_ppu": 600.0,
    "pluto_shmifo": 600.0,
    "hailo15l_shmifo": 600.0,
    "hailo15h": 600.0,
    "hailo15m": 600.0,
    "ginger": 50.0,
    "lavender": 50.0,
    "hailo8l": 400.0,
    "pluto": 750.0,
    "hailo15l": 750.0,
    "hailo10h": 600.0,
    "mars": 750.0,
    "mars_ppu": 750.0,
}

HW_CONSTS_DIR = os.path.dirname(__file__)


class HWArch:
    SAGE_B0_ARCHS = ["sage_b0", "mustard_b0", "paprika_b0", "hailo8p", "hailo8r", "hailo8l", "hailo8", "hailo8v"]
    MARS_ARCHS = ["mars"]
    PLUTO_ARCHS = ["pluto", "hailo15l"] + MARS_ARCHS
    MERCURY_ARCHS = ["mercury", "hailo15h", "ginger", "lavender", "hailo15m", "hailo10h", *PLUTO_ARCHS]
    HAILO15_ARCHS = ["hailo15h", "hailo15m"]
    MARS_ARCHS = ["hailo10h2"]
    POWER_PROFILING_ARCHS = []
    # TODO: this should be added to the stuff exported from the core
    LIMITED_SHIFTS_ARCHS = [
        "nutmeg",
        "zaatar",
        "mustard",
        "mustard_b0",
        "paprika_b0",
        "ginger",
        "lavender",
    ]  # paprika_b0 is not here
    LIMITED_SHIFTS = [2, 4]
    UNLIMITED_SHIFTS = [1, 2, 3, 4]
    HAILO15_LARGE_MODEL_PARAMS_TH = 10**7

    def __init__(self, arch, version=None):
        self._arch = arch
        self._version = version
        self._is_mercury_arch = arch in self.MERCURY_ARCHS
        self._is_pluto_arch = arch in self.PLUTO_ARCHS
        self._does_support_power_profliling = arch in self.POWER_PROFILING_ARCHS
        arch_json = self._load_arch_json(arch, version)
        self._consts = self._load_consts(arch_json)

    def _load_consts(self, arch_json):
        consts = arch_json["consts"]
        if self._arch in ["hailo8p", "hailo8v", "hailo8r", "hailo8l"]:
            # A hack to support 8 clusters in N_CLUSTERS
            consts["CORE_PKG::N_CLUSTERS"] = consts["CORE_PKG::MAX_CLUSTERS"]
        if self._arch in ["hailo8l"]:
            consts["CORE_PKG::N_CLUSTERS"] = int(consts["CORE_PKG::N_CLUSTERS"] / 2)
        if self._arch in ["sage_b0"]:
            consts["CORE_PKG::ACTUAL_CLUSTERS"] = consts["CORE_PKG::MAX_CLUSTERS"]
        else:
            consts["CORE_PKG::ACTUAL_CLUSTERS"] = consts["CORE_PKG::N_CLUSTERS"]
        consts.setdefault("CORE_PKG::N_PREPOST_CLUSTERS", 1)
        return consts

    def _load_arch_json(self, arch, version):
        version_name = version or "default"
        if arch in ["hailo8", "hailo8p", "hailo8v", "hailo8r", "hailo8l"]:
            arch = "sage_b0"
        elif arch in ["hailo15h", "hailo15m", "hailo10h"]:
            arch = "mercury"
        elif arch in ["hailo15l"]:
            arch = "pluto"
        elif arch in ["hailo10h2"]:  # TODO this is a hack to support mars we need to change it to "mars" in the future
            arch = "mercury"
        json_filename = f"{arch}.{version_name}.json"
        json_path = os.path.join(HW_CONSTS_DIR, json_filename)
        return json.load(open(json_path))

    @property
    def name(self):
        return self._arch

    @property
    def version(self):
        return self._version

    @property
    def is_mercury_arch(self):
        return self._is_mercury_arch

    @property
    def is_pluto_arch(self):
        return self._is_pluto_arch

    @property
    def consts(self):
        return self._consts

    @property
    def clk_freq_m(self):
        return _CLK_FREQS[self._arch]

    @property
    def clk_freq(self):
        return _CLK_FREQS[self._arch] * 1e6

    @property
    def ppu_clk_freq_m(self):
        ppu_key = self._arch + "_ppu"
        return _CLK_FREQS[ppu_key] if ppu_key in _CLK_FREQS else _CLK_FREQS[self._arch]

    @property
    def ppu_clk_freq(self):
        ppu_key = self._arch + "_ppu"
        return (_CLK_FREQS[ppu_key] if ppu_key in _CLK_FREQS else _CLK_FREQS[self._arch]) * 1e6

    @property
    def shmifo_clk_freq(self):
        shmifo_key = self._arch + "_shmifo"
        return (_CLK_FREQS[shmifo_key] if shmifo_key in _CLK_FREQS else _CLK_FREQS[self._arch]) * 1e6

    @property
    def does_support_power_profiling(self):
        return self._does_support_power_profliling

    @property
    def supported_shifts(self):
        return self.LIMITED_SHIFTS if self._arch in self.LIMITED_SHIFTS_ARCHS else self.UNLIMITED_SHIFTS

    @staticmethod
    def get_real_hw_arch(hw_arch):
        return hw_arch if hw_arch != "hailo10p" else "hailo15m"
