from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.concat_op import ConcatOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.element_wise_add_op import ElementwiseAddDirectOp
from hailo_model_optimization.acceleras.atomic_ops.mock_conv_op import MockConvOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.proposal_generator_op import ProposalGeneratoOp
from hailo_model_optimization.acceleras.atomic_ops.slice_op import SliceOp
from hailo_model_optimization.acceleras.hailo_layers.hailo_bbox_decoder import POINTS_PER_BOX, HailoBBoxDecoder
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    DataPath,
    LayerType,
    OptimizationTarget,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasInitializationError
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import get_kernel_bits_and_sign_by_precision_mode

NMS_FORCED_BITS = 12


class HailoFusedBboxDecoder(HailoBBoxDecoder):
    """
    This class combines HailoBBoxDecoder and ProposalGeneratoOp into one layer. Note
    that the activation in the HailoBBoxDecoder module has been moved to the output
    stage.
    """

    _hn_type = LayerType.FUSED_BBOX_DECODER

    def __init__(
        self,
        name: str,
        num_of_anchors: int = 1,
        num_inputs: int = 2,
        input_division_factor: int = 1,
        x_centers: np.ndarray = None,
        y_centers: np.ndarray = None,
        activation: Union[str, callable, ActivationType] = ActivationType.RELU,
        logger=None,
        **kwargs,
    ):
        if num_inputs != 2:
            raise AccelerasInitializationError(
                f"Number of inputs to bbox decoder {self.full_name} must be 2 but got {num_inputs}",
            )

        self.x_centers = x_centers
        self.y_centers = y_centers
        self._width_scale_factor = None
        self._height_scale_factor = None
        self._width_scale_factor_minus = None
        self._height_scale_factor_minus = None
        if num_of_anchors != 1:
            raise AccelerasInitializationError(
                f"Number of anchors to fused bbox decoder {self.full_name} must be 1 but got {num_of_anchors}",
            )
        self._num_of_anchors = num_of_anchors

        # Creates the ops:
        self._passthru_ops = []
        for input_index in range(num_inputs):
            self._passthru_ops.append(PassthruOp(f"{name}/passthru_op_in_{input_index}", logger=logger))

        # Note: _num_inputs is defined and used only for the parent HailoBBoxDecoder class
        self._num_inputs = 1

        # * The regression input channel
        self._slice_ops = [SliceOp(f"{name}/slice_op", features_slice=(0, POINTS_PER_BOX, 1), logger=logger)]
        self._pre_concat_ops = [ConcatOp(f"{name}/pre_concat_op", concat_elements=1, logger=logger, vector_zp=True)]
        self._conv_ops = [
            ConvStrippedOp(
                f"{name}/conv_op_0",
                kernel_size=(1, 1),
                filters=POINTS_PER_BOX,
                is_depthwise=False,
                stride_align="NW",
                strides=(1, 1),
                groups=1,
                padding="VALID",
                dilation_rate=(1, 1),
                trainable=False,
                vector_zp=True,
                logger=logger,
            ),
        ]
        self._bias_ops = [
            AddBiasOp(f"{name}/bias_op_0", trainable=False, axis=(1, 2, 3), is_correctable=False, logger=logger),
        ]
        self._concat_op = ConcatOp(f"{name}/concat_op", concat_elements=num_of_anchors, logger=logger)
        self._act_op = ActivationOp(f"{name}/act_op", activation, logger=logger)
        self._proposal_generator_op = ProposalGeneratoOp(
            f"{name}/proposal_generator_op",
            proposals_per_output=4,
            input_division_factor=input_division_factor,
            logger=logger,
        )
        # * The classes input channel
        self._mock_op1 = MockConvOp(f"{name}/mock_op1", logger=logger)
        self._mock_op1.kernel = 0.5
        self._mock_op2 = MockConvOp(f"{name}/mock_op2", logger=logger)
        self._mock_op2.kernel = 0.5
        self._ew_add_op = ElementwiseAddDirectOp(f"{name}/elementwise_add_op", logger=logger)

        # calls BaseHailoLayer's init
        super(HailoBBoxDecoder, self).__init__(name=name, logger=logger, **kwargs)

    @property
    def pre_acc_shift(self):
        return self._ew_add_op.pre_acc_shift

    @staticmethod
    def to_single_element(value):
        if isinstance(value, (list, tuple)) and len(value) == 1:
            return value[0]
        elif isinstance(value, np.ndarray) and value.size == 1:
            return value.item()
        else:
            return value

    def import_weights(self, layer_params: LayerParams):
        """
        The input to each conv is shift_left_x, shift_left_y, shift_right_x, shift_right_y
        The output is ymin, xmin, ymax, xmax
        ymin = y_center - (1 / height_scale_factor) * shift_left_y
        xmin = x_center - (1 / width_scale_factor) * shift_left_x
        ymax = y_center + (1 / height_scale_factor) * shift_right_y
        xmax = x_center + (1 / width_scale_factor) * shift_right_x

        """

        self._width_scale_factor = layer_params["width_scale_factor"]
        self._height_scale_factor = layer_params["height_scale_factor"]
        self._width_scale_factor_minus = -1 * self._width_scale_factor
        self._height_scale_factor_minus = -1 * self._height_scale_factor
        self.x_centers = layer_params["x_centers"]
        self.y_centers = layer_params["y_centers"]

        # builds conv kernel for the decoding operation
        kernel_np = np.zeros([POINTS_PER_BOX, POINTS_PER_BOX], dtype=np.float32)
        kernel_np[0, :] = [0, -1 * self.to_single_element(self._height_scale_factor), 0, 0]  # ymin
        kernel_np[1, :] = [-1 * self.to_single_element(self._width_scale_factor), 0, 0, 0]  # xmin
        kernel_np[2, :] = [0, 0, 0, self.to_single_element(self._height_scale_factor)]  # ymax
        kernel_np[3, :] = [0, 0, self.to_single_element(self._width_scale_factor), 0]  # xmax
        kernel_np = np.transpose(kernel_np)
        kernel_np = np.reshape(kernel_np, [1, 1, POINTS_PER_BOX, POINTS_PER_BOX])
        self._conv_ops[0].import_weights(kernel_np, layer_params)

        # load bias params (y_centers, x_centers)
        # The centers are written twice in the NPZ
        x = self.x_centers[:, 0]
        y = self.y_centers[:, 0]
        x = np.reshape(x, [1, len(x)])
        y = np.reshape(y, [len(y), 1])
        x_grid = np.tile(x, y.shape)
        y_grid = np.tile(y, x.shape)
        bias_np = np.zeros(list(x_grid.shape) + [POINTS_PER_BOX])
        bias_np[:, :, 0] = y_grid  # ymin
        bias_np[:, :, 1] = x_grid  # xmin
        bias_np[:, :, 2] = y_grid  # ymax
        bias_np[:, :, 3] = x_grid  # xmax
        self._bias_ops[0].import_weights(bias_np)
        self._act_op.import_weights(layer_params)

    def _export_weights(self):
        weights_dict = dict()
        weights_dict["x_centers"] = self.x_centers
        weights_dict["y_centers"] = self.y_centers
        weights_dict["height_scale_factor"] = self._height_scale_factor
        weights_dict["width_scale_factor"] = self._width_scale_factor
        weights_dict["height_scale_factor_minus"] = self._height_scale_factor_minus
        weights_dict["width_scale_factor_minus"] = self._width_scale_factor_minus
        return weights_dict

    # region Export Hw Params
    def _export_ops_hw_params(self) -> dict:
        #! If this looks smelly and Ugly is because it is !! But it was worse New Legacy
        params = super()._export_ops_hw_params()
        params.update(self._ew_add_op.export_hw_params())

        moc1 = self._mock_op1.export_hw_params()
        params["mock_op1/quant_kernel"] = moc1["kernel"]
        params["mock_op1/kernel_zero_point"] = moc1["zp_kernel"]
        params["mock_op1/mac_shift"] = moc1["output_stage/mult_shift"]

        moc2 = self._mock_op2.export_hw_params()
        params["mock_op2/quant_kernel"] = moc2["kernel"]
        params["mock_op2/kernel_zero_point"] = moc2["zp_kernel"]
        params["mock_op2/mac_shift"] = moc2["output_stage/mult_shift"]

        return params

    def kernels_to_anchors(self, qnpz):
        anchors = dict()
        anchors_height_scale_factor = []
        anchors_width_scale_factor = []
        anchors_height_scale_factor_minus = []
        anchors_width_scale_factor_minus = []
        for i in range(self._num_of_anchors):
            kernel = qnpz[f"kernel_anchor_{i}"].reshape(-1)
            anchors_width_scale_factor_minus.append(kernel[1])
            anchors_height_scale_factor_minus.append(kernel[4])
            anchors_width_scale_factor.append(kernel[11])
            anchors_height_scale_factor.append(kernel[14])

        anchors["anchors_height_scale_factor"] = np.array(anchors_height_scale_factor, dtype=np.int32)
        anchors["anchors_width_scale_factor"] = np.array(anchors_width_scale_factor, dtype=np.int32)
        anchors["anchors_height_scale_factor_minus"] = np.array(anchors_height_scale_factor_minus, dtype=np.int32)
        anchors["anchors_width_scale_factor_minus"] = np.array(anchors_width_scale_factor_minus, dtype=np.int32)
        return anchors

    # endregion

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self.enforce_internal_encoding()
        super().create_hw_params(weights_clipping, optimization_target, hw_shifts=[1])
        # --> kernel_scale = output_scale * 2**-pre_acc_shift / input_scale
        max_mocks_input_scale = max([*self._mock_op1.input_scale, *self._mock_op2.input_scale])
        kernel_scale_without_shift = max(self._concat_op.output_scale.numpy()) / max_mocks_input_scale
        pre_acc_shift = self._conv_ops[0].pre_acc_shift  # Force same pre_acc_shift for the conv_op & mock_op,
        # Divide the input range for the class input to fit into the two mock_ops
        # >>> kernel od both mock_op1 & mock_op2 is 0.5
        self._mock_op1.create_hw_params(factor=kernel_scale_without_shift, pre_acc_shift=pre_acc_shift)
        self._mock_op2.create_hw_params(factor=kernel_scale_without_shift, pre_acc_shift=pre_acc_shift)
        self._propogate_mock_op_encoding()
        self._ew_add_op.create_hw_params(forced_ratio=1.0)
        self._ew_add_op.enforce_encoding()

        # Change activation to CLIP to ensure that the output will not exceed the range [0, 2**12]
        #   for the nms layer (see create_output_encoding_candidates).
        self._act_op.act_name, self._act_op.act_func = self._act_op.create_act_name_and_func(ActivationType.CLIP)
        act_native_params = self._act_op.act_numeric_params
        output_bits = self.output_op.output_lossy_elements[0].bits
        act_native_params["clip_min"] = 0.0
        act_native_params["clip_max"] = np.max(self.output_scale) * (2 ** np.minimum(NMS_FORCED_BITS, output_bits) - 1)
        self._act_op.import_weights(act_native_params)

        # Update hw params for activation_op:
        acc_scale_candidate = self._conv_ops[0].accumulator_scale_candidate[0]
        self._act_op.create_hw_params(acc_scale_candidate, optimization_target)
        self.enforce_internal_encoding()

    def _propogate_mock_op_encoding(self, training=False):
        self._passthru_ops[1].enforce_encoding()
        self._mock_op1.input_scales = [self._passthru_ops[1].output_scale]
        self._mock_op2.input_scales = [self._passthru_ops[1].output_scale]
        self._mock_op1.input_zero_point = [self._passthru_ops[1].output_zero_point]
        self._mock_op2.input_zero_point = [self._passthru_ops[1].output_zero_point]
        self._mock_op1.enforce_encoding(training=training)
        self._mock_op2.enforce_encoding(training=training)
        self._ew_add_op.input_scales[0] = self._mock_op1.output_scale
        self._ew_add_op.input_scales[1] = self._mock_op2.output_scale

    def enforce_internal_encoding(self, training=False, **kwargs):
        super().enforce_internal_encoding(training=training, **kwargs)
        self._propogate_mock_op_encoding()
        self._ew_add_op.enforce_encoding(training=training)

        # Propagate through the proposal generator op part
        self._proposal_generator_op.input_scales[0] = self._concat_op.output_scale  # from concat_op

        proposal_generator_input_scales = tf.cast(tf.reshape(self._ew_add_op.output_scale, [-1]), tf.float32)
        self._proposal_generator_op.input_scales[1] = proposal_generator_input_scales
        self._proposal_generator_op.enforce_encoding()

        # Update the activation op
        self._act_op.input_scale = self._proposal_generator_op.output_scale
        self._act_op.input_zero_points = [self._proposal_generator_op.output_zero_point]
        self._act_op.enforce_encoding(training=training)

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget
    ):
        super().create_quant_element_custom_behavior(precision_config, optimization_target)
        self._mock_op1.create_weight_quant_element()
        self._mock_op2.create_weight_quant_element()

        precision_mode = precision_config.precision_mode
        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode)
        self._ew_add_op.create_weight_quant_element(kernel_bits, signed)

    def is_differentiable(self) -> bool:
        """
        This is a postprocess block, and we don't want to train it.
        """
        return False

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()

        # adds io nodes
        reg_layer = layer_flow.add_input()
        cls_layer = layer_flow.add_input()
        out1 = layer_flow.add_output()

        # Add ops to layer flow
        layer_flow.add_node(self._passthru_ops[0])
        layer_flow.add_node(self._passthru_ops[1])
        layer_flow.add_node(self._slice_ops[0])
        layer_flow.add_node(self._pre_concat_ops[0])
        layer_flow.add_node(self._conv_ops[0])
        layer_flow.add_node(self._bias_ops[0])
        layer_flow.add_node(self._concat_op)
        layer_flow.add_node(self._act_op)
        layer_flow.add_node(self._proposal_generator_op)
        layer_flow.add_node(self._mock_op1)
        layer_flow.add_node(self._mock_op2)
        layer_flow.add_node(self._ew_add_op)

        # Builds bbox decoder flow
        layer_flow.add_edge(reg_layer, self._passthru_ops[0], data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(self._passthru_ops[0], self._slice_ops[0], data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(self._slice_ops[0], self._pre_concat_ops[0], input_index=0, data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(self._pre_concat_ops[0], self._conv_ops[0], DataPath.LAYER_IN)
        layer_flow.add_edge(self._conv_ops[0], self._bias_ops[0], DataPath.ACCUMULATOR)
        layer_flow.add_edge(self._bias_ops[0], self._concat_op, DataPath.ACCUMULATOR, input_index=0)
        layer_flow.add_edge(self._concat_op, self._proposal_generator_op, input_index=0, data_path=DataPath.ACCUMULATOR)

        # Build classes to proposal generator flow
        layer_flow.add_edge(cls_layer, self._passthru_ops[1], data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(self._passthru_ops[1], self._mock_op1, data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(self._passthru_ops[1], self._mock_op2, data_path=DataPath.LAYER_IN)
        layer_flow.add_edge(self._mock_op1, self._ew_add_op, input_index=0, data_path=DataPath.ACCUMULATOR)
        layer_flow.add_edge(self._mock_op2, self._ew_add_op, input_index=1, data_path=DataPath.ACCUMULATOR)
        layer_flow.add_edge(self._ew_add_op, self._proposal_generator_op, input_index=1, data_path=DataPath.ACCUMULATOR)
        layer_flow.add_edge(self._proposal_generator_op, self._act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self._act_op, out1, input_index=0, data_path=DataPath.LAYER_OUT)

        return layer_flow

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        input_division_factor = 1
        if "params" in hn_element:
            if "input_division_factor" in hn_element["params"]:
                input_division_factor = hn_element["params"]["input_division_factor"]
        num_inputs = len(hn_element["input_shapes"])
        layer = cls(
            name=lname,
            num_inputs=num_inputs,
            input_division_factor=input_division_factor,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def create_output_encoding_candidates(self, forced_range=None, translation_config=None):
        """
        Forcing 12 bits on the output scales of this layer leads to 12 bits at the output
        of the model because the following layers (concat & NMS) does not change the
        scales.

        Note that one row for the nms output on the HW is [xmin ymin xmax ymax score].
        Here, the first 4 entries must be 12 bits and the last entry is 16 bits.
        For the last entry, 12 bits are occupied, and the last 4 bits are zeros.
        """
        output_index = 0
        if forced_range is not None:
            min_stats, max_stats = forced_range
        else:
            max_stats = np.max(self.get_output_stats()[output_index].max)
            min_stats = 0.0  # force zero for zp
        range = max_stats - min_stats
        output_bits = self.output_op.output_lossy_elements[0].bits
        output_scale = np.array(range / (2 ** (np.minimum(NMS_FORCED_BITS, output_bits)) - 1))
        if output_scale.shape == ():
            output_channels = self.output_shape[-1]
            output_scale = np.repeat(output_scale, output_channels)
        self._act_op.output_scales[0] = output_scale
        self._act_op.output_zero_points[0] = min_stats / output_scale[0]
