import math
from copy import deepcopy

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.statistics.statistics_base import TypeStats, scale_stats, update_stats
from hailo_model_optimization.acceleras.utils.acceleras_definitions import MatmulCorrectionType
from hailo_model_optimization.algorithms.matmul_correction.correction_blocks import (
    MMCorrectionBlock2,
    MMCorrectionBlock3,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class FixZpCompEncoding(OptimizationAlgorithm):
    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(model, model_config, name="Fix zp_comp Encoding", logger_level=logger_level, logger=logger)

    def should_skip_algo(self):
        return False

    def finalize_global_cfg(self, algo_config):
        pass

    def get_algo_config(self):
        return self._model_config.matmul_correction

    def _setup(self):
        super()._setup()

    def _get_valid_layer_cfg(self, lname, cfg):
        pass

    def finalize_flat_layers_fields(self, algo_config):
        pass

    def finalize_layer_cfg(self, layers_cfg_dict):
        return dict()

    def _validate_layer_config(self, lname, cfg):
        pass

    def _run_int(self):
        for lname in self._model.flow.toposort():
            layer = self._model.layers[lname]
            if isinstance(layer, HailoMatmul):
                if layer.zp_comp_added and layer.zp_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK:
                    self.add_zp_comp_block(layer)

                elif layer.zp_comp_added and layer.zp_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_2:
                    self.add_zp_comp_block_2(layer)

                elif layer.zp_comp_added and layer.zp_correction_type == MatmulCorrectionType.ZP_COMP_BLOCK_3:
                    self.add_zp_comp_block_3(layer)
                elif layer.zp_comp_added and layer.zp_correction_type == MatmulCorrectionType.ZP_COMP:
                    pass

    def add_zp_comp_block(self, matmul: HailoMatmul):
        self._logger.debug(f"Fixing zp_comp encoding for layer {matmul.full_name}")
        concat_name = self._model.flow.predecessors_sorted(matmul.full_name)[1]
        linear_name, reduce_sum_name = self._model.flow.predecessors_sorted(concat_name)
        weights_layer_name = self._model.flow.predecessors_sorted(linear_name)[0]
        weights_layer = self._model.layers[weights_layer_name]
        reduce_sum = self._model.layers[reduce_sum_name]
        linear = self._model.layers[linear_name]
        concat = self._model.layers[concat_name]

        ## get linear scale candidates
        orig_scale = linear.output_scale[0]
        linear.create_output_encoding_candidates()
        scalar_scale = linear.output_scale[0]

        # set the input and output scale of the linear layer
        linear.set_input_scale(scalar_scale * np.ones_like(linear.input_scale), index=0)
        linear.set_output_scale(scalar_scale * np.ones_like(linear.output_scale), index=0)
        center_value = np.float32(-linear.get_output_lossy_elements()[0].min_value)
        linear.set_input_zero_point(center_value, index=0)

        # set the input and output scale of the Transpose layer
        self.safe_set_output_scales(weights_layer, 0, scalar_scale * np.ones_like(weights_layer.output_scale), matmul)
        self.safe_set_output_zp(weights_layer, 0, center_value, matmul)

        # set the input and output scale of reduce_sum layer
        reduce_sum.set_input_zero_point(center_value, index=0)
        reduce_sum.set_input_scale(scalar_scale * np.ones_like(reduce_sum.input_scale), index=0)
        reduce_sum.set_output_scale(scalar_scale * np.ones_like(reduce_sum.output_scale), index=0)

        # set the input and output scale of the concat layer
        concat.set_input_scale(scalar_scale * np.ones_like(concat.input_scales[0]), index=0)
        concat.set_input_scale(scalar_scale * np.ones_like(concat.input_scales[1]), index=1)
        concat.set_output_scale(scalar_scale * np.ones_like(concat.output_scale), index=0)

        # set the input and output scale of the matmul layer
        matmul.set_input_scale(scalar_scale * np.ones_like(matmul.input_scales[1]), index=1)

        ## update reduce kernal and matmul feed_repeat
        max_output_rs = reduce_sum.get_output_stats()[0].max.max()
        max_output_linear = linear.get_output_stats()[0].max.max()
        min_output_rs = reduce_sum.get_output_stats()[0].min.min()
        min_output_linear = linear.get_output_stats()[0].min.min()

        # calc linear to reduce sum ratio.
        if np.sign(max_output_linear) == np.sign(max_output_rs) and np.sign(min_output_linear) == np.sign(
            min_output_rs,
        ):
            if np.sign(max_output_linear) == 0:
                ratio_max = 1
            else:
                ratio_max = max_output_rs / max_output_linear
            if np.sign(min_output_linear) == 0:
                ratio_min = 1
            else:
                ratio_min = min_output_rs / min_output_linear
            max_ratio = max(ratio_max, ratio_min)
        else:
            self._logger.warning(
                f"layer {matmul.full_name} cannot fix reduce sum equalization, skipping feed repeat update",
            )
            return

        current_feed_repeat = matmul.matmul_op.feed_repeat
        new_feed_repeat = math.ceil(current_feed_repeat * max_ratio)
        feed_repeat_ratio = current_feed_repeat / new_feed_repeat
        reduce_sum.reduce_sum_op.kernel *= feed_repeat_ratio

        matmul.matmul_op.feed_repeat = new_feed_repeat

        self._logger.debug(
            f"layer {matmul.full_name} feed repeat updated from {current_feed_repeat} to {new_feed_repeat}"
            + " -- "
            + f"initial_output_scale : {orig_scale:.3f}, updated_output_scale : {scalar_scale:.3f}"
            + f" -- {center_value=}",
        )

        ## set reduce_sum output and preact stats
        reduce_sum_os = reduce_sum.get_output_stats()[0]
        reduce_sum_pas = reduce_sum.get_preact_stats()[0]
        scale_stats(reduce_sum_os, 1 / feed_repeat_ratio)
        scale_stats(reduce_sum_pas, 1 / feed_repeat_ratio)

        # set concat limvals
        ## concat input stats [1]
        concat_is = concat.get_input_stats()[1]
        update_stats(concat_is, reduce_sum_os.min, TypeStats.MIN, clear_cannot_update=True)
        update_stats(concat_is, reduce_sum_os.max, TypeStats.MAX, clear_cannot_update=True)

        ## concat input stats [0]
        concat_is = concat.get_input_stats()[0]
        linear_os = linear.get_output_stats()[0]
        update_stats(concat_is, linear_os.min, TypeStats.MIN, clear_cannot_update=True)
        update_stats(concat_is, linear_os.max, TypeStats.MAX, clear_cannot_update=True)

        ## concat output stats
        concat_is = concat.get_input_stats()
        concat_os_min, concat_os_max = self.concat_stats(concat_is, matmul.groups)
        concat_os = concat.get_output_stats()[0]
        update_stats(concat_os, concat_os_min, TypeStats.MIN, clear_cannot_update=True)
        update_stats(concat_os, concat_os_max, TypeStats.MAX, clear_cannot_update=True)

        ## set matmul input stats [1]
        matmul_is = matmul.get_input_stats()[1]
        update_stats(matmul_is, concat_os_min, TypeStats.MIN, clear_cannot_update=True)
        update_stats(matmul_is, concat_os_max, TypeStats.MAX, clear_cannot_update=True)

    def add_zp_comp_block_2(self, matmul: HailoMatmul):
        self._logger.debug(f"Fixing zp_comp_2 encoding for layer {matmul.full_name}")

        mmcb = MMCorrectionBlock2(self._model, self._model_config, self._logger, addition=True)
        bn = mmcb.get_block_names(matmul.full_name)
        weights_layer_name = self._model.flow.predecessors_sorted(bn.linear)[0]
        weights_layer = self._model.layers[weights_layer_name]
        scales_low = weights_layer.output_scales[0]
        precision_split = self._model.layers[bn.precision_splitter]

        # Setting Zero point on Transpose layer
        center_value = 2 ** (weights_layer.get_output_lossy_elements()[0].bits - 1)
        self.safe_set_output_zp(weights_layer, 0, center_value, matmul)

        mmcb.enforce_block_encodings(matmul.full_name, scales_low)

        matmul.matmul_op.feed_repeat = [1, precision_split.ratio]

    def add_zp_comp_block_3(self, matmul: HailoMatmul):
        self._logger.debug(f"Fixing zp_comp_3 encoding for layer {matmul.full_name}")
        layers = self._model.layers
        flow = self._model.flow
        mmcb = MMCorrectionBlock3(self._model, self._model_config, self._logger, addition=True)
        bn = mmcb.get_block_names(matmul.full_name)
        weights_layer_name = flow.predecessors_sorted(bn.linear)[0]

        weights_layer = layers[weights_layer_name]
        scales_low = weights_layer.output_scales[0]

        # Setting Zero point on Transpose layer
        center_value = 2 ** (weights_layer.get_output_lossy_elements()[0].bits - 1)
        self.safe_set_output_zp(weights_layer, 0, center_value, matmul)

        mmcb.enforce_block_encodings(matmul.full_name, scales_low)
        ratio = layers[bn.precision_splitter].optimize_ratio()
        matmul.matmul_op.feed_repeat = [1, ratio]

        # Fix Depthwise
        depth = layers[bn.depth_wise]
        stats = depth.export_stats()
        vals_max = 120 * depth.output_scale
        vals_min = -120 * depth.output_scale
        lin_vals = (min(vals_min), max(vals_max))

        stats["stats/input_0/max"] = vals_max
        stats["stats/input_0/min"] = vals_min
        stats["stats/input_0/stats_limvals"] = lin_vals
        stats["stats/preact/max"] = vals_max
        stats["stats/preact/min"] = vals_min
        stats["stats/preact/stats_limvals"] = lin_vals
        stats["stats/output_0/max"] = vals_max
        stats["stats/output_0/min"] = vals_min
        stats["stats/output_0/stats_limvals"] = lin_vals

        depth.import_stats(stats)

    def concat_stats(self, cc_is, gropus):
        ccop = ConcatOp(name="test", concat_elements=2, logger=None, group_sizes=[1] * gropus)
        min0 = cc_is[0].min.copy().reshape(1, 1, 1, -1)
        min1 = cc_is[1].min.copy().reshape(1, 1, 1, -1)
        max0 = cc_is[0].max.copy().reshape(1, 1, 1, -1)
        max1 = cc_is[1].max.copy().reshape(1, 1, 1, -1)
        min_cc = ccop.call_native([min0, min1]).numpy().squeeze()
        max_cc = ccop.call_native([max0, max1]).numpy().squeeze()
        return min_cc, max_cc

    def safe_set_output_scales(self, layer: BaseHailoLayer, index: int, scales, matmul):
        temp_scale = deepcopy(scales)
        layer.set_output_scale(scales, index)
        layer.enforce_io_encoding()
        if not np.allclose(layer.output_scales[index], temp_scale):
            self._logger.warning(
                f"Matmul-correction encountered an unsupported layer: {layer.full_name} "
                f"when adding component {matmul.zp_correction_type}. Please consider using a "
                "different correction_type, or forcing this behivor (at the cost of model accuracy) by using "
                f"pre_quantization_optimization(matmul_correction, layers=[{layer.full_name}], correction_type=zp_comp_none)"
            )

    def safe_set_output_zp(self, layer: BaseHailoLayer, index: int, zp, matmul):
        temp_zp = deepcopy(zp)
        layer.set_output_zero_point(zp, index)
        layer.enforce_io_encoding()
        if not np.allclose(layer.output_zero_points[index], temp_zp):
            self._logger.warning(
                f"Matmul-correction encountered an unsupported layer: {layer.full_name} "
                f"when adding component {matmul.zp_correction_type}. Please consider using a "
                "different correction_type, or forcing this behivor (at the cost of model accuracy) by using "
                f"pre_quantization_optimization(matmul_correction, layers=[{layer.full_name}], correction_type=zp_comp_none)"
            )
