#!/usr/bin/env python
import copy
import tempfile
from contextlib import suppress

import numpy as np
import onnx
import onnxruntime as rt
from onnx import TensorProto
from onnx.helper import ValueInfoProto, make_tensor_value_info
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE

from hailo_sdk_client.exposed_definitions import Dims
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.onnx_tools.definitions import ONNX_LARGE_MODEL_BYTE_COUNT
from hailo_sdk_common.tools.models_translator_helper import map_hn_orig_names_to_orig_names, valid_orig_name


class UnsupportedGraphInputError(Exception):
    pass


class TFBackendInferenceError(Exception):
    pass


class OnnxRuntimeInferenceError(Exception):
    pass


def remove_initializers_for_inference(model):
    inputs = model.graph.input
    name_to_input = {}
    for graph_input in inputs:
        name_to_input[graph_input.name] = graph_input

    for initializer in model.graph.initializer:
        if initializer.name in name_to_input:
            inputs.remove(name_to_input[initializer.name])

    return model


def handle_start_node(hn_model, node, model, output_shapes, float_type, nodes_to_delete, names_map):
    new_inputs = []
    inputs_to_delete = []
    layer = hn_model.get_layer_by_original_name(node.name)
    input_layers = [layer]
    if layer.op != LayerType.input_layer:
        input_layers = [x for x in hn_model.predecessors(layer) if x.op == LayerType.input_layer]

    initializer_names = [x.name for x in model.graph.initializer]
    model_inputs = [input_vertex for input_vertex in model.graph.input if input_vertex.name not in initializer_names]
    for input_index in range(len(input_layers)):
        input_layer = input_layers[input_index]

        # If there are more nodes before the current node, even though it is given as a start node
        if all(model_input.name not in input_layer.original_names for model_input in model_inputs):
            input_shape = []
            if output_shapes and input_layer.original_names:
                possible_inputs = [
                    x for x in model.graph.node if any(valid_orig_name(x.name) in y for y in input_layer.original_names)
                ]
                if possible_inputs:
                    output_name = possible_inputs[0].output[0]
                    input_shape = [-1] + output_shapes[output_name][0][1:]

            if not input_shape:
                input_shape = [-1, input_layer.output_features]
                if len(input_layer.output_shape) == 4:
                    input_shape.extend([input_layer.output_height, input_layer.output_width])

            # edge case - batch normalization with identity input is ignored
            old_node_inputs = [x for x in model.graph.node if any(y in node.input for y in x.output)]
            old_node_inputs = [
                x for x in old_node_inputs if node.op_type != "BatchNormalization" or x.op_type != "Identity"
            ]

            input_layer_name = input_layer.original_names[0] if input_layer.original_names else input_layer.name
            input_layer_name = names_map[input_layer_name]
            new_model_input = make_tensor_value_info(input_layer_name, float_type, input_shape)
            relevant_input_idx = [i for i, x in enumerate(node.input) if x == f"{new_model_input.name}_output_0"]
            relevant_input_idx = relevant_input_idx[0] if len(relevant_input_idx) > 0 else 0
            node.input[relevant_input_idx] = new_model_input.name
            new_inputs.append(new_model_input)

            if old_node_inputs:
                nodes_to_delete.append(old_node_inputs[0])
                newly_added_inputs = nodes_to_delete[-1].input
                while newly_added_inputs:
                    current_added_inputs = []
                    for inp in newly_added_inputs:
                        # edge case - handle model inputs to delete, not just model nodes
                        if inp in [x.name for x in model_inputs]:
                            inputs_to_delete.extend([x for x in model.graph.input if x.name == inp])

                        # handle nodes to delete, ignore duplications (from alternate paths)
                        possible_inputs = [x for x in model.graph.node if any(y == inp for y in x.output)]
                        nodes_to_delete.extend([x for x in possible_inputs if x not in nodes_to_delete])
                        for inp_node in possible_inputs:
                            current_added_inputs.extend(inp_node.input)

                    newly_added_inputs = current_added_inputs

    return new_inputs, inputs_to_delete


