from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, NamedTuple, Tuple, Union

import torch
import torch.nn as nn

from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasMode

# region Datatypes
TensorList = List[torch.Tensor]
TensorDict = Dict[str, torch.Tensor]
LayerOutput = Union[torch.Tensor, List[torch.Tensor]]
ModelOutput = Union[LayerOutput, List[LayerOutput]]

size_2_int = Union[int, Tuple[int, int]]
QType = NamedTuple("QType", [("bits", int), ("signed", bool)])
MACPrecisionConfig = NamedTuple(
    "MACPrecisionConfig",
    [
        ("input_qtype", QType),
        ("weight_qtype", QType),
        ("accumulator_qtype", QType),
        ("bias_mode", Union[str, BiasMode]),
        ("quantization_groups", int),
    ],
)

APUPrecisionConfig = NamedTuple(
    "APUPrecisionConfig",
    [
        ("accumulator_qtype", QType),
        ("output_qtype", QType),
        ("quantization_groups", int),
    ],
)

IOPrecisionConfig = NamedTuple("IOPrecisionConfig", [("output_qtype", QType)])

PrecisionConfigT = Union[MACPrecisionConfig, APUPrecisionConfig]


# endregion
# region Enums
class ParamsMode(Enum):
    NATIVE = 0
    QUANT = 1


class InferMode(Enum):
    NATIVE = 0
    QUANT = 1


# endregion


# region Dataclass
@dataclass
class SaitmaBuilding:
    model_flow: ModelFlow
    layers: Dict[str, nn.Module]
    input_shapes: Dict[nn.Module, Tuple[int, ...]]


Encoding = NamedTuple(
    "Encoding",
    [
        ("scale_by_group", torch.Tensor),
        ("scale_repeats", int),
        ("zero_point_by_group", torch.Tensor),
        ("zero_point_repeats", int),
        ("equalization_vector", torch.Tensor),
        ("factor_by_group", torch.Tensor),
        ("scale_by_channel", torch.Tensor),
        ("zero_point_by_channel", torch.Tensor),
    ],
)


@dataclass(frozen=True)
class DimsInfo:
    height: int = 1
    width: int = 1
    channels: int = 1

    def as_chw_tuple(self) -> Tuple[int, int, int]:
        return self.channels, self.height, self.width

    def as_hwc_tuple(self) -> Tuple[int, int, int]:
        return self.height, self.width, self.channels

    def as_hw_tuple(self) -> Tuple[int, int]:
        return self.height, self.width
