from hailo_model_optimization.acceleras.atomic_ops.reduce_max_op import ReduceMaxOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EquivClassification,
    LayerHandlerType,
    LayerType,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams


class HailoReduceMax(BaseHailoSingleAtomic):
    """
    Implement Hailo reduce_max layer,
        - takes one inputs,
        - multiply the input by the weight
        - sums along the given axis in the acc
        - activation in the APU
    """

    _hn_type = LayerType.REDUCE_MAX
    OP_NAME = "reduce_max_op"

    def __init__(
        self,
        name: str,
        groups=1,
        axis=None,
        logger=None,
        **kwargs,
    ):
        atomic_op = ReduceMaxOp(
            f"{name}/{self.OP_NAME}",
            groups=groups,
            reduce_axes=axis,
            logger=logger,
        )
        super().__init__(name=name, core_op=atomic_op, logger=logger, **kwargs)

    def enforce_internal_encoding(self, training=False, **kwargs):
        self.atomic_op.enforce_encoding()

    def import_weights(self, layer_params: LayerParams):
        pass

    def _export_weights(self):
        return dict()

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "groups": 1,
        }
        return dict(defaults)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        axis = hn_element["params"]["reduce_axes"] if "reduce_axes" in hn_element["params"] else [-1]
        layer = cls(name=lname, groups=params["groups"], axis=axis, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)
