#!/usr/bin/env python


import numpy as np
from past.utils import old_div

from hailo_sdk_common.hailo_nn.exceptions import HailoNNException
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.logger.logger import default_logger

DEFAULT_EPSILON = 1e-3


def should_skip_bn(params):
    return "moving_mean" not in params.properties


def null_channels(hn_layer, gamma, moving_var, epsilon):
    """
    Null channels that are 'dead' - e.g. have zero variance and are nulled post activation.
    The reason to null them proactively is due to the fact that be fold BN back in the conv kernel.
    For these channels this results in very high multiplication factors for the kernel, which causes
    the limvals to increase greatly for no reason.
    The channels are nulled by setting their gamma to be 0.

    Note: Such "rogue" channels were only observed in Mobilenet_v1
    """
    null_channels_cutoff_factor = hn_layer.translation_config.null_channels_cutoff_factor
    cutoff = null_channels_cutoff_factor * epsilon
    null_channels = moving_var < cutoff
    null_channels_count = np.sum(null_channels)
    if null_channels.any():
        default_logger().debug(
            f"Nullifying {null_channels_count} out of {null_channels.shape[0]} channels of layer {hn_layer.name}. Output shape {hn_layer.output_shapes}",
        )
    if gamma.ndim == 0:
        gamma = np.full(fill_value=gamma, shape=moving_var.shape, dtype=np.float128)
    gamma[null_channels] = 0.0
    return gamma, null_channels_count


def calc_batch_norm_params(params, npz_layer, hn_layer, should_null_channels):
    pre_layer_bn = hasattr(hn_layer, "pre_layer_bn") and hn_layer.pre_layer_bn
    if hn_layer.op == LayerType.batch_norm:
        layer_has_kernel = False
        layer_has_bias = False
    else:
        layer_has_kernel = hasattr(params, "kernel") and npz_layer in params.kernel
        layer_has_bias = hasattr(params, "bias") and params.bias.get(npz_layer, None) is not None

    kernel = params.kernel[npz_layer] if layer_has_kernel else np.ones([1, 1, 1])
    kernel = kernel.astype(np.float128)

    bias = params.bias[npz_layer] if layer_has_bias else None
    bias = bias.astype(np.float128) if bias is not None else None

    layer_has_epsilon = "epsilon" in params.properties and npz_layer in params.epsilon
    layer_epsilon = params.epsilon[npz_layer] if layer_has_epsilon else DEFAULT_EPSILON
    epsilon = np.array(layer_epsilon).astype(np.float128)

    beta = params.beta[npz_layer].astype(np.float128)
    moving_mean = params.moving_mean[npz_layer].astype(np.float128)
    moving_variance = params.moving_variance[npz_layer].astype(np.float128)
    moving_variance[moving_variance < 0] = 0
    gamma = params.gamma[npz_layer].astype(np.float128)
    if should_null_channels:
        gamma, null_channels_count = null_channels(hn_layer, gamma, moving_variance, epsilon)
    else:
        null_channels_count = 0

    sigma = np.sqrt(moving_variance + epsilon)
    miu = moving_mean
    div = old_div(gamma, sigma)
    new_bias = beta - (div * miu)
    extended_div = div[np.newaxis, np.newaxis, :, np.newaxis]
    bias_addition = 0 if bias is None else bias if pre_layer_bn else div * bias
    if pre_layer_bn:
        new_bias_extended = new_bias[np.newaxis, np.newaxis, :, np.newaxis]
        new_bias = np.sum(new_bias_extended * kernel, axis=2).flatten()
        new_kernel = extended_div * kernel
    else:
        # handle depthwise layer
        should_reshape_div = div.shape[0] != kernel.shape[-1] and div.shape[0] == kernel.shape[-2]
        div = extended_div if should_reshape_div else div
        new_kernel = kernel * div

    new_bias += bias_addition  # add original bias if exists

    return (new_kernel.astype(np.float32), new_bias.astype(np.float32)), null_channels_count


def batch_norm_rescale_params(model, params, keep_normalization_params=True):
    """
    This function takes ModelParam object containing the parameters of the model, as given by the framework.

    Args:
        model (HN): Network representation, used to check various parameter per layer, that has effect on bn rescaling
        params (:class:`~hailo_sdk_common.model_params.model_params.ModelParams`):
            The params on which to apply batch norm fusing.
        keep_normalization_params (bool): keep the native BN/Normalization params in the NPZ for post fusing.

    Returns:
        (dict) Updated batch normalization params.

    """
    if should_skip_bn(params):
        return params

    new_params = {
        "params_kind": params["params_kind"],
    }
    null_channels_count = 0
    for npz_layer in params:
        for k, v in params[npz_layer].items():
            if keep_normalization_params or not is_bn_info_param_key(k):
                new_params[f"{npz_layer}/{k}"] = v

        if hasattr(params, "moving_mean") and npz_layer in params.moving_mean:
            hn_layer = None
            try:
                hn_layer = model.get_layer_by_name(npz_layer)
            except HailoNNException:
                default_logger().warning(f"Layer {npz_layer} from npz file doesn't exist in the HN")

            if hn_layer:
                bn_params, null_channels_count_in_layer = calc_batch_norm_params(
                    params,
                    npz_layer,
                    hn_layer,
                    not keep_normalization_params,
                )
                null_channels_count += null_channels_count_in_layer
                new_params[f"{npz_layer}/kernel:0"], new_params[f"{npz_layer}/bias:0"] = bn_params

    if null_channels_count > 0:
        default_logger().debug(f"Nullified {null_channels_count} of the channels in the model")

    return new_params


def is_bn_info_param_key(key):
    for normalization_param_key in ["gamma:0", "epsilon:0", "beta:0", "moving_mean:0", "moving_variance:0"]:
        if normalization_param_key == key:
            return True

    return False
