import tensorflow as tf

from hailo_model_optimization.acceleras.native_layers.base_native_layer import BaseNativeLayer


class NativeReduceMean(BaseNativeLayer):
    def __init__(self, name, reduce_axes=None, logger=None) -> None:
        super().__init__(name=name, logger=logger)
        if reduce_axes is None:
            reduce_axes = [3]
        self._reduce_axes = reduce_axes

    def import_weights(self, params: dict) -> None:
        self._logger.debug("Native reduce mean import_weights was triggered, but nothing happened")

    def export_weights(self) -> dict:
        self._logger.debug("Native reduce mean export_weights was triggered, but nothing happened")
        return dict()

    def call(self, inputs):
        return tf.reduce_mean(input_tensor=inputs, axis=self._reduce_axes, keepdims=True)

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "reduce_axes": [3],
        }
        return dict(defaults)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        layer = cls(
            name=lname,
            reduce_axes=params["reduce_axes"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer
