import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.softmax_op import SoftmaxOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)


class HailoSoftmax(BaseHailoSingleAtomic):
    """
    Currently degenerate, implemented as fully native activation.
    TODO simulate the PPU- (and/or the upcoming core-) implementation -
        - will include multiple AtomicOps.
    """

    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
    }
    _hn_type = LayerType.SOFTMAX
    OP_NAME = "softmax_op"

    def __init__(self, name: str, axis=-1, groups=1, logger=None, **kwargs):
        # TODO - implement it without using ActivationOP
        op = SoftmaxOp(f"{name}/{self.OP_NAME}", logger=logger, axis=axis, groups=groups)
        super().__init__(name=name, core_op=op, logger=logger, **kwargs)

    def create_output_encoding_candidates(self, forced_range=None, translation_config=None):
        # fix in advance the limvlas to be [0,1] and the output_scale and output_zero_point
        output_scale = 1.0 / 255.0
        # get the shapes from the statistics
        output_shape = self.get_output_stats()[0].min.shape[0]
        output_scale = np.repeat(output_scale, output_shape)  # softmax is a probability distribution function,
        self.atomic_op.output_scale = output_scale
        self.atomic_op.output_zero_point = np.float32(0)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        axis = hn_element["params"]["logits_axis"] if "logits_axis" in hn_element["params"] else -1
        groups = hn_element["params"]["groups"] if "groups" in hn_element["params"] else 1
        layer = cls(name=lname, axis=axis, groups=groups, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def _get_hn_input_shapes(self):
        return self._get_hn_shapes(self.input_shapes, self.inputs_dim)

    def _get_hn_output_shapes(self):
        return self._get_hn_shapes(self.output_shapes, self.outputs_dim)

    def _get_hn_shapes(self, shapes, dims):
        hn_shapes = []
        for dim in dims:
            hn_shape = [[-1] + ([np.prod(shape[1:])] if dim == 2 else list(np.array(shape[1:]))) for shape in shapes]
            hn_shapes.extend(hn_shape)
        return hn_shapes

    @property
    def homogeneous(self):
        return False

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE
        else:
            supported_precision_mode = super()._get_precision_mode_supported_in_hw(arch)
        return supported_precision_mode
