from hailo_model_optimization.acceleras.atomic_ops.argmax_op import ArgMaxOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)


class HailoArgMax(BaseHailoSingleAtomic):
    """
    Single op layer of ArgMaxOp. The argmax axis is the last
    Args:
        None
    Examples:
        Examples of use
        >>> op = ArgMaxOp()
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a16,
    }
    _hn_type = LayerType.ARGMAX
    OP_NAME = "argmax_op"

    def __init__(self, name: str, reverse_order: bool = False, logger=None, **kwargs):
        op = ArgMaxOp(name=f"{name}/{self.OP_NAME}", reverse_order=reverse_order, logger=logger)
        super().__init__(name=name, core_op=op, logger=logger, **kwargs)
        self.encoding_const = False

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {"reverse_order": False}
        return dict(defaults)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        reverse_order = params.get("reverse_order", False)
        layer = cls(name=lname, reverse_order=reverse_order, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.PLUTO}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE
        elif arch in {OptimizationTarget.SAGE}:
            supported_precision_mode = {
                PrecisionMode.a8_w8,
                PrecisionMode.a16_w16,
                PrecisionMode.a8_w8_a8,
                PrecisionMode.a16_w16_a16,
            }
        else:
            supported_precision_mode = super()._get_precision_mode_supported_in_hw(arch)
        return supported_precision_mode
