#!/usr/bin/env python
import numpy as np

from hailo_sdk_common.hailo_nn.hn_definitions import LayerType


def transpose_weights(hailo_nn, params):
    if params is not None:
        for param, value in params.items():
            if param == "params_kind":
                continue
            layer_name = "/".join(param.split("/")[:2])
            layer = hailo_nn.get_layer_by_name(layer_name)
            if not layer.transposed:
                continue
            succs = list(hailo_nn.successors(layer))
            # Tensors' features from conv layers before bbox decoders should be changed to a different permutation,
            # because y should be replaced with x.
            if any(succ.op == LayerType.bbox_decoder for succ in succs):
                perm_matrix = np.zeros([layer.output_features, layer.output_features], dtype=np.float32)
                for i in range(layer.output_features):
                    if i % 2 == 0:
                        perm_matrix[i + 1][i] = 1
                    else:
                        perm_matrix[i - 1][i] = 1
                if param.endswith("kernel:0"):
                    params[param] = np.transpose(value, [1, 0, 2, 3])
                    for i in range(np.shape(value)[0]):
                        for j in range(np.shape(value)[1]):
                            params[param][i][j] = np.asarray(
                                np.matmul(params[param][i][j], perm_matrix),
                                dtype=np.float32,
                            )
                elif param.endswith("bias:0"):
                    params[param] = np.matmul(value, perm_matrix)
            # Transpose kernels
            elif param.endswith("kernel:0"):
                if len(np.shape(value)) == 4:
                    params[param] = np.transpose(value, [1, 0, 2, 3])
                elif len(np.shape(value)) == 2:
                    prev_layer = next(iter(hailo_nn.predecessors(layer)))
                    input_shape = prev_layer.output_shape
                    if len(input_shape) == 4:
                        # In case of conv to dense
                        params[param] = np.reshape(
                            np.transpose(
                                np.reshape(
                                    value,
                                    [input_shape[2], input_shape[1], input_shape[3], layer.output_features],
                                ),
                                [1, 0, 2, 3],
                            ),
                            [layer.input_features, layer.output_features],
                        )
            # For bbox decoders, x parameters should be switched with y parameters from original network
            elif param.endswith("x_centers:0"):
                y_centers = param.split("x_centers:0")[0] + "y_centers:0"
                params[param] = params[y_centers]
                params[y_centers] = value
            elif param.endswith("anchors_widths:0"):
                anchors_heights = param.split("anchors_widths:0")[0] + "anchors_heights:0"
                params[param] = params[anchors_heights]
                params[anchors_heights] = value
            elif param.endswith("anchors_widths_div_2:0"):
                anchors_heights = param.split("anchors_widths_div_2:0")[0] + "anchors_heights_div_2:0"
                params[param] = params[anchors_heights]
                params[anchors_heights] = value
            elif param.endswith("anchors_widths_minus_div_2:0"):
                anchors_heights = param.split("anchors_widths_minus_div_2:0")[0] + "anchors_heights_minus_div_2:0"
                params[param] = params[anchors_heights]
                params[anchors_heights] = value
            elif param.endswith("const_data:0"):
                if len(np.shape(value)) == 3:
                    params[param] = np.transpose(value, [1, 0, 2])

    return hailo_nn, params
