import copy

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise_add import HailoDepthwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_sub import HailoElementwiseSub
from hailo_model_optimization.acceleras.hailo_layers.hailo_normalization import HailoNormalization
from hailo_model_optimization.acceleras.hailo_layers.hailo_normalization_add import HailoNormalizationAdd
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    FeaturePolicy,
    InfusibleEWAddType,
    LayerType,
    PaddingType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import EWAddFusingError
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class EWAddFusing(OptimizationAlgorithm):
    """
    Fusing EWAdd layers with conv, normalization or depthwise layers.
    Can be disabled using the command: pre_quantization_optimization(ew_add_fusing, policy=disabled).
    """

    SUPPORTED_CONV_AND_ADD_STRIDE = 1
    FUSING_TYPES = {
        HailoConv: HailoConvAdd,
        HailoNormalization: HailoNormalizationAdd,
        HailoDepthwise: HailoDepthwiseAdd,
    }

    def __init__(self, model: HailoModel, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, "Element-wise Add Fusing", logger_level, **kwargs)
        self.layers_degrees = None

    def get_algo_config(self):
        return self._model_config.ew_add_fusing

    def finalize_global_cfg(self, algo_config):
        pass

    def should_skip_algo(self):
        return self.get_algo_config().policy != FeaturePolicy.enabled

    def _setup(self):
        self.layers_degrees = {}

    def _run_int(self):
        # input layer's degree is 1
        for start_lname in self._model.flow.input_nodes:
            self.layers_degrees[start_lname] = 1

        for lname in self._model.flow.toposort():
            self._calculate_layer_degree(lname)
            layer = self._model.layers[lname]
            if isinstance(layer, HailoElementwiseAdd):
                # Try to fuse conv and add, or fallback to standalone ew add layer according to policy
                self._find_valid_conv_and_fuse(layer)

    def _calculate_layer_degree(self, lname):
        if lname not in self.layers_degrees:
            # calculate current layer's degree according to predecessors
            max_pred_degree = 0
            for pred_name in self._model.flow.predecessors(lname):
                max_pred_degree = max(max_pred_degree, self.layers_degrees[pred_name])
            self.layers_degrees[lname] = max_pred_degree + 1

    def _find_valid_conv_and_fuse(self, ew_add):
        # verifying ew_add has only two predecessors
        preds = [self._model.layers[lname] for lname in self._model.flow.predecessors(ew_add.full_name)]
        if len(preds) != 2:
            raise EWAddFusingError(
                f'elementwise-add layer {ew_add.full_name} must have exactly two predecessors. '
                f'(translated from [{", ".join(ew_add.hn_element["original_names"])}]).',
            )

        for i in range(2):
            pred = preds[i]
            neighbor = preds[1 - i]
            if not self._should_fuse_conv(pred, neighbor, ew_add):
                continue

            # incase i==1 the case is ew_sub(x, conv(y)) so will be translated to conv_add(x, -y)
            # so neg the kernel is required and considered it as regular ew_add
            #  x - (Ay + b) = x + ((-A)y + (-b))
            neg_output = i == 1 and isinstance(ew_add, HailoElementwiseSub)

            # replace standalone conv and add with fused ConvAndAdd
            conv_preds_names = list(self._model.flow.predecessors(pred.full_name))
            self._add_conv_add_layer([conv_preds_names[0], neighbor.full_name], pred, ew_add, neg_output)
            out_index = self._model.flow.get_edge_output_index(neighbor.full_name, ew_add.full_name)
            self._model.flow.remove_edge(neighbor.full_name, ew_add.full_name)
            self._model.flow.add_edge(neighbor.full_name, pred.full_name, input_index=1, output_index=out_index)
            self._model.remove_layer(ew_add)
            self._model_config.remove_layer_from_all_configs(ew_add.full_name)
            for j, output in enumerate(self._model.flow.output_layer_order):
                if output == ew_add.full_name:
                    self._model.flow.output_layer_order[j] = pred.full_name
            return

        self._logger.debug(
            f"Could not fuse element-wise-add layer {ew_add.full_name} with any predecessor. compatible "
            "options are: exactly two predecessors, one of which  is a conv layer, with one successor "
            "only, and a linear activation (activation must be after the EW-add op).",
        )
        # Fuse ew-add layer as an identity conv fallback standalone ew-add layer
        if self.get_algo_config().infusible_ew_add_type == InfusibleEWAddType.conv and not isinstance(
            ew_add,
            HailoElementwiseSub,
        ):
            self._handle_infusible_ew_add(ew_add)

    def _should_fuse_conv(self, pred, neighbor, ew_add):
        neighbor_fusible = self._is_layer_fusible(neighbor)
        pred_wins = (
            not neighbor_fusible
            or neighbor_fusible
            and (self.layers_degrees[pred.full_name] >= self.layers_degrees[neighbor.full_name])
        )
        conv_preds = list(self._model.flow.predecessors(pred.full_name))
        # The ew_add neighbor must not be the input of the conv.
        return (
            self._is_layer_fusible(pred)
            and pred_wins
            and neighbor.full_name not in conv_preds
            and self._precision_modes_fusible(pred, ew_add)
        )

    def _is_layer_fusible(self, layer):
        if type(layer) not in [HailoConv, HailoDepthwise, HailoNormalization]:
            return False

        valid_strides = all(x == self.SUPPORTED_CONV_AND_ADD_STRIDE for x in layer.conv_op.strides)
        valid_succs_len = len(list(self._model.flow.successors(layer.full_name))) == 1
        valid_output = not layer.transpose_output_width_features
        valid_activation = layer.act_op.act_name == ActivationType.LINEAR
        return all([valid_strides, valid_succs_len, valid_output, valid_activation])

    def _precision_modes_fusible(self, conv, ew_add):
        """
        Check if the precision modes of the conv and the ew_add are compatible for fusing, by:
        1.Verifying the input precision modes match so mixed precision weights can be fused.
        2.Verifying the conv.output and ew_add.input precision modes match to avoid edge cases of
        mismatching precision modes.
        """
        illegal_acceleras_precision_fusing = [
            (PrecisionMode.a8_w8_a16.value.split("_"), PrecisionMode.a8_w8.value.split("_")),
            (PrecisionMode.a16_w16_a16.value.split("_"), PrecisionMode.a8_w8.value.split("_")),
        ]
        conv_precision_mode_list = conv.get_precision_mode().value.split("_")
        ew_add_precision_mode_list = ew_add.get_precision_mode().value.split("_")
        if (conv_precision_mode_list, ew_add_precision_mode_list[:-1]) in illegal_acceleras_precision_fusing:
            raise EWAddFusingError(
                f"Illegal precision modes for nodes: {conv.full_name} with precision mode: "
                f"{conv.get_precision_mode().value}, {ew_add.full_name} with precision mode: "
                f"{ew_add.get_precision_mode().value}."
            )
        return conv_precision_mode_list[0] == ew_add_precision_mode_list[0] == conv_precision_mode_list[-1]

    def _add_conv_add_layer(self, sources, orig_conv, ew_add, neg_output):
        hn = copy.deepcopy(orig_conv.hn_element)
        hn["input"] = sources
        in_shapes = [orig_conv.input_shape, ew_add.input_shapes[0]]
        hn["input_shapes"] = [[-1, *in_shape[1:]] for in_shape in in_shapes]
        hn["output"] = ew_add.hn_element["output"]
        hn["output_shapes"] = ew_add.hn_element["output_shapes"]
        orig_names = hn.get("original_names", []) + ew_add.hn_element.get("original_names", [])
        hn["original_names"] = list(set(orig_names))
        hn["params"]["activation"] = ew_add.act_op.act_name.value
        hn["params"]["elementwise_add"] = True
        hn["params"]["elementwise_add_factor"] = -1 if isinstance(ew_add, HailoElementwiseSub) and not neg_output else 1
        conv_add_layer = self.FUSING_TYPES[type(orig_conv)].from_hn(lname=orig_conv.full_name, hn_element=hn)
        self._set_conv_add_layer(conv_add_layer, orig_conv, ew_add, neg_output)

    def _set_conv_add_layer(
        self,
        conv_add: HailoConvAdd,
        conv: HailoConv,
        ew_add: HailoElementwiseAdd,
        neg_output: bool,
    ):
        ew_add_name_wo_scope = ew_add.full_name.split("/", 1)[-1]
        translation_cfg = self._model_config.translation_config
        self._model.replace_layer(conv_add, conv)

        # Precision Config
        ew_add_cfg = self._model_config.precision_config.layers[ew_add.full_name]
        conv_cfg = self._model_config.precision_config.layers[conv.full_name]
        # new precision mode: (conv-a)_(conv-w)_(ewadd-a)
        conv_add_precision_mode = (
            "_".join(conv_cfg.precision_mode.value.split("_")[:-1])
            + "_"
            + ew_add_cfg.precision_mode.value.split("_")[-1]
        )
        precision_cfg = LayerPrecisionConfig(
            meta=conv_cfg.meta,
            quantization_groups=conv_cfg.quantization_groups,
            precision_mode=conv_add_precision_mode,
            bias_mode=conv_cfg.bias_mode,
        )
        conv_add.import_precision_config(precision_cfg, self.optimization_target)
        self._model_config.precision_config.layers[conv.full_name] = precision_cfg

        cfg_dict = {}
        if ew_add.full_name in translation_cfg.layers:
            cfg_dict = translation_cfg.layers[ew_add.full_name].dict()
        if ew_add_name_wo_scope in translation_cfg.layers:
            cfg_dict.update(translation_cfg.layers[ew_add_name_wo_scope].dict())
        ew_add_trans_cfg = LayerTranslationConfig(**cfg_dict)
        translation_cfg.layers.update({conv.full_name: ew_add_trans_cfg})
        weights = conv.export_weights()
        if neg_output:
            weights = {key: -value for key, value in weights.items()}
        ew_add_weights = ew_add.export_weights()
        if "bias" in ew_add_weights:
            ew_add_weights["bias"] += weights.get("bias", 0)
        weights.update(ew_add_weights)
        conv_add.import_weights(weights)

    def _handle_infusible_ew_add(self, ew_add):
        hn = copy.deepcopy(ew_add.hn_element)
        hn["type"] = LayerType.CONV.value
        params = HailoConvAdd.get_default_params()
        params.update(hn["params"])
        hn["params"] = params
        hn["params"]["kernel_shape"] = [1, 1, ew_add.output_shape[-1], ew_add.output_shape[-1]]
        hn["params"]["padding"] = PaddingType.VALID.value
        hn["params"]["elementwise_add"] = True
        hn["params"]["batch_norm"] = False
        conv_add_layer = HailoConvAdd.from_hn(ew_add.full_name, hn)
        self._model.replace_layer(conv_add_layer, ew_add)
        conv_add_layer.import_precision_config(
            self._model_config.precision_config.layers[ew_add.full_name], self.optimization_target
        )
        self._set_dummy_conv_weights(ew_add, conv_add_layer)

        self._logger.debug(f"Elementwise-add layer {ew_add.full_name} was fused as stand-alone layer.")

    def _set_dummy_conv_weights(self, ew_add, conv_add):
        kernel = np.diag(np.ones(ew_add.output_shape[-1], dtype=np.float32))
        kernel = np.expand_dims(kernel, axis=(0, 1))
        bias = np.zeros(ew_add.output_shape[-1], dtype=np.float32)
        weights = {"kernel": kernel, "bias": bias}
        weights.update(ew_add.export_weights())
        conv_add.import_weights(weights)
