from hailo_sdk_client.exposed_definitions import NNFramework
from hailo_sdk_client.model_translator.edge_nn_translator import EdgeNNConverter
from hailo_sdk_client.model_translator.exceptions import (
    RecordableParserError,
    UnsupportedEWLayerError,
    UnsupportedModelError,
    UnsupportedOperationError,
    UnsupportedShuffleLayerError,
)
from hailo_sdk_client.model_translator.tflite_translator.tflite_graph import (
    ACTIVATION_OPS,
    ADD_N_OPS,
    ADD_OPS,
    ARGMAX_OPS,
    CONCAT_OPS,
    CONV2D_OPS,
    CUSTOM_SIGN_OPS,
    DENSE_OPS,
    DEPTH_TO_SPACE_OPS,
    DIV_OPS,
    EQUAL_OPS,
    L2_OPS,
    MAX_OPS,
    MIN_OPS,
    MUL_OPS,
    NORMALIZATION_OPS,
    OPTIONAL_NULL_OPS,
    PACK_OPS,
    PAD_OPS,
    POOL_OPS,
    POW_OPS,
    REDUCE_MAX_OPS,
    REDUCE_MIN_OPS,
    RESIZE_OPS,
    SHUFFLE_OPS,
    SKIP_OPS,
    SLICE_OPS,
    SOFTMAX_OPS,
    SPACE_TO_DEPTH_OPS,
    SPLIT_OPS,
    SUB_OPS,
    SUM_OPS,
    SUPPORTED_FUSED_ACTIVATIONS,
    SUPPORTED_OPS_UNION,
    TILE_OPS,
    TFLiteGraph,
)
from hailo_sdk_client.model_translator.tflite_translator.tflite_layer_creator import (
    create_activation_layer,
    create_layer_from_vertex,
)
from hailo_sdk_common.hailo_nn.exceptions import RecordableCreateLayerError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType, SpaceToDepthType


