#!/usr/bin/env python

import numpy as np
from past.utils import old_div

from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendQuantizationException
from hailo_sdk_common.logger.logger import default_logger

SHIFT_CALCULATE_BUFFER = 0.3
SHIFT_DELTA_EPSILON = 0.05
SHIFT_DELTA_WARNING_TH = 4


class ShiftsCalculator:
    def __init__(self, supported_shifts, accumulator_width):
        self._logger = default_logger()
        self.shifts = np.array(supported_shifts)
        self.max_shift = max(self.shifts)
        self.accumulator_width = accumulator_width
        self.accumulator_max_value = self._get_accumulator_max_value(self.accumulator_width)

    @staticmethod
    def _get_accumulator_max_value(accumulator_width):
        return 2.0 ** (accumulator_width - 1) - 1

    def calculate_shift(self, accumulator_scale, limvals_pre_act, name, force_shift=None):
        max_distance = old_div(np.max(np.abs(limvals_pre_act)), accumulator_scale)
        accumulator_max_value = self.accumulator_max_value
        if max_distance == 0:
            # note that if the max distance is 0 then we need the minimal shift - 1
            return 1, 0
        calculated_shift = np.log2(old_div(max_distance, accumulator_max_value)) + SHIFT_CALCULATE_BUFFER
        shift_delta = 0
        if force_shift is not None:
            if force_shift not in self.shifts:
                raise BackendQuantizationException(f"Shift by {force_shift} is not supported")
            shift = force_shift
        else:
            available_shifts = [s for s in self.shifts if s >= calculated_shift]
            shift = int(min([*available_shifts, self.max_shift]))
            if not available_shifts:
                shift_delta = SHIFT_DELTA_EPSILON + calculated_shift - self.max_shift
                msg = f"No shifts available for layer {name}, using max shift instead. delta={shift_delta:.04f}"
                if shift_delta >= SHIFT_DELTA_WARNING_TH:
                    self._logger.warning(msg)
                else:
                    self._logger.debug(msg)
        return shift, shift_delta
