import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.element_wise_add_op import ElementwiseAddDirectOp, ElementwiseAddOp


class ElementwiseSubOp(ElementwiseAddOp):
    """
    The subtraction analogue of elementwise-add part of Conv&Add:
    The 2nd (L3) input is multiplied by a constant, then added to the 1st (L1/L2) input,
    to arrive at a L1/L2 ("accumulator") result.
    TODO: add HN flag to ew_add to denote subtraction, this will make this file redundant
    TODO: change infer encodings and numerization.
    """

    def call_hw_sim(self, inputs, **kwargs):
        accumulator_input, L3_input = inputs
        # NOTE: in the case of ew_sub self._total_factor will be negative
        return accumulator_input + L3_input * self._total_factor

    def call_native(self, inputs, **kwargs):
        return inputs[0] - inputs[1]


class ElementwiseSubDirectOp(ElementwiseAddDirectOp):
    """
    The core part of "standalone elwa-subtract" layer (standalone as in "NOT conv & add").
    Represents an elementwise-subtraction of two L3 inputs,
    after multiplying each by a separate constant, to arrive at a L1/L2 ("accumulator") result
    TODO: add HN flag to ew_add to denote subtraction, this will make this file redundant
    TODO: change infer encodings and numerization.
    """

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._kernel = np.array([[1], [-1]]) * self._kernel

    def call_native(self, inputs, **kwargs):
        self.repeat_inputs(inputs)
        if self.preload_kernel:
            return inputs[0] * self.kernel[0] + inputs[1] * self.kernel[1]
        return inputs[0] - inputs[1]

    def _set_mac_shift(self, hw_shifts=None):
        weight_bits = self.weight_lossy_elements.factor.bits
        if weight_bits == 16:
            self.pre_acc_shift = tf.convert_to_tensor(0)
        else:
            self.pre_acc_shift = tf.convert_to_tensor(1)
        if hw_shifts is not None:
            self.pre_acc_shift = hw_shifts[0]
        self.shift_delta = 0
