#!/usr/bin/env python
import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasCorrectionPolicy, LayerSupportStatus
from hailo_sdk_common.logger.logger import default_logger


class IBCError(Exception):
    pass


def _extract_layer_name(tensor_name):
    return tensor_name.split("/")[:2]


class NumericGraphReassigner:
    """Assign an existing TF graph with new biases."""

    def __init__(self, session):
        """Initialize the re-assigner from an existing TF session."""
        self._session = session
        self._graph = session.graph
        self._tensors = {}

    def _get_variables(self, params):
        if not self._tensors:
            # using keys() explicitly because of ModelParams issues
            for param_name in params:
                if "bias" not in param_name:
                    continue
                var = tf.compat.v1.global_variables(param_name)
                if len(var) > 1:
                    raise ValueError("Multiple tensors with the same name")
                elif len(var) == 0:
                    continue
                tensor = var[0]
                self._tensors[param_name] = tensor
        assert self._tensors, "The tensors to updated list is empty"
        return self._tensors

    def reassign(self, params, layer_name):
        """Assign TF bias variable of a given layer."""
        tensors = self._get_variables(params)
        assign_ops = []
        for param_name, param_value in params.items():
            if param_name not in tensors or _extract_layer_name(param_name) != _extract_layer_name(layer_name):
                continue
            assign_ops.append(tensors[param_name].assign(param_value))
        with self._session.as_default(), self._session.graph.as_default():
            self._session.run(assign_ops)


class IBC:
    def __init__(self, hn_model, native_model, numeric_model, ibc_config=None):
        self._hn_model = hn_model
        self._native_model = native_model
        self._numeric_model = numeric_model
        self._reassigner = NumericGraphReassigner(numeric_model.session)
        if ibc_config is None or ibc_config.policy is None:
            self._default_policy = BiasCorrectionPolicy.disabled
        else:
            self._default_policy = ibc_config.policy
        self._layers_config = ibc_config.layers

    def _get_results(self, graph_model, initializer, num_batch, layer):
        session = graph_model.session
        res_native_batches = []
        with session.as_default(), session.graph.as_default():
            session.run(initializer)
        for _i in range(num_batch):
            with session.graph.as_default(), session.as_default():
                res_native_batches.append(session.run(layer))
        return np.concatenate(res_native_batches, axis=0)

    def run(self, native_initializer, numeric_initializer, num_batch, update_layer_bias_callback):
        # TODO: fix iteration
        params = None
        native_tensors = self._native_model.all_layers
        numeric_tensors = self._numeric_model.all_layers
        for layer_native, layer_numeric in zip(native_tensors, numeric_tensors):
            # skip layers without bias and not supported
            to_correct, explanation = self._does_require_correction(layer_native.name)
            layer_name = "/".join(_extract_layer_name(layer_native.name))
            if not to_correct:
                if self._default_policy == BiasCorrectionPolicy.enabled:
                    default_logger().debug(f"No bias correction for layer {layer_name} because {explanation}")
                continue

            # run this iteration
            res_native = self._get_results(self._native_model, native_initializer, num_batch, layer_native)
            res_numeric = self._get_results(self._numeric_model, numeric_initializer, num_batch, layer_numeric)
            with self._numeric_model.graph.as_default(), self._numeric_model.session.as_default():
                hn_layer = self._hn_model.get_layer_by_name("/".join(_extract_layer_name(layer_name)))
                output_diff = self._get_layer_output_diff(res_native, res_numeric, hn_layer)
                params = update_layer_bias_callback(output_diff, layer_numeric)

                # update the graph for the next iterations
                self._reassigner.reassign(params, layer_numeric.name)
        return params

    def _does_require_correction(self, layer_name):
        layer = self._hn_model.get_layer_by_name("/".join(_extract_layer_name(layer_name)))
        support_status = layer.ibc_supported()
        if layer.name in self._layers_config:
            policy = self._layers_config[layer.name].policy
        else:
            policy = BiasCorrectionPolicy.allowed
        if policy == BiasCorrectionPolicy.allowed:
            policy = self._default_policy
        if support_status == LayerSupportStatus.supported:
            if policy == BiasCorrectionPolicy.disabled:
                return False, "layer is not chosen for bias correction"
            else:
                return True, ""
        elif support_status == LayerSupportStatus.unsupported:
            return False, f"layer of type {layer.op.value} is not supported by IBC"
        elif support_status == LayerSupportStatus.unexpected:
            raise IBCError(f"of an unexpected status during IBC. type - {layer.op.value}")
        else:
            raise IBCError(f"of an unknown support status - {support_status.value}")

    @staticmethod
    def _get_layer_output_diff(res_native, res_numeric, hn_layer):
        if hasattr(hn_layer, "transpose_output_width_features") and hn_layer.transpose_output_width_features:
            res_native = np.transpose(res_native, axes=[0, 1, 3, 2])
            res_numeric = np.transpose(res_numeric, axes=[0, 1, 3, 2])
        if len(res_numeric.shape) == 4:
            return np.mean(res_native, (0, 1, 2)) - np.mean(res_numeric, (0, 1, 2))
        else:
            return np.mean(res_native, (0)) - np.mean(res_numeric, (0))
