from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    IgnoreHwLimitationAssertionPolicy,
    PaddingType,
    StrideAlignType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasInitializationError,
    AccelerasUnsupportedError,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import calculate_shifts
from hailo_model_optimization.acceleras.utils.padding_utils import handle_padding


class CrossCorrelationDWOp(BaseAtomicOp):
    """
    Implement atomic opt for cross-correlation depth-wise layer.

    Description:
    This operation computes a cross-correlation (i.e., convolution in ML jargon) for an input tensor
    and a filters (kernel) that is given externally (aka dynamic input) by the user.

    Args:
        inputs  : (tensor) [batch_shape, in_height, in_width, in_channels] <--> NHWC
        filters : (tensor) [batch_shape, filter_height, filter_width, in_channels]

    """

    num_inputs = 2
    num_outputs = 1

    def __init__(
        self,
        name: str,
        strides=(1, 1),
        dilations=(1, 1),
        stride_align: Union[str, StrideAlignType] = "NW",
        padding: Union[str, PaddingType] = "VALID",
        fully_native=None,
        logger=None,
        **kwargs,
    ):
        """
        Args:
            strides, padding : arguments forwarded to strides and padding.

        """
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

        self.padding = PaddingType(padding)
        if self.padding not in {PaddingType.VALID, PaddingType.SAME}:
            raise AccelerasInitializationError(f"Padding type {padding} is not supported in {type(self)}")
        self.padding_const_value = 0

        if len(strides) != 2 and len(strides) != 4:
            raise AccelerasInitializationError(
                f"This {name} op supports only strides of length 2 or 4, but strides={strides} was given.",
            )
        self.strides = strides if len(strides) == 2 else strides[1:3]
        self.stride_align = StrideAlignType(stride_align)

        if np.any(np.array(dilations) != 1):
            raise AccelerasImplementationError(f"layer {self.full_name} do not support dilations other than (1,1)")
        if len(dilations) != 2 and len(dilations) != 4:
            raise AccelerasInitializationError(
                f"This {name} op supports only dilations of length 2 or 4, but dilations={dilations} was given.",
            )
        self.dilation_rate = dilations if len(dilations) == 2 else dilations[1:3]

        self.pre_acc_shift = 0

    @property
    def padding_const_value_q(self):
        input_scale = self.input_scales[0] if self.input_scales[0].shape == () else self.input_scales[0][0]
        zp = tf.reduce_mean(self.input_zero_points[0])
        quantized_val = self.padding_const_value / input_scale + zp
        return self.input_lossy_elements[0](quantized_val)

    def create_hw_params(self, preact_limvals, hw_shifts=None, **kwargs):
        """
        Calculates the pre accumulator shift, in bits.
        """
        if self.input_zero_points[0] != 0 or self.input_zero_points[1] != 0:
            raise AccelerasImplementationError(f"layer {self.full_name} do not support zero point != 0")

        limvals = np.abs(preact_limvals).max()
        expected_max_output = np.max(limvals / (self.input_scales[0] * self.input_scales[1]))

        accumultor_size = self.output_lossy_element.bits  # get accumulator

        pre_acc_shift, shift_delta = calculate_shifts(
            expected_max_output, accumultor_size, SHIFT_CALCULATE_BUFFER, hw_shifts=hw_shifts
        )
        if shift_delta != 0 and self._ignore_hw_limitation_assertion != IgnoreHwLimitationAssertionPolicy.enabled:
            name_to_display = "/".join(self.full_name.split("/")[:-1])
            factor = self._trunc_plus(2**shift_delta, 3)
            factor_sqrt = np.sqrt(factor)
            range_min0, range_max0 = self.get_input_limvals(0)
            range_min0, range_max0 = self._trunc_plus(range_min0, 3), self._trunc_plus(range_max0, 3)
            range0_str = f"[{range_min0:.03f}, {range_max0:.03f}]"
            range_min1, range_max1 = self.get_input_limvals(1)
            range_min1, range_max1 = self._trunc_plus(range_min1, 3), self._trunc_plus(range_max1, 3)
            range1_str = f"[{range_min1:.03f}, {range_max1:.03f}]"
            range0_fix_str = f"[{range_min0*factor_sqrt:.03f}, {range_max0*factor_sqrt:.03f}]"
            range1_fix_str = f"[{range_min1*factor_sqrt:.03f}, {range_max1*factor_sqrt:.03f}]"
            raise AccelerasUnsupportedError(
                f"layer {name_to_display} does not support shift delta. To overcome this issue you should "
                f"force larger range at the inputs of the layer using command "
                f"quantization_param([layer_name], force_range_in=[range_min, range_max], force_range_index=index) "
                f"current range of input 0 is {range0_str} and input 1 is {range1_str}."
                f"You should increase the multiplication of these ranges by a factor of {factor:.03f}, "
                f"e.g. you can apply factor of sqrt({factor:.03f}) to both inputs:\n"
                f"quantization_param([{name_to_display}], force_range_in={range0_fix_str}, force_range_index=0)\n"
                f"quantization_param([{name_to_display}], force_range_in={range1_fix_str}, force_range_index=1)\n",
            )

        self.pre_acc_shift = pre_acc_shift
        self.enforce_encoding()

    @staticmethod
    def _trunc_plus(value, decimals=0):
        """
        Truncate a float to a certain number of decimal places, and round up the last digit.
        """
        return (np.trunc(value * 10**decimals) + np.sign(value)) / 10**decimals

    def call_native(self, inputs, padding_const_value=None, **kwargs):
        """
        Args:
            inputs: (list) [input tensor, filters]

        Description: Implements a depth-wise convolution via tensorflow 2 (depth-wise means conv2d for each layer at a time).

        """
        kernel_size = (inputs[1].shape[1], inputs[1].shape[2])
        padding_const_value = self.padding_const_value if padding_const_value is None else padding_const_value
        padded_input = handle_padding(
            inputs[0],
            self.padding,
            kernel_size,
            self.strides,
            padding_const_value,
            self.stride_align,
            self.dilation_rate,
        )
        output = tf.map_fn(
            lambda elems: tf.nn.depthwise_conv2d(
                tf.expand_dims(elems[0], 0),  # H,W,C -> 1,H,W,C
                tf.expand_dims(elems[1], 3),  # H,W,C -> H,W,C,1
                strides=[1, self.strides[0], self.strides[1], 1],
                dilations=self.dilation_rate,
                padding="VALID",
            ),  # --> Result of conv is 1,H,W,1
            elems=[padded_input, inputs[1]],  # [input, kernel]
            dtype=tf.float32,
        )
        result = output[:, 0, :, :, :]  # B,1,H,W,C -> B,H,W,C
        return result

    def enforce_encoding(self, *args, **kwargs):
        """
        Enforce quantized encoding of scales and zeros from input to pre-accumulator input.
        """
        self.output_scale = self.input_scales[0] * self.input_scales[1] * 2 ** (self.pre_acc_shift)
        self.output_zero_point = np.array(0)

    def call_hw_sim(self, inputs, **kwargs):
        """
        Args:
            inputs: a list (2,) of EagerTensor for the input and filters (dynamic weights).
        Description:
            Simulating the core numeric functionality of Hailo MAC, that is, this function runs the
            mathematical operations (as in call_native()) with chip enabled valid mathematical operations
            (e.g., using only unit, int, etc.)

        """
        out_pre_shift = self.call_native(inputs, self.padding_const_value_q, **kwargs)
        out = out_pre_shift * 2.0 ** (-1 * self.pre_acc_shift)
        return out

    def export_independent_params(self):
        return {
            "pre_acc_shift": np.array(self.pre_acc_shift, np.float32),
        }

    def import_independent_params(self, params):
        self.pre_acc_shift = params["pre_acc_shift"]

    def export_quant_weights(self):
        return {"padding_const_value": self.padding_const_value_q.numpy()}

    def export_hw_params(self):
        return {
            "padding_const_value": self.padding_const_value_q.numpy().astype(np.uint16),
            "output_stage/mult_shift": np.array(self.pre_acc_shift, np.uint8),
        }

    def import_weights(self, layer_params: LayerParams):
        param_dict = dict(layer_params)
        pad_const_value = param_dict.get("padding_const_value", self.padding_const_value)
        self.padding_const_value = np.float32(pad_const_value)

    def export_weights(self, apply_scale_factors=False):
        return {"padding_const_value": self.padding_const_value}

    def create_weight_quant_element(self, kernel_bits, signed=True):
        pass
