import numpy as np


def _add_first_and_end_points(vector):
    # add first point to vector
    first_piece_by_group = np.take(vector, 0, axis=1)
    vector = np.insert(vector, 0, first_piece_by_group, axis=1)

    #  and last point to the vector
    last_piece_by_group = np.take(vector, -1, axis=1)
    return np.insert(vector, -1, last_piece_by_group, axis=1)


def split_to_quantize_groups(vector, base_group_size, offset=0):
    """
    sample from the vector by the quantization groups
    Args:
        vector: the vector to work on (shape : (num_peices, num_channles)
        base_group_size: the number of channels in each group
        add_points: a bool - default True - rather to sadd the last and end point

    Returns

    """
    # TODO: verify that all the values are the same of the is an offset (the case of relu6)
    vector_transposed = np.transpose(vector)  # shape is now (num_channels, num_pieces)
    vector_groupwise = vector_transposed[offset::base_group_size]  # shape is now (num_groups, num_pieces)
    return vector_groupwise


def get_base_group_size(quantization_groups_num, num_of_channels, validate_shapes=False):
    if validate_shapes and quantization_groups_num > 1 and ((num_of_channels / quantization_groups_num) % 8 != 0):
        base_group_size = int(np.ceil(num_of_channels / quantization_groups_num / 8) * 8)
    else:
        base_group_size = int(np.ceil(num_of_channels / quantization_groups_num))
    # base_group_size - the number of channels in each group.
    # NOTE: 1 the last group can have a different number of channels
    # NOTE: 2 if there is one group then the base group =num_of_channels
    return base_group_size


def get_quantization_groups_info(quantization_groups_num, num_of_channels, validate_shapes=False):
    base_group_size = get_base_group_size(quantization_groups_num, num_of_channels, validate_shapes)
    split_points = [i for i in range(0, num_of_channels, base_group_size)] + [num_of_channels]
    return base_group_size, split_points, num_of_channels


def get_split_size(quantization_groups_num, base_group_size, num_of_channels):
    if quantization_groups_num == 1:
        return [num_of_channels]
    return [base_group_size] * (quantization_groups_num - 1) + [
        num_of_channels - base_group_size * (quantization_groups_num - 1),
    ]
