#!/usr/bin/env python
import struct

import xxhash
from google.protobuf import text_format

from hailo_sdk_client.allocator import hef_pb2

HEF_HEADER_MAGIC = 0x01484546


class HefBaseHeader:
    BASE_HEADER_FORMAT = ">III"
    BASE_HEADER_SIZE = struct.calcsize(BASE_HEADER_FORMAT)
    EXTRA_HEADER_FORMAT = ""

    def __init__(self, data=None):
        self.magic = HEF_HEADER_MAGIC
        self.version = 0
        self.proto_size = 0
        self.padding_size = 0
        self.ccw_size = 0
        self.additional_info_size = 0
        self.hash = 0
        if data:
            self.parse(data)

    def __str__(self):
        return f"----- HEF header -----\nmagic: {hex(self.magic)}\nversion: {hex(self.version)}\nproto_size: {hex(self.proto_size)}"

    def __repr__(self):
        return self.__str__()

    def __len__(self):
        return self.HEADER_SIZE

    @property
    def HEADER_SIZE(self):
        return self.BASE_HEADER_SIZE + struct.calcsize(self.EXTRA_HEADER_FORMAT)

    @property
    def size(self):
        return self.HEADER_SIZE

    def parse(self, data):
        self.magic, self.version, self.proto_size = struct.unpack(
            self.BASE_HEADER_FORMAT, data[: self.BASE_HEADER_SIZE]
        )

    def recalculate_hash(self, proto, padding, ccw):
        xxh = xxhash.xxh3_64()
        xxh.update(proto)
        xxh.update(padding)
        xxh.update(ccw)
        return xxh.intdigest()

    def serialize(self):
        return struct.pack(self.BASE_HEADER_FORMAT, self.magic, self.version, self.proto_size)

    def update(self, proto, padding, ccw):
        self.proto_size = len(proto)
        self.padding_size = len(padding)
        self.ccws_size = len(ccw) + self.padding_size  # padding is included in ccws_size for some reason
        self.hash = self.recalculate_hash(proto, padding, ccw)


class HefHeaderV0(HefBaseHeader):
    def __init__(self, data=None):
        raise NotImplementedError("HEF version 0 is deprecated")


class HefHeaderV1(HefBaseHeader):
    def __init__(self, data=None):
        raise NotImplementedError("HEF version 1 is deprecated")


class HefHeaderV2(HefBaseHeader):
    EXTRA_HEADER_FORMAT = ">QQQQ"

    def __init__(self, data=None):
        super().__init__(data)
        self.version = 2
        self.reserved2 = 0
        if data:
            self.parse(data)

    def __str__(self):
        return f"{super().__str__()}\nxxhash3_64: {hex(self.hash)}\nccws_size: {self.ccws_size}\nreserved1: {hex(self.padding_size)}\nreserved2: {hex(self.reserved2)}"

    def parse(self, data):
        super().parse(data)
        self.hash, self.ccws_size, self.padding_size, self.reserved2 = struct.unpack(
            self.EXTRA_HEADER_FORMAT, data[self.BASE_HEADER_SIZE : self.HEADER_SIZE]
        )

    def serialize(self):
        return super().serialize() + struct.pack(
            self.EXTRA_HEADER_FORMAT, self.hash, self.ccws_size, self.padding_size, self.reserved2
        )


class HefHeaderV3(HefBaseHeader):
    EXTRA_HEADER_FORMAT = ">QQIQQQ"

    def __init__(self, data=None):
        super().__init__(data)
        self.version = 3
        self.reserved1 = 0
        self.reserved2 = 0
        if data:
            self.parse(data)

    def __str__(self):
        return f"{super().__str__()}\nxxhash3_64: {hex(self.hash)}\nccws_size: {self.ccws_size}\npadding_size: {self.padding_size}\nadditional_info_size: {self.additional_info_size}\nreserved1: {hex(self.reserved1)}\nreserved2: {hex(self.reserved2)}"

    def parse(self, data):
        super().parse(data)
        self.hash, self.ccws_size, self.padding_size, self.additional_info_size, self.reserved1, self.reserved2 = (
            struct.unpack(self.EXTRA_HEADER_FORMAT, data[self.BASE_HEADER_SIZE : self.HEADER_SIZE])
        )

    def serialize(self):
        return super().serialize() + struct.pack(
            self.EXTRA_HEADER_FORMAT,
            self.hash,
            self.ccws_size,
            self.padding_size,
            self.additional_info_size,
            self.reserved1,
            self.reserved2,
        )


