from typing import Union

from hailo_model_optimization.acceleras.atomic_ops.element_wise_sub_op import ElementwiseSubDirectOp
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationType, LayerType, OptimizationTarget
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_decomposition_count_by_bias_mode,
    get_kernel_bits_and_sign_by_precision_mode,
)


class HailoElementwiseSub(HailoElementwiseAdd):
    """
    Represents a layer that performs elementwise subtraction between two L3 inputs,
    and a (normally trivialized) activation to arrive at L3 output.

    Example:
            >>> ew_sub_layer = HailoElementwiseSub()
        >>> rand_data1 = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> rand_data2 = tf.random.normal([200, 200], 0, 1, tf.float32)
        >>> ew_sub_result = ew_sub_layer(rand_data1, rand_data2)

    TODO: add HN flag to ew_add to denote subtraction, this will make this file redundant
    TODO: change infer encodings and numerization.

    """

    _hn_type = LayerType.ELEMENTWISE_SUB

    def __init__(
        self,
        name: str,
        activation: Union[str, callable, ActivationType] = "linear",
        input_repeats=None,
        logger=None,
        **kwargs,
    ):
        super().__init__(name=name, activation=activation, input_repeats=input_repeats, logger=logger, **kwargs)
        input_repeats = input_repeats if input_repeats else [[1, 1, 1], [1, 1, 1]]
        self.ew_add_op = ElementwiseSubDirectOp(
            f"{name}/elementwise_sub_op", input_repeats=input_repeats, logger=logger
        )
        self._layer_flow = self._build_flow()  # TODO: a bit hacky

    def create_quant_element_custom_behavior(
        self,
        precision_config: LayerPrecisionConfig,
        optimization_target: OptimizationTarget,
    ):
        bias_mode = precision_config.bias_mode
        precision_mode = precision_config.precision_mode
        quant_groups = precision_config.quantization_groups

        kernel_bits, signed = get_kernel_bits_and_sign_by_precision_mode(precision_mode, force_signed_kernel=True)
        num_decomposition = get_decomposition_count_by_bias_mode(bias_mode)

        self.ew_add_op.create_weight_quant_element(kernel_bits, signed)
        self.bias_op.create_weight_quant_element(kernel_bits, signed, num_decomposition)
        self.act_op.create_weight_quant_element(optimization_target)

        # set quantization groups
        self.act_op.set_quantization_groups(quant_groups)

    def _verify_and_set_hn_io_shapes(self):
        # TODO: there is a bug in the broadcast of const_data
        # https://hailotech.atlassian.net/browse/SDK-39317
        return
