#!/usr/bin/env python

import networkx as nx
import numpy as np
from past.utils import old_div

from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    FeatureMultiplierType,
    LayerType,
    NMSMetaArchitectures,
)
from hailo_sdk_common.numeric_utils.numeric_utils import get_bbox_centers_for_ssd, get_bbox_centers_for_yolo


class WeightsGenerator:
    def __init__(
        self,
        leaky_alpha=0.1,
        activation_threshold=0.3,
        activation_delta_bias=-1.0,
        random_prune_rate=0,
        hardsigmoid=(1.0, 0.5),
        swish_beta=1.0,
        random_gen=None,
        seed=None,
    ):
        self._leaky_alpha = leaky_alpha
        self._activation_threshold = activation_threshold
        self._activation_delta_bias = activation_delta_bias
        self._swish_beta = swish_beta
        self._hardsigmoid = hardsigmoid
        self._activation_less_values = 6
        self._clip_min = 0
        self._clip_max = 1
        self._activation_greater_values = 0
        self._pow_exponent = 0.3
        self._random_prune_rate = random_prune_rate
        self._random_gen = random_gen if random_gen else np.random.default_rng(seed)

    @staticmethod
    def _get_params_keys_for_layer(model_name, layer_name, layer_op):
        if layer_op == LayerType.bbox_decoder:
            keys = [
                "anchors_heights",
                "anchors_widths",
                "anchors_heights_div_2",
                "anchors_heights_minus_div_2",
                "anchors_widths_div_2",
                "anchors_widths_minus_div_2",
                "y_centers",
                "x_centers",
            ]
        elif layer_op == LayerType.fused_bbox_decoder:
            keys = ["y_centers", "x_centers", "height_scale_factor", "width_scale_factor"]
        elif layer_op == LayerType.layer_normalization:
            keys = ["epsilon"]
        else:
            keys = [
                "kernel",
                "bias",
                "beta",
                "gamma",
                "moving_mean",
                "moving_variance",
                "epsilon",
                "leaky_alpha",
                "activation_threshold",
                "activation_delta_bias",
                "const_data",
                "hardsigmoid_alpha",
                "hardsigmoid_beta",
                "swish_beta",
                "activation_less_values",
                "activation_greater_values",
                "clip_min",
                "clip_max",
                "pow_exponent",
            ]

        if layer_op == LayerType.feature_multiplier:
            keys.append("power_table")

        generic_key_format = "{model}/{layer}/{key}:0"
        return {key: generic_key_format.format(model=model_name, layer=layer_name, key=key) for key in keys}

    def generate(self, model):
        weights = {}
        for layer in nx.topological_sort(model):
            layer_prefix = layer.scope
            keys = self._get_params_keys_for_layer(layer_prefix, layer.name_without_scope, layer.op)

            if hasattr(layer, "activation"):
                if layer.activation == ActivationType.leaky:
                    weights[keys["leaky_alpha"]] = self._leaky_alpha

                if layer.activation == ActivationType.hardsigmoid:
                    weights[keys["hardsigmoid_alpha"]], weights[keys["hardsigmoid_beta"]] = self._hardsigmoid

                if layer.activation == ActivationType.threshold:
                    weights[keys["activation_threshold"]] = self._activation_threshold

                if layer.activation == ActivationType.biased_delta:
                    weights[keys["activation_delta_bias"]] = self._activation_delta_bias

                if layer.activation == ActivationType.swish:
                    weights[keys["swish_beta"]] = self._swish_beta

                if layer.activation == ActivationType.less:
                    weights[keys["activation_less_values"]] = self._activation_less_values

                if layer.activation == ActivationType.greater:
                    weights[keys["activation_greater_values"]] = self._activation_greater_values

                if layer.activation == ActivationType.clip:
                    weights[keys["clip_min"]] = self._clip_min
                    weights[keys["clip_max"]] = self._clip_max

                if layer.activation == ActivationType.pow:
                    weights[keys["pow_exponent"]] = self._pow_exponent

            if layer.op in {LayerType.deconv, LayerType.conv, LayerType.dense, LayerType.dw, LayerType.normalization}:
                self._generate_conv_weights(keys, layer, weights)

            elif layer.op == LayerType.batch_norm:
                weights.update(self._generate_batch_norm_weights(keys, layer.input_shape[-1]))

            elif layer.op == LayerType.bbox_decoder:
                self._generate_bbox_weights(keys, layer, weights)

            elif layer.op == LayerType.fused_bbox_decoder:
                self._generate_fused_bbox_weights(keys, layer, weights)

            elif layer.op == LayerType.feature_multiplier:
                self._generate_feature_multiplier_weights(keys, layer, weights)

            elif layer.op == LayerType.const_input:
                self._generate_const_input(keys, layer, weights)

            elif layer.op == LayerType.layer_normalization:
                self._generate_epsilon(keys, layer, weights)

        if self._random_prune_rate > 0:
            for k, w in weights.items():
                weights[k] = w * (self._random_gen.uniform(0, 1, size=w.shape) < 1 - self._random_prune_rate)

        return weights

    def _generate_bbox_weights(self, keys, layer, weights):
        num_of_anchors = int(layer.output_shape[3] / 4)
        y_centers, x_centers = self._generate_centers(layer, num_of_anchors)
        anchors_heights, anchors_widths = self._generate_anchors(num_of_anchors)
        weights[keys["anchors_heights"]] = anchors_heights
        weights[keys["anchors_widths"]] = anchors_widths
        weights[keys["anchors_heights_div_2"]] = anchors_heights / 2
        weights[keys["anchors_widths_div_2"]] = anchors_widths / 2
        weights[keys["anchors_heights_minus_div_2"]] = -anchors_heights / 2
        weights[keys["anchors_widths_minus_div_2"]] = -anchors_widths / 2
        weights[keys["y_centers"]] = y_centers
        weights[keys["x_centers"]] = x_centers

    def _generate_fused_bbox_weights(self, keys, layer, weights):
        boxes_input = layer.input_shapes[0]
        num_of_anchors = int(boxes_input[3] / 4)
        y_centers, x_centers = get_bbox_centers_for_yolo(
            NMSMetaArchitectures.YOLOV6,
            boxes_input[1],
            boxes_input[2],
            num_of_anchors,
        )
        anchors_heights, anchors_widths = self._generate_anchors(num_of_anchors)
        weights[keys["height_scale_factor"]] = anchors_heights
        weights[keys["width_scale_factor"]] = anchors_widths
        weights[keys["y_centers"]] = y_centers
        weights[keys["x_centers"]] = x_centers

    def _generate_feature_multiplier_weights(self, keys, layer, weights):
        # Add power table only if not user specified, otherwise, assume weights already in npz:
        if layer.feature_multiplier_type == FeatureMultiplierType.yolov5:
            weights[keys["power_table"]] = layer.init_power_table(layer.yolov5())

    def _generate_anchors(self, num_of_anchors):
        anchors_heights = self._random_gen.uniform(0.01, 1, size=num_of_anchors).astype("f")
        anchors_widths = self._random_gen.uniform(0.01, 1, size=num_of_anchors).astype("f")
        return anchors_heights, anchors_widths

    def _generate_centers(self, layer, num_of_anchors):
        return get_bbox_centers_for_ssd(layer.input_shape[1], layer.input_shape[2], num_of_anchors)

    def _generate_conv_weights(self, keys, layer, weights):
        kernel_shape = list(layer.kernel_shape)
        if layer.op in [LayerType.conv, LayerType.deconv]:
            kernel_shape[2] = old_div(kernel_shape[2], sum(layer.group_sizes)) * max(layer.group_sizes)
        bias_shape = kernel_shape[-2] if layer.op in [LayerType.dw, LayerType.normalization] else kernel_shape[-1]
        self._conv_weights_from_shape(bias_shape, kernel_shape, keys, layer, weights)
        if layer.bn_enabled:
            weights.update(self._generate_batch_norm_weights(keys, bias_shape))

    def _conv_weights_from_shape(self, bias_shape, kernel_shape, keys, layer, weights):
        T = 1
        if not (layer.op == LayerType.dw and layer.dynamic_weights):
            T = self._conv_kernel_from_shape(T, kernel_shape, keys, layer, weights)
            self._conv_bias_from_shape(T, bias_shape, keys, weights)

    def _conv_bias_from_shape(self, T, bias_shape, keys, weights):
        weights[keys["bias"]] = self._random_gen.normal(0, 1 * T, size=bias_shape).astype("f")

    def _conv_kernel_from_shape(self, T, kernel_shape, keys, layer, weights):
        if layer.activation in [
            ActivationType.relu,
            ActivationType.linear,
            ActivationType.relu6,
            ActivationType.relu1,
            ActivationType.threshold,
        ]:
            T = np.sqrt(2.0 / np.prod(kernel_shape[:-1]))
        if layer.activation in [ActivationType.sigmoid, ActivationType.exp]:
            T = np.sqrt(1.0 / np.prod(kernel_shape[:-1]))
        if layer.activation == ActivationType.leaky:
            A = self._random_gen.random()
            weights[keys["leaky_alpha"]] = A
            T = np.sqrt(2.0 / (np.prod(kernel_shape[:-1]) * (1 + A**2)))

        weights[keys["kernel"]] = self._random_gen.normal(0, 1 * T, size=kernel_shape).astype("f")
        return T

    def _generate_batch_norm_weights(self, keys, bn_out_shape):
        # TODO: fix formulas to be normal
        return {
            keys["moving_mean"]: self._random_gen.uniform(-1, 1, size=bn_out_shape).astype("f"),
            keys["beta"]: self._random_gen.uniform(-1, 1, size=bn_out_shape).astype("f"),
            keys["gamma"]: self._random_gen.uniform(-1, 1, size=bn_out_shape).astype("f"),
            keys["moving_variance"]: self._random_gen.uniform(0, 1, size=bn_out_shape).astype("f"),
        }

    def _generate_const_input(self, keys, layer, weights):
        input_shape = list(layer.input_shape)
        weights[keys["const_data"]] = self._random_gen.integers(0, 256, size=input_shape[1:], dtype=np.uint8)

    def _generate_epsilon(self, keys, layer, weights):
        weights[keys["epsilon"]] = 1e-6


def generate_random_weights_for_model(model, random_gen=None, seed=None):
    return WeightsGenerator(random_gen=random_gen, seed=seed).generate(model)