def prepare_onnx_model_inference(loaded_model, hn_model, start_node_names=None, end_node_names=None):
    model = remove_initializers_for_inference(loaded_model)
    model_orig_names = [x.name for x in model.graph.node]
    names_map = map_hn_orig_names_to_orig_names(hn_model, model_orig_names)

    start_node_names = [names_map.get(x, x) for x in start_node_names] if start_node_names else None
    end_node_names = [names_map.get(x, x) for x in end_node_names] if end_node_names else None
    out_shapes = {
        value_info.name: [[x.dim_value for x in value_info.type.tensor_type.shape.dim]]
        for value_info in loaded_model.graph.value_info
    }
    out_types = {value_info.name: value_info.type.tensor_type.elem_type for value_info in loaded_model.graph.value_info}
    out_types.update({output.name: output.type.tensor_type.elem_type for output in loaded_model.graph.output})

    with suppress(Exception):
        out_shapes = run_shape_inference(model, end_node_names, output_shapes=out_shapes)
    # maps the output edges to their nodes
    edges_to_nodes = {output: node.name for node in model.graph.node for output in node.output}
    # maps the output of each node
    nodes_to_shapes = {node: out_shapes[edge][0] for edge, node in edges_to_nodes.items() if edge in out_shapes}

    initializer_names = [x.name for x in model.graph.initializer]
    model_inputs = [input_vertex for input_vertex in model.graph.input if input_vertex.name not in initializer_names]
    for model_input in model_inputs:
        model_input.type.tensor_type.shape.dim[0].dim_value = -1

    for node_input_shape in model.graph.value_info:
        if len(node_input_shape.type.tensor_type.shape.dim) == 4:
            node_input_shape.type.tensor_type.shape.dim[0].dim_value = -1

    float_type = model.graph.input[0].type.tensor_type.elem_type
    if float_type not in [int(TensorProto.FLOAT), int(TensorProto.FLOAT16)]:
        float_type = int(TensorProto.FLOAT)

    new_inputs = []
    new_output_nodes = {}
    nodes_to_delete = []
    inputs_to_delete = []

    for node in model.graph.node:
        if start_node_names and node.name in start_node_names:
            curr_inputs = handle_start_node(hn_model, node, model, out_shapes, float_type, nodes_to_delete, names_map)
            new_inputs.extend(curr_inputs[0])
            inputs_to_delete.extend(curr_inputs[1])

        if end_node_names and (
            (len(node.output) > 0 and node.output[0] in end_node_names) or node.name in end_node_names
        ):
            if node.name in end_node_names:
                out_node_layer = hn_model.get_layer_by_original_name(node.name)
            else:
                out_node_layer = hn_model.get_layer_by_original_name(node.output[0])

            output_shape = [-1, out_node_layer.output_shapes[0][-1]]
            if len(out_node_layer.output_shapes[0]) == 4:
                output_shape.extend([out_node_layer.output_height, out_node_layer.output_width])

            if node.output[0] in out_types:
                data_type = (
                    out_types[node.output[0]]
                    if (node.op_type != "ArgMax") and (out_types[node.output[0]] != 0)
                    else (onnx.TensorProto.INT64 if node.op_type == "ArgMax" else float_type)
                )
            else:
                data_type = onnx.TensorProto.INT64 if node.op_type == "ArgMax" else float_type
            new_out_node = make_tensor_value_info(
                name=node.output[0],
                elem_type=data_type,
                shape=nodes_to_shapes.get(node.name, output_shape),
            )
            name_key = node.name if node.name in end_node_names else new_out_node.name
            new_output_nodes[name_key] = new_out_node

        if node.op_type == "BatchNormalization":
            for attribute in node.attribute:
                if attribute.name == "spatial":
                    attribute.i = 1

    # handle new input nodes
    if start_node_names and new_inputs:
        relevant_graph_inputs = [x for x in model.graph.input if x not in inputs_to_delete]
        new_model_inputs = new_inputs + relevant_graph_inputs
        for i, x in reversed(list(enumerate(model.graph.input))):
            del model.graph.input[i]
        for x in new_model_inputs:
            model.graph.input.append(x)

        deleted_input_names = [x.name for x in inputs_to_delete]
        for node in model.graph.node:
            for i, inp in reversed(list(enumerate(node.input))):
                if inp in deleted_input_names:
                    del node.input[i]

    # handle new output nodes, and keep order same as end_node_names
    if end_node_names:
        for i, x in reversed(list(enumerate(model.graph.output))):
            del model.graph.output[i]
        for end_node_name in end_node_names:
            model.graph.output.append(new_output_nodes[end_node_name])

    # delete all disconnected nodes
    for i, x in reversed(list(enumerate(model.graph.node))):
        if x in nodes_to_delete:
            del model.graph.node[i]

    return model, names_map


