import json
from collections import OrderedDict

from hailo_sdk_client.sdk_backend.profiler.base_data_extractor import BaseDataExtractor
from hailo_sdk_client.sdk_backend.script_parser.commands import SupportedCommands


class MetadataExtractor(BaseDataExtractor):
    def __init__(self, hn, modifications_meta_data, hw_arch):
        self._hn = hn
        self._modifications_meta_data = modifications_meta_data
        self._hw_arch = hw_arch

    def update(self, export):
        model_details = {
            "model_name": self._hn.name,
            "weights": sum((hn_layer.weights for hn_layer in self._hn), 0.0),
            "hw_arch": self._hw_arch.name,
            "total_ops_per_frame": sum((hn_layer.ops for hn_layer in self._hn), 0.0),
            "input_shapes": self._get_input_shapes(),
            "output_shapes": self._get_output_shapes(),
            "post_processing": self._get_post_processing(),
        }
        model_details.update(self._get_input_modifications_meta_data())
        export["stats"]["model_details"].update(model_details)
        export["hn_model"] = self._get_extended_hn_model_dict()

    def _get_input_shapes(self):
        input_layers = self._hn.get_non_const_input_layers()
        input_shapes = OrderedDict()
        for input_layer in input_layers:
            input_shapes[input_layer.name] = [int(x) for x in input_layer.input_shape[1:]]

        return ["x".join([str(dim) for dim in shape]) for shape in input_shapes.values()]

    def _get_output_shapes(self):
        output_layers = self._hn.get_real_output_layers(remove_non_neural_core_layers=False)
        output_shapes = OrderedDict()
        for output_layer in output_layers:
            output_shapes[output_layer.name] = [int(x) for x in output_layer.output_shapes[0][1:]]

        return ["x".join([str(dim) for dim in shape]) for shape in output_shapes.values()]

    def _get_input_modifications_meta_data(self):
        input_conversions = []
        resize_list = []
        has_normalization = False
        is_transposed = False

        for configs in self._modifications_meta_data.inputs.values():
            for config in configs:
                if config.cmd_type == SupportedCommands.INPUT_CONVERSION:
                    input_conversions.append(config.conversion_type.value)
                elif config.cmd_type == SupportedCommands.RESIZE:
                    input_shapes_str = "x".join([str(x) for x in config.output_shape])
                    resize_list.append(input_shapes_str)
                elif config.cmd_type == SupportedCommands.NORMALIZATION:
                    has_normalization = True
                elif config.cmd_type == SupportedCommands.TRANSPOSE:
                    is_transposed = True

        return {
            "input_conversion": input_conversions,
            "resize_input": resize_list,
            "normalization": has_normalization,
            "transpose": is_transposed,
        }

    def _get_post_processing(self):
        post_processing = []

        for configs in self._modifications_meta_data.outputs.values():
            for config in configs:
                if config.cmd_type == SupportedCommands.NMS_POSTPROCESS:
                    post_processing.append(f"{config.meta_arch.value} NMS")
                elif config.cmd_type == SupportedCommands.CHANGE_OUTPUT_ACTIVATION:
                    layer_name = config.hn_layer_name
                    orig_activation = config.original_activation.value
                    new_activation = config.new_activation.value
                    post_processing.append(f"{layer_name}: {orig_activation}->{new_activation}")
                elif config.cmd_type == SupportedCommands.LOGITS_LAYER:
                    post_processing.append(config.logits_type.value)
                elif config.cmd_type == SupportedCommands.RESIZE:
                    post_processing.append(f'resize to {"x".join([str(x) for x in config.output_shape])}')

        return post_processing

    def _get_extended_hn_model_dict(self):
        extended_hn_model_dict = json.loads(self._hn.to_hn(self._hn.name))

        for layer in self._hn:
            extended_hn_model_dict["layers"][layer.name]["weights"] = layer.weights
            extended_hn_model_dict["layers"][layer.name]["macs"] = layer.macs
            extended_hn_model_dict["layers"][layer.name]["ops"] = layer.ops

        return extended_hn_model_dict
