from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.reduce_mean_op import ReduceMeanOp
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum import HailoReduceSum
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType, LayerType


class HailoReduceMean(HailoReduceSum):
    """
    Implement Hailo reduce_mean layer,
        - takes one inputs,
        - multiply the input by the weight
        - sums along the given axis in the acc
        - activation in the APU
    """

    _hn_type = LayerType.REDUCE_MEAN

    def __init__(
        self,
        name: str,
        groups: int = 1,
        reduce_axes: tuple = (3,),
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        logger=None,
        **kwargs,
    ):
        super().__init__(
            name=name,
            logger=logger,
            groups=groups,
            reduce_axes=reduce_axes,
            activation=activation,
            **kwargs,
        )

        self.reduce_sum_op = ReduceMeanOp(
            f"{name}/reduce_sum_op",
            groups=groups,
            reduce_axes=reduce_axes,
            logger=logger,
        )
        self._layer_flow = self._build_flow()  # TODO: a bit hacky

        self.encoding_const = True