def run_onnx_runtime_inference(
    hailo_nn: HailoNN,
    input_dataset,
    model_path,
    start_node_names,
    end_node_names,
    output_format=None,
):
    model = onnx.load(model_path, load_external_data=True)
    single_batch = any(node.op_type == "Resize" for node in model.graph.node)
    runtime_model, orig_to_onnx_names = prepare_onnx_model_inference(model, hailo_nn, start_node_names, end_node_names)
    onnx_rt_sess = get_onnxrt_session(runtime_model)
    if isinstance(input_dataset, dict):
        input_dataset = {orig_to_onnx_names.get(x, x): y for x, y in input_dataset.items()}

    # this is how the new output nodes are saved in the RT proto - the name is actually an index of the out edge
    start_node_names = [x.name for x in onnx_rt_sess.get_inputs()]
    output_node_indices = [x.name for x in onnx_rt_sess.get_outputs()]
    batch_size = (
        next(iter(input_dataset.values())).shape[0] if isinstance(input_dataset, dict) else input_dataset[0].shape[0]
    )
    if batch_size > 1 and single_batch:
        single_batch_inputs = []
        for batch in range(batch_size):
            if isinstance(input_dataset, dict):
                single_batch_inputs.append(
                    {node_name: data[batch : batch + 1] for node_name, data in input_dataset.items()},
                )
            else:
                single_batch_inputs.append(
                    {node_name: data[batch : batch + 1] for node_name, data in zip(start_node_names, input_dataset)},
                )
        results = []
        for feed_dict in single_batch_inputs:
            single_batch_result = onnx_rt_sess.run(output_node_indices, feed_dict)
            reshape_onnx_results(single_batch_result, output_format, end_node_names)
            results.append(single_batch_result)

        return [
            np.asarray([result[end_node_index][0] for result in results])
            for end_node_index in range(len(output_node_indices))
        ]

    feed_dict = input_dataset if isinstance(input_dataset, dict) else dict(zip(start_node_names, input_dataset))

    onnx_rt_results = onnx_rt_sess.run(output_node_indices, feed_dict)
    reshape_onnx_results(onnx_rt_results, output_format, end_node_names)
    return onnx_rt_results


def reshape_onnx_results(onnx_results, output_format=None, onnx_names=None):
    if onnx_names is not None and len(onnx_names) != len(onnx_results):
        raise UnsupportedGraphInputError("Can't compare results due to different number of outputs.")
    elif onnx_names is None:
        onnx_names = ["" for _ in onnx_results]

    for i, (name, onnx_res) in enumerate(zip(onnx_names, onnx_results)):
        shape = np.shape(onnx_res)
        if len(shape) == 1:
            onnx_results[i] = np.reshape(onnx_res, [1, shape[0]])
        elif output_format is not None and (name in output_format or not name and len(onnx_results) == 1):
            curr_format = output_format[name] if name in output_format else list(output_format.values())[-1]
            onnx_results[i] = reshape_by_output_format(onnx_res, curr_format)
        elif len(shape) == 3:
            onnx_results[i] = np.transpose(np.expand_dims(onnx_res, axis=1), [0, 2, 1, 3])
        elif len(shape) == 4:
            onnx_results[i] = np.transpose(onnx_res, [0, 2, 3, 1])
        elif len(shape) == 5:
            onnx_results[i] = np.transpose(onnx_res, [0, 3, 4, 2, 1])


def reshape_by_output_format(res, out_format):
    if len(res.shape) != len(out_format):
        raise UnsupportedGraphInputError("Can't compare results due to different number of dimensions.")

    out_format = out_format.copy()
    for dim in [Dims.BATCH, Dims.GROUPS, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH, Dims.STACK]:
        if dim not in out_format:
            res = np.expand_dims(res, axis=0)
            out_format.insert(0, dim)

    batch_i = out_format.index(Dims.BATCH)
    channels_i = out_format.index(Dims.CHANNELS)
    height_i = out_format.index(Dims.HEIGHT)
    width_i = out_format.index(Dims.WIDTH)
    groups_i = out_format.index(Dims.GROUPS)
    stack_i = out_format.index(Dims.STACK)

    perm = [batch_i, height_i, width_i, groups_i, channels_i]
    stack_insert = 3 if stack_i < channels_i else 5
    perm.insert(stack_insert, stack_i)
    res = np.transpose(res, perm)
    res = np.reshape(res, [*res.shape[:-3], -1])
    return res


