from abc import abstractmethod
from typing import List, Tuple, TypeVar

import numpy as np
from pydantic import BaseModel

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.hailo_layers.hailo_concat import HailoConcat
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_depthwise import HailoDepthwise
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import HailoPrecisionSplit
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split_signed import HailoPrecisionSplitSigned
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum import HailoReduceSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_shortcut import HailoShortcut
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerZeroStaticChannelsConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ZP_FEED_REPEAT,
    BiasMode,
    DataPath,
    LayerFeaturePolicy,
    OptimizationTarget,
    PrecisionMode,
    StatsState,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class MMCBlockNames(BaseModel):
    shortcut: str
    reduce_sum: str
    linear: str
    concat: str


class MMCBlockNames2(MMCBlockNames):
    precision_splitter: str


class MMCBlockNames3(MMCBlockNames):
    precision_splitter: str
    pass_thrue: str
    precision_concat: str
    depth_wise: str


MMBLOCK = TypeVar("MMBLOCK", MMCBlockNames, MMCBlockNames2, MMCBlockNames3)


class BaseMMCorrectionBlock:
    BLOCK_NAMES: MMBLOCK

    def __init__(self, model: HailoModel, model_config, logger, addition, initial_feed_repeat=1):
        self._model = model
        self._model_config = model_config
        self._logger = logger
        self._addition = addition
        self._initial_feed_repeat = initial_feed_repeat

    @property
    def optimization_target(self):
        return self._model.optimization_target

    @abstractmethod
    def add_correction_block(self, mm_lname):
        """Here you have to implement the logic of adding a correction block,
        what layers should be add it"""

    @abstractmethod
    def enforce_block_encodings(self, mm_lname: str, scales) -> Tuple[np.ndarray, np.ndarray]:
        """Here you have to implement the logic of enforcing all the layers on the block constrains,
        meaning setting scales and zero points for the block.
        """

    def get_input_name(self, mm_lname: str) -> str:
        block = self.get_block_names(mm_lname)
        input_a = self._model.flow.predecessors_sorted(block.linear)
        input_b = self._model.flow.predecessors_sorted(block.reduce_sum)

        if input_a != input_b or len(input_a) != 1:
            raise AccelerasImplementationError(f"Matmul Correction block for {mm_lname} is not well define")
        return input_b[0]

    def get_sub_flow(self, mm_lname) -> ModelFlow:
        input_name = self.get_input_name(mm_lname)
        return self._model.flow.get_sub_flow(input_name, mm_lname)

    def force_vals_by_edges(self, edges: List[Tuple[str, str]], scales: np.ndarray = None, zp: np.ndarray = None):
        for u, v in edges:
            edge_data = self._model.flow.get_edge_data(u, v)
            source = self._model.layers[u]
            target = self._model.layers[v]

            if scales is not None:
                source.set_output_scale(scales, edge_data["output_index"])
                target.set_input_scale(scales, edge_data["input_index"])

            if zp is not None:
                source.set_output_zero_point(zp, edge_data["output_index"])
                target.set_input_zero_point(zp, edge_data["input_index"])

    def get_block_names(self, mm_lname: str) -> "MMBLOCK":
        mm_scope = "/".join(mm_lname.split("/")[:-1])
        mm_lname = mm_lname.split("/")[-1]
        mm_block_name, mm_lname = OptimizationAlgorithm.get_block_and_layer_names(mm_lname)
        template = f"{mm_scope}/{mm_block_name}_{{}}_{mm_lname}" if mm_block_name else f"{mm_scope}/{{}}_{mm_lname}"
        vals = {key: template.format(key) for key in self.BLOCK_NAMES.model_fields.keys()}
        return self.BLOCK_NAMES(**vals)

    def get_matmul_layers(self, lname):
        mtml = self._model.layers[lname]
        preds = self._model.flow.predecessors_sorted(lname)
        if len(preds) != 2:
            raise ValueError(f"Layer {lname} does not have two predecessors")
        input_w = self._model.layers[preds[1]]
        return mtml, input_w

    def get_weight_group_size(self, matmul, input_w):
        if isinstance(input_w, HailoConv):
            weights = input_w.export_weights()
            group_size = int(weights["kernel"].shape[-1] / (matmul.groups // matmul.input_tiles[1][-1]))
        else:
            raise ValueError(f"Layer {input_w.full_name} is not supported for group size calculation")
        return group_size

    def set_precision_config(self, layer, precision_cfg, use_dfault=False):
        if use_dfault:
            precision_cfg = layer.get_default_precision_config()

        layer.verify_config(precision_cfg)
        layer.import_precision_config(precision_cfg, self.optimization_target)
        self._model_config.precision_config.layers[layer.full_name] = precision_cfg

    def set_layer_io_shapes(self, input_shape, layer):
        flipb = False
        for lw in input_shape:
            if lw[0] == -1:
                flipb = True
                lw[0] = 1
        layer._hn_element["input_shapes"] = input_shape

        if len(layer.output_scales) == 1:
            output_shapes = [list(layer.compute_output_shape(input_shape))]
        else:
            output_shapes = [list(shape) for shape in layer.compute_output_shape(input_shape)]

        if flipb:
            for lw in input_shape:
                if lw[0] == 1:
                    lw[0] = -1
            for lw in output_shapes:
                if lw[0] == 1:
                    lw[0] = -1

        layer._hn_element["input_shapes"] = input_shape
        layer._hn_element["output_shapes"] = output_shapes

    def _add_node(self, new_layer, predecessors, output_index=0):
        self._model._unlock_model()
        self._model.layers[new_layer.full_name] = new_layer
        self._model._lock_model()
        node = new_layer.full_name
        self._model.flow.add_node(node)

        for ind, predecessor in enumerate(predecessors):
            self._model.flow.add_edge(predecessor.full_name, node, input_index=ind, output_index=output_index)

    def _remove_correction_from_stats(self, matmul: HailoMatmul, group_size):
        def _correct_stats(stat):
            return np.concatenate(
                [
                    stat[i * group_size : (i + 1) * group_size][: -matmul.matmul_op.zp_comp_rank]
                    for i in range(matmul.groups // matmul.input_tiles[1][-1])
                ]
            )

        for in_op, in_index in matmul._input_stats_ops():
            if in_index == 1:
                in_stats = in_op.get_input_stats(in_index)
                updated_input = self.update_stats(_correct_stats, in_stats)
                in_op.stats_managers[f"inputs_{in_index}"]._stats = updated_input

    def _remove_correction(self, matmul: HailoMatmul, concat_lname):
        block_names = self.get_block_names(matmul.full_name)
        weights_layer_name = self._model.flow.predecessors_sorted(block_names.linear)[0]
        weights_layer = self._model.layers[weights_layer_name]

        for cont in block_names.dict().values():
            if cont != block_names.shortcut:
                if cont in self._model.layers:
                    self._model.remove_layer(self._model.layers[cont])
                    self._model_config.precision_config.layers.pop(cont)

        # reconnect the weights layer to the matmul
        self._model.flow.add_edge(weights_layer_name, matmul.full_name, input_index=1, output_index=0)
        # update the input shape of the matmul
        matmul._hn_element["input_shapes"][1] = weights_layer._hn_element["output_shapes"][0]

        # update matmul stats
        group_size = int(
            weights_layer._hn_element["output_shapes"][0][-1] / (matmul.groups // matmul.input_tiles[1][-1])
        )
        self._remove_correction_from_stats(matmul, group_size)

    def add_bias_to_matmul(self, mm_lname: str, bias_out: np.array):
        """Add bias to the matmul using zp compzation"""
        lyrs = self._model.layers

        block: MMCBlockNames = self.get_block_names(mm_lname)
        matmul: HailoMatmul = lyrs[mm_lname]

        min_vals, max_vals = matmul.get_group_output_limvals(matmul.groups)[0]

        x_scales, _ = [np.reshape(scale, [matmul.groups, -1])[:, 0] for scale in matmul.input_scales]
        zp = matmul.input_zero_points[0]
        if zp == 0:
            return  # maybe rise ?

        zp_vals = np.full_like(x_scales, zp)

        # There is a chance that zp is a vector
        bias = -1 * bias_out / (x_scales * zp_vals)
        reduce_sum = lyrs[block.reduce_sum]
        reduce_sum.bias_op.import_weights(bias)
        output_stats = matmul.get_output_stats()[0]
        new_min = min_vals + bias_out
        new_max = max_vals + bias_out
        output_stats.min[...] = new_min.repeat(output_stats.min.size / matmul.groups)
        output_stats.max[...] = new_max.repeat(output_stats.min.size / matmul.groups)

    @staticmethod
    def update_stats(correction_func, orig_stats):
        updated_stats = orig_stats._replace(
            min=correction_func(orig_stats.min),
            max=correction_func(orig_stats.max),
            energy=correction_func(orig_stats.energy),
            mean=correction_func(orig_stats.mean),
        )
        return updated_stats

    def _fix_unsigned_data_path(self, layer: BaseHailoLayer):
        if isinstance(layer, BaseHailoSingleAtomic) and isinstance(layer.atomic_op, BaseNonArithmeticAtomicOp):
            sub_flow = self.create_non_arithmecti_sub_flow(layer.full_name, self._model)
            for u_node, v_node, edge_data in sub_flow.edges(data=True):
                parent = self._model.layers[u_node]
                child = self._model.layers[v_node]
                parent.set_output_data_path(DataPath.LAYER_OUT)
                child.set_input_data_path(DataPath.LAYER_IN, edge_data["input_index"])
        else:
            layer.set_output_data_path(DataPath.LAYER_OUT)

    def create_non_arithmecti_sub_flow(self, lname: str, model: HailoModel, *, subflow: ModelFlow = None):
        """Create a sub flow from the given layer name,"""
        if subflow is None:
            subflow = ModelFlow()
        layer = model.layers[lname]
        if isinstance(layer, BaseHailoSingleAtomic) and isinstance(layer.atomic_op, BaseNonArithmeticAtomicOp):
            for parent in model.flow.predecessors_sorted(lname):
                edge = (parent, lname)
                if edge not in subflow.edges:
                    subflow.add_edge(*edge, **model.flow.get_edge_data(*edge))
                    self.create_non_arithmecti_sub_flow(parent, model, subflow=subflow)

            for succ in model.flow.successors_sorted(lname):
                edge = (lname, succ)
                if edge not in subflow.edges:
                    subflow.add_edge(*edge, **model.flow.get_edge_data(*edge))
                    self.create_non_arithmecti_sub_flow(succ, model, subflow=subflow)
        return subflow


# region CorrectionBlock0
class MMCorrectionBlock(BaseMMCorrectionBlock):
    """Matmul Correction block.
    Add online Zp Compensation Channel using one 8 bits value

    ┌────────────────────────────────────────────────────────────────┐
    │                                                                │
    │    ┌──────┐          ┌───────┐                                 │
    │    │      │          │ MatMul│                                 │
    │    │ Inp  ├──────────►       │                                 │
    │    │      │          │       │                                 │
    │    └──────┘       ┌─►└───────┘                                 │
    │                   │                                            │
    │    ───────────────┘                                            │
    │                                                                │
    │         │  │  │                                                │
    │         │  │  │                                                │
    │         ▼  ▼  ▼                                                │
    │            ┌────────┐                                          │
    │            │ Reduce ├────────────┐       ┌────────┐            │
    │ ┌──────┐ ┌►│ Sum    │            │       │        │            │
    │ │      │ │ └────────┘            └───────►Concat  │            │
    │ │ Inp  ├─┘                               │        │            │
    │ │      │                                 │        │   ┌─────┐  │
    │ │      │      ┌───────┐           ┌──────►        ├───►     │  │
    │ │      ├──────► linear│           │      │        │   │MatMul  │
    │ └──────┘      │       ├───────────┘      └────────┘   │     │  │
    │               └───────┘                             ┌─►     │  │
    │                                                     │ └─────┘  │
    │                                                     │          │
    │  ───────────────────────────────────────────────────┘          │
    │                                                                │
    └────────────────────────────────────────────────────────────────┘

    """

    BLOCK_NAMES = MMCBlockNames

    def get_block_names(self, mm_lname: str) -> MMCBlockNames:
        return super().get_block_names(mm_lname)

    def add_correction_block(self, mm_lname):
        matmul, input_w = self.get_matmul_layers(mm_lname)
        bn = self.get_block_names(mm_lname)

        multiple_indices = len(self._model.flow.successors_sorted(input_w.full_name)) > 1
        output_index = self._model.flow.edges[(input_w.full_name, matmul.full_name)]["output_index"]
        input_shape = input_w._hn_element["output_shapes"][output_index]
        # init layers:
        if multiple_indices:
            sc = HailoStandaloneActivation(bn.shortcut)
        rs = HailoReduceSum(bn.reduce_sum, groups=(matmul.groups // matmul.input_tiles[1][-1]), reduce_axes=[3])
        gc = HailoConcat(bn.concat, num_inputs=2, group_sizes=[1] * (matmul.groups // matmul.input_tiles[1][-1]))
        lin = HailoStandaloneActivation(bn.linear, activation="linear")

        # build flow
        self._model.flow.remove_edge(input_w.full_name, matmul.full_name)
        if multiple_indices:
            self._add_node(sc, [input_w], output_index=output_index)
            self._add_node(rs, [sc])
            self._add_node(lin, [sc])
            self.set_layer_io_shapes([input_shape], sc)
            self.set_precision_config(sc, precision_cfg=None, use_dfault=True)
        else:
            self._add_node(rs, [input_w], output_index=output_index)
            self._add_node(lin, [input_w], output_index=output_index)
        self._add_node(gc, [lin, rs])
        self._model.flow.add_edge(gc.full_name, matmul.full_name, input_index=1, output_index=output_index)

        # set shapes

        self.set_layer_io_shapes([input_shape], lin)
        self.set_layer_io_shapes([input_shape], rs)
        self.set_layer_io_shapes([lin._hn_element["output_shapes"][0], rs._hn_element["output_shapes"][0]], gc)
        matmul._hn_element["input_shapes"][1] = gc._hn_element["output_shapes"][0]

        # set presicion
        self.set_precision_config(lin, precision_cfg=None, use_dfault=True)
        self.set_precision_config(gc, precision_cfg=None, use_dfault=True)

        precision_cfg_rs = rs.get_default_precision_config()
        precision_cfg_rs.bias_mode = BiasMode.double_scale_initialization  # double_scale_decomposition
        self.set_precision_config(rs, precision_cfg_rs, use_dfault=False)  # change to double scale decomposition.

        # reset input_w to layer_data
        self._fix_unsigned_data_path(input_w)
        # rs kernel
        rs.reduce_sum_op.kernel = -rs.reduce_sum_op.kernel / self._initial_feed_repeat
        matmul.matmul_op.feed_repeat = self._initial_feed_repeat

    def enforce_block_encodings(self, mm_lname: str, scales) -> Tuple[np.ndarray, np.ndarray]:
        bn = self.get_block_names(mm_lname)
        scales = np.array(scales)

        # Layers
        rs = self._model.layers[bn.reduce_sum]
        li = self._model.layers[bn.linear]
        concat = self._model.layers[bn.concat]
        matmult = self._model.layers[mm_lname]

        zp_128 = np.full_like(rs.input_zero_point, 128)
        zp_0 = np.full_like(zp_128, 0)

        # Seting Input Scales and ZP
        rs.set_input_scale(scales, 0)
        rs.set_input_zero_point(zp_128, 0)
        li.set_input_scale(scales, 0)
        li.set_input_zero_point(zp_128, 0)
        low_scales = scales.reshape([(matmult.groups // matmult.input_tiles[1][-1]), -1]).mean(axis=-1)

        self.force_vals_by_edges([(bn.reduce_sum, bn.concat)], low_scales, zp_0)
        self.force_vals_by_edges([(bn.linear, bn.concat)], scales, zp_0)
        concat.enforce_io_encoding()
        self.force_vals_by_edges([(bn.concat, mm_lname)], concat.output_scale, concat.output_zero_point)

        return concat.output_scale, concat.output_zero_point


# region CorrectionBlock 16 Bits
class MMCorrectionBlock2(BaseMMCorrectionBlock):
    """Matmul Correction block.
    Add online Zp Compensation with 16 bits zp comp

    ┌────────────────────────────────────────────────────────────────────────────────┐
    │                                                                                │
    │       ┌──────┐          ┌───────┐                                              │
    │       │      │          │ MatMul│                                              │
    │       │ Inp  ├──────────►       │                                              │
    │       │      │          │       │                                              │
    │       └──────┘       ┌─►└───────┘                                              │
    │                      │                                                         │
    │       ───────────────┘                                                         │
    │                                                                                │
    │            │  │  │                                                             │
    │            │  │  │                                                             │
    │            ▼  ▼  ▼                                                             │
    │               ┌────────┐  ┌──────────────┐                                     │
    │               │ Reduce ├─►│ Presicion    │  ┌────────┐                         │
    │    ┌──────┐ ┌►│ Sum    │  │ Split        ├─►│        │                         │
    │    │      │ │ └────────┘  │ Signed       │  │Concat  │                         │
    │    │ Inp  ├─┘             └──────────────┘  │        │                         │
    │    │      │                                 │        │   ┌─────┐               │
    │    │      │      ┌───────┐           ┌──────►        ├───►     │               │
    │    │      ├──────► linear│           │      │        │   │MatMul               │
    │    └──────┘      │       ├───────────┘      └────────┘   │     │               │
    │                  └───────┘                             ┌─►     │               │
    │                                                        │ └─────┘               │
    │                                                        │                       │
    │     ───────────────────────────────────────────────────┘                       │
    │                                                                                │
    │                                                                                │
    └────────────────────────────────────────────────────────────────────────────────┘


    """

    BLOCK_NAMES = MMCBlockNames2

    def add_correction_block(self, mm_lname):
        matmul, input_w = self.get_matmul_layers(mm_lname)
        block_names = self.get_block_names(mm_lname)

        multiple_indices = len(self._model.flow.successors_sorted(input_w.full_name)) > 1
        output_index = self._model.flow.edges[(input_w.full_name, matmul.full_name)]["output_index"]
        input_shape = input_w._hn_element["output_shapes"][output_index]
        # init layers:
        if multiple_indices:
            sc = HailoStandaloneActivation(block_names.shortcut)
        rs = HailoReduceSum(
            block_names.reduce_sum, groups=(matmul.groups // matmul.input_tiles[1][-1]), reduce_axes=[3]
        )
        ps_activation = "pwl" if self.optimization_target != OptimizationTarget.SAGE else "linear"
        ps = HailoPrecisionSplitSigned(
            block_names.precision_splitter,
            groups=(matmul.groups // matmul.input_tiles[1][-1]),
            activation=ps_activation,
        )
        gc = HailoConcat(
            block_names.concat, num_inputs=2, group_sizes=[1] * (matmul.groups // matmul.input_tiles[1][-1])
        )
        lin = HailoStandaloneActivation(block_names.linear, activation="linear")
        # init layers:

        # build flow
        self._model.flow.remove_edge(input_w.full_name, matmul.full_name)
        if multiple_indices:
            self._add_node(sc, [input_w], output_index=output_index)
            self._add_node(rs, [sc])
            self._add_node(ps, [rs])
            self._add_node(lin, [sc])
            self.set_layer_io_shapes([input_shape], sc)
            self.set_precision_config(sc, precision_cfg=None, use_dfault=True)
        else:
            self._add_node(rs, [input_w], output_index=output_index)
            self._add_node(ps, [rs], output_index=0)
            self._add_node(lin, [input_w], output_index=output_index)
        self._add_node(gc, [lin, ps])
        self._model.flow.add_edge(gc.full_name, matmul.full_name, input_index=1, output_index=output_index)

        # set shapes

        self.set_layer_io_shapes([input_shape], lin)
        self.set_layer_io_shapes([input_shape], rs)
        self.set_layer_io_shapes([rs._hn_element["output_shapes"][0]], ps)
        self.set_layer_io_shapes([lin._hn_element["output_shapes"][0], ps._hn_element["output_shapes"][0]], gc)
        matmul._hn_element["input_shapes"][1] = gc._hn_element["output_shapes"][0]

        # set presicion
        self.set_precision_config(lin, precision_cfg=None, use_dfault=True)
        self.set_precision_config(gc, precision_cfg=None, use_dfault=True)

        precision_cfg_rs = rs.get_default_precision_config()
        precision_cfg_rs.precision_mode = PrecisionMode.a8_w8_a16  # Set 15 bits output
        precision_cfg_rs.bias_mode = BiasMode.double_scale_initialization  # double_scale_decomposition
        self.set_precision_config(rs, precision_cfg_rs, use_dfault=False)  # change to double scale decomposition.

        self.set_precision_config(ps, precision_cfg=None, use_dfault=True)
        # reset input_w to layer_data
        self._fix_unsigned_data_path(input_w)
        # rs kernel
        if self.optimization_target != OptimizationTarget.SAGE:
            # If we are not in SAGE, we could change rs kernel to -1, and replace the depthwise activation to -x.
            # That way we avoid clipping that might accure when the low section is exactly 128.
            rs.neg_weights()
            ps.act_op.import_weights(
                {"thresholds": np.array([0.0]), "offsets": np.array([0.0, 0.0]), "slopes": np.array([-1.0, -1.0])}
            )
        ps.conv_op.kernel = np.array([-1])
        ps.build([1] + ps.input_shape[1:])

    def enforce_block_encodings(self, mm_lname: str, scales) -> Tuple[np.ndarray, np.ndarray]:
        bn = self.get_block_names(mm_lname)
        scales = np.array(scales)
        rs = self._model.layers[bn.reduce_sum]
        li = self._model.layers[bn.linear]
        ps = self._model.layers[bn.precision_splitter]
        concat = self._model.layers[bn.concat]
        matmult = self._model.layers[mm_lname]

        zp_128 = np.full_like(rs.input_zero_point, 128)
        # 16512 is 2**13 + 2**7 there will be mid range for the zero Points High and Low
        zp_16512 = np.full_like(rs.output_zero_point, 16512)
        zp_0 = np.full_like(zp_128, 0)

        # Seting Input Scales and ZP
        rs.set_input_scale(scales, 0)
        rs.set_input_zero_point(zp_128, 0)
        li.set_input_scale(scales, 0)
        li.set_input_zero_point(zp_128, 0)
        low_scales = scales.reshape([(matmult.groups // matmult.input_tiles[1][-1]), -1]).mean(axis=-1)

        self.force_vals_by_edges([(bn.reduce_sum, bn.precision_splitter)], low_scales, zp_16512)

        ps.optimize_ratio()
        ps.enforce_io_encoding()

        self.force_vals_by_edges([(bn.precision_splitter, bn.concat)], ps.output_scales[0], ps.output_zero_points[0])

        self.force_vals_by_edges([(bn.linear, bn.concat)], scales, zp_0)
        concat.enforce_io_encoding()
        self.force_vals_by_edges([(bn.concat, mm_lname)], concat.output_scale, concat.output_zero_point)

        return concat.output_scale, concat.output_zero_point


# region CorrectionBlock 16 Bits Layers
class MMCorrectionBlock3(BaseMMCorrectionBlock):
    """
    MatMul correction block.
    Adds online Zp compensation to the transpose input
     ────────────────────────────────────────────────────────────────────────────────┐
    │                                                                                │
    │       ┌──────┐          ┌───────┐                                              │
    │       │      │          │ MatMul│                                              │
    │       │ Inp  ├──────────►       │                                              │
    │       │      │          │       │                                              │
    │       └──────┘       ┌─►└───────┘                                              │
    │                      │                                                         │
    │       ───────────────┘                                                         │
    │                                                                                │
    │            │  │  │                                                             │
    │            │  │  │                                                             │
    │            ▼  ▼  ▼                                                             │
    │               ┌────────┐  ┌──────┐ ┌─────┐  ┌───┐                              │
    │               │ Reduce ├─►│ Pres ├─► Pas ├─►│con│ ┌─────┐                      │
    │    ┌──────┐ ┌►│ Sum    │  │ Split│ └─────┘  │   │ │deph │                      │
    │    │      │ │ └────────┘  │      │          │cat├─►wise │                      │
    │    │ Inp  ├─┘             └──────┴─────────►└───┘ │     │   ┌────┐             │
    │    │      │                                       │     ├───►Con │   ┌─────┐   │
    │    └──────┴──┐   ┌───────┐                        └─────┘   │    ├───►     │   │
    │              │   │ linear│                                ┌─►Cat │   │MatMul   │
    │              └───►       ├────────────────────────────────┘ └────┘   │     │   │
    │                  └───────┘                                         ┌─►     │   │
    │                                                                    │ └─────┘   │
    │                                                                    │           │
    │     ───────────────────────────────────────────────────────────────┘           │
    │                                                                                │
    │                                                                                │
    └────────────────────────────────────────────────────────────────────────────────┘


    """

    BLOCK_NAMES = MMCBlockNames3

    def add_correction_block(self, mm_lname):
        matmul, input_w = self.get_matmul_layers(mm_lname)

        multiple_indices = len(self._model.flow.successors_sorted(input_w.full_name)) > 1
        output_index = self._model.flow.edges[(input_w.full_name, matmul.full_name)]["output_index"]
        input_shape = input_w._hn_element["output_shapes"][output_index]
        # init layers:\
        block_names = self.get_block_names(mm_lname)

        if multiple_indices:
            sc = HailoStandaloneActivation(block_names.shortcut)
        rs = HailoReduceSum(
            block_names.reduce_sum, groups=(matmul.groups // matmul.input_tiles[1][-1]), reduce_axes=[3]
        )
        ps = HailoPrecisionSplit(block_names.precision_splitter)
        short = HailoShortcut(block_names.pass_thrue)
        gc0 = HailoConcat(
            block_names.precision_concat,
            num_inputs=2,
            group_sizes=[1] * (matmul.groups // matmul.input_tiles[1][-1]),
        )
        depth_activation = "pwl" if self.optimization_target != OptimizationTarget.SAGE else "linear"
        depth = HailoDepthwise(block_names.depth_wise, kernel_size=[1, 1], activation=depth_activation)

        gc1 = HailoConcat(
            block_names.concat, num_inputs=2, group_sizes=[1] * (matmul.groups // matmul.input_tiles[1][-1])
        )
        lin = HailoStandaloneActivation(block_names.linear, activation="linear")

        # build flow
        self._model.flow.remove_edge(input_w.full_name, matmul.full_name)
        if multiple_indices:
            self._add_node(sc, [input_w], output_index=output_index)
            self._add_node(rs, [sc])

            # Transpose - Reduce Sum Flow
            self._add_node(ps, [rs])
            self._add_node(short, [ps], output_index=1)
            self._add_node(gc0, [ps, short], output_index=0)
            self._add_node(depth, [gc0])

            # Transpose - Linear Flow
            self._add_node(lin, [sc])
            self.set_layer_io_shapes([input_shape], sc)
            self.set_precision_config(sc, precision_cfg=None, use_dfault=True)
        else:
            self._add_node(rs, [input_w], output_index=output_index)

            # Transpose - Reduce Sum Flow
            self._add_node(ps, [rs])
            self._add_node(short, [ps], output_index=1)
            self._add_node(gc0, [ps, short], output_index=0)
            self._add_node(depth, [gc0])

            self._add_node(lin, [input_w], output_index=output_index)

        self._add_node(gc1, [lin, depth])
        self._model.flow.add_edge(gc1.full_name, matmul.full_name, input_index=1, output_index=output_index)

        # set shapes
        self.set_layer_io_shapes([input_shape], lin)
        self.set_layer_io_shapes([input_shape], rs)
        self.set_layer_io_shapes([rs._hn_element["output_shapes"][0]], ps)
        self.set_layer_io_shapes([ps._hn_element["output_shapes"][1]], short)
        self.set_layer_io_shapes([ps._hn_element["output_shapes"][0], short._hn_element["output_shapes"][0]], gc0)
        self.set_layer_io_shapes([gc0._hn_element["output_shapes"][0]], depth)
        self.set_layer_io_shapes([lin._hn_element["output_shapes"][0], depth._hn_element["output_shapes"][0]], gc1)

        matmul._hn_element["input_shapes"][1] = gc1._hn_element["output_shapes"][0]

        # set presicion
        self.set_precision_config(lin, precision_cfg=None, use_dfault=True)

        precision_cfg_rs = rs.get_default_precision_config()

        precision_cfg_rs.precision_mode = PrecisionMode.a8_w8_a16  # Set 15 bits output
        precision_cfg_rs.bias_mode = BiasMode.double_scale_initialization  # double_scale_decomposition
        self.set_precision_config(rs, precision_cfg_rs, use_dfault=False)  # change to double scale decomposition.
        self.set_precision_config(ps, precision_cfg=None, use_dfault=True)
        self.set_precision_config(gc0, precision_cfg=None, use_dfault=True)
        self.set_precision_config(depth, precision_cfg=None, use_dfault=True)
        self.set_precision_config(gc0, precision_cfg=None, use_dfault=True)

        # reset input_w to layer_data
        self._fix_unsigned_data_path(input_w)

        # rs kernel
        if self.optimization_target != OptimizationTarget.SAGE:
            # If we are not in SAGE, we could change rs kernel to -1, and replace the depthwise activation to -x.
            # That way we avoid clipping that might accure when the low section is exactly 128.
            rs.neg_weights()
            depth.act_op.import_weights(
                {"thresholds": np.array([0.0]), "offsets": np.array([0.0, 0.0]), "slopes": np.array([-1.0, -1.0])}
            )
        depth.conv_op.kernel = np.array([-1])
        depth.build([1] + depth.input_shape[1:])
        # Problem Here is that precision splitter will have native 0 on one channel.
        config = LayerZeroStaticChannelsConfig(policy=LayerFeaturePolicy.disabled)
        self._model_config.zero_static_channels.layers[depth.full_name] = config.copy()

    def enforce_block_encodings(self, mm_lname: str, scales) -> Tuple[np.ndarray, np.ndarray]:
        bn = self.get_block_names(mm_lname)
        scales = np.array(scales)
        rs = self._model.layers[bn.reduce_sum]
        li = self._model.layers[bn.linear]
        ps = self._model.layers[bn.precision_splitter]
        psc = self._model.layers[bn.precision_concat]
        concat = self._model.layers[bn.concat]
        matmult = self._model.layers[mm_lname]

        zp_128 = np.full_like(rs.input_zero_point, 128)
        zp_16512 = np.full_like(rs.output_zero_point, 16512)
        zp_0 = np.full_like(zp_128, 0)

        # Seting Input Scales and ZP
        rs.set_input_scale(scales, 0)
        rs.set_input_zero_point(zp_128, 0)
        li.set_input_scale(scales, 0)
        li.set_input_zero_point(zp_128, 0)
        low_scales = scales.reshape([(matmult.groups // matmult.input_tiles[1][-1]), -1]).mean(axis=-1)

        self.force_vals_by_edges([(bn.reduce_sum, bn.precision_splitter)], low_scales, zp_16512)

        ps.enforce_io_encoding()
        ratio = ps.optimize_ratio()  # Probably move this method to another place??

        self.force_vals_by_edges(
            [(bn.precision_splitter, bn.precision_concat)], ps.output_scales[0], ps.output_zero_points[0]
        )
        self.force_vals_by_edges(
            [(bn.precision_splitter, bn.pass_thrue), (bn.pass_thrue, bn.precision_concat)],
            ps.output_scales[1],
            ps.output_zero_points[1],
        )
        psc.enforce_io_encoding()
        self.force_vals_by_edges([(bn.precision_concat, bn.depth_wise)], psc.output_scale, psc.output_zero_point)
        new_out_scale = np.array(psc.output_scale)
        new_out_scale[1::2] = low_scales * ratio

        self.force_vals_by_edges([(bn.depth_wise, bn.concat)], new_out_scale, 0)
        self.force_vals_by_edges([(bn.linear, bn.concat)], scales, zp_0)
        concat.enforce_io_encoding()
        self.force_vals_by_edges([(bn.concat, mm_lname)], concat.output_scale, concat.output_zero_point)
        return concat.output_scale, concat.output_zero_point

    def _remove_correction(self, matmul: HailoMatmul, concat_lname):
        super()._remove_correction(matmul, concat_lname)
        block_names = self.get_block_names(matmul.full_name)
        self._model_config.zero_static_channels.layers.pop(block_names.depth_wise)


# region CorrectionBlock by weights
class MMCorrectionWeights:
    def __init__(self, model, model_config, logger, addition):
        self._model = model
        self._model_config = model_config
        self._logger = logger
        self._addition = addition

    @property
    def optimization_target(self):
        return self._model_config.precision_config.target

    def _add_correction_weights(self, matmul: HailoMatmul, conv_lname):
        """
        Calculating and appending the zp compensation term.
        """
        conv = self._model.layers[conv_lname]
        if conv.zp_comp_add:
            return  # backwards compatibility
        weights = conv.export_weights()
        group_size = int(weights["kernel"].shape[-1] / (matmul.groups // matmul.input_tiles[1][-1]))

        for i in range(matmul.groups // matmul.input_tiles[1][-1]):
            # Slice the kernel and bias corresponding to the current group
            group_kernel = weights["kernel"][:, :, :, i * group_size : (i + 1) * group_size]
            group_bias = weights["bias"][i * group_size : (i + 1) * group_size]

            # Calculate the correction
            feed_repeat = ZP_FEED_REPEAT
            correction_factor = -1.0 * group_kernel.sum(axis=-1, keepdims=True) / feed_repeat
            bias_correction_factor = -1.0 * np.sum(group_bias) / feed_repeat
            bias_correction_factor = np.tile(bias_correction_factor, 1)
            group_kernel = np.concatenate((group_kernel, correction_factor), axis=-1)
            group_bias = np.concatenate((group_bias, bias_correction_factor), axis=-1)

            if i == 0:
                new_kernel, new_bias = group_kernel, group_bias
            else:
                new_kernel = np.concatenate((new_kernel, group_kernel), axis=-1)
                new_bias = np.concatenate((new_bias, group_bias), axis=-1)

        weights.update({"kernel": new_kernel, "bias": new_bias})
        self._replace_conv(weights, matmul, conv_lname)

    def _create_conv(self, orig_conv: HailoConv, weights):
        conv_hn = orig_conv.to_hn()
        conv_hn["params"]["kernel_shape"] = list(weights["kernel"].shape)
        conv_hn["params"]["zp_comp_added"] = True
        conv_hn["output_shapes"] = [[*shape[:-1], weights["kernel"].shape[-1]] for shape in conv_hn["output_shapes"]]
        new_conv = HailoConv.from_hn(orig_conv.full_name, conv_hn)
        new_conv.import_weights(weights)
        new_conv.set_output_data_path(DataPath.LAYER_OUT_WEIGHTS)
        precision_cfg = self._model_config.precision_config.layers[orig_conv.full_name]
        new_conv.verify_config(precision_cfg)
        new_conv.import_precision_config(precision_cfg, self.optimization_target)
        return new_conv

    def _replace_conv(self, weights, matmul: HailoMatmul, conv_lname):
        orig_conv = self._model.layers[conv_lname]
        conv_preds = self._model.flow.predecessors_sorted(orig_conv.full_name)
        new_conv = self._create_conv(orig_conv, weights)
        self._model.remove_layer(orig_conv)
        self._model.add_layer(new_conv, [(pred, matmul.full_name) for pred in conv_preds])
        matmul._hn_element["input_shapes"][1] = orig_conv._hn_element["output_shapes"][0]
        matmul.zp_comp_added = self._addition

        if not self._addition:
            new_conv.build(orig_conv.input_shapes)
            for orig_op, new_op in zip(orig_conv.atomic_ops, new_conv.atomic_ops):
                new_op.stats_managers = orig_op.stats_managers
                new_op.stats_collection_state = StatsState.COMPLETE

    def _remove_correction(self, matmul: HailoMatmul, conv_lname):
        conv = self._model.layers[conv_lname]
        weights = conv.export_weights()
        corrected_size = int(weights["kernel"].shape[-1] / (matmul.groups // matmul.input_tiles[1][-1]))
        # Slicing the last channel of kernel and bias of each group
        kernels = [
            weights["kernel"][..., i * corrected_size : (i + 1) * corrected_size][..., :-1]
            for i in range(matmul.groups // matmul.input_tiles[1][-1])
        ]
        biases = [
            weights["bias"][i * corrected_size : (i + 1) * corrected_size][:-1]
            for i in range(matmul.groups // matmul.input_tiles[1][-1])
        ]
        new_kernel = np.concatenate(kernels, axis=-1)
        new_bias = np.concatenate(biases, axis=-1)
        weights.update({"kernel": new_kernel, "bias": new_bias})
        self._remove_correction_from_stats(conv, matmul, corrected_size)
        self._replace_conv(weights, matmul, conv_lname)

    @staticmethod
    def update_stats(correction_func, orig_stats):
        updated_stats = orig_stats._replace(
            min=correction_func(orig_stats.min),
            max=correction_func(orig_stats.max),
            energy=correction_func(orig_stats.energy),
            mean=correction_func(orig_stats.mean),
        )
        return updated_stats

    def _remove_correction_from_stats(self, conv: HailoConv, matmul: HailoMatmul, group_size):
        def _correct_stats(stat):
            return np.concatenate(
                [
                    stat[i * group_size : (i + 1) * group_size][:-1]
                    for i in range(matmul.groups // matmul.input_tiles[1][-1])
                ]
            )

        for out_op, out_index in conv._output_stats_ops():
            output_stats = out_op.get_output_stats(out_index)
            updated_output = self.update_stats(_correct_stats, output_stats)
            out_op.stats_managers[f"outputs_{out_index}"]._stats = updated_output

        for act_op in conv._iterate_act_ops():
            preact_stats = act_op.get_input_stats(0)
            updated_preact = self.update_stats(_correct_stats, preact_stats)
            act_op.stats_managers["inputs_0"]._stats = updated_preact

        for in_op, in_index in matmul._input_stats_ops():
            if in_index == 1:
                in_stats = in_op.get_input_stats(in_index)
                updated_input = self.update_stats(_correct_stats, in_stats)
                in_op.stats_managers[f"inputs_{in_index}"]._stats = updated_input
