from hailo_model_optimization.acceleras.utils.acceleras_definitions import FormatConversionType
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers import FusedConv2DLayer


class SlicedConvSplitter(FuserAlgorithm):
    NAME = "sliced_conv_splitting"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._fuser_helper = FuserHelper(self.model)
        self._shortcuts_to_remove = []

    def get_algo_config(self):
        return self._model_config

    def _setup(self):
        pass

    def should_skip_algo(self):
        return False

    def log_config(self):
        pass

    def _run_int(self):
        self._handle_conv_and_slices()
        self._handle_conv_and_feature_splitter()
        self._remove_shortcuts()

    def _handle_conv_and_slices(self):
        layers_to_remove = []
        new_layers = []
        successors_meta_data = {}
        for layer in list(self._model):
            if layer.op != LayerType.conv:
                continue

            if layer.groups != 1:
                continue

            # check first that all successors of the current conv are slices with no strides on features, and no
            # slicing on spatial dims
            slices = []
            not_all_succs_are_valid_slices = False
            conv_succs = list(self._model.successors(layer))
            for succ in conv_succs:
                if (
                    (succ.op != LayerType.slice)
                    or not (succ.features_slice[2] == succ.width_slice[2] == succ.height_slice[2] == 1)
                    or not (succ.width_slice[0] == succ.height_slice[0] == 0)
                    or (succ.height_slice[1] != succ.input_height)
                    or (succ.width_slice[1] != succ.input_width)
                    or (self._is_postprocess_ancestor(succ))
                    or succ.groups != 1
                ):
                    not_all_succs_are_valid_slices = True
                    break
                slices.append(succ)
            if not_all_succs_are_valid_slices or len(slices) < 2:
                continue

            # check that the slices don't overlap on the features
            features_slices_overlap = False
            slices = sorted(slices, key=lambda x: x.features_slice[0])
            for i in range(len(slices) - 1):
                if slices[i].features_slice[1] > slices[i + 1].features_slice[0]:
                    features_slices_overlap = True
                    break
            if features_slices_overlap:
                continue

            layers_to_remove.append(layer)

            conv_preds = list(self._model.predecessors(layer))
            src_params = self.params[layer.name]

            for pred in conv_preds:
                for succ in conv_succs:
                    self._add_sliced_conv(
                        layer,
                        layers_to_remove,
                        new_layers,
                        pred,
                        succ,
                        successors_meta_data,
                        src_params,
                    )

                self._model.remove_edge(pred, layer)

        # finalize graph manipulation outside toposort iteration
        for layer in layers_to_remove:
            self._model.remove_layer(layer)

        for layer in new_layers:
            self._model.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_conv_and_feature_splitter(self):
        layers_to_remove = []
        new_layers = []
        successors_meta_data = {}

        for layer in list(self._model):
            if layer.op != LayerType.conv or layer.groups != 1:
                continue

            conv_succs = list(self._model.successors(layer))
            if len(conv_succs) > 1:
                continue

            spatial_reshape = None
            if (
                conv_succs[0].op == LayerType.format_conversion
                and conv_succs[0].conversion_type == FormatConversionType.spatial_reshape
                and layer.kernel_shape[:2] == layer.strides[1:3] == layer.dilations[1:3] == [1, 1]
                and layer.layer_disparity == 1
                and len(list(self._model.successors(conv_succs[0]))) == 1
            ):
                spatial_reshape = conv_succs[0]
                conv_succs = list(self._model.successors(spatial_reshape))

            if conv_succs[0].op != LayerType.feature_splitter:
                continue

            conv_preds = list(self._model.predecessors(layer))
            if len(conv_preds) != 1:
                continue

            feature_split = conv_succs[0]
            fs_succs = list(self.model.successors(feature_split))
            are_almost_all_succs_shortcut = sum(succ.op == LayerType.shortcut for succ in fs_succs) >= len(fs_succs) - 2
            if not (self._is_matmul_ancestor(feature_split) or are_almost_all_succs_shortcut):
                continue

            conv_pred = conv_preds[0]
            src_params = self.params[layer.name]
            layers_to_remove.extend([layer, feature_split])
            self._model.remove_edge(layer, next(iter(self._model.successors(layer))))
            self._model.remove_edge(conv_pred, layer)

            if spatial_reshape:
                self._model.remove_edge(spatial_reshape, feature_split)
                self._model.add_edge(conv_pred, spatial_reshape)
                spatial_reshape.replace_input_index(layer.index, conv_pred.index)
                spatial_reshape.replace_input_layer(layer.name, conv_pred.name)
                conv_pred.replace_output_index(layer.index, spatial_reshape.index)
                conv_pred.replace_output_layer(layer.name, spatial_reshape.name)

            end = 0
            groups = feature_split.groups
            new_index = self._model.get_next_index()
            for i, output in enumerate(feature_split.outputs):
                output_index = feature_split.output_indices[i]
                output_shape = feature_split.output_shapes[i]

                sliced_conv = FusedConv2DLayer.from_layer(layer)
                block_name, succ_name = self.get_block_and_layer_names(feature_split.name_without_scope)
                new_layers.append(sliced_conv)
                sliced_conv.name = f"{feature_split.scope}/{block_name}conv_{succ_name}_{i + 1}"
                sliced_conv.index = new_index + i
                sliced_conv.move_params(layer)
                sliced_conv.move_params(feature_split)
                slice_size = output_shape[-1]
                sliced_conv.kernel_shape[-1] = slice_size
                start = end
                end = start + slice_size
                self._move_fused_slice_params(sliced_conv, layer, src_params, [start, end], groups=groups)

                if spatial_reshape:
                    self._model.add_edge(spatial_reshape, sliced_conv)
                    if feature_split.index in spatial_reshape.output_indices:
                        spatial_reshape.replace_output_index(feature_split.index, sliced_conv.index)
                        spatial_reshape.replace_output_layer(feature_split.name, sliced_conv.name)
                    else:
                        spatial_reshape.append_output_index(sliced_conv.index)
                        spatial_reshape.append_output_layer(sliced_conv.name)
                else:
                    self._model.add_edge(conv_pred, sliced_conv)
                    if layer.index in conv_pred.output_indices:
                        conv_pred.replace_output_index(layer.index, sliced_conv.index)
                        conv_pred.replace_output_layer(layer.name, sliced_conv.name)
                    else:
                        conv_pred.append_output_index(sliced_conv.index)
                        conv_pred.append_output_layer(sliced_conv.name)

                sliced_conv.outputs = [output]
                sliced_conv.output_indices = [output_index]
                sliced_conv.output_shapes = [output_shape]

                sliced_conv_succ = self._model.get_layer_by_name(output)
                sliced_conv_succ.replace_input_index(feature_split.index, sliced_conv.index)
                sliced_conv_succ.replace_input_layer(feature_split.name, sliced_conv.name)
                self._model.update_successors_meta_data(sliced_conv_succ, successors_meta_data)
                self._model.remove_edge(feature_split, sliced_conv_succ)
                self._model.add_edge(sliced_conv, sliced_conv_succ)
                if sliced_conv_succ.op == LayerType.shortcut:
                    self._shortcuts_to_remove.append(sliced_conv_succ)

        for layer in layers_to_remove:
            self._model.remove_layer(layer)

        for layer in new_layers:
            self._model.relax_new_layer_into_graph(layer, successors_meta_data)

    def _add_sliced_conv(self, layer, layers_to_remove, new_layers, pred, succ, successors_meta_data, src_params):
        # switch each slice to a conv layer and copy the conv attributes
        sliced_conv = FusedConv2DLayer.from_layer(layer)
        block_name, succ_name = self.get_block_and_layer_names(succ.name_without_scope)
        sliced_conv.name = f"{succ.scope}/{block_name}conv_{succ_name}"
        sliced_conv.index = self._model.get_next_index()
        sliced_conv.kernel_shape[-1] = succ.output_shape[-1]

        sliced_conv.move_params(layer)
        sliced_conv.move_params(succ)

        self._move_fused_slice_params(sliced_conv, layer, src_params, succ.features_slice)

        pred.replace_output_index(layer.index, sliced_conv.index)
        pred.replace_output_layer(layer.name, sliced_conv.name)

        sliced_conv.replace_input_index(layer.index, pred.index)
        sliced_conv.replace_input_layer(layer.name, pred.name)
        sliced_conv.outputs = succ.outputs
        sliced_conv.output_indices = succ.output_indices
        sliced_conv.output_shapes = succ.output_shapes

        layers_to_remove.append(succ)
        new_layers.append(sliced_conv)

        sliced_conv_succs = list(self._model.successors(succ))
        for sliced_conv_succ in sliced_conv_succs:
            sliced_conv_succ.replace_input_index(succ.index, sliced_conv.index)
            sliced_conv_succ.replace_input_layer(succ.name, sliced_conv.name)
            self._model.update_successors_meta_data(sliced_conv_succ, successors_meta_data)
            self._model.remove_edge(succ, sliced_conv_succ)
            self._model.add_edge(sliced_conv, sliced_conv_succ)

        self._model.remove_edge(layer, succ)
        self._model.add_edge(pred, sliced_conv)

        if succ.name in self.model.net_params.output_layers_order:
            index = self.model.net_params.output_layers_order.index(succ.name)
            self.model.net_params.output_layers_order[index] = sliced_conv.name

    @staticmethod
    def _should_keep_param(x):
        if "activation_delta_bias" in x:
            return True

        return all(k not in x for k in ["kernel", "bias"])

    def _is_matmul_ancestor(self, layer):
        successors = {layer: succ for succ in self._model.successors(layer)}
        while len(successors) > 0:
            curr, succ = successors.popitem()
            if succ.op == LayerType.matmul and (curr.name == succ.inputs[1] or curr.op == LayerType.format_conversion):
                return True
            if succ.op in [LayerType.normalization, LayerType.batch_norm, LayerType.format_conversion]:
                successors.update({succ: new_succ for new_succ in self._model.successors(succ)})

        return False

    def _is_postprocess_ancestor(self, slice):
        # return True if the slice layer has a postprocess child or grandchild
        successors = set(self._model.successors(slice))
        while len(successors) > 0:
            succ = successors.pop()
            if succ.op == LayerType.postprocess:
                return True
            if succ.op == LayerType.activation and succ.activation == ActivationType.sigmoid:
                successors = set(self._model.successors(succ))
            else:
                return False
        return False

    def _remove_shortcuts(self):
        layers_to_remove = []

        for layer in self._shortcuts_to_remove:
            self._fuser_helper.remove_layer(layer, layers_to_remove, self.model)
            self._logger.debug(f"Removed null shortcut layer {layer.name}.")

        for layer in layers_to_remove:
            self.model.remove_layer(layer)
