from functools import wraps

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.encoding.encoding_flow import EncodingFlowGraph


def arg_func_generator(arg):
    return lambda: arg


def const_input_support(foo):
    """
    A wrapper for function that replace its constant arguments with a dummy encoding node.

    Args:
        foo: The function to wrap.

    """

    @wraps(foo)
    def wrap(self, *args, temp_nodes=None, **kwargs):
        new_args = []
        if temp_nodes is None:
            temp_nodes = [False] * len(args)
        for i, arg in enumerate(args):
            if isinstance(arg, str):
                new_args.append(arg)
            else:
                dummy_arg = self.dummy()
                shape = arg.shape if hasattr(arg, "shape") else ()
                arg_string = f"{arg}" if np.prod(shape) <= 4 else f"outside_var({shape})"
                self._flow.get_encoding(dummy_arg).shape = shape
                self._flow.get_encoding(dummy_arg).scalar = False
                self._flow.add_constraint([], dummy_arg, arg_func_generator(arg), func_string=arg_string)
                new_args.append(dummy_arg)
                temp_nodes[i] = True
        return foo(self, *new_args, temp_nodes=temp_nodes, **kwargs)

    return wrap


def get_length(inp):
    if isinstance(inp, tf.Tensor):
        # Use tf.shape to get a dynamic length (this returns a tensor)
        static_shape = inp.shape.as_list()
        if static_shape and static_shape[0] is not None:
            return static_shape[0]
        else:
            return tf.shape(inp)[0]
    else:
        return len(inp)