class HefWrapper:
    HEF_PAGE_SIZE = 4096

    def __init__(self, raw_hef_data):
        self._hef_raw = raw_hef_data
        self._quiet = False
        self._is_mcw = False
        self._parse_hef()

    @classmethod
    def from_hef_path(cls, hef_path):
        with open(hef_path, "rb") as f:
            hef_raw = f.read()
        return cls(hef_raw)

    def save(self, path):
        with open(path, "wb") as f:
            f.write(self.serialize())

    @staticmethod
    def base_header(data):
        return HefBaseHeader(data)

    @staticmethod
    def hef_header_factory(version):
        if version == 0:
            return HefHeaderV0()
        elif version == 1:
            return HefHeaderV1()
        elif version == 2:
            return HefHeaderV2()
        elif version == 3:
            return HefHeaderV3()
        else:
            raise ValueError(f"Invalid HEF version: {version}")

    def update_header(self):
        self._header.update(self._proto.SerializeToString(), self.padding, self._ccws)

    def _parse_hef(self):
        base_header = self.base_header(self._hef_raw)
        if base_header.magic != HEF_HEADER_MAGIC:
            raise ValueError(f"Invalid HEF magic: {hex(base_header.magic)}[{hex(HEF_HEADER_MAGIC)}]!")
        self._header = self.hef_header_factory(base_header.version)
        self._header.parse(self._hef_raw)
        self._proto = hef_pb2.ProtoHEFHef()
        self._proto.ParseFromString(self._hef_raw[self._header.size : self._header.size + self._header.proto_size])
        # notice that padding is included in ccws_size
        self._ccws = self._hef_raw[
            self._header.size + self._header.proto_size + self._header.padding_size : self._header.size
            + self._header.proto_size
            + self._header.ccws_size
        ]
        self._additional_info = self._hef_raw[self._header.size + self._header.proto_size + self._header.ccws_size :]
        assert len(self._additional_info) == self._header.additional_info_size, "Wrong sizes! failed to parse HEF"

    @property
    def hef_proto(self):
        return self._proto

    @property
    def hef_header(self):
        return self._header

    @property
    def ccws(self):
        return self._ccws

    @ccws.setter
    def ccws(self, ccws_data):
        self._ccws = ccws_data
        self.update_header()

    @property
    def padding(self):
        return b"\x00" * self.calc_alignment_padding_size()

    @property
    def hef_version(self):
        return self.hef_proto.header.version

    @property
    def version(self):
        return self._header.version

    @version.setter
    def version(self, new_version):
        if self._header.version != new_version:
            self._header = self.hef_header_factory(new_version)
            self.update_header()

    @property
    def network_groups(self):
        return self.hef_proto.network_groups

    @property
    def external_resources(self):
        return self.hef_proto.external_resources

    def calc_alignment_padding_size(self):
        if self._header.version < 3:
            return 0
        proto_size = len(self._proto.SerializeToString())
        return self.HEF_PAGE_SIZE - (self._header.size + proto_size) % self.HEF_PAGE_SIZE

    def serialize(self):
        self.update_header()
        proto_str = self._proto.SerializeToString()
        return self._header.serialize() + proto_str + self.padding + self._ccws + self._additional_info

    def get_parsed_config_message(self, data, cfg_channel_index, field_seperator):
        return (
            f"data: {data.hex()}{field_seperator}data_length: 0x{len(data):0x}"
            f"{field_seperator}cfg_channel_index: {cfg_channel_index}"
        )

    def custom_message_formatter(self, message, indent, as_one_line):
        field_seperator = "\n{}".format(indent * " ")

        # in quiet mode don't print buffers
        if self._quiet:
            return None

        if isinstance(message, hef_pb2.ProtoHEFActionWriteDataCcw):
            return self.get_parsed_config_message(message.data, message.cfg_channel_index, field_seperator)

        elif isinstance(message, hef_pb2.ProtoHEFActionWriteData):
            return (
                f"address: 0x{message.address:0x}{field_seperator}data: {message.data.hex()}"
                f"{field_seperator}data_length: 0x{len(message.data):0x}"
            )

        elif isinstance(message, hef_pb2.ProtoHEFActionWriteDataCcwPtr):
            data = self._ccws[message.offset : message.offset + message.size]
            return self.get_parsed_config_message(
                data,
                message.cfg_channel_index,
                field_seperator,
                message.size,
                message.offset,  # , message.burst_size
            )

        # In case there is no need for a custom formatter `None` should be returned and a default formmater would be used
        return None

    def _format_hef_proto(self, hef_proto_to_print):
        hef_proto_to_print.mapping = b""
        formatted_message = "\n----- HEF proto -----\n"
        formatted_message += text_format.MessageToString(
            hef_proto_to_print,
            message_formatter=self.custom_message_formatter,
        )
        return formatted_message

    def print_hef(self, pdb=None, quiet=True):
        output = str(self.hef_header)
        self._quiet = quiet
        raw_message = self._proto
        formatted_message = self._format_hef_proto(raw_message)
        output += formatted_message
        if pdb:
            print("Opening ipdb")
            print("HEF is stored at the object 'raw_message'")
            __import__("ipdb").set_trace()
        else:
            print(output)

    def append(self, hef):
        assert self.hef_version == hef.hef_version, "HEF versions must be the same"
        for network_group in hef.network_groups:
            add_offset_to_ccws(network_group, len(self.ccws))
            self.network_groups.append(network_group)
        self.ccws += hef.ccws


def add_offset_to_ccws(network_group, offset):
    for operation in network_group.preliminary_config.operation:
        for action in operation.actions:
            if action.HasField("write_data_ccw_ptr"):
                action.write_data_ccw_ptr.offset = action.write_data_ccw_ptr.offset + offset
    for context in network_group.contexts:
        for operation in context.operations:
            for action in operation.actions:
                if action.HasField("write_data_ccw_ptr"):
                    action.write_data_ccw_ptr.offset = action.write_data_ccw_ptr.offset + offset