class TFLiteConverter(EdgeNNConverter):
    def __init__(self, model, values, start_node_names=None, end_node_names=None):
        end_node_names = self._get_real_end_node_names(model, end_node_names)
        super().__init__(
            graph=TFLiteGraph(model, values),
            start_node_names=start_node_names,
            end_node_names=end_node_names,
        )

        self._fused_activations = {}
        self._nn_framework = NNFramework.TENSORFLOW_LITE

    @staticmethod
    def _get_real_end_node_names(tflite_model, end_node_names):
        if not end_node_names:
            subgraph = tflite_model.Subgraphs(0)
            net_output = subgraph.OutputsAsNumpy()
            end_node_names = []
            for output in net_output:
                for j in range(subgraph.OperatorsLength()):
                    node = subgraph.Operators(j)
                    if output in node.OutputsAsNumpy():
                        end_node_names.append(subgraph.Tensors(node.Outputs(0)).Name().decode())
        return end_node_names

    def _add_output_layers(self, fused_activations=None):
        super()._add_output_layers(fused_activations=self._fused_activations)

    def _handle_fused_layers(self):
        self._add_fused_activations()

    def _should_skip_vertex(self, vertex):
        if vertex.op in SKIP_OPS:
            return True
        if vertex.op == "RESHAPE" and vertex.is_flatten_reshape():
            if self._end_node_names and vertex.name in self._end_node_names:
                raise UnsupportedModelError(
                    f"Reshape node {vertex.name} (flatten operator) was selected as an end "
                    f"node, but such operations aren't supported at graph end. Please retry "
                    f"with a different end node.",
                )
            self._update_consumed_vertices_states([vertex], should_assign_vertex_to_layer=False)
            return True
        if (
            (
                vertex.op == "GREATER"
                and vertex.is_threshold_activation()
                or vertex.op == "EXP"
                and vertex.is_mish_activation()
                or (
                    vertex.op == "LOGISTIC"
                    and vertex.is_silu_activation()
                    or vertex.op == "MUL"
                    and vertex.is_swish_activation_first_mul()
                )
                or (
                    vertex.op == "UNPACK"
                    and not vertex.is_null_operation()
                    or vertex.op == "SQUEEZE"
                    and not (vertex.is_flatten_reshape() or vertex.is_rank4_to_rank3_reshape())
                )
            )
            or vertex.op == "PACK"
            and not vertex.is_pack_resize()
        ):
            return True

        return False

    def _add_fused_activations(self):
        for vertex, layer in self._vertices_to_layers.items():
            if vertex in self._fused_activations:
                act, _, _ = create_activation_layer(
                    vertex,
                    is_fused_activation=True,
                    op=self._fused_activations[vertex],
                )
                self._layers_graph.push_layer(act, preds=[layer])

    @staticmethod
    def _add_original_names(layer, vertices):
        for vertex in vertices:
            layer.add_original_name(vertex.name)

    def _layer_callback_from_vertex(self, vertex):
        const_input_layer = None
        try:
            if vertex.op not in SUPPORTED_OPS_UNION:
                raise UnsupportedOperationError(f"{vertex.op} operation is unsupported")
            if vertex.op in CONV2D_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.conv, vertex)
            elif (
                vertex.op in ACTIVATION_OPS
                or (vertex.op in CUSTOM_SIGN_OPS and vertex.is_biased_delta_activation())
                or (
                    vertex.op in MUL_OPS
                    and (
                        vertex.is_threshold_activation()
                        or vertex.is_mish_activation()
                        or vertex.is_silu_activation()
                        or vertex.is_swish_activation_second_mul()
                    )
                )
                or (vertex.op in DIV_OPS and vertex.is_inv_pos_activation())
                or (vertex.op in MIN_OPS + MAX_OPS and not vertex.is_ew_max() and not vertex.is_ew_min())
            ):
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.base_activation, vertex)
            elif vertex.op in DENSE_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.dense, vertex)
            elif vertex.op in L2_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.l2_normalization, vertex)
            elif vertex.op in POOL_OPS or vertex.is_global_max_pool():
                layer, consumed_vertices, activation = create_layer_from_vertex("pool", vertex)
            elif vertex.op in SOFTMAX_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.softmax, vertex)
            elif vertex.op in ARGMAX_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.argmax, vertex)
            elif vertex.op in ADD_OPS:
                if vertex.is_ew_add():
                    layer, consumed_vertices, activation, const_input_layer = create_layer_from_vertex(
                        LayerType.base_ew_add,
                        vertex,
                    )
                elif vertex.is_mul_by_2_ew_add() or vertex.is_normalization():
                    layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.normalization, vertex)
                else:
                    raise UnsupportedEWLayerError(
                        f"Failed to determine type of layer to create in node {vertex.name} ({vertex.op})",
                    )
            elif vertex.op in ADD_N_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.base_ew_add_n, vertex)
            elif vertex.op in SUB_OPS and vertex.is_ew_sub():
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.base_ew_sub, vertex)
            elif vertex.op in MUL_OPS and vertex.is_ew_mult():
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.ew_mult, vertex)
            elif vertex.op in DIV_OPS and vertex.is_ew_div():
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.ew_div, vertex)
            elif vertex.op in MAX_OPS and vertex.is_ew_max():
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.ew_max, vertex)
            elif vertex.op in MIN_OPS and vertex.is_ew_min():
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.ew_min, vertex)
            elif vertex.op in MUL_OPS and vertex.is_square():
                if vertex.is_reduce_l2():
                    layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.reduce_l2, vertex)
                else:
                    layer, consumed_vertices, activation = create_layer_from_vertex(
                        LayerType.feature_multiplier,
                        vertex,
                    )
            elif vertex.op in POW_OPS:
                if vertex.is_null_operation():
                    # power of 1
                    layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.null, vertex)
                else:
                    vertex.verify_pow_is_square()
                    layer, consumed_vertices, activation = create_layer_from_vertex(
                        LayerType.feature_multiplier,
                        vertex,
                    )
            elif vertex.op in NORMALIZATION_OPS and vertex.is_normalization():
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.normalization, vertex)
            elif vertex.op in PAD_OPS or (vertex.op in CONCAT_OPS and vertex.is_external_pad_concat()):
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.external_pad, vertex)
            elif vertex.op in CONCAT_OPS:
                layer, consumed_vertices, activation, const_input_layer = create_layer_from_vertex(
                    LayerType.concat,
                    vertex,
                )
            elif vertex.op in SPLIT_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.feature_splitter, vertex)
            elif vertex.op in RESIZE_OPS or (vertex.op in PACK_OPS and vertex.is_pack_resize()):
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.resize, vertex)
            elif vertex.op in REDUCE_MAX_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.reduce_max, vertex)
            elif vertex.op in REDUCE_MIN_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.reduce_min, vertex)
            elif vertex.op in SUM_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.reduce_sum, vertex)
            elif vertex.op in OPTIONAL_NULL_OPS and vertex.is_null_operation():
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.null, vertex)
            elif vertex.op in DEPTH_TO_SPACE_OPS or vertex.is_depth_to_space_reshape_block()[0]:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.depth_to_space, vertex)
            elif vertex.op in SPACE_TO_DEPTH_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.space_to_depth, vertex)
            elif vertex.op in SHUFFLE_OPS:
                if vertex.is_shuffle():
                    layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.feature_shuffle, vertex)
                elif (
                    vertex.is_features_reshape()
                    or vertex.is_flat_to_frames_reshape()
                    or vertex.is_supported_transpose()
                ):
                    layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.format_conversion, vertex)
                else:
                    raise UnsupportedShuffleLayerError(
                        f"Failed to determine type of layer to create in node {vertex.name} ({vertex.op})",
                    )
            elif vertex.op in SLICE_OPS:
                if vertex.is_space_to_depth_slice_block():
                    layer, consumed_vertices, activation = create_layer_from_vertex(
                        LayerType.space_to_depth,
                        vertex,
                        space_to_depth_type=SpaceToDepthType.focus,
                    )
                else:
                    layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.base_slice, vertex)
            elif vertex.op in EQUAL_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.equal, vertex)
            elif vertex.op in TILE_OPS:
                layer, consumed_vertices, activation = create_layer_from_vertex(LayerType.tile, vertex)
            else:
                raise UnsupportedModelError(f"Unexpected node {vertex.name} ({vertex.op})")

        except (RecordableParserError, RecordableCreateLayerError) as e:
            self._handle_recordable_parser_error(vertex, e)
            return

        if const_input_layer is not None:
            self._add_layer(const_input_layer, has_edge=False)
            self._add_layer(layer)
            self._add_original_names(layer, consumed_vertices)
            self._layers_graph.add_edge(const_input_layer, layer)
            layer.add_input_by_vertex(const_input_layer, input_name=const_input_layer.original_names[0])
        else:
            self._add_layer(layer)
            self._add_original_names(layer, consumed_vertices)

        consumed_vertices.append(vertex)
        self._handle_consumed_vertices(consumed_vertices)

        if activation and activation in SUPPORTED_FUSED_ACTIVATIONS:
            self._fused_activations.update({vertex: activation})

    def _consume_flatten_chain(self, pred, layer):
        if pred.op == "RESHAPE":
            self._consume_pre_layer_op(pred, layer)

    def _prevent_transpose_hw_suggestion(self):
        # TODO: SDK-45504
        pass
