import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerType


class HailoBatchNorm(HailoDepthwise):
    """
    Our batch normalization layer is a mock dw for post training inference
    (rather than actual normalization layer)
    """

    _hn_type = LayerType.BATCH_NORM

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = super().get_default_params()
        defaults["kernel_shape"] = [1, 1, 1, 1]
        return dict(defaults)

    def _change_native_kernel(self, kernel):
        return np.squeeze(kernel, -1)

    def _layer_dependent_hw_params_modifications(self, params: dict):
        params["kernel"] = np.squeeze(params["kernel"], axis=-1)
        return params
