from enum import Enum

import tensorflow as tf


class ParamState(Enum):
    DEFAULT = "default"
    ENFORCED = "enforced"
    VARIABLE = "variable"


class AccelerasParam:
    def __init__(self, default_value, name, op_name):
        self._value = default_value
        self.full_name = name
        self._op_name = op_name

    @property
    def value(self):
        return self._value

    def enforce(self, value):
        self._value = tf.identity(value, name=self.full_name)

    def create(self, value):
        self._value = tf.Variable(value, name=self.full_name, trainable=False)

    def is_variable(self):
        return isinstance(self._value, tf.Variable)

    def is_tensor(self):
        return isinstance(self._value, tf.Tensor)

    def export_variable(self):
        if self.is_variable():
            return {self.full_name: self._value}
        elif self.is_tensor():
            return {}
        else:
            raise ValueError(f"HW params {self.full_name} cannot be exported before update")


class MACShift(AccelerasParam):
    def __init__(self, op_name):
        super().__init__(default_value=0, name="mac_shift", op_name=op_name)


class ShiftDelta(AccelerasParam):
    def __init__(self, op_name):
        super().__init__(default_value=0, name="shift_delta", op_name=op_name)


class OutputFactors(AccelerasParam):
    def __init__(self, op_name):
        super().__init__(default_value=[1], name="output_factors", op_name=op_name)


class InputScale(AccelerasParam):
    def __init__(self, op_name, index):
        super().__init__(default_value=1, name=f"input_scale_{index}", op_name=op_name)


class InputZeroPoint(AccelerasParam):
    def __init__(self, op_name, index):
        super().__init__(default_value=0, name=f"input_zero_point_{index}", op_name=op_name)


class OutputScale(AccelerasParam):
    def __init__(self, op_name, index):
        super().__init__(default_value=1, name=f"output_scale_{index}", op_name=op_name)


class OutputZeroPoint(AccelerasParam):
    def __init__(self, op_name, index):
        super().__init__(default_value=0, name=f"output_zero_point_{index}", op_name=op_name)
