from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoOutputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_resize_bilinear_mac import HailoResizeBilinearMac
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasUnsupportedError
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class MixedPrecision(OptimizationAlgorithm):
    """
    Assingns precision mode to every layer. Precision mode is passed to the compiler as a triplet (input_bit_width, weight_bit_width, output_bit_width).
    Determined by a traversal on the model graph and the precision modes provided by the alls.
    """

    # TODO: Remove this algorithm completly. CreateMixedPrecision does most of the logic already...

    def __init__(self, model, model_config, logger_level: int, params_statistics=None, logger=None):
        super().__init__(
            model=model,
            model_config=model_config,
            name="Mixed Precision",
            logger_level=logger_level,
            logger=logger,
        )
        self._params_statistics = params_statistics
        self._generated_statistics = dict()

    def get_algo_config(self):
        return self._model_config.precision_config

    def log_config(self):
        pass

    def _setup(self):
        pass

    def _run_int(self):
        for u_layer_name, v_layer_name in self._model.flow.edges:
            u_layer = self._model._acceleras_layers[u_layer_name]
            v_layer = self._model._acceleras_layers[v_layer_name]
            if isinstance(u_layer, BaseHailoNonNNCoreLayer):
                continue
            u_precision_mode = u_layer.get_precision_mode()
            v_prevision_mode = v_layer.get_precision_mode()
            u_activation_bits = self._parse_precision_mode_activation_bits(u_precision_mode)
            u_weights_bits = self._parse_precision_mode_weight_bits(u_precision_mode)
            v_activation_bits = self._parse_precision_mode_activation_bits(v_prevision_mode)
            u_layer_cross_layer_precision_mode_candidate = self._create_cross_layer_precision_mode(
                u_activation_bits,
                u_weights_bits,
                v_activation_bits,
            )
            if u_layer.cross_layer_precision_mode:
                if u_layer.cross_layer_precision_mode != u_layer_cross_layer_precision_mode_candidate:
                    raise AccelerasUnsupportedError(
                        f"Two cross layer precision mode for {u_layer.full_name} are currently unsupported - "
                        f"{u_layer.cross_layer_precision_mode}, {u_layer_cross_layer_precision_mode_candidate}",
                    )
            else:
                u_layer.cross_layer_precision_mode = u_layer_cross_layer_precision_mode_candidate
                if u_activation_bits == 16 and (isinstance(u_layer, (HailoElementwiseAdd, HailoResizeBilinearMac))):
                    u_layer.bias_mode_double_scale_initializtation_needed = True
            if isinstance(v_layer, HailoOutputLayer):
                v_layer.cross_layer_precision_mode = self._create_cross_layer_precision_mode(
                    v_activation_bits,
                    v_activation_bits,
                    v_activation_bits,
                )

    def _parse_precision_mode_activation_bits(self, precision_mode):
        activation_data = self._parse(precision_mode)[0]
        assert activation_data.startswith("a")
        return int(activation_data[1:])

    def _parse_precision_mode_weight_bits(self, precision_mode):
        weights_data = self._parse(precision_mode)[1]
        assert weights_data.startswith("w")
        return int(weights_data[1:])

    def _parse(self, precision_mode):
        value = precision_mode.value
        return value.split("_")

    def _create_cross_layer_precision_mode(self, u_activation_bits, u_weights_bits, v_activation_bits):
        return "a" + str(u_activation_bits) + "_w" + str(u_weights_bits) + "_a" + str(v_activation_bits)

    def should_skip_algo(self):
        return False

    def export_statistics(self):
        return self._generated_statistics

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
