from typing import Tuple

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DEFAULT_PADDING_NEG_INF_VALUE
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams


class PaddingOp(BaseAtomicOp):
    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        padding: Tuple[int, int, int, int, int, int],
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

        self.padding = padding
        self.padding_const_value = 0

    def call_native(self, inputs, **kwargs):
        # assigns to native value of real -inf if needed
        const_value = (
            self.padding_const_value
            if self.padding_const_value != DEFAULT_PADDING_NEG_INF_VALUE
            else -np.float32(np.inf)
        )
        return self._call_internal(inputs[0], const_value)

    def call_hw_sim(self, inputs, **kwargs):
        return self._call_internal(inputs[0], self.padding_const_value_q)

    def _call_internal(self, inputs, padding_val):
        padding = self.padding
        paddings = tf.constant([[0, 0], [padding[0], padding[1]], [padding[2], padding[3]], [padding[4], padding[5]]])
        op = tf.pad(inputs, paddings, "CONSTANT", constant_values=padding_val)
        return op

    def _compute_output_shape(self, input_shape):
        w_add = self.padding[0] + self.padding[1]
        h_add = self.padding[2] + self.padding[3]
        f_add = self.padding[4] + self.padding[5]
        shape = [input_shape[0], input_shape[1] + w_add, input_shape[2] + h_add, input_shape[3] + f_add]
        return shape

    def create_weight_quant_element(self, **kwargs):
        pass

    @property
    def padding_const_value_q(self):
        if self.padding_const_value == DEFAULT_PADDING_NEG_INF_VALUE:
            return tf.constant(0, dtype=tf.float32)

        output_scale = self.output_scale if self.output_scale.shape == () else self.output_scale[0]
        output_zero_point = self.output_zero_point if self.output_zero_point.shape == () else self.output_zero_point[0]
        quantized_val = self.padding_const_value / output_scale + output_zero_point
        return self.input_lossy_element(quantized_val)

    def enforce_encoding(self):
        self.output_scale = self.update_zp_or_scale(self.input_scales[0])
        self.output_zero_point = self.update_zp_or_scale(self.input_zero_points[0])

    def update_zp_or_scale(self, input_value):
        if input_value.shape == ():
            output_value = input_value
        else:
            features_padding = self.padding[4] != 0 or self.padding[5] != 0
            if features_padding:
                padding_extension = self.padding[4] + self.padding[5]
                output_shape = input_value.shape[0] + padding_extension
                output_value = tf.repeat(input_value[0], output_shape)
            else:
                output_value = input_value
        return output_value

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    def export_hw_params(self):
        return {"padding_const_value": self.padding_const_value_q.numpy().astype(np.uint16)}

    def export_quant_weights(self):
        return {"padding_const_value": self.padding_const_value_q.numpy()}

    def export_weights(self):
        return {"padding_const_value": self.padding_const_value}

    def import_weights(self, layer_params: LayerParams):
        param_dict = dict(layer_params)
        pad_const_value = param_dict.get("padding_const_value", self.padding_const_value)
        self.padding_const_value = np.float32(pad_const_value)
