#!/usr/bin/env python
import io
import os
import struct
import sys
from enum import IntEnum

import msgpack
import numpy as np

from hailo_sdk_client import ClientRunner
from hailo_sdk_client.exposed_definitions import InferenceContext
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.serialization.numpy_serialization import hailo_np_savez

HAILODC_FD = 100
HAILODC_HW_ARCH = "hailo8"


class HailoDCStatus(IntEnum):
    SUCCESS = 0
    ENVIRONMENT_ERROR = 1
    PARSER_ERROR = 2
    OPTIMIZE_ERROR = 3
    COMPILER_ERROR = 4
    PROFILER_ERROR = 5
    VISUALIZER_ERROR = 6
    MODEL_SCRIPT_ERROR = 7
    EXECUTION_ERROR = 8
    INFER_ERROR = 9


def build(args):
    try:
        model = bytes(args["onnx_model"])
        start_node_names = args.get("start_node_names", None)
        end_node_names = args.get("end_node_names", None)
        runner = ClientRunner(hw_arch=HAILODC_HW_ARCH)
        runner.translate_onnx_model(
            model=model,
            start_node_names=start_node_names,
            end_node_names=end_node_names,
            disable_rt_metadata_extraction=True,
        )
    except Exception:
        default_logger().exception("HailoDC parsing error:")
        sys.exit(HailoDCStatus.PARSER_ERROR.value)

    try:
        runner.load_model_script(args["model_script"])
    except Exception:
        default_logger().exception("HailoDC model script error:")
        sys.exit(HailoDCStatus.MODEL_SCRIPT_ERROR.value)

    try:
        calib_data = np.array(args["calib_data"]).reshape(args["calib_data_shape"])
        runner.optimize(calib_data)
    except Exception:
        default_logger().exception("HailoDC optimization error:")
        sys.exit(HailoDCStatus.OPTIMIZE_ERROR.value)

    try:
        if "fps" in args:
            default_logger().info("Ignore fps argument")
        hef = runner.compile()
    except Exception:
        default_logger().exception("HailoDC compilation error:")
        sys.exit(HailoDCStatus.COMPILER_ERROR.value)

    try:
        hn_model = runner.get_hn_model()
        qnpz_in_memory = io.BytesIO()
        hailo_np_savez(qnpz_in_memory, **dict(iter(runner.get_params_translated().items())))
        honpz_in_memory = io.BytesIO()
        hailo_np_savez(honpz_in_memory, **dict(iter(runner._sdk_backend.get_params_hailo_optimized().items())))
        return {
            "hef": hef,
            "hn": hn_model.to_hn(hn_model.name).encode(),
            "qnpz": qnpz_in_memory.getvalue(),
            "honpz": honpz_in_memory.getvalue(),
        }
    except Exception:
        default_logger().exception("HailoDC prepare data to send error:")
        sys.exit(HailoDCStatus.EXECUTION_ERROR.value)


def infer(args):
    try:
        runner = ClientRunner(hw_arch=HAILODC_HW_ARCH)
        runner.set_hn(bytes(args["hn"]).decode())
        runner.load_params(io.BytesIO(bytes(args["qnpz"])))
        runner.load_params(io.BytesIO(bytes(args["honpz"])))
        dataset = np.array(args["dataset"]).reshape(args["dataset_shape"])

        with runner.infer_context(InferenceContext.SDK_QUANTIZED) as ctx:
            inference_results = runner.infer(ctx, dataset)

        output_layers = runner.get_hn_model().get_real_output_layers(remove_non_neural_core_layers=False)

        if len(output_layers) == 1:
            inference_results = [inference_results]

        return {
            "inference_results": [inference_result.flatten().tolist() for inference_result in inference_results],
            "inference_results_shapes": [inference_result.shape for inference_result in inference_results],
            "output_layers": [output_layer.name for output_layer in output_layers],
        }

    except Exception:
        default_logger().exception("HailoDC infer error:")
        sys.exit(HailoDCStatus.INFER_ERROR.value)


if __name__ == "__main__":
    data = sys.stdin.buffer.read()
    args = msgpack.loads(data)

    method = args.get("method", "build")
    result = build(args) if method == "build" else infer(args)

    try:
        data_to_send = msgpack.dumps(result)
        pipe = os.fdopen(HAILODC_FD, mode="wb")
        pipe.write(struct.pack("<I", len(data_to_send)))
        pipe.write(data_to_send)
        pipe.close()
    except Exception:
        default_logger().exception("HailoDC communication error:")
        sys.exit(HailoDCStatus.EXECUTION_ERROR.value)

    sys.exit(HailoDCStatus.SUCCESS.value)
