import copy
from abc import ABC

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    FormatConversionType,
    IOType,
    LayerHandlerType,
    LayerSupportStatus,
    PostprocessTarget,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class IOLayer(Layer, ABC):
    def __init__(self):
        super().__init__()
        self._io_type = IOType.STANDARD
        self._cache_id = -1
        self._is_real_io = True

    @property
    def is_real_io(self):
        return self._is_real_io

    @is_real_io.setter
    def is_real_io(self, value):
        self._is_real_io = value

    @property
    def io_type(self):
        return self._io_type

    @io_type.setter
    def io_type(self, io_type):
        self._io_type = io_type

    @property
    def cache_id(self):
        return self._cache_id

    @cache_id.setter
    def cache_id(self, cache_id):
        self._cache_id = cache_id

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.transposed = self.transposed
        node.io_type = pb_wrapper.IO_TYPE_HN_TO_PB_TYPE[self.io_type]
        node.cache_id = self.cache_id
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.transposed = pb.transposed
        layer.io_type = pb_wrapper.IO_TYPE_PB_TO_HN_TYPE[pb.io_type]
        layer.cache_id = pb.cache_id
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.transposed = old_layer.transposed
        layer.io_type = old_layer.io_type
        layer.cache_id = old_layer.cache_id
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["transposed"] = self.transposed
        result["engine"] = self.engine.value
        result["io_type"] = self.io_type.value
        if self.io_type == IOType.CACHE:
            result["cache_id"] = self.cache_id
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "transposed" in hn:
            layer.transposed = hn["transposed"]
        if "engine" in hn:
            layer.engine = PostprocessTarget(hn["engine"])
        if "io_type" in hn:
            layer.io_type = IOType(hn["io_type"])
        if "cache_id" in hn:
            layer.cache_id = hn["cache_id"]
        return layer


class InputLayer(IOLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.input_layer
        # needed in order to support cases where an input layer has multiple outputs
        self._number_of_inputs_supported = None
        self._transposed = False
        self._conversion_type = None
        self._emulate_conversion = False

    @classmethod
    def create(cls, original_name, output_shapes):
        layer = cls()
        layer.add_original_name(original_name)
        layer.output_shapes = output_shapes
        for shape in layer.output_shapes:
            shape[0] = -1
            if len(shape) == 3:
                shape.insert(1, 1)
        if len(output_shapes) > 1 and any(shape != output_shapes[0] for shape in output_shapes[1:]):
            raise UnsupportedModelError(f"{layer.full_name_msg} has different output shapes.")
        layer.input_shapes = [layer.output_shapes[0]]
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_INPUT
        return node

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def conversion_type(self):
        return self._conversion_type

    @conversion_type.setter
    def conversion_type(self, conversion_type):
        self._conversion_type = conversion_type

    @property
    def emulate_conversion(self):
        return self._emulate_conversion

    @emulate_conversion.setter
    def emulate_conversion(self, emulate_conversion):
        self._emulate_conversion = emulate_conversion

    def to_hn(self, should_get_default_params=False):
        result = super().to_hn(should_get_default_params)
        if self.conversion_type:
            result["conversion_type"] = self.conversion_type.value
            result["emulate_conversion"] = self.emulate_conversion
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "transposed" in hn:
            layer.transposed = hn["transposed"]
        if "conversion_type" in hn:
            layer.conversion_type = FormatConversionType(hn["conversion_type"])
            layer.emulate_conversion = hn["emulate_conversion"]
        return layer

    def __str__(self):
        description = super().__str__()
        if self.conversion_type:
            description += f" +{self.conversion_type.value}"

        return description

    @property
    def short_description(self):
        description = super().short_description
        if self.conversion_type:
            description += f" +{self.conversion_type.value}"

        return description

    def is_zippable(self, other):
        """Input layers can be zipped only when they have the same shape"""
        if self.output_features != other.output_features:
            return False

        return super().is_zippable(other)


class OutputLayer(IOLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.output_layer
        self._transposed = False
        self._output_scale_per_channel = False

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None):
        layer = cls()
        layer.add_original_name(original_name)
        layer.input_vertex_order = input_vertex_order
        if output_shapes:
            layer.output_shapes = output_shapes
        for shape in layer.output_shapes:
            if len(shape) == 3:
                shape.insert(1, 1)
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_OUTPUT
        return node

    @classmethod
    def from_data(cls, name, output_shape):
        layer = cls()
        layer.add_original_name(name)
        layer.input_shape = output_shape
        layer.output_shapes = output_shape
        return layer

    def to_hn(self, should_get_default_params=False):
        result = super().to_hn(should_get_default_params)
        if self.output_scale_per_channel:
            result["output_scale_per_channel"] = self.output_scale_per_channel
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "output_scale_per_channel" in hn:
            layer.output_scale_per_channel = hn["output_scale_per_channel"]
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def finetune_supported(self):
        return False

    @property
    def output_scale_per_channel(self):
        return self._output_scale_per_channel

    @output_scale_per_channel.setter
    def output_scale_per_channel(self, output_scale_per_channel):
        self._output_scale_per_channel = output_scale_per_channel


class ExternalInputLayer(IOLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.external_input_layer
        self._transposed = False

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_EXTERNAL_INPUT
        return node

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported


class ExternalOutputLayer(IOLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.external_output_layer
        self._transposed = False

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_EXTERNAL_OUTPUT
        return node

    @classmethod
    def from_data(cls, name, output_shape):
        layer = cls()
        layer.add_original_name(name)
        layer.input_shape = output_shape
        layer.output_shapes = output_shape
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def finetune_supported(self):
        return False


class PPInputLayer(IOLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.pp_input_layer
        self._transposed = False

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_PP_INPUT
        return node

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected


class PPOutputLayer(IOLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.pp_output_layer
        self._transposed = False

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_PP_OUTPUT
        return node

    @classmethod
    def from_data(cls, name, output_shape):
        layer = cls()
        layer.add_original_name(name)
        layer.input_shape = output_shape
        layer.output_shapes = output_shape
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected

    @property
    def finetune_supported(self):
        return False
