import operator
from functools import reduce

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp


class FlattenOp(BaseNonArithmeticAtomicOp):
    """
    Given tensor, this operation returns a new tf.Tensor that has the same values as tensor in the same order,
    except with a new shape given by shape.
    The tf.reshape does not change the order of or the total number of elements in the tensor.
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, ndim=4, logger=None, fully_native=None, **kwargs):
        """
        Args:
            reshape_size: Defines the shape of the output tensor.

        """
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._ndim = ndim

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        # Should be replaced with math.prod once we upgrade to python 3.8
        last_dim = reduce(operator.mul, inp.shape[1:], 1)
        shape = [-1, *((self._ndim - 2) * [1]), last_dim]
        op = tf.reshape(inp, shape)
        return op

    def _compute_output_shape(self, input_shape):
        last_dim = reduce(operator.mul, input_shape[1:], 1)
        shape = [input_shape[0], *((self._ndim - 2) * [1]), last_dim]
        return shape

    def export_independent_params(self):
        return {
            "input_shape": self.input_shape[1:],
        }

    def import_independent_params(self, params):
        self._input_shapes = [[None, *params["input_shape"]]]

    def enforce_encoding(self):
        self.output_zero_point = self.input_zero_points[0]
        value = self.input_scales[0]
        if self.input_scales[0].shape == () or len(self.input_shape) == 2:
            self.output_scale = value
        else:
            self.output_scale = tf.tile(value, [tf.reduce_prod(self.input_shape[1:-1])])

    def _get_stats_axes(self, data_shape):
        return np.array([0])

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")
        if len(self.input_shape) == 2:
            enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")
        else:
            enc.callback(
                f"{self.full_name}/output_scale:0",
                f"{self.full_name}/input_scale:0",
                tf.tile,
                callback_name="tf.tile",
                multiples=[tf.reduce_prod(self.input_shape[1:-1])],
                outs_shape=(self.output_shape[-1],),
            )