class EncodingSubOp:
    def __init__(self, flow: EncodingFlowGraph):
        self._flow = flow
        self._dummies = dict()

    def dummy(self, id=None):
        """
        Create a dummy encoding with unique id.
        If id already exist, return the dummy encoding associated with that id.
        If id is None, created a new dummy encoding anyway.

        Args:
            id (Any, optional): A unique id for the dummy encoding. Defaults to None.

        Returns:
            str: created dummy encoding.

        """
        if id is None:
            return self._flow.add_dummy_encoding()
        if id not in self._dummies.keys():
            self._dummies[id] = self._flow.add_dummy_encoding()
        return self._dummies[id]

    @const_input_support
    def identity(self, *args, temp_nodes=None):
        """
        Add constraint s.t node_0 = node_1
        """
        if temp_nodes is None:
            temp_nodes = [False] * len(args)
        node_0, node_1 = args

        def state_function(inputs_enc, outputs_enc):
            inp = inputs_enc[0]
            out = outputs_enc[0]
            changed = False
            if out.shape is None and inp.shape is not None:
                changed = True
                out.shape = inp.shape
            if (out.scalar is None and inp.scalar is not None) or (inp.scalar and not out.scalar):
                changed = True
                out.scalar = inp.scalar
            return [changed]

        def status_function(inputs_enc, outputs_enc):
            return inputs_enc[0].shape == outputs_enc[0].shape

        if not temp_nodes[0]:
            self._flow.add_constraint(
                [node_1],
                node_0,
                lambda x: x,
                state_function=state_function,
                status_function=status_function,
                func_string="{0}",
            )
        if not temp_nodes[1]:
            self._flow.add_constraint(
                [node_0],
                node_1,
                lambda x: x,
                state_function=state_function,
                status_function=status_function,
                func_string="{0}",
            )

    @const_input_support
    def shift(self, *args, temp_nodes=None):
        """
        Add constraint s.t node_0 = node_1 * (2 ** node_2)
        """
        if temp_nodes is None:
            temp_nodes = [False] * len(args)
        node_0, node_1, shift = args

        def state_function(inputs_enc, outputs_enc):
            inp = inputs_enc[0]
            out = outputs_enc[0]
            changed = False
            if out.shape is None and inp.shape is not None:
                changed = True
                out.shape = inp.shape
            if (out.scalar is None and inp.scalar is not None) or (inp.scalar and not out.scalar):
                changed = True
                out.scalar = inp.scalar
            return [changed]

        def status_function(inputs_enc, outputs_enc):
            return inputs_enc[0].shape == outputs_enc[0].shape

        if not temp_nodes[0]:
            self._flow.add_constraint(
                [node_1, shift],
                node_0,
                lambda x, y: x * tf.cast(2**y, dtype=x.dtype),
                state_function=state_function,
                status_function=status_function,
                func_string="{0} * (2 ** {1})",
            )
        if not temp_nodes[1]:
            self._flow.add_constraint(
                [node_0, shift],
                node_1,
                lambda x, y: x / tf.cast(2**y, dtype=x.dtype),
                state_function=state_function,
                status_function=status_function,
                func_string="{0} / (2 ** {1})",
            )

    @const_input_support
    def _mul(self, *args, temp_nodes=None):
        """
        Add constraint s.t node_0 = node_1 * node_2
        """
        if temp_nodes is None:
            temp_nodes = [False] * len(args)
        node_0, node_1, node_2 = args

        def state_function(inputs_enc, outputs_enc):
            inp1, inp2 = inputs_enc
            out = outputs_enc[0]
            changed = False
            if out.shape is None and inp1.shape is not None and inp2.shape is not None:
                changed = True
                out.shape = np.broadcast_shapes(inp1.shape, inp2.shape)
            if out.scalar is None and inp1.scalar is not None and inp2.scalar is not None:
                changed = True
                out.scalar = inp1.scalar and inp2.scalar
            if (
                inp1.scalar is not None
                and inp2.scalar is not None
                and inp1.shape is not None
                and inp2.shape is not None
            ):
                if (not out.scalar) and (
                    (inp1.scalar and inp2.scalar)
                    or (inp1.scalar and np.prod(inp2.shape) == 1)
                    or (np.prod(inp1.shape) == 1 and inp2.scalar == 1)
                ):
                    changed = True
                    out.scalar = True
            return [changed]

        def status_function(inputs_enc, outputs_enc):
            return np.broadcast_shapes(inputs_enc[0].shape, inputs_enc[1].shape) == outputs_enc[0].shape

        def safe_divide(x, y):
            if tf.is_tensor(x) or tf.is_tensor(y):
                if not tf.is_tensor(x):
                    x = tf.convert_to_tensor(x, dtype=y.dtype)
                if not tf.is_tensor(y):
                    y = tf.convert_to_tensor(y, dtype=x.dtype)
            return tf.divide(x, y)

        def safe_multiply(x, y):
            if tf.is_tensor(x) or tf.is_tensor(y):
                if not tf.is_tensor(x):
                    x = tf.convert_to_tensor(x, dtype=y.dtype)
                if not tf.is_tensor(y):
                    y = tf.convert_to_tensor(y, dtype=x.dtype)
            return tf.multiply(x, y)

        if not temp_nodes[0]:
            self._flow.add_constraint(
                [node_1, node_2],
                node_0,
                safe_multiply,
                state_function=state_function,
                status_function=status_function,
                func_string="{0} * {1}",
            )
        if not temp_nodes[1]:
            self._flow.add_constraint(
                [node_0, node_2],
                node_1,
                safe_divide,
                state_function=state_function,
                status_function=status_function,
                func_string="{0} / {1}",
            )
        if not temp_nodes[2]:
            self._flow.add_constraint(
                [node_0, node_1],
                node_2,
                safe_divide,
                state_function=state_function,
                status_function=status_function,
                func_string="{0} / {1}",
            )

    def mul(self, *args, inverse=False):
        """
        Add constraint s.t node_0 = node_1 * node_2

        Set inverse to True if it is possible to calculate the inverse function of this constraint (a.k.a both node_1
        and node_2 != 0). By default inverse is False.
        """
        temp_nodes = [False, False, False] if inverse else [False, True, True]
        return self._mul(*args, temp_nodes=temp_nodes)

    def div(self, *args, inverse=False):
        """
        Add constraint s.t node_0 = node_1 / node_2

        This assume that always node_2 != 0. If this assumption doesn't hold, use `self.mul` instead.

        Set inverse to True if it is possible to calculate the inverse function of this constraint (a.k.a node_0 != 0).
        By default inverse is False.
        """
        temp_nodes = [False, False, False] if inverse else [False, False, True]
        return self._mul(args[1], args[0], args[2], temp_nodes=temp_nodes)

    @const_input_support
    def add(self, *args, temp_nodes=None):
        """
        Add constraint s.t node_0 = node_1 + node_2
        """
        if temp_nodes is None:
            temp_nodes = [False] * len(args)
        node_0, node_1, node_2 = args

        def state_function(inputs_enc, outputs_enc):
            inp1, inp2 = inputs_enc
            out = outputs_enc[0]
            changed = False
            if out.shape is None and inp1.shape is not None and inp2.shape is not None:
                changed = True
                out.shape = np.broadcast_shapes(inp1.shape, inp2.shape)
            if out.scalar is None and inp1.scalar is not None and inp2.scalar is not None:
                changed = True
                out.scalar = inp1.scalar and inp2.scalar
            if (
                inp1.scalar is not None
                and inp2.scalar is not None
                and inp1.shape is not None
                and inp2.shape is not None
            ):
                if (not out.scalar) and (
                    (inp1.scalar and inp2.scalar)
                    or (inp1.scalar and np.prod(inp2.shape) == 1)
                    or (np.prod(inp1.shape) == 1 and inp2.scalar == 1)
                ):
                    changed = True
                    out.scalar = True
            return [changed]

        def status_function(inputs_enc, outputs_enc):
            return np.broadcast_shapes(inputs_enc[0].shape, inputs_enc[1].shape) == outputs_enc[0].shape

        if not temp_nodes[0]:
            self._flow.add_constraint(
                [node_1, node_2],
                node_0,
                lambda x, y: x + y,
                state_function=state_function,
                status_function=status_function,
                func_string="{0} + {1}",
            )
        if not temp_nodes[1]:
            self._flow.add_constraint(
                [node_0, node_2],
                node_1,
                lambda x, y: x - y,
                state_function=state_function,
                status_function=status_function,
                func_string="{0} - {1}",
            )
        if not temp_nodes[2]:
            self._flow.add_constraint(
                [node_0, node_1],
                node_2,
                lambda x, y: x - y,
                state_function=state_function,
                status_function=status_function,
                func_string="{0} - {1}",
            )

    def sub(self, *args):
        """
        Add constraint s.t node_0 = node_1 - node_2
        """
        return self.add(args[1], args[0], args[2])

    def inv(self, *args):
        """
        Add constraint s.t node_0 = 1. / node_1
        """
        return self.div(args[0], 1.0, args[1], inverse=True)

    def callback(
        self,
        out_nodes,
        inp_nodes,
        callback,
        callback_name="callback",
        outs_scalar=None,
        outs_shape=None,
        *args,
        **kwargs,
    ):
        """
        Add constraint s.t node_0 = callback(*inp_nodes)
        """
        inp_nodes = inp_nodes if isinstance(inp_nodes, list) else [inp_nodes]
        out_nodes = out_nodes if isinstance(out_nodes, list) else [out_nodes]
        outs_scalar = outs_scalar if outs_scalar is not None else [None] * len(out_nodes)
        outs_shape = outs_shape if outs_shape is not None else [None] * len(out_nodes)
        outs_scalar = outs_scalar if isinstance(outs_scalar, list) else [outs_scalar]
        outs_shape = outs_shape if isinstance(outs_shape, list) else [outs_shape]

        def state_function(inputs_enc, outputs_enc):
            changed = [False] * len(outputs_enc)
            for i, out in enumerate(outputs_enc):
                if outs_shape[i] is not None and out.shape != outs_shape[i]:
                    changed[i] = True
                    out.shape = outs_shape[i]
                if (
                    outs_shape[i] is None
                    and out.shape is None
                    and all(inp.shape is not None for inp in inputs_enc)
                    and out.shape != inputs_enc[0].shape
                ):
                    changed[i] = True
                    out.shape = inputs_enc[0].shape
                if (not out.scalar) and outs_scalar[i] is not None:
                    changed[i] = True
                    out.scalar = outs_scalar[i]
                if out.scalar is None and all(inp.scalar is not None for inp in inputs_enc):
                    changed[i] = True
                    out.scalar = all(inp.scalar for inp in inputs_enc)
                if (not out.scalar) and all(inp.scalar for inp in inputs_enc) and outs_scalar[i] is None:
                    changed[i] = True
                    out.scalar = True
            return changed

        def inner_callback(*nodes, **more_kwargs):
            return callback(*nodes, *args, **kwargs, **more_kwargs)

        args_string = [f"{{{i}}}" for i in range(len(inp_nodes))]
        args_string.extend(f"{v}" for v in args)
        args_string.extend(f"{k}={v}" for k, v in kwargs.items())
        func_string = f'{callback_name}({", ".join(args_string)})'

        self._flow.add_constraint(
            inp_nodes,
            out_nodes,
            inner_callback,
            state_function=state_function,
            func_string=func_string,
        )

    def lossy_element(self, node_0, node_1, lossy_element):
        """
        Add constraint s.t node_0 = lossy_element(node_1)
        """
        return self.callback(node_0, node_1, lossy_element, callback_name="lossy")

    def cast(self, node_0, node_1, dtype=tf.float32):
        """
        Add constraint s.t node_0 = tf.cast(node_1, dtype)
        """
        return self.callback(node_0, node_1, tf.cast, callback_name="tf.cast", dtype=dtype)

    def concat(self, *args, group_sizes=None):
        """
        Add constraint s.t out_node = concat(inp_nodes)
        """
        out_node, inp_nodes = args[0], list(args[1:])
        # if group_sizes is None:
        #     group_sizes = (1,)

        lengths = list()

        def forward_state_function(inputs_enc, outputs_enc):
            changed = False
            if all(inp.shape is not None for inp in inputs_enc):
                expected_shape = (sum(inp.shape[0] if len(inp.shape) > 0 else 1 for inp in inputs_enc),)
                if outputs_enc[0].shape is None:
                    changed = True
                    outputs_enc[0].shape = expected_shape
            if outputs_enc[0].scalar is None:
                changed = True
                outputs_enc[0].scalar = False
            return [changed]

        def forward_status_function(inputs_enc, outputs_enc):
            lengths.clear()
            lengths.append(0)
            for inp in inputs_enc:
                lengths.append((inp.shape[0] if len(inp.shape) > 0 else 1) + lengths[-1])
            if lengths[-1] != outputs_enc[0].shape[0]:
                return False
            return not outputs_enc[0].scalar

        if group_sizes is None:

            def forward_callback(*inputs):
                return tf.cast(tf.concat(inputs, axis=0), tf.float32)
        else:

            def forward_callback(*inputs):
                concat_scales_by_group = []
                total_size = sum(group_sizes)
                for ii, group_size in enumerate(group_sizes):
                    prev_group_sizes_sum = sum(group_sizes[:ii])
                    for inp in inputs:
                        size_scale_common_denominator = get_length(inp) // total_size
                        start = size_scale_common_denominator * prev_group_sizes_sum
                        size_scale_for_curr_group = size_scale_common_denominator * group_size
                        end = start + size_scale_for_curr_group
                        concat_scales_by_group.append(inp[start:end])
                return tf.cast(tf.concat(concat_scales_by_group, axis=0), tf.float32)

        forward_string = f'concat({", ".join(f"{{{i}}}" for i in range(len(inp_nodes)))})'

        def backward_state_function(outputs_enc, inputs_enc):
            changed = [False] * len(inputs_enc)
            if outputs_enc[0].shape is not None and sum(inp.shape is None for inp in inputs_enc) == 1:
                missing_shape = (
                    outputs_enc[0].shape[0]
                    - sum(inp.shape[0] if len(inp.shape) > 0 else 1 for inp in inputs_enc if inp.shape is not None),
                )
                for i, inp in enumerate(inputs_enc):
                    if inp.shape is None:
                        changed[i] = True
                        inp.shape = missing_shape
            for i, inp in enumerate(inputs_enc):
                if outputs_enc[0].scalar and not inp.scalar:
                    changed[i] = True
                    inp.scalar = True
                elif inp.scalar is None:
                    changed[i] = True
                    inp.scalar = False
            return changed

        def backward_status_function(outputs_enc, inputs_enc):
            lengths.clear()
            lengths.append(0)
            for inp in inputs_enc:
                lengths.append((inp.shape[0] if len(inp.shape) > 0 else 1) + lengths[-1])
            if lengths[-1] != outputs_enc[0].shape[0]:
                return False
            return outputs_enc[0].scalar or not any(inp.scalar for inp in inputs_enc)

        if group_sizes is None:

            def backward_callback(out):
                return tuple(out[lengths[index] : lengths[index + 1]] for index in range(len(lengths) - 1))
        else:

            def backward_callback(out):
                total_size = sum(group_sizes)
                start_outs = [
                    (lengths[-1] // total_size) * sum(group_sizes[:ii]) + (lengths[i] // total_size) * group_sizes[ii]
                    for ii in range(len(group_sizes))
                    for i in range(len(lengths) - 1)
                ] + [lengths[-1]]
                indices = [
                    tf.concat(
                        [
                            tf.range(start_outs[i], start_outs[i + 1])
                            for i in range(j, len(start_outs) - 1, len(lengths) - 1)
                        ],
                        axis=0,
                    )
                    for j in range(len(lengths) - 1)
                ]
                return tuple(tf.gather(out, indices[i]) for i in range(len(lengths) - 1))

        backward_string = ", ".join(f"{{{0}}}[slice_{i}]" for i in range(len(inp_nodes)))

        self._flow.add_constraint(
            inp_nodes,
            out_node,
            forward_callback,
            state_function=forward_state_function,
            status_function=forward_status_function,
            func_string=forward_string,
        )
        self._flow.add_constraint(
            out_node,
            inp_nodes,
            backward_callback,
            state_function=backward_state_function,
            status_function=backward_status_function,
            func_string=backward_string,
        )
