from hailo_sdk_client.post_fuser.algorithms.exceptions import ArgmaxMappingException
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper, PostprocessAdditionMode
from hailo_sdk_client.tools.logits_layer_addition import LogitsLayersAdder
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType


class ArgmaxMapping(FuserAlgorithm):
    NAME = "argmax_mapping"

    def get_algo_config(self):
        return self._model_config

    def _run_int(self):
        if not self._hw_arch.is_mercury_arch:
            # the arch is hailo8, performs argmax mapping
            argmax_layers = [argmax_layer for argmax_layer in self.model.nodes if argmax_layer.op == LayerType.argmax]

            # maps argmax layers from PPU, to HRT-pp if needed
            argmax_layers_to_remap = [
                layer for layer in argmax_layers if not self.validate_ppu_argmax_conditions(layer)
            ]

            for argmax_to_remap in argmax_layers_to_remap:
                # maps layer to HRT-pp
                pred_layer = next(iter(self._model.predecessors(argmax_to_remap)))  # argmax has only one predecessor
                succ_layer = next(iter(self._model.successors(argmax_to_remap)))
                if succ_layer.op == LayerType.output_layer:
                    # replaces argmax layer with a logits postprocess layer
                    # removes argmax and argmax output layer from the graph (argmax must be an output layer)
                    new_logits_layer_name = FuserHelper.add_logits_as_postprocess_layer_to_hn(
                        self.model,
                        [pred_layer],
                        LayerType.argmax,
                        axis=-1,
                        remove_succ_output=False,
                        postprocess_addition_mode=PostprocessAdditionMode.MAPPING_FROM_NN_CORE_TO_CPU,
                    )
                    output_index = self._model.net_params.output_layers_order.index(argmax_to_remap.name)
                    self._model.net_params.output_layers_order[output_index] = new_logits_layer_name[0]
                    self._model.remove_layer(succ_layer)
                    self._model.remove_layer(argmax_to_remap)
                    self._logger.info(
                        f"Mapping the argmax layer {argmax_to_remap.name_without_scope} from the neural "
                        f"core to CPU due to availability of resources",
                    )
                else:
                    raise ArgmaxMappingException(
                        f"Can't map argmax layer {argmax_to_remap.name}. Argmax layer must be " "an output layer.",
                    )

    def validate_ppu_argmax_conditions(self, layer):
        output_rank_cond = len(layer.output_shape) == LogitsLayersAdder.ARGMAX_SUPPORTED_RANK
        input_channels_cond = layer.input_shape[-1] <= LogitsLayersAdder.MAX_CHANNELS_SUPPORTED_ARGMAX
        return output_rank_cond and input_channels_cond or layer.reverse_order

    def _setup(self):
        pass

    def log_config(self):
        pass

    def should_skip_algo(self):
        return False
