from typing import TYPE_CHECKING

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    AlgoModificationTracker,
)
from hailo_sdk_client.hw_consts.hw_arch import HWArch
from hailo_sdk_client.post_fuser.algorithms import (
    ActivationsFolding,
    AddPreLnNormalization,
    ArgmaxMapping,
    DeadChannelsRemoval,
    InputFeaturesDefuse,
    MHADefuse,
    MulByScalarAfterMatMulFolding,
    NormalizationOptimizer,
    SlicedConvSplitter,
    SoftmaxMapping,
    SplitFusedActivation,
    SplitLeakyAndPReLUWithNegSlope,
    TiledSEOptimizer,
)

if TYPE_CHECKING:
    from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.post_fuser.algorithms.fuse_softmax_additive_mask import FuseSoftmaxAdditiveMask
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.model_params.model_params import ModelParams


class HailoNNPostFuser:
    def __init__(
        self,
        model: HailoNN,
        params: ModelParams,
        params_statistics: ModelParams,
        config: ModelOptimizationConfig,
        hw_arch: HWArch,
    ):
        self._model = model
        self._params = params
        self._params_statistics = params_statistics
        self._config = config
        self._hw_arch = hw_arch
        self._modifications_meta_data = AlgoModificationTracker()

    @property
    def model(self) -> HailoNN:
        return self._model

    @property
    def params(self) -> ModelParams:
        return self._params

    @property
    def params_statistics(self) -> ModelParams:
        return self._params_statistics

    @property
    def modifications_meta_data(self) -> AlgoModificationTracker:
        return self._modifications_meta_data

    def run(self):
        algos = [
            MulByScalarAfterMatMulFolding,
            SlicedConvSplitter,
            NormalizationOptimizer,
            MulByScalarAfterMatMulFolding,  # second time for ViT case
            MHADefuse,
            SlicedConvSplitter,  # second time because MHA might create FS after conv
            FuseSoftmaxAdditiveMask,
            SoftmaxMapping,
            ArgmaxMapping,
            DeadChannelsRemoval,
            TiledSEOptimizer,
            InputFeaturesDefuse,
            ActivationsFolding,
            SplitFusedActivation,
            AddPreLnNormalization,
            SplitLeakyAndPReLUWithNegSlope,
        ]

        for algo_class in algos:
            algo: FuserAlgorithm = algo_class(self.model, self.params, self._config, self._hw_arch)
            self._model = algo.run()
            self._params = ModelParams(algo.params.params)
            statistics = algo.get_statistics()
            self._params_statistics.update(statistics)
            modifications_meta_data = algo.get_modifications_meta_data()
            self._modifications_meta_data.update(modifications_meta_data)
