#!/usr/bin/env python
"""
@purpose: This module quantizes both params and inputs. When is is used to quantize
          params, it takes the preprocessed params and the quantization statistics,
          and quantizes those parameters based on the statistics.
"""

import collections
import contextlib

import numpy as np

QuantParamsT = collections.namedtuple("QuantParams", "zero_point scale")
FloatingPointTypes = [float, np.float16, np.float32, np.float64]

with contextlib.suppress(AttributeError):
    FloatingPointTypes.append(np.float128)


def QuantParams(zero_point, scale):
    return QuantParamsT(np.float32(zero_point), np.float32(scale))


def get_quantized_int(array, qp, limvals, sym_flag, bits, quantize_inplace=False):
    if (array.dtype == np.uint8) and (qp.scale == 1) and (qp.zero_point == 0) and not sym_flag:
        # In this case the rest of the quantization won't do anything
        # We just clip the values to limvals
        return np.clip(array, limvals[0], limvals[1], out=(array if quantize_inplace else None))

    # quantize_inplace will only work if the array is of a floating point dtype
    output = array if quantize_inplace and array.dtype in FloatingPointTypes else None

    # Truncate data before quantization to avoid illegal values
    output = np.clip(array, limvals[0], limvals[1], out=output)
    output = np.divide(output, qp.scale, dtype=np.float32, out=output)
    if sym_flag:
        # For weights - return signed (e.g. int8)
        output = np.rint(output, out=output)
        output = np.clip(output, -(2 ** (bits - 1) - 1), 2 ** (bits - 1) - 1, out=output)
    else:
        # For data - return unsigned (e.g. uint8)
        output = np.add(output, qp.zero_point, out=output)
        output = np.rint(output, out=output)
    return output


def quantize_data(data, qp, limvals, sym_flag, bits, quantize_inplace=False):
    quant_params = QuantParams(qp[0], qp[1])
    return get_quantized_int(data, quant_params, limvals, sym_flag, bits, quantize_inplace=quantize_inplace)


def rescale_output(q_data, qp_scale, qp_zp, quantize_inplace=False):
    # quantize_inplace will only work if the array is of a floating point dtype
    output = q_data if quantize_inplace and q_data.dtype in FloatingPointTypes else None

    output = np.subtract(q_data, qp_zp, out=output)
    return np.multiply(output, qp_scale, out=output)
