from typing import Union

import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import NativeName


class GenericNativeOp(BaseAtomicOp):
    """
    Describes a generic native operation.
    - this op is dedicated for a only native operation, that has no scale change
    - this op will be used for geting statistice on the operation

    """

    def __init__(
        self,
        name: str,
        op_func_name: Union[NativeName, str, callable],
        num_inputs=1,
        num_outputs=1,
        config_params={},
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        self._num_inputs = num_inputs
        self._num_outputs = num_outputs

        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.config_params = config_params
        self.function, self.function_name = self.get_function_by_native(op_func_name, **kwargs)

    @property
    def num_inputs(self) -> int:
        return self._num_inputs

    @property
    def num_outputs(self) -> int:
        return self._num_outputs

    def get_function_by_native(self, op_func_name, **kwargs):
        if callable(op_func_name):
            return op_func_name, "callable"

        def ew_sub(inputs, **kwargs):
            input_repeats = self.config_params["input_repeats"]
            for i, repeats in enumerate(input_repeats):
                for dim, repeat in enumerate(repeats):
                    inputs[i] = tf.repeat(inputs[i], repeat, axis=dim + 1)

            return inputs[0] - inputs[1]

        op_func_by_name = {
            NativeName.EW_SUB: ew_sub,
        }

        if isinstance(op_func_name, str):
            op_func_name = NativeName[op_func_name]

        if isinstance(op_func_name, NativeName):
            func_op = op_func_by_name[op_func_name]
            return func_op, op_func_name

    def create_weight_quant_element(self, **kwargs):
        # Non arithmetic ops shouldn't have any weights
        pass

    def create_hw_params(self, *args):
        self.enable_lossy()
        """NO quantization.."""

    def call_native(self, inputs, **kwargs):
        return self.function(inputs, **kwargs)

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def export_weights(self):
        return dict()

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    def import_weights(self, layer_params, **kwargs):
        self.func_params = {}

    def _compute_output_shape(self, input_shape):
        # the first input_shape is repeted by input_repeats to match the second input shape
        return input_shape[1]