def look_for_bwd_node(predecessors, node, ops_chain):
    preds = predecessors[node.name]
    if ops_chain[0] is None and not preds:
        return node

    for next_node in preds:
        if ops_chain[0] == next_node.op_type:
            if len(ops_chain) == 1:
                return next_node
            recursion_result = look_for_bwd_node(predecessors, next_node, ops_chain[1:])
            if recursion_result is not None:
                return recursion_result
    return None


def prepare_model_for_shape_inference(model, end_names, output_shapes=None):
    # Modifying model outputs (to save shapes of all nodes) so it must be a model copy
    model = remove_initializers_for_inference(copy.deepcopy(model))
    outputs_to_save = []
    output_shapes = {} if output_shapes is None else output_shapes
    predecessors = {}
    for node in model.graph.node:
        predecessors[node.name] = [x for x in model.graph.node if any(y in x.output for y in node.input)]

    if end_names and any(x.name in end_names or any(y in end_names for y in x.output) for x in model.graph.node):
        valid_nodes = []
        for end_node_name in end_names:
            possible_end_nodes = [x for x in model.graph.node if x.name == end_node_name or end_node_name in x.output]
            end_node = possible_end_nodes[0]
            if len(possible_end_nodes) != 1 or end_node in valid_nodes:
                continue
            valid_nodes.append(end_node)
            current_nodes = [x for x in predecessors[end_node.name] if x not in valid_nodes]
            while current_nodes:
                new_nodes = []
                for current_node in current_nodes:
                    valid_nodes.append(current_node)
                    all_nodes = valid_nodes + new_nodes + current_nodes
                    new_nodes.extend([x for x in predecessors[current_node.name] if x not in all_nodes])
                current_nodes = new_nodes
    else:
        valid_nodes = list(model.graph.node)

    new_output_nodes = []
    for node in model.graph.node:
        if node not in valid_nodes:
            continue
        if node.op_type == "BatchNormalization":
            for attribute in node.attribute:
                if attribute.name == "spatial":
                    attribute.i = 1
        elif node.op_type in ["Pad", "Slice"]:
            outputs_to_save.extend(node.input[1:])
        elif node.op_type == "Reshape":
            long_chain = ["BatchNormalization", "ReduceMean", "Mul", "Sub", "ReduceMean", None]
            long_reduce_mean = look_for_bwd_node(predecessors, node, long_chain)
            short_reduce_mean = look_for_bwd_node(predecessors, node, ["BatchNormalization", "ReduceMean", None])
            mha_chain = ["Expand", "Reshape", "Where", "Flatten", "Gather", "Cast", "Resize", None]
            multi_head_attention_mask_chain = look_for_bwd_node(predecessors, node, mha_chain)
            if (long_reduce_mean and short_reduce_mean) or multi_head_attention_mask_chain:
                add_output_to_save(new_output_nodes, node, outputs_to_save)
        elif node.op_type == "Transpose":
            pos_embed_chain = look_for_bwd_node(
                predecessors,
                node,
                [
                    "Reshape",
                    "Transpose",
                    "Concat",
                    "Reshape",
                    "Concat",
                    "Unsqueeze",
                    "Sin",
                    "Slice",
                    "Div",
                    "Unsqueeze",
                    "Mul",
                    "Div",
                    "CumSum",
                    "Cast",
                    "Not",
                    "Gather",
                    "Cast",
                    "Resize",
                    None,
                ],
            )
            if pos_embed_chain:
                add_output_to_save(new_output_nodes, node, outputs_to_save)
        elif node.op_type == "Mul":
            mha_chain = ["Where", "Reshape", "Expand", "Reshape", "Flatten", "Gather", "Cast", "Resize", None]
            mha_mask_chain = look_for_bwd_node(predecessors, node, mha_chain)
            if mha_mask_chain:
                add_output_to_save(new_output_nodes, node, outputs_to_save)

        out_nodes = []
        for output in node.output:
            existing_output_shapes = output_shapes.get(output)
            if existing_output_shapes and all(shape and 0 not in shape for shape in existing_output_shapes):
                continue
            out_node = ValueInfoProto()
            out_node.name = output
            out_nodes.append(out_node)
        new_output_nodes.extend(out_nodes)

    for i, x in reversed(list(enumerate(model.graph.output))):
        del model.graph.output[i]
    for i, x in reversed(list(enumerate(model.graph.node))):
        if x not in valid_nodes:
            del model.graph.node[i]
    for new_output_node in new_output_nodes:
        model.graph.output.append(new_output_node)

    return model, list(set(outputs_to_save))


