import copy

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_concat import HailoConcat
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    LayerFeaturePolicy,
)
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class SwitchConcatWithAdd(OptimizationAlgorithm):
    def __init__(self, model: HailoModel, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, "switch_concat_with_add", logger_level, **kwargs)
        self.layers_degrees = None
        self._layers_to_switch = []

    def get_algo_config(self):
        return self._model_config.switch_concat_with_add

    def finalize_global_cfg(self, algo_config):
        pass

    def should_skip_algo(self):
        for lname in self._model.flow.toposort():
            layer = self._model.layers[lname]
            if isinstance(layer, HailoConcat):
                layer_cfg = self.get_algo_config().layers
                policy = layer_cfg.get(lname, None)
                if policy is None:
                    continue
                if policy == LayerFeaturePolicy.enabled:
                    if len(layer.input_shapes) == 2 and len(layer.output_shapes) == 1:
                        suc_name = self._model.flow.successors_sorted(lname)[0]
                        suc = self._model.layers[suc_name]
                        if isinstance(suc, HailoConv):
                            self._layers_to_switch.append(lname)
        return len(self._layers_to_switch) == 0

    def _setup(self):
        self.layers_degrees = {}

    def _run_int(self):
        for lname in self._layers_to_switch:
            layer = self._model.layers[lname]
            suc_name = self._model.flow.successors_sorted(lname)[0]
            suc = self._model.layers[suc_name]
            preds = [pred for pred in self._model.flow.predecessors_sorted(lname)]
            conv_sucs = self._model.flow.successors_sorted(suc_name)
            original_weights = suc.export_weights()
            new_convs = []
            acc_channels = 0
            act_name = suc.hn_element["params"]["activation"]
            for i in range(2):
                channels = layer.input_shapes[i][-1]
                hn_new = copy.deepcopy(suc.hn_element)
                weights_new = suc.export_weights()
                weights_new["kernel"] = weights_new["kernel"][:, :, acc_channels : acc_channels + channels, :]
                if i == 0:
                    weights_new["bias"] = weights_new["bias"]
                else:
                    weights_new["bias"] = np.zeros_like(weights_new["bias"])
                hn_new["input_shapes"][0][-1] = channels
                hn_new["params"]["activation"] = ActivationType.LINEAR.value
                hn_new["params"]["kernel_shape"][-2] = channels
                new_layer = HailoConv.from_hn(f"{suc.full_name}_s{i}", hn_new)
                new_layer.import_weights(weights_new)
                new_convs.append(new_layer)
                acc_channels += channels

            ## remove conv and concat layers
            for pred in preds:
                self._model.flow.remove_edge(pred, lname)
            self._model.flow.remove_edge(lname, suc_name)
            self._remove_layer_node(suc)
            self._remove_layer_node(layer)

            ## add new layers
            ew_add_hn = {
                "params": {},
                "input_shapes": [conv.output_shape for conv in new_convs],
                "output_shapes": [new_convs[0].output_shape],
            }
            ew_add_hn["params"]["activation"] = act_name
            ew_add_layer = HailoElementwiseAdd.from_hn(f"{suc.full_name}_add", ew_add_hn)
            original_weights.pop("kernel")
            original_weights.pop("bias")
            ew_add_layer.import_weights(original_weights)
            self._model.flow.add_node(ew_add_layer.full_name)
            self._model.layers[ew_add_layer.full_name] = ew_add_layer
            for i in range(2):
                self._model.layers[new_convs[i].full_name] = new_convs[i]
                self._model.flow.add_node(new_convs[i].full_name)
                self._model.flow.add_edge(preds[i], new_convs[i].full_name, input_index=0, output_index=0)
                self._model.flow.add_edge(new_convs[i].full_name, ew_add_layer.full_name, input_index=i, output_index=0)
            for conv_succ in conv_sucs:
                self._model.flow.add_edge(ew_add_layer.full_name, conv_succ, input_index=0, output_index=0)
        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level,
            logger=self._logger,
        )
        algo.run()

    def _remove_layer_node(self, layer):
        self._model.flow.remove_node(layer.full_name)
        self._model.layers.pop(layer.full_name, None)
        self._model_config.remove_layer_from_all_configs(layer.full_name)

    def _get_valid_layer_cfg(self, lname, cfg):
        if lname in self._model.layers:
            return cfg
        return None
