from typing import OrderedDict

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ConstInputLayer(Layer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_RANK3_SUPPORTED = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.const_input
        self._const_values = None
        self._input_tiles = [[1, 1, 1]]

    @classmethod
    def create(cls, original_name, output_shapes, values):
        layer = super().create(original_name, [], output_shapes)
        layer.const_values = values
        layer.input_shapes = [[output_shapes[0][0], *values.shape]]
        if any(
            not (output_dim / input_dim).is_integer()
            for input_dim, output_dim in zip(values.shape, output_shapes[0][1:])
        ):
            msg = (
                f"In vertex {original_name} the constant value shape {values.shape} must be broadcastable to the "
                f"output shape {output_shapes[0][1:]}"
            )
            raise UnsupportedModelError(msg)
        # broadcast ratio is the ratio between the output shape and the const value shape
        # used to broadcast the value to the output shape
        layer.input_tiles = [
            [output_dim // value_dim for output_dim, value_dim in zip(output_shapes[0][1:], values.shape)],
        ]
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_CONST_INPUT
        node.kernel_shape.width, node.kernel_shape.features = self.input_shape[1], self.input_shape[1]
        for tile in self._input_tiles:
            input_tiles = node.input_tiles.add()
            input_tiles.height, input_tiles.width, input_tiles.features = tile
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.input_tiles = [
            [input_tile.height, input_tile.width, input_tile.features] for input_tile in pb.input_tiles
        ]
        return layer

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def requires_quantized_weights(self):
        return True

    @property
    def const_values(self):
        return self._const_values

    @const_values.setter
    def const_values(self, const_values):
        self._const_values = const_values

    @property
    def input_tiles(self):
        return self._input_tiles

    @input_tiles.setter
    def input_tiles(self, input_tiles):
        self._input_tiles = input_tiles

    def move_params(self, layer):
        super().move_params(layer)
        if layer.op == LayerType.const_input:
            self._const_values = layer.const_values
        if layer.op == LayerType.equal:
            self._const_values = layer.constant_input

    def _calc_output_shape(self):
        return [-1, *[dim * ratio for dim, ratio in zip(self._input_shapes[0][1:], self._input_tiles[0])]]

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op == LayerType.const_input:
            layer.input_tiles = old_layer.input_tiles
        return layer

    def to_hn(self, should_get_default_params=False):
        result = super().to_hn(should_get_default_params)
        result["params"] = OrderedDict()
        result["params"]["input_tiles"] = self._input_tiles
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.input_tiles = (
            hn["params"]["input_tiles"] if "params" in hn and "input_tiles" in hn["params"] else [[1] * 3]
        )
        return layer

    def __str__(self):
        return super().__str__() + f", {self.input_tiles=}"