def run_shape_inference(model, end_node_names, net_inputs_shapes=None, output_shapes=None):
    if output_shapes is None or net_inputs_shapes is not None:
        output_shapes = {}

    model, outputs_to_save = prepare_model_for_shape_inference(model, end_node_names, output_shapes)
    feed_dict = get_feed_dict(net_inputs_shapes, model)
    output_shapes.update(
        {
            node_name: [list(value.shape)] * len([x for x in model.graph.node if node_name in x.input])
            for node_name, value in feed_dict.items()
        },
    )

    output_tensor_names = [x.name for x in model.graph.output]
    if not output_tensor_names:
        return output_shapes

    onnx_rt_sess = get_onnxrt_session(model)
    rt_results = onnx_rt_sess.run(output_tensor_names, feed_dict)

    for output_name, result in zip(output_tensor_names, rt_results):
        output_shapes[output_name] = [list(result.shape)]
        if output_name in outputs_to_save:
            output_shapes[output_name + "_value"] = result

    return output_shapes


def get_feed_dict(net_inputs_shapes, runtime_model):
    feed_dict = {}
    for start_node in runtime_model.graph.input:
        start_node_shape = [x.dim_value for x in start_node.type.tensor_type.shape.dim]
        # validate shapes are not dynamic at this stage - should have been augmented prior to this
        none_or_negative_cond = any(x is None or x <= 0 for x in start_node_shape[1:])
        dynamic_axis_cond = any(isinstance(x, str) for x in start_node_shape)
        if none_or_negative_cond or dynamic_axis_cond:
            raise UnsupportedGraphInputError(
                f"Unsupported dynamic shape({start_node_shape}) found on input node {start_node.name}. "
                "Please use net_input_shapes, see documentation for additional info.",
            )

        # validate that provided net_input_shapes doesn't contradict original model metadata
        batch_size = start_node_shape[0]
        if net_inputs_shapes and start_node.name in net_inputs_shapes:
            batch_size = net_inputs_shapes[start_node.name][0]
            if not dynamic_axis_cond:
                input_shape = net_inputs_shapes[start_node.name]
                if list(input_shape[1:]) != list(start_node_shape[1:]):
                    raise UnsupportedGraphInputError(
                        f"Provided net_input_shapes for node {input_shape} contradicts "
                        f"model original input shape {start_node.shape}.",
                    )
        elif batch_size is None or isinstance(batch_size, str) < 0 or batch_size <= 0:
            batch_size = 1

        dtype = TENSOR_TYPE_TO_NP_TYPE.get(start_node.type.tensor_type.elem_type, np.dtype("float32"))
        feed_dict[start_node.name] = np.zeros([batch_size] + start_node_shape[1:]).astype(dtype)
    return feed_dict


def get_onnxrt_session(runtime_model):
    onnx_session_options = rt.SessionOptions()
    onnx_session_options.log_severity_level = 3  # suppress warnings (3 is level error)
    onnx_session_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_DISABLE_ALL
    onnx_session_options.enable_cpu_mem_arena = False
    large_model_detected = runtime_model.ByteSize() > ONNX_LARGE_MODEL_BYTE_COUNT
    providers = ["CPUExecutionProvider"]
    if large_model_detected:
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_path = f"{temp_dir}/rt.onnx"
            onnx.save_model(runtime_model, temp_path, save_as_external_data=True)
            onnx_rt_sess = rt.InferenceSession(temp_path, onnx_session_options, providers)
    else:
        onnx_rt_sess = rt.InferenceSession(runtime_model.SerializeToString(), onnx_session_options, providers)
    return onnx_rt_sess


