#!/usr/bin/env python
import math

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PreQuantizationDefuseType
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers import EWAddNLayer, FeatureSplitterLayer


class InputFeaturesDefuse(FuserAlgorithm):
    NAME = "input_features_defuse"
    HW_FRAME = 8

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._fuser_helper = FuserHelper(self.model)

    def get_algo_config(self):
        return self._model_config.defuse

    def _setup(self):
        pass

    def _run_int(self):
        act_dict = self._defuse_input_features()
        ew_adds_to_params = self._fuser_helper.split_ew_add_n_layers(split_to_base_layers=False, act_dict=act_dict)
        for ew_add, params in ew_adds_to_params.items():
            self.params.update({f"{ew_add}/{x}": y for x, y in params.items()})

    def should_skip_algo(self):
        return False

    def log_config(self):
        pass

    @staticmethod
    def _should_keep_param(x):
        return False

    def _defuse_input_features(self):
        activations = {}
        algo_cfg = self.get_algo_config()
        for layer in list(self.model):
            new_layers = []
            successors_meta_data = {}

            if (
                layer.op not in [LayerType.dense, LayerType.conv]
                or layer.op == LayerType.conv
                and layer.layer_disparity > 1
            ):
                continue
            if (
                layer.name in algo_cfg.layers
                and algo_cfg.layers[layer.name].defuse_type == PreQuantizationDefuseType.INPUT_FEATURES
            ):
                num_splits = algo_cfg.layers[layer.name].num_splits
            else:
                in_features = (
                    layer.input_features // layer.groups if layer.op == LayerType.conv else layer.input_features
                )
                num_splits = math.ceil(
                    in_features * layer.kernel_height * layer.kernel_width / self._hw_arch.consts["DEFUSE_MAX_SIZE"]
                )

            if num_splits == 1:
                continue

            preds = list(self.model.predecessors(layer))
            if len(preds) != 1:
                continue

            pred = preds[0]
            succs = list(self.model.successors(layer))

            layer_index = pred.outputs.index(layer.name)
            pred_output_shape = pred.output_shapes[layer_index]

            initial_split_size = pred_output_shape[-1] // num_splits
            flatten_spatial_size = math.prod(pred_output_shape[1:-1]) if layer.op == LayerType.dense else 1

            align_size = 1 if initial_split_size < self.HW_FRAME else self.HW_FRAME
            aligned_initial_split_size = (initial_split_size // align_size) * align_size
            split_sizes = [aligned_initial_split_size] * num_splits
            remaining_features = pred_output_shape[-1] - (aligned_initial_split_size * num_splits)
            i = 0
            while remaining_features > align_size:
                split_sizes[i] += align_size
                remaining_features -= align_size
                i = (i + 1) % num_splits
            split_sizes[i] += remaining_features

            orig_layer_name = layer.name
            layer.name += "_d0"
            block_name, layer_name = self.get_block_and_layer_names(layer.name_without_scope)
            fs_layer = FeatureSplitterLayer()
            self.model.add_node(fs_layer)
            new_layers.append(fs_layer)
            fs_layer.name = f"{layer.scope}/{block_name}fs_{layer_name}"
            fs_layer.index = self.model.get_next_index()
            fs_layer.inputs = layer.inputs.copy()
            fs_layer.input_shapes = layer.input_shapes.copy()
            fs_layer.input_indices = layer.input_indices.copy()
            fs_layer.outputs = [layer.name]
            fs_layer.output_shapes = [pred_output_shape[:-1] + [split] for split in split_sizes]
            fs_layer.output_indices = [layer.index]
            fs_layer.split_sizes = split_sizes
            self.model.add_edge(fs_layer, layer)

            pred.outputs[layer_index] = fs_layer.name
            pred.output_indices[layer_index] = fs_layer.index

            add_n_layer = EWAddNLayer.from_layer(layer)
            self.model.add_node(add_n_layer)
            new_layers.append(add_n_layer)
            add_n_layer.name = f"{layer.scope}/{block_name}ew_add_n_{layer_name}"
            add_n_layer.index = fs_layer.index + 1
            add_n_layer.inputs = [layer.name]
            add_n_layer.input_shapes = [layer.output_shape for _ in range(num_splits)]
            add_n_layer.input_indices = [layer.index]
            add_n_layer.outputs = layer.outputs.copy()
            add_n_layer.output_shapes = layer.output_shapes.copy()
            add_n_layer.output_indices = layer.output_indices.copy()
            add_n_layer.move_params(layer)
            act_params = {
                x: y for x, y in self.params[orig_layer_name].items() if "kernel" not in x and "bias" not in x
            }
            activations[add_n_layer.name] = (layer.activation, act_params)

            for succ in succs:
                self.model.remove_edge(layer, succ)
                self.model.add_edge(add_n_layer, succ)
                succ.replace_input_index(layer.index, add_n_layer.index)
                succ.replace_input_layer(orig_layer_name, add_n_layer.name)
                self.model.update_successors_meta_data(succ, successors_meta_data)

            layer_splits = [shape[-1] for shape in fs_layer.output_shapes]
            start = layer_splits[0]
            end = layer_splits[0]
            src_params = self.params[orig_layer_name]
            for i in range(1, num_splits):
                end += layer_splits[i]
                layer_split = type(layer).from_layer(layer)
                self.model.add_node(layer_split)
                new_layers.append(layer_split)
                layer_split.name = f"{orig_layer_name}_d{i}"
                layer_split.index = add_n_layer.index + i
                layer_split.kernel_shape[-2] = layer_splits[i] * flatten_spatial_size
                layer_split.inputs = [fs_layer.name]
                layer_split.input_shapes = [fs_layer.output_shapes[i].copy()]
                layer_split.input_indices = [fs_layer.index]
                layer_split.outputs = [add_n_layer.name]
                layer_split.output_shapes = [layer.output_shapes[0].copy()]
                layer_split.output_indices = [add_n_layer.index]
                layer_split.move_params(layer)
                layer_split.activation = ActivationType.linear
                self._move_fused_slice_params(
                    layer_split,
                    layer,
                    src_params,
                    [start, end],
                    input_defuse=True,
                    pred_output_shape=pred_output_shape,
                )
                start = end

                self.model.add_edge(fs_layer, layer_split)
                fs_layer.append_output_layer(layer_split.name)
                fs_layer.append_output_index(layer_split.index)
                self.model.add_edge(layer_split, add_n_layer)
                add_n_layer.append_input_index(layer_split.index)
                add_n_layer.append_input_layer(layer_split.name)

            self.model.add_edge(pred, fs_layer)
            self.model.add_edge(layer, add_n_layer)
            self.model.remove_edge(pred, layer)
            layer.kernel_shape[-2] = layer_splits[0] * flatten_spatial_size
            layer.inputs = [fs_layer.name]
            layer.input_shapes = [fs_layer.output_shapes[0].copy()]
            layer.input_indices = [fs_layer.index]
            layer.outputs = [add_n_layer.name]
            layer.output_indices = [add_n_layer.index]
            layer.activation = ActivationType.linear
            self._move_fused_slice_params(layer, layer, src_params, [0, layer_splits[0]], True, pred_output_shape)

            for i, output in enumerate(self.model.net_params.output_layers_order):
                if orig_layer_name == output:
                    self.model.net_params.output_layers_order[i] = add_n_layer.name

            self._logger.info(f"{orig_layer_name} splitted into {num_splits} {layer.op.value} layers.")

            for layer in new_layers:
                self.model.relax_new_layer_into_graph(layer, successors_meta_data)
        return activations
