from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp


class PassthruOp(BaseAtomicOp):
    """
    Describes a no-op (aka identity aka passthru aka dummy),
    useful to inject tensor quantization outside of layer context.

    Note:
      1. Quantizes to UINT8 by default ("L3 passthru" if you will), can be changed by passing other quant elements
      2. Make sure to pass the scale & zero_point and to call set_lossy() to actually quantize..

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "name": self.full_name,
                "fully_native": self.fully_native,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        valid_kwargs = {
            "name": config.pop("name"),
            "fully_native": config.pop("fully_native"),
        }
        return cls(**valid_kwargs)

    def _compute_output_shape(self, input_shape):
        return input_shape

    def create_weight_quant_element(self, **kwargs):
        pass

    def call_bit_exact(self, inputs, **kwargs):
        return inputs

    def call_hw_sim(self, inputs, **kwargs):
        return inputs

    def call_native(self, inputs, **kwargs):
        return inputs

    def export_weights(self):
        return dict()

    def create_hw_params(self, *args, **kwargs):
        pass

    def enforce_encoding(self, *args, **kwargs):
        self.forward_encoding()

    def forward_encoding(self):
        self.output_scale = self.input_scales[0]
        self.output_zero_point = self.input_zero_points[0]

    def backward_encoding(self):
        self._tracker.locked = False
        self.input_scale = self.output_scale
        self.input_zero_point = self.output_zero_point
        self._tracker.locked = True

    @property
    def bit_exact_supported(self) -> bool:
        return True

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")