def add_output_to_save(new_output_nodes, node, outputs_to_save):
    outputs_to_save.extend(node.output)
    for output in node.output:
        out_node = ValueInfoProto()
        out_node.name = output
        new_output_nodes.append(out_node)


def set_model_net_input_shapes(onnx_model, net_input_shapes):
    start_node_inputs = []  # Nodes derived from start_node_names
    net_inputs = []
    initializer_names = [x.name for x in onnx_model.graph.initializer]
    for start_name in net_input_shapes:
        possible_inputs = [x for x in onnx_model.graph.input if x.name == start_name]
        if possible_inputs:
            inp = make_tensor_value_info(possible_inputs[0].name, onnx.TensorProto.FLOAT, net_input_shapes[start_name])
            net_inputs.append(inp)
        else:
            graph_vertices = [x for x in onnx_model.graph.node if x.name == start_name]
            if not graph_vertices:
                raise UnsupportedGraphInputError(f"Couldn't find start node {start_name} in the given model.")

            # Create new input vertex
            start_vertex = graph_vertices[0]
            prevs = [
                x
                for x in onnx_model.graph.node
                if x.name not in initializer_names and any(out == start_vertex.input[0] for out in x.output)
            ] + [x for x in onnx_model.graph.input if x.name == start_vertex.input[0]]
            if not prevs:
                raise UnsupportedGraphInputError(f"Can't find predecessors for node {start_name} in the given model.")

            inp = make_tensor_value_info(prevs[0].name, onnx.TensorProto.FLOAT, net_input_shapes[start_name])
            curr_inputs = [x for x in onnx_model.graph.node if any(y == start_vertex.input[0] for y in x.output)]
            start_node_inputs.extend(curr_inputs)

            # Assign new input vertex as an input to start_vertex
            start_vertex.input[0] = inp.name
            net_inputs.append(inp)

    start_nodes_to_delete = []
    if start_node_inputs:
        # Find and add to start_nodes_to_delete, all the nodes that are before the chosen start nodes
        for start_node_input in start_node_inputs:
            start_nodes_to_delete.append(start_node_input)
            latest_nodes_inputs = [item for sublist in [x.input for x in start_nodes_to_delete] for item in sublist]
            i = 0
            while i < len(latest_nodes_inputs):
                curr = [x for x in onnx_model.graph.node if any(y == latest_nodes_inputs[i] for y in x.output)]
                start_nodes_to_delete.extend(curr)
                latest_nodes_inputs.extend([item for sublist in [x.input for x in curr] for item in sublist])
                i += 1

    # Delete old input vertices from graph
    del onnx_model.graph.input[:]

    # Add new inputs to the graph inputs
    onnx_model.graph.input.extend(net_inputs)

    # Remove all start_nodes_to_delete from graph
    for i, x in reversed(list(enumerate(onnx_model.graph.node))):
        if x in start_nodes_to_delete:
            del onnx_model.graph.node[i]


def get_native_input_data_from_onnx(hailo_nn, onnx_input_data, output_format=None):
    """
    Transpose the data from ONNX input order to HN order.
    """
    input_layers = hailo_nn.get_input_layers()
    hn_input_data = {}

    if not isinstance(onnx_input_data, dict):  # unify data manipulation for non dict inputs, by converting to dict
        if len(input_layers) > 1:
            raise OnnxRuntimeInferenceError(
                "For multiple input tensors, the dataset should be given as "
                "as a dictionary, where each key is the original name of the input node.",
            )
        onnx_input_data = {input_layers[0].original_names[0]: onnx_input_data[0]}

    for vertex_name, vertex_data in onnx_input_data.items():
        valid_name = valid_orig_name(vertex_name)
        hn_layer = next(iter(layer for layer in hailo_nn.get_input_layers() if valid_name in layer.original_names))

        if not output_format or valid_name not in output_format:
            raise OnnxRuntimeInferenceError(
                f"Either output format is missing or output format is missing for vertex {vertex_name}. "
                "Please provide output format for all vertices.",
            )
        hn_input_data[hn_layer.name] = reshape_by_output_format(vertex_data, output_format[valid_name])

    return hn_input_data
