import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import FeaturePolicy
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class Deequalize(OptimizationAlgorithm):
    """
    Deequalize the kernel and bias of the model after euqlization
    """

    def __init__(self, model: HailoModel, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, "Deequalize Weights", logger_level, **kwargs)
        self.deequalize_params = dict()

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg

    def get_algo_config(self):
        return self._model_config.globals

    def finalize_global_cfg(self, algo_config):
        pass

    def should_skip_algo(self):
        return self.get_algo_config().deequalize != FeaturePolicy.enabled

    def _setup(self):
        pass

    def _run_int(self):
        def to_numpy(data):
            if hasattr(data, "numpy"):
                return data.numpy()
            return data

        for lname, layer in self._model.layers.items():
            if isinstance(layer, BaseHailoConv) or isinstance(layer, HailoElementwiseAdd):
                layer.enforce_internal_encoding()
                if isinstance(layer, BaseHailoConv):
                    old_kernel = layer.export_native_kernel()
                    weights_new = (
                        layer.conv_op.kernel.numpy() / layer.conv_op.kernel_scale.numpy()
                        + to_numpy(layer.conv_op.zp_kernel)
                    ) / 2**layer.conv_op.total_rshift

                    new_bias = layer.bias_add_op.encode_bias().numpy()
                    weights_new = weights_new.reshape(old_kernel.shape)
                    diff = old_kernel - weights_new
                    if np.max(abs(diff)) > 0:
                        self.deequalize_params[f"{lname}/kernel:0"] = weights_new
                        self.deequalize_params[f"{lname}/bias:0"] = new_bias

                elif isinstance(layer, HailoElementwiseAdd):
                    weights_new = layer.ew_add_op.kernel.numpy() / layer.ew_add_op.kernel_scale.numpy()
                    diff = layer.ew_add_op.kernel.numpy() - weights_new

                    if np.max(abs(diff)) > 0:
                        self.deequalize_params[f"{lname}/kernel:0"] = weights_new
