from typing import NamedTuple, Optional

import numpy as np
import tensorflow as tf
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_shuffle import HailoFeatureShuffle
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mars import HailoSoftmaxMars
from hailo_model_optimization.acceleras.hailo_layers.op_factories import gen_acceleras_layers_from_hn
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerEqualizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    FormatConversionType,
    OptimizationTarget,
    PrecisionMode,
    SoftmaxBiasOptimizationAlgorithm,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasResourceError
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.smart_softmax_stats.smart_softmax_stats import SmartSoftmaxStats, SoftmaxBlock
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector


class RoPEBlock(NamedTuple):
    """a generic softmax"""

    conv: str
    feature_shuffle: str
    ew_mult_cos: str
    ew_mult_sin: str
    cos_input: str
    sin_input: str
    ew_add: str


class OptimizeSoftmaxBias(OptimizationAlgorithm):
    """
    This class is responsible for adding a bias to the k layer (secound matmul input) such that the pre-exponent
    activation range will be as small as possible.
    """

    def __init__(self, model, model_config, logger_level, dataset, logger=None):
        super().__init__(model, model_config, name="Optimize Softmax Bias", logger_level=logger_level, logger=logger)
        self.smart_softmax_stats = SmartSoftmaxStats(model, model_config, logger_level, logger)
        self._samples = dict()
        self._unbatched_dataset = dataset

    def _setup(self):
        self.smart_softmax_stats._setup()
        return super()._setup()

    def should_skip_algo(self):
        smart_softmax_stats_policy = self.get_algo_config().policy
        optimize_bias = self.get_algo_config().optimize_bias
        return smart_softmax_stats_policy == ThreeWayPolicy.disabled or optimize_bias == ThreeWayPolicy.disabled

    def get_algo_config(self):
        return self._model_config.smart_softmax_stats

    def log_config(self):
        pass

    def _get_softmax_input_matmul_name(self, lname):
        return self._model.flow.predecessors_sorted(lname)[0]

    def _run_int_mars(self):
        self.softmax_blocks = []
        softmax_iterator = filter(lambda x: isinstance(self._model.layers[x], HailoSoftmaxMars), self._model.layers)
        for lname in softmax_iterator:
            if not self._should_skip_block_mars(lname):
                self._logger.info(f"Optimizing bias for softmax layer {lname}")
                self.add_softmax_bias_mars(lname)

    def add_softmax_bias_mars(self, softmax_name):
        cfg = self.get_algo_config()
        matmul_layer_n = self._get_softmax_input_matmul_name(softmax_name)
        if cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.AC_DC:
            return self._split_ac_dc(matmul_layer_n)
        k_lname = self._get_k_lname(matmul_layer_n)
        k_layer = self._model.layers[k_lname]
        # self._logger.info(
        #     f"Optimizing bias for softmax layer {softmax_name} matmul_layer_n {matmul_layer_n} k_lname:{k_lname}"
        # )
        self._change_bias(k_layer, matmul_layer_n, softmax_name)

    def _should_skip_block_mars(self, softmax_name):
        matmul_layer_n = self._get_softmax_input_matmul_name(softmax_name)
        if not isinstance(self._model.layers[matmul_layer_n], HailoMatmul):
            return True
        k_lname = self._get_k_lname(matmul_layer_n)
        k_layer = self._model.layers[k_lname]
        cfg = self.get_algo_config()
        if cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.AC_DC:
            q_lname = self._get_q_lname(matmul_layer_n)
            return self._get_rope_layers(k_lname) is None or self._get_rope_layers(q_lname) is None
        return not k_layer.is_changing_bias_supported

    def _run_int(self):
        if self.optimization_target != OptimizationTarget.MARS:
            self.smart_softmax_stats.find_and_build_softmax_blocks()

            # Optimize bias for each softmax block
            self._prepare_data()
            for softmax_block in self.smart_softmax_stats.softmax_blocks:
                if not self._should_skip_block(softmax_block):
                    self.add_softmax_bias(softmax_block)
        else:
            # Optimize bias for each softmax block
            self._prepare_data_mars()
            self._run_int_mars()

    def _get_k_lname(self, matmul_layer):
        k_lname = self._model.flow.predecessors_sorted(matmul_layer)[1]
        return k_lname

    def _get_q_lname(self, matmul_layer):
        q_lname = self._model.flow.predecessors_sorted(matmul_layer)[0]
        return q_lname

    def _get_rope_layers(self, output_lname: str) -> Optional[RoPEBlock]:
        ew_add = output_lname
        if not isinstance(self._model.layers[output_lname], HailoElementwiseAdd):
            return None
        ew_mult_cos, ew_mult_sin = self._model.flow.predecessors_sorted(ew_add)
        if not isinstance(self._model.layers[ew_mult_cos], HailoElementwiseMult) or not isinstance(
            self._model.layers[ew_mult_sin], HailoElementwiseMult
        ):
            return None
        conv, cos_input = self._model.flow.predecessors_sorted(ew_mult_cos)
        if (
            not isinstance(self._model.layers[cos_input], HailoInputLayer)
            or self._model.layers[cos_input].conversion_type != FormatConversionType.cos
        ):
            return None
        feature_shuffle, sin_input = self._model.flow.predecessors_sorted(ew_mult_sin)
        if (
            not isinstance(self._model.layers[sin_input], HailoInputLayer)
            or self._model.layers[sin_input].conversion_type != FormatConversionType.sin
        ):
            return None
        if not isinstance(self._model.layers[feature_shuffle], HailoFeatureShuffle):
            return None
        if (
            conv != self._model.flow.predecessors_sorted(feature_shuffle)[0]
            or not self._model.layers[conv].is_changing_bias_supported
        ):
            return None
        return RoPEBlock(
            conv=conv,
            feature_shuffle=feature_shuffle,
            ew_mult_cos=ew_mult_cos,
            ew_mult_sin=ew_mult_sin,
            cos_input=cos_input,
            sin_input=sin_input,
            ew_add=ew_add,
        )

    def _should_skip_block(self, softmax_block: SoftmaxBlock):
        k_lname = self._get_k_lname(softmax_block.matmul)
        k_layer = self._model.layers[k_lname]
        cfg = self.get_algo_config()
        if cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.AC_DC:
            q_lname = self._get_q_lname(softmax_block.matmul)
            return self._get_rope_layers(k_lname) is None or self._get_rope_layers(q_lname) is None
        return not k_layer.is_changing_bias_supported

    def _prepare_data(self):
        cfg = self.get_algo_config()
        if cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.ZERO_MEAN:
            self._collect_k_stats()
        elif cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.MSE:
            self._infer_model_samples()
        elif cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.AC_DC:
            pass
        else:
            raise ValueError(f"Unknown SoftmaxBiasOptimizationAlgorithm {cfg.optimize_bias_algorithm}")

    def _prepare_data_mars(self):
        cfg = self.get_algo_config()
        if cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.ZERO_MEAN:
            self._collect_k_stats_mars()
        elif cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.MSE:
            self._infer_model_samples_mars()
        elif cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.AC_DC:
            pass
        else:
            raise ValueError(f"Unknown SoftmaxBiasOptimizationAlgorithm {cfg.optimize_bias_algorithm}")

    def _get_softmax_from_matmul(self, matmul_layer):
        layer_to_return = None
        for lname in self._model.flow.successors_sorted(matmul_layer):
            if isinstance(self._model.layers[lname], HailoSoftmaxMars):
                layer_to_return = lname
        if layer_to_return is None:
            raise ValueError(f"Could not find softmax layer for matmul layer {matmul_layer}")
        return layer_to_return

    def _reduce_max_calculation(self, matmul_layer, value_to_save, save_internal_list, internal_outputs):
        # the lname that is saved is the matmtul op. we need to get the softmax layer name for some things.
        softmax_name = self._get_softmax_from_matmul(matmul_layer)  # get the softmax layer name
        predeccesors = self._model.flow.predecessors_sorted(softmax_name)
        if len(predeccesors) == 2:
            mask_n = predeccesors[1]
            mask_vals = internal_outputs[save_internal_list.index(mask_n)]
            value_to_save = np.where(mask_vals == 0, -np.inf, value_to_save)

        groups = self._model.layers[softmax_name].groups
        shape = value_to_save.shape
        value_to_save = np.max(value_to_save.reshape(*shape[:-1], groups, -1), axis=-1)
        matmul_layer = softmax_name
        return matmul_layer, value_to_save

    def _infer_model_samples_mars(self):
        layers_to_collect = dict()
        softmax_iterator = filter(lambda x: isinstance(self._model.layers[x], HailoSoftmaxMars), self._model.layers)
        calc_reduce_max = []
        masking = []
        for lname in softmax_iterator:
            if self._should_skip_block_mars(lname):
                continue
            matmul_layer_n = self._get_softmax_input_matmul_name(lname)
            predeccesors = self._model.flow.predecessors_sorted(lname)
            q_lname = self._model.flow.predecessors_sorted(matmul_layer_n)[0]
            out_index = self._model.flow.get_edge_output_index(q_lname, matmul_layer_n)
            out_index = self._model.layers[q_lname].resolve_output_index(out_index)
            layers_to_collect[q_lname] = out_index

            # add all layers that are needed for the reduce_max calculation
            layers_to_collect[matmul_layer_n] = 0
            calc_reduce_max.append(matmul_layer_n)
            if len(predeccesors) == 2:
                mask_n = predeccesors[1]
                layers_to_collect[mask_n] = 0
                masking.append(mask_n)

        if len(layers_to_collect) == 0:
            return

        cfg = self.get_algo_config()
        dataset = self._unbatched_dataset.take(cfg.sample_size).batch(1).map(lambda image, info: image)

        shapes = [(cfg.sample_size,) + shape for shape in self._model.get_input_shapes()]
        self._model.compute_output_shape(shapes)

        for lname, output_index in layers_to_collect.items():
            if lname in calc_reduce_max:
                softmax_name = self._get_softmax_from_matmul(lname)
                groups = self._model.layers[softmax_name].groups
                output_shape = [*self._model.layers[softmax_name].output_shapes[output_index][:-1], groups]
                lname = softmax_name
            else:
                output_shape = self._model.layers[lname].output_shapes[output_index]
            self._samples[lname] = np.empty(output_shape)

        save_internal_list = list(layers_to_collect.keys())
        self._model.set_output_interal_layers(save_internal_list)
        try:
            pbar = tqdm(total=cfg.sample_size, dynamic_ncols=True, unit="entries", desc="Softmax Bias Optimization")
            for i, preprocessed_data in enumerate(dataset):
                _, internal_outputs = self._model.predict_on_batch(preprocessed_data)
                for lname, value in zip(save_internal_list, internal_outputs):
                    if lname in masking:
                        # this is the mask layer. we need only its tensor to calculate the reduce_max
                        continue
                    if self._model.layers[lname].num_outputs == 1:
                        value_to_save = value
                    else:
                        value_to_save = value[layers_to_collect[lname]]
                    if lname in calc_reduce_max:
                        lname, value_to_save = self._reduce_max_calculation(
                            lname, value_to_save, save_internal_list, internal_outputs
                        )
                    self._samples[lname][i] = value_to_save

                pbar.update(1)
            pbar.refresh()  # flush the fine pbar state
            pbar.close()
        except tf.errors.ResourceExhaustedError:
            raise AccelerasResourceError(
                "GPU memory has been exhausted while optimizing softmax bias. Please try useing diffrent optimize bias "
                "algorithm (optimize_bias_algorithm=ZERO_MEAN), disableing this algorithm (optimize_bias=disabled), or "
                "runing on CPU."
            )
        self._model.reset_output_interal_layers()

    def _collect_k_stats_mars(self):
        softmax_iterator = filter(lambda x: isinstance(self._model.layers[x], HailoSoftmaxMars), self._model.layers)
        k_layers = set()
        for lname in softmax_iterator:
            if not self._should_skip_block_mars(lname):
                matmul_layer_n = self._get_softmax_input_matmul_name(lname)
                k_layers.add(self._get_k_lname(matmul_layer_n))

        if len(k_layers) > 0:
            stats_collector = StatsCollector(
                self._model,
                self._model_config,
                self._logger_level,
                self._unbatched_dataset,
                layers_to_handle=k_layers,
                logger=self._logger,
            )
            stats_collector.run()

    def _collect_k_stats(self):
        k_layers = set()
        for softmax_block in self.smart_softmax_stats.softmax_blocks:
            if not self._should_skip_block(softmax_block):
                k_layers.add(self._get_k_lname(softmax_block.matmul))

        if len(k_layers) > 0:
            stats_collector = StatsCollector(
                self._model,
                self._model_config,
                self._logger_level,
                self._unbatched_dataset,
                layers_to_handle=k_layers,
                logger=self._logger,
            )
            stats_collector.run()

    def _infer_model_samples(self):
        layers_to_collect = dict()
        for softmax_block in self.smart_softmax_stats.softmax_blocks:
            if self._should_skip_block(softmax_block):
                continue
            q_lname = self._model.flow.predecessors_sorted(softmax_block.matmul)[0]
            out_index = self._model.flow.get_edge_output_index(q_lname, softmax_block.matmul)
            out_index = self._model.layers[q_lname].resolve_output_index(out_index)
            layers_to_collect[q_lname] = out_index
            layers_to_collect[softmax_block.reduce_max] = 0

        if len(layers_to_collect) == 0:
            return

        cfg = self.get_algo_config()
        dataset = self._unbatched_dataset.take(cfg.sample_size).batch(1).map(lambda image, info: image)

        shapes = [(cfg.sample_size,) + shape for shape in self._model.get_input_shapes()]
        self._model.compute_output_shape(shapes)

        for lname, output_index in layers_to_collect.items():
            self._samples[lname] = np.empty(self._model.layers[lname].output_shapes[output_index])

        save_internal_list = list(layers_to_collect.keys())
        self._model.set_output_interal_layers(save_internal_list)
        try:
            pbar = tqdm(total=cfg.sample_size, dynamic_ncols=True, unit="entries", desc="Softmax Bias Optimization")
            for i, preprocessed_data in enumerate(dataset):
                _, internal_outputs = self._model.predict_on_batch(preprocessed_data)
                for lname, value in zip(save_internal_list, internal_outputs):
                    if self._model.layers[lname].num_outputs == 1:
                        self._samples[lname][i] = value
                    else:
                        self._samples[lname][i] = value[layers_to_collect[lname]]
                pbar.update(1)
            pbar.refresh()  # flush the fine pbar state
            pbar.close()
        except tf.errors.ResourceExhaustedError:
            raise AccelerasResourceError(
                "GPU memory has been exhausted while optimizing softmax bias. Please try useing diffrent optimize bias "
                "algorithm (optimize_bias_algorithm=ZERO_MEAN), disableing this algorithm (optimize_bias=disabled), or "
                "runing on CPU."
            )
        self._model.reset_output_interal_layers()

    def _get_optimal_bias(self, q_output, reduce_max_output, force_zero_centered=True, version=0):
        """
        Compute optimal bias to minimize reduce_max output range.

        version is an argument used for development purposes to test different algorithms, and should not be used in
        production.
        """
        if version == 0:
            q_inv = np.linalg.pinv(q_output)

            if force_zero_centered:
                post_matmul_bias = np.zeros((q_output.shape[0], 1, 1))
            else:
                # Compute the orthogonal projection matrix onto the kernel of q.
                proj = np.expand_dims(np.eye(q_output.shape[1]), 0) - q_output @ q_inv  # (groups, samples, samples)
                proj_sum_inv = np.linalg.pinv(np.sum(proj, axis=-1, keepdims=True))  # (groups, 1, samples)

                post_matmul_bias = proj_sum_inv @ proj @ reduce_max_output  # (groups, 1, 1)

            bias_change = q_inv @ (post_matmul_bias - reduce_max_output)  # (groups, features, 1)
        elif version == 1:
            p = np.linalg.pinv(
                q_output.transpose(0, 2, 1)
                @ np.expand_dims(
                    np.eye(q_output.shape[1]) - np.ones((q_output.shape[1], q_output.shape[1])) / q_output.shape[1], 0
                )
                @ q_output
            )
            q = (
                -q_output.transpose(0, 2, 1)
                @ np.expand_dims(
                    np.eye(q_output.shape[1]) - np.ones((q_output.shape[1], q_output.shape[1])) / q_output.shape[1], 0
                )
                @ reduce_max_output
            )
            bias_change = p @ q
            post_matmul_bias = np.mean(q_output @ bias_change + reduce_max_output, axis=1, keepdims=True)
        elif version == 2:
            p = np.linalg.pinv(
                np.expand_dims(
                    np.eye(q_output.shape[1]) - np.ones((q_output.shape[1], q_output.shape[1])) / q_output.shape[1], 0
                )
                @ q_output
            )
            q = (
                -np.expand_dims(
                    np.eye(q_output.shape[1]) - np.ones((q_output.shape[1], q_output.shape[1])) / q_output.shape[1], 0
                )
                @ reduce_max_output
            )
            bias_change = p @ q
            post_matmul_bias = np.mean(q_output @ bias_change + reduce_max_output, axis=1, keepdims=True)
        elif version == 3:
            p = np.linalg.pinv(q_output)
            q = -reduce_max_output
            bias_change = p @ q
            post_matmul_bias = np.mean(q_output @ bias_change + reduce_max_output, axis=1, keepdims=True)
        else:
            raise ValueError(f"Unknown version {version} for optimize bias MSE algorithm.")

        return bias_change, post_matmul_bias

    def _loss(self, q_output, reduce_max_output, bias_change, post_matmul_bias):
        return np.mean(np.square(q_output @ bias_change + reduce_max_output - post_matmul_bias))

    def add_softmax_bias(self, softmax_block: SoftmaxBlock):
        cfg = self.get_algo_config()
        if cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.AC_DC:
            return self._split_ac_dc(softmax_block.matmul)
        k_lname = self._get_k_lname(softmax_block.matmul)
        k_layer = self._model.layers[k_lname]
        self._change_bias(k_layer, softmax_block.matmul, softmax_block.reduce_max)

    def _change_bias(self, k_layer, matmul_layer_n, reduce_max_n):
        original_bias = k_layer.export_native_bias()

        cfg = self.get_algo_config()
        if cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.ZERO_MEAN:
            bias = original_bias - k_layer.get_output_stats()[0].mean
        elif cfg.optimize_bias_algorithm == SoftmaxBiasOptimizationAlgorithm.MSE:
            # Compute optimal bias to minimize reduce_max output range.
            # That is, find bias* = argmin {||q_output @ bias + reduce_max_output - post_matmul_bias||^2}
            q_tile = self._model.layers[matmul_layer_n].input_tiles[0][-1]
            k_tile = self._model.layers[matmul_layer_n].input_tiles[1][-1]
            groups = self._model.layers[matmul_layer_n].groups
            features = self._model.layers[matmul_layer_n].input_shapes[0][-1] // (groups // q_tile)

            q_lname = self._model.flow.predecessors_sorted(matmul_layer_n)[0]

            # q_output (groups, samples, features)
            q_output = np.repeat(
                self._samples[q_lname]
                .reshape(-1, (groups // q_tile) // k_tile, k_tile, features)
                .transpose(1, 0, 2, 3)
                .reshape((groups // q_tile) // k_tile, -1, features),
                q_tile,
                axis=0,
            )
            # reduce_max_output (groups, samples, 1)
            reduce_max_output = (
                self._samples[reduce_max_n]
                .reshape(-1, groups // k_tile, k_tile, 1)
                .transpose(1, 0, 2, 3)
                .reshape(groups // k_tile, -1, 1)
            )

            bias_change, post_matmul_bias = self._get_optimal_bias(q_output, reduce_max_output, cfg.force_zero_centered)

            self._logger.debug(
                f"Softmax bias optimization Loss: {self._loss(q_output, reduce_max_output, bias_change, post_matmul_bias)}"
            )

            bias = original_bias + bias_change.reshape(-1)
        else:
            raise ValueError(f"Unknown SoftmaxBiasOptimizationAlgorithm {cfg.optimize_bias_algorithm}")

        k_layer.import_native_bias(bias)

    def _get_name(self, lname, unique_name):
        scope, short_name = lname.split("/", 1)
        block_name, base_name = self.get_block_and_layer_names(short_name)
        return f"{scope}/{block_name}{unique_name}_{base_name}"

    def _duplicate_rope(self, rope_block: RoPEBlock, unique_name):
        return RoPEBlock(
            conv=self._get_name(rope_block.conv, unique_name),
            feature_shuffle=self._get_name(rope_block.feature_shuffle, unique_name),
            ew_mult_cos=self._get_name(rope_block.ew_mult_cos, unique_name),
            ew_mult_sin=self._get_name(rope_block.ew_mult_sin, unique_name),
            cos_input=self._get_name(rope_block.cos_input, unique_name),
            sin_input=self._get_name(rope_block.sin_input, unique_name),
            ew_add=self._get_name(rope_block.ew_add, unique_name),
        )

    def _get_post_rope_k_dc(self, k_dc, k_token_counts, k_rope, num_of_channels, k_groups, groups, version=0):
        """
        Compute and optimize the post rope dc component of k.

        version is an argument used for development purposes to test different algorithms, and should not be used in
        production.
        """
        k_position_ids = np.arange(-k_token_counts, 0).reshape(-1, 1, 1, 1) + 1

        k_cos_input = self._model.layers[k_rope.cos_input]
        k_sin_input = self._model.layers[k_rope.sin_input]
        k_cos_theta = k_cos_input.conversion_weights.theta.reshape(1, 1, 2, -1)[..., -num_of_channels:]
        k_sin_theta = k_sin_input.conversion_weights.theta.reshape(1, 1, 2, -1)[..., -num_of_channels:]

        # Apply RoPE with potisions 1-k_token_counts to 0
        post_rope_k_dc = k_dc * np.cos(k_position_ids * k_cos_theta) + (
            k_dc[..., ::-1, :] * np.sin(k_position_ids * k_sin_theta)
        )

        # optimize the post_rope_k_dc
        if version == 0:  # subtruct mean
            post_rope_k_dc -= np.mean(post_rope_k_dc, axis=0, keepdims=True)
        elif version == 1:  # subtruct last
            post_rope_k_dc -= post_rope_k_dc[-1:]
        elif version == 2:  # subtruct midpoint
            post_rope_k_dc -= (
                np.min(post_rope_k_dc, axis=0, keepdims=True) + np.max(post_rope_k_dc, axis=0, keepdims=True)
            ) / 2
        else:
            pass

        # Repeat and reshape to match the expected weights shape
        post_rope_k_dc = np.repeat(post_rope_k_dc.reshape(k_token_counts, k_groups, -1), groups // k_groups, axis=1)
        post_rope_k_dc = post_rope_k_dc.transpose(2, 1, 0).reshape(1, 1, 2 * num_of_channels, -1)
        return post_rope_k_dc

    def _get_const_data_rope_q(self, q_token_counts, q_rope, num_of_channels):
        q_position_ids = np.arange(-q_token_counts, 0).reshape(-1, 1, 1, 1) + 1

        q_cos_input = self._model.layers[q_rope.cos_input]
        q_sin_input = self._model.layers[q_rope.sin_input]
        q_cos_theta = q_cos_input.conversion_weights.theta.reshape(1, 1, 2, -1)[..., -num_of_channels:]
        q_sin_theta = q_sin_input.conversion_weights.theta.reshape(1, 1, 2, -1)[..., -num_of_channels:]

        q_cos_const_data = np.cos(q_position_ids * q_cos_theta).reshape(1, q_token_counts, -1)
        q_sin_const_data = np.sin(q_position_ids * q_sin_theta).reshape(1, q_token_counts, -1)
        return q_cos_const_data, q_sin_const_data

    def _gen_layer(
        self, lname, hn_layer, weights={}, precision_mode=PrecisionMode.a8_w8_a8, disable_equalization=False
    ):
        layer = gen_acceleras_layers_from_hn(lname, hn_layer, self.optimization_target, logger=self._logger)[lname]
        layer.import_weights(weights)
        layer_cfg = layer.get_default_precision_config()
        layer_cfg.precision_mode = precision_mode
        self._model_config.precision_config.layers[lname] = layer_cfg
        layer.import_precision_config(layer_cfg, self.optimization_target)
        self._model.layers[lname] = layer
        self._model.flow.add_node(lname, is_input=False)
        if disable_equalization:
            self._model_config.equalization.layers[lname] = LayerEqualizationConfig(policy="disabled")

    def _split_ac_dc(self, matmul_layer_n):
        num_of_channels = self.get_algo_config().dc_channels
        k_rope = self._get_rope_layers(self._get_k_lname(matmul_layer_n))
        q_rope = self._get_rope_layers(self._get_q_lname(matmul_layer_n))

        matmul_layer = self._model.layers[matmul_layer_n]

        # set the new layers names
        conv_name = self._get_name(matmul_layer_n, "dc")
        const_q_rope = self._duplicate_rope(q_rope, "dc")
        ew_add_name = self._get_name(matmul_layer_n, "ew_add_ac_dc")

        output_16bit = (
            matmul_layer.get_precision_mode().has_output_bits()
            and matmul_layer.get_precision_mode().output_bits() == 16
        )

        # init split parameters
        groups = matmul_layer.groups
        k_groups = groups // matmul_layer.input_tiles[1][-1]
        q_groups = groups // matmul_layer.input_tiles[0][-1]
        if q_groups != groups:
            raise ValueError(
                f"AC_DC optimization is not supported in layer {matmul_layer_n} where number of groups != q_groups ({groups} != {q_groups})"
            )

        k_bias = self._model.layers[k_rope.conv].export_native_bias().reshape(1, k_groups, 2, -1)
        k_dc = k_bias[..., -num_of_channels:].copy()
        k_bias[..., -num_of_channels:] = 0
        self._model.layers[k_rope.conv].import_native_bias(k_bias.reshape(-1))

        k_token_counts = matmul_layer.input_shapes[1][-2]
        q_token_counts = matmul_layer.input_shapes[0][-2]
        q_head_channels = self._model.layers[q_rope.conv].output_shapes[0][-1] // q_groups

        post_rope_k_dc = self._get_post_rope_k_dc(k_dc, k_token_counts, k_rope, num_of_channels, k_groups, groups)
        q_cos_const_data, q_sin_const_data = self._get_const_data_rope_q(q_token_counts, q_rope, num_of_channels)

        # create the new layers
        self._gen_layer(
            const_q_rope.cos_input,
            {
                "type": "const_input",
                "input": [],
                "output": [const_q_rope.ew_mult_cos],
                "input_shapes": [[-1, 1, q_token_counts, q_cos_const_data.shape[-1]]],
                "output_shapes": [[-1, 1, q_token_counts, q_cos_const_data.shape[-1] * q_groups]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "input_tiles": [[1, 1, q_groups]],
                },
            },
            weights={"const_data": q_cos_const_data},
        )
        self._gen_layer(
            const_q_rope.sin_input,
            {
                "type": "const_input",
                "input": [],
                "output": [const_q_rope.ew_mult_sin],
                "input_shapes": [[-1, 1, q_token_counts, q_sin_const_data.shape[-1]]],
                "output_shapes": [[-1, 1, q_token_counts, q_sin_const_data.shape[-1] * q_groups]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "input_tiles": [[1, 1, q_groups]],
                },
            },
            weights={"const_data": q_sin_const_data},
        )

        self._gen_layer(
            const_q_rope.conv,
            {
                "type": "slice",
                "input": [q_rope.conv],
                "output": [const_q_rope.ew_mult_cos, const_q_rope.feature_shuffle],
                "input_shapes": [[-1, 1, q_token_counts, q_head_channels * q_groups]],
                "output_shapes": [
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                ],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "height_slice": [0, 1, 1],
                    "width_slice": [0, q_token_counts, 1],
                    "features_slice": [q_head_channels // 2 - num_of_channels, q_head_channels // 2, 1],
                    "groups": 2 * q_groups,
                },
            },
        )

        self._gen_layer(
            const_q_rope.feature_shuffle,
            {
                "type": "feature_shuffle",
                "input": [const_q_rope.conv],
                "output": [const_q_rope.ew_mult_sin],
                "input_shapes": [[-1, 1, q_token_counts, 2 * num_of_channels * q_groups]],
                "output_shapes": [[-1, 1, q_token_counts, 2 * num_of_channels * q_groups]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {"groups": q_groups, "groups_slice": [num_of_channels, 2 * num_of_channels, 1]},
            },
        )

        self._gen_layer(
            const_q_rope.ew_mult_cos,
            {
                "type": "ew_mult",
                "input": [const_q_rope.conv, const_q_rope.cos_input],
                "output": [const_q_rope.ew_add],
                "input_shapes": [
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                ],
                "output_shapes": [[-1, 1, q_token_counts, 2 * num_of_channels * q_groups]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {"activation": "linear", "is_softmax_mask": False, "ew_mult_type": "on_apu"},
            },
        )
        self._gen_layer(
            const_q_rope.ew_mult_sin,
            {
                "type": "ew_mult",
                "input": [const_q_rope.feature_shuffle, const_q_rope.sin_input],
                "output": [const_q_rope.ew_add],
                "input_shapes": [
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                ],
                "output_shapes": [[-1, 1, q_token_counts, 2 * num_of_channels * q_groups]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {"activation": "linear", "is_softmax_mask": False, "ew_mult_type": "on_apu"},
            },
        )

        self._gen_layer(
            const_q_rope.ew_add,
            {
                "type": "ew_add",
                "input": [const_q_rope.ew_mult_cos, const_q_rope.ew_mult_sin],
                "output": [conv_name],
                "input_shapes": [
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                    [-1, 1, q_token_counts, 2 * num_of_channels * q_groups],
                ],
                "output_shapes": [[-1, 1, q_token_counts, 2 * num_of_channels * q_groups]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {"activation": "linear"},
            },
        )

        self._gen_layer(
            conv_name,
            {
                "type": "conv",
                "input": [const_q_rope.ew_add],
                "output": [ew_add_name],
                "input_shapes": [[-1, 1, q_token_counts, 2 * num_of_channels * q_groups]],
                "output_shapes": [[-1, 1, q_token_counts, k_token_counts * groups]],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {
                    "kernel_shape": [1, 1, 2 * num_of_channels, k_token_counts * groups],
                    "strides": [1, 1, 1, 1],
                    "dilations": [1, 1, 1, 1],
                    "padding": "VALID",
                    "groups": groups,
                    "layer_disparity": 1,
                    "input_disparity": 1,
                    "batch_norm": False,
                    "elementwise_add": False,
                    "activation": "linear",
                },
            },
            weights={"kernel": post_rope_k_dc, "bias": np.zeros(k_token_counts * groups)},
            precision_mode=PrecisionMode.a8_w8_a16 if output_16bit else PrecisionMode.a8_w8_a8,
            disable_equalization=True,
        )

        self._gen_layer(
            ew_add_name,
            {
                "type": "ew_add",
                "input": [matmul_layer_n, conv_name],
                "output": self._model.flow.successors_sorted(matmul_layer_n),
                "input_shapes": [
                    [-1, 1, q_token_counts, k_token_counts * groups],
                    [-1, 1, q_token_counts, k_token_counts * groups],
                ],
                "output_shapes": [
                    [-1, 1, q_token_counts, k_token_counts * groups]
                    for _ in self._model.flow.successors_sorted(matmul_layer_n)
                ],
                "original_names": [],
                "compilation_params": {},
                "quantization_params": {},
                "params": {"activation": "linear"},
            },
            precision_mode=PrecisionMode.a16_w16_a16 if output_16bit else PrecisionMode.a8_w8_a8,
        )

        # update the flow
        self._model.flow.add_edge(
            q_rope.conv,
            const_q_rope.conv,
            input_index=0,
            output_index=len(self._model.flow.successors_sorted(q_rope.conv)),
        )
        self._model.flow.add_edge(const_q_rope.conv, const_q_rope.ew_mult_cos, input_index=0, output_index=0)
        self._model.flow.add_edge(const_q_rope.conv, const_q_rope.feature_shuffle, input_index=0, output_index=1)
        self._model.flow.add_edge(const_q_rope.cos_input, const_q_rope.ew_mult_cos, input_index=1, output_index=0)
        self._model.flow.add_edge(const_q_rope.feature_shuffle, const_q_rope.ew_mult_sin, input_index=0, output_index=0)
        self._model.flow.add_edge(const_q_rope.sin_input, const_q_rope.ew_mult_sin, input_index=1, output_index=0)
        self._model.flow.add_edge(const_q_rope.ew_mult_cos, const_q_rope.ew_add, input_index=0, output_index=0)
        self._model.flow.add_edge(const_q_rope.ew_mult_sin, const_q_rope.ew_add, input_index=1, output_index=0)
        self._model.flow.add_edge(const_q_rope.ew_add, conv_name, input_index=0, output_index=0)
        self._model.flow.add_edge(conv_name, ew_add_name, input_index=1, output_index=0)
        output_names = list(self._model.flow.successors_sorted(matmul_layer_n))
        for outp in output_names:
            succ_input_index = self._model.flow.get_edge_input_index(matmul_layer_n, outp)
            pred_output_index = self._model.flow.get_edge_output_index(matmul_layer_n, outp)
            self._model.flow.add_edge(ew_add_name, outp, input_index=succ_input_index, output_index=pred_output_index)
            self._model.flow.remove_edge(matmul_layer_n, outp)
        self._model.flow.add_edge(matmul_layer_n, ew_add_name, input_index=0, output_index=0)

        # make sure the model will be recompiled
        self._model.built = False
        self._model.predict_function = None
        self._model.train_function = None
        self._model.test_function = None

    def finalize_global_cfg(self, algo_config):
        self.check_dataset_length(algo_config, "sample_size", self._unbatched_dataset)
