from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_splitter import HailoFeatureSplitter
from hailo_model_optimization.acceleras.hailo_layers.hailo_shortcut import HailoShortcut
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask import HailoSoftmaxMask
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask_on_mac import HailoSoftmaxMaskOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    LayerType,
    PrecisionMode,
    ThreeWayPolicy,
)
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.smart_softmax_stats.smart_softmax_stats import SmartSoftmaxStats, SoftmaxBlock


class CreateSoftmaxMask(OptimizationAlgorithm):
    """
    This class is responsible for adding a softmax_mask layer before the reduce max in case the softmax is masked.
    """

    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(model, model_config, name="Create Softmax Mask", logger_level=logger_level, logger=logger)
        self.smart_softmax_stats = SmartSoftmaxStats(model, model_config, logger_level, logger)

    def _setup(self):
        self.smart_softmax_stats._setup()
        return super()._setup()

    def should_skip_algo(self):
        policy = self.get_algo_config().policy
        return policy == ThreeWayPolicy.disabled

    def get_algo_config(self):
        return self._model_config.smart_softmax_stats

    def log_config(self):
        pass

    def _run_int(self):
        mask_added = False

        self.smart_softmax_stats.find_and_build_softmax_blocks()

        shapes = [(None,) + shape for shape in self._model.get_input_shapes()]
        self._model.compute_output_shape(shapes)

        for softmax_block in self.smart_softmax_stats.softmax_blocks:
            if self._is_contain_mask(softmax_block):
                self._add_softmax_mask_layer(softmax_block)
                mask_added = True

        if mask_added:
            algo = CreateMixedPrecision(
                model=self._model,
                model_config=self._model_config,
                logger_level=self._logger_level,
                logger=self._logger,
            )
            algo.run()

    def _is_ew_mult_mask(self, successors_list):
        if len(successors_list) != 1:
            return False

        lname = successors_list[0]

        layer = self._model.layers[lname]
        return isinstance(layer, HailoElementwiseMult)

    def _is_contain_mask(self, softmax_block: SoftmaxBlock):
        # Check whether the softmax block already contains a softmax mask layer.
        successors_matmul = self._model.flow.successors_sorted(softmax_block.matmul)
        successors_matmul = self.smart_softmax_stats.skip_precision_change(successors_matmul)
        contains_softmax_mask = self._is_ew_mult_mask(successors_matmul)

        # Check whether the softmax block needs to have a softmax mask layer.
        successors_ew_sub = self._model.flow.successors_sorted(softmax_block.ew_sub)
        successors_ew_sub = self.smart_softmax_stats.skip_precision_change(successors_ew_sub)
        softmax_mask_layer_needed = self._is_ew_mult_mask(successors_ew_sub)

        return softmax_mask_layer_needed and not contains_softmax_mask

    def _get_mask_input_layer(self, softmax_block: SoftmaxBlock):
        pred = softmax_block.ew_sub
        ew_mult_mask_layer = self._model.flow.successors_sorted(softmax_block.ew_sub)[0]
        if "precision_change" in ew_mult_mask_layer:
            pred = ew_mult_mask_layer
            ew_mult_mask_layer = self._model.flow.successors_sorted(ew_mult_mask_layer)[0]
        mask_in_index = 1 - self._model.flow.get_edge_input_index(pred, ew_mult_mask_layer)
        mask_input = self._model.flow.predecessors_sorted(ew_mult_mask_layer)[mask_in_index]
        if isinstance(self._model.layers[mask_input], HailoFeatureSplitter):
            mask_input = self._add_shortcut_layer(mask_input, ew_mult_mask_layer)
        return mask_input

    def _add_shortcut_layer(self, source_name, target_name):
        out_index = self._model.flow.get_edge_output_index(source_name, target_name)
        shape = self._model.layers[source_name]._hn_element["output_shapes"][out_index]
        hn_element = {
            "type": LayerType.SHORTCUT.value,
            "input": [source_name],
            "output": [target_name],
            "input_shapes": [shape],
            "output_shapes": [shape],
        }
        shortcut_name = self._get_new_full_name(source_name, f"shortcut{out_index}")
        shortcut_layer = HailoShortcut.from_hn(shortcut_name, hn_element)
        self._model.add_layer(shortcut_layer, [(source_name, target_name)])
        shortcut_cfg = LayerPrecisionConfig(
            precision_mode=PrecisionMode.a8_w8,
            bias_mode=BiasMode.single_scale_decomposition,
            quantization_groups=1,
        )
        self._model_config.precision_config.layers[shortcut_name] = shortcut_cfg
        shortcut_layer.import_precision_config(shortcut_cfg, self.optimization_target)
        self._model_config.translation_config.layers[shortcut_name] = LayerTranslationConfig()
        return shortcut_name

    def _get_new_full_name(self, orig_name, addition):
        scope, short_name = orig_name.split("/")
        block_name, short_name = self.get_block_and_layer_names(short_name)
        return f"{scope}/{block_name}{addition}_{short_name}"

    def _add_softmax_mask_layer(self, softmax_block: SoftmaxBlock):
        input_name = softmax_block.matmul
        mask_input = self._get_mask_input_layer(softmax_block)
        mask_lname = self._get_new_full_name(input_name, "softmax_mask")
        shape = self._model.layers[input_name].to_hn()["output_shapes"][0]

        is_16bit = (
            self._model.layers[input_name].get_precision_mode().has_output_bits()
            and self._model.layers[input_name].get_precision_mode().output_bits() == 16
        )

        layer_precision_cfg = LayerPrecisionConfig(
            precision_mode=PrecisionMode.a16_w16 if is_16bit else PrecisionMode.a8_w8,
            bias_mode=BiasMode.double_scale_initialization,
            quantization_groups=1,
        )

        if is_16bit:
            pc = HailoStandaloneActivation.from_hn(
                self._get_new_full_name(input_name, "precision_change"),
                {
                    "type": "activation",
                    "input": [mask_input],
                    "output": [mask_lname],
                    "input_shapes": [shape],
                    "output_shapes": [shape],
                    "params": {"activation": "linear"},
                },
            )
            self._model.layers[pc.full_name] = pc
            self._model.flow.add_node(pc.full_name, is_input=False)
            # add layer to configuration
            self._model_config.precision_config.layers[pc.full_name] = LayerPrecisionConfig(
                precision_mode=PrecisionMode.a8_w8_a16,
                bias_mode=BiasMode.single_scale_decomposition,
                quantization_groups=1,
            )
            mask_layer = HailoSoftmaxMaskOnMac.from_hn(
                mask_lname,
                {
                    "type": "ew_mult",
                    "input": [input_name, pc.full_name],
                    "output": self._model.flow.successors_sorted(input_name),
                    "input_shapes": [shape, shape],
                    "output_shapes": [shape for _ in self._model.flow.successors_sorted(input_name)],
                    "compilation_params": {},
                    "quantization_params": {},
                    "params": {
                        "activation": "linear",
                        "is_softmax_mask": True,
                    },
                },
            )
            mask_layer.import_weights(dict())
        else:
            mask_layer = HailoSoftmaxMask(mask_lname)
            mask_layer._hn_element = {
                "type": "ew_mult",
                "input": [input_name, mask_input],
                "output": self._model.flow.successors_sorted(input_name),
                "input_shapes": [shape, shape],
                "output_shapes": [shape for _ in self._model.flow.successors_sorted(input_name)],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "activation": "linear",
                    "is_softmax_mask": True,
                },
            }
            mask_layer.import_weights(dict())
            mask_layer.mock_kernel_values = [2, 2]

        self._model.add_layer(
            mask_layer,
            [(input_name, succ) for succ in self._model.flow.successors_sorted(input_name)],
        )
        output_index = len(self._model.flow.successors_sorted(mask_input))

        if is_16bit:
            self._model.flow.add_edge(mask_input, pc.full_name, input_index=0, output_index=output_index)
            self._model.flow.add_edge(pc.full_name, mask_lname, input_index=1, output_index=0)
        else:
            self._model.flow.add_edge(mask_input, mask_lname, input_index=1, output_index=output_index)

        self._model_config.precision_config.layers[mask_lname] = layer_precision_cfg
        self._model.layers[mask_lname].import_precision_config(layer_precision_cfg, self.optimization_target)
        if is_16bit:
            # forcing the pc to be in the range [0, 2**14] to cancel with the shift in the softmax mask layer
            self._force_range_out(pc.full_name, [0.0, (2**15 - 1) / 2**14])
        # forcing the mask input to be in the range [0, 128] to cancel with the shift in the softmax mask layer
        self._force_range_out(mask_input, [0.0, 255 / 128])

    def _force_range_out(self, lname, range_out):
        translation_config = self._model_config.translation_config.layers.get(
            lname,
            LayerTranslationConfig.get_default(),
        )
        if translation_config.force_range_out is None:
            translation_config.force_range_out = range_out
        self._model_config.translation_config.layers[lname] = translation_config

    def finalize_global_cfg(self, algo_config):
        pass
