import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv_add import BaseHailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import HailoPrecisionSplitPixels
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerConvA16W4Config,
    LayerEqualizationConfig,
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ZP_LOW_SPLIT_PRECISION_PIXEL,
    BiasMode,
    LayerFeaturePolicy,
    PrecisionMode,
)
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector


class DecomposeConvA16W4(OptimizationAlgorithm):
    """
    DecomposeConvA16W4 decomposes a convolution layer to "simulate" conv with precision a16_w4.
    The algorithm will decompose the conv layer to a precision_splitter layer , conv layer and a depthwise layer.
    The precision_splitter layer will split the input tensor to 2x tensors high and low.
    The Depthwise layer will combine the high and low tensors to the output tensor.

    """

    def __init__(self, model, model_config, logger_level, dataset, logger=None):
        super().__init__(model, model_config, name="Decompose Conv Input", logger_level=logger_level, logger=logger)
        self._unbatched_dataset = dataset

        self._shortcut_added_index = len([lname for lname in self._model.layers.keys() if "/precision_change" in lname])
        self._layers_to_decompose = []
        self._precision_split_layers = []

    def _setup(self):
        return super()._setup()

    def should_skip_algo(self):
        self._layers_to_decompose = []
        layer_cfg = self.get_algo_config().layers
        for lname in self._model.flow.toposort():
            layer = self._model.layers[lname]
            if isinstance(layer, BaseHailoConv) and not isinstance(layer, BaseHailoConvAdd):
                policy = layer_cfg.get(lname, LayerConvA16W4Config.get_default()).policy
                if policy == LayerFeaturePolicy.enabled and layer.conv_op.kernel_size[1] == 1:
                    self._layers_to_decompose.append(lname)
        return len(self._layers_to_decompose) == 0

    def get_algo_config(self):
        return self._model_config.conv_a16_w4

    def log_config(self):
        pass

    def _run_int(self):
        for lname in self._layers_to_decompose:
            self._decompose_layer(lname)

        self._finalize_decompose_params()

    def _change_ew_mult_to_a16(self, lname):
        ew_mult = self._model.layers[lname]

        ew_mult16_hn_elem = ew_mult.to_hn()
        ew_mult16_hn_elem["quantization_params"] = {}
        ew_mult16_hn_elem["params"]["ew_mult_type"] = "on_mac"

        ew_mult16 = HailoElementwiseMultOnMac.from_hn(ew_mult.full_name, ew_mult16_hn_elem)
        ew_mult16.import_weights(ew_mult.export_weights())

        ew_mult16_cfg = ew_mult16.get_default_precision_config()
        self._model_config.precision_config.layers[ew_mult16.full_name] = ew_mult16_cfg
        ew_mult16.import_precision_config(ew_mult16_cfg, self.optimization_target)

        self._model.replace_layer(ew_mult16, ew_mult)

    def _change_output_to_a16(self, edge):
        lname, target = edge
        layer_cfg = self._model_config.precision_config.layers.get(lname, None)
        if layer_cfg is None:
            precision_dict = self._model.layers[lname].get_layer_precision_config()
            if precision_dict is not None:
                layer_cfg = LayerPrecisionConfig(**precision_dict)
            else:
                layer_cfg = self._model.layers[lname].get_default_precision_config()
        layer_cfg.precision_mode = layer_cfg.precision_mode.reduce()
        layer_cfg.precision_mode = PrecisionMode(f"{layer_cfg.precision_mode.name}_a16")

        if (
            layer_cfg.precision_mode in self._model.layers[lname].SUPPORTED_PRECISION_MODE
            and len(self._model.flow.successors_sorted(lname)) == 1
        ):
            self._model_config.precision_config.layers[lname] = layer_cfg
            self._model.layers[lname].import_precision_config(layer_cfg, self.optimization_target)
        else:
            self._add_shortcut_layer(lname, target, PrecisionMode.a8_w8_a16)

    def _decompose_layer(self, lname):
        input_layers = [(pred, lname) for pred in self._model.flow.predecessors_sorted(lname)]
        if isinstance(self._model.layers[input_layers[0][0]], HailoElementwiseMult):
            self._logger.debug(f"change ew_mult {input_layers[0][0]} to a16 on_mac")
            ew_mult = input_layers[0][0]
            input_layers = [(pred, ew_mult) for pred in self._model.flow.predecessors_sorted(ew_mult)]
            self._change_ew_mult_to_a16(ew_mult)

        for inp_name in input_layers:
            self._change_output_to_a16(inp_name)

        scope, short_name = lname.split("/", 1)
        block_name, base_name = self.get_block_and_layer_names(short_name)
        input_shape = [-1, *[int(s) for s in self._model.layers[lname].input_shape[1:]]]
        output_shape = [-1, *[int(s) for s in self._model.layers[lname].output_shape[1:]]]
        new_input_shape = input_shape.copy()
        new_input_shape[-2] *= 2
        new_output_shape = output_shape.copy()
        new_output_shape[-2] *= 2
        shift_high = self.get_algo_config().layers[lname].shift_high == LayerFeaturePolicy.enabled

        self._logger.debug(f"set precision_split_zp {lname}")
        self._model.layers[lname].precision_split_zp = True

        ps = HailoPrecisionSplitPixels.from_hn(
            f"{scope}/{block_name}precision_split_{base_name}",
            {
                "type": "precision_splitter",
                "input": self._model.flow.predecessors_sorted(lname),
                "output": [lname],
                "input_shapes": [input_shape],
                "output_shapes": [new_input_shape],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "precision_split_mode": "pixels",
                },
            },
            self._logger,
        )
        ps_cfg = ps.get_default_precision_config()
        self._model_config.precision_config.layers[ps.full_name] = ps_cfg
        ps.import_precision_config(ps_cfg, self.optimization_target)
        self._model.add_layer(ps, [(self._model.flow.predecessors_sorted(lname)[0], lname)])
        self._precision_split_layers.append(ps.full_name)

        if shift_high:
            self._add_shift_high(
                f"{scope}/{block_name}const_input_{base_name}",
                f"{scope}/{block_name}ew_mult_shift_{base_name}",
                ps.full_name,
                lname,
                new_input_shape,
            )

        # we multiply the bias by 256/257 to accomodate :
        # (conv*high+bias_d) + (conv*low+bias_d)/256 = conv*(high+low/256) + bias_d + bias_d/256 = conv*(high+low/256) + bias_d*257/256
        # bias_d = bias*256/257
        # same calculation for shift_high results in :
        # bias_d = bias*256/129
        bias_factor = (256.0 / 129.0) if shift_high else (256.0 / 257.0)
        self._model.layers[lname].import_native_bias(self._model.layers[lname].export_native_bias() * bias_factor)
        layer_cfg = self._model_config.precision_config.layers.get(lname, None)
        if layer_cfg is None:
            precision_dict = self._model.layers[lname].get_layer_precision_config()
            if precision_dict is not None:
                layer_cfg = LayerPrecisionConfig(**precision_dict)
            else:
                layer_cfg = self._model.layers[lname].get_default_precision_config()
        output_bits = layer_cfg.precision_mode.output_bits()
        layer_cfg.precision_mode = PrecisionMode.a8_w4_a16 if output_bits == 16 else PrecisionMode.a8_w4_a8
        self._model_config.precision_config.layers[lname] = layer_cfg
        self._model.layers[lname].import_precision_config(layer_cfg, self.optimization_target)

        dw = HailoDepthwise.from_hn(
            f"{scope}/{block_name}dw_{base_name}",
            {
                "type": "dw",
                "input": [lname],
                "output": self._model.flow.successors_sorted(lname),
                "input_shapes": [new_output_shape],
                "output_shapes": [output_shape for _ in self._model.flow.successors_sorted(lname)],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "kernel_shape": [1, 2, output_shape[-1], 1],
                    "strides": [1, 1, 2, 1],
                    "dilations": [1, 1, 1, 1],
                    "padding": "VALID",
                    "groups": 1,
                    "disparity": 1,
                    "input_disparity": 1,
                    "batch_norm": False,
                    "elementwise_add": False,
                    "activation": "linear",
                },
            },
            self._logger,
        )
        dw.import_weights(
            {
                "kernel": np.repeat(np.array([[[[1 / 2**8]], [[0.5 if shift_high else 1]]]]), output_shape[-1], axis=2),
                "bias": np.zeros((output_shape[-1],)),
            }
        )

        # force kernel scale to 1/256 to avoid scaling the kernel
        dw.conv_op.kernel_scale_forced_to_save = True
        dw.conv_op.kernel_scale_forced = 1 / 2**8
        dw.forced_output_scale_scalar_dof = 1.0  # Force the output scale to be equal to the input scale
        self._model_config.equalization.layers[dw.full_name] = LayerEqualizationConfig(policy="disabled")

        dw_cfg = dw.get_default_precision_config()
        dw_cfg.precision_mode = PrecisionMode.a16_w16_a16 if output_bits == 16 else PrecisionMode.a8_w8_a8
        self._model_config.precision_config.layers[dw.full_name] = dw_cfg
        dw.import_precision_config(dw_cfg, self.optimization_target)
        self._model.add_layer(dw, [(lname, succ) for succ in self._model.flow.successors_sorted(lname)])

    def _add_shift_high(self, const_name, ew_mult_name, ps_name, conv_name, new_input_shape):
        const = HailoConst.from_hn(
            const_name,
            {
                "type": "const_input",
                "input": [],
                "output": [ew_mult_name],
                "input_shapes": [[-1, 1, new_input_shape[2], 1]],
                "output_shapes": [[-1, 1, new_input_shape[2], 1]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "input_tiles": [
                        [1, 1, 1]
                    ],  # Currently compiler has a bug with input_tiles & input_repeats together
                },
            },
        )
        const.import_weights(
            {"const_data": np.tile(np.array([[[1.0], [2.0]]], dtype=np.float32), (1, new_input_shape[2] // 2, 1))}
        )
        const_cfg = const.get_default_precision_config()
        self._model_config.precision_config.layers[const.full_name] = const_cfg
        const.import_precision_config(const_cfg, self.optimization_target)
        self._model_config.equalization.layers[const.full_name] = LayerEqualizationConfig(policy="disabled")
        const_translation_config = self._model_config.translation_config.layers.get(
            const.full_name,
            LayerTranslationConfig.get_default(),
        )
        if const_translation_config.force_range_out is None:
            const_translation_config.force_range_out = [0.0, 255 / 64]  # Force the quantize data to be [64, 128]
        self._model_config.translation_config.layers[const.full_name] = const_translation_config
        self._model.layers[const.full_name] = const
        self._model.flow.add_node(const.full_name, is_input=False)

        ew_mult = HailoElementwiseMult.from_hn(
            ew_mult_name,
            {
                "type": "ew_mult",
                "input": [ps_name, const.full_name],
                "output": [conv_name],
                "input_shapes": [new_input_shape, [-1, 1, new_input_shape[2], 1]],
                "output_shapes": [new_input_shape],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "activation": "linear",
                    "is_softmax_mask": False,
                    "ew_mult_type": "on_apu",
                    "input_repeats": [[1, 1, 1], [new_input_shape[1], 1, new_input_shape[3]]],
                },
            },
        )
        ew_mult.import_weights(dict())
        ew_mult.mock_kernel_values = [2, 2]
        ew_mult.forced_output_scale_scalar_dof = 64.0  # Force the output scale to be equal to the input scale
        ew_mult.bias_add_op1.merge_residue_into_bias = False
        ew_mult.forced_shift_zp = 2
        ew_mult_cfg = ew_mult.get_default_precision_config()
        self._model_config.precision_config.layers[ew_mult.full_name] = ew_mult_cfg
        ew_mult.import_precision_config(ew_mult_cfg, self.optimization_target)
        self._model.layers[ew_mult.full_name] = ew_mult
        self._model.flow.add_node(ew_mult.full_name, is_input=False)
        self._model.flow.add_edge(ps_name, ew_mult.full_name, input_index=0, output_index=0)
        self._model.flow.add_edge(const.full_name, ew_mult.full_name, input_index=1, output_index=0)
        self._model.flow.add_edge(ew_mult.full_name, conv_name, input_index=0, output_index=0)
        self._model.flow.remove_edge(ps_name, conv_name)

    def _finalize_decompose_params(self):
        for lname in self._precision_split_layers:
            pred = self._model.layers[self._model.flow.predecessors_sorted(lname)[0]]
            pred.output_split_precision_zp = ZP_LOW_SPLIT_PRECISION_PIXEL

        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level,
            logger=self._logger,
        )
        algo.run()

        stats_collector = StatsCollector(
            self._model,
            self._model_config,
            self._logger_level,
            self._unbatched_dataset,
            layers_to_handle=self._precision_split_layers,
            logger=self._logger,
        )
        stats_collector.run()
        for lname in self._precision_split_layers:
            self._model.layers[lname].create_splits()

    def _add_shortcut_layer(self, source, targets, explicit_mode):
        if not isinstance(targets, list):
            targets = [targets]
        shape = list(self._model.layers[source].output_shapes[0])
        shape[0] = -1
        hn = {
            "type": "activation",
            "input": source,
            "output": targets,
            "input_shapes": [shape],
            "output_shapes": [shape],
            "params": {"activation": "linear"},
        }
        # add layer to model
        scope_name, layer_name = source.split("/")
        block_name, _ = self.get_block_and_layer_names(layer_name)
        source_num_outputs = self._model.layers[source].num_outputs
        if source_num_outputs == 1:
            target_groups = [targets]
        else:
            target_groups = [[target] for target in targets]
        for index, target_group in enumerate(target_groups):
            output_index_prefix = f"_{index}" if source_num_outputs > 1 else ""
            shortcut_name = (
                f"{scope_name}/{block_name}precision_change{self._shortcut_added_index}{output_index_prefix}"
            )
            shortcut_layer = HailoStandaloneActivation.from_hn(lname=shortcut_name, hn_element=hn)

            edges = [(source, target) for target in target_group]
            self._model.add_layer(shortcut_layer, edges)
            # add layer to configuration
            precision_cfg = LayerPrecisionConfig(
                precision_mode=explicit_mode,
                bias_mode=BiasMode.single_scale_decomposition,
                quantization_groups=1,
            )
            self._model_config.precision_config.layers[shortcut_name] = precision_cfg
        self._shortcut_added_index += 1

        return shortcut_name

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
