from typing import Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.flatten_op import FlattenOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig

# TODO: something is off about this implementation. the input shape should be set in build instead of call
#       real_inputs_shape is a unique property, where is it used exactly? maybe input_spec would be enough?
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    LayerType,
    OptimizationTarget,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasException


class HailoDense(BaseHailoConv):
    """
    Represents `dense` layer in the hn
    Implementing dense layer via conv -
    NOTE: Change the call function to reshape the input before going into the dense layer.
    """

    SUPPORTED_PRECISION_MODE = BaseHailoConv.SUPPORTED_PRECISION_MODE
    _hn_type = LayerType.DENSE

    def __init__(
        self,
        name: str,
        units,
        activation: Union[str, callable, ActivationType] = "linear",
        logger=None,
        **kwargs,
    ):
        self.flatten_op = FlattenOp(f"{name}/flatten_op", logger=logger)
        conv_op = ConvStrippedOp(
            f"{name}/conv_op",
            kernel_size=(1, 1),
            is_depthwise=False,
            filters=units,
            groups=1,
            strides=(1, 1),
            dilation_rate=(1, 1),
            logger=logger,
        )
        super().__init__(
            name=name,
            conv_op=conv_op,
            activation=activation,
            logger=logger,
            **kwargs,
        )
        self.conv_op.validate_shapes = True
        self.act_op.validate_shapes = True
        self.input_spec = (
            tf.keras.layers.InputSpec()
        )  # currently Dense is agnostic to #spatial-dims of input, flattens anyi

        self.encoding_const = False

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()

        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.flatten_op)
        layer_flow.add_node(self.conv_op)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)

        layer_flow.add_edge(in1, self.flatten_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.flatten_op, self.conv_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.conv_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.output_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        params.update(hn_element.get("params", dict()))
        kshape = params["kernel_shape"]
        transpose_output_width_features = params.get("transpose_output_width_features", False)

        if transpose_output_width_features:
            raise AccelerasException("transpose_output_width_features is not supported in dense layer in acceleras yet")
        layer = cls(
            name=lname,
            units=kshape[-1],
            activation=params["activation"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def import_native_kernel(self, kernel, layer_params=None):
        # Input kernel for dense is 2d
        kernel_4d = np.expand_dims(kernel, axis=(0, 1))
        self.conv_op.import_weights(kernel_4d)

    def get_numeric_kernel_np(self):
        numeric_kernel = super().get_numeric_kernel_np()
        # squeeze kernel
        return np.squeeze(numeric_kernel, axis=(0, 1))

    def _layer_dependent_hw_params_modifications(self, params: dict):
        params["kernel"] = np.squeeze(params["kernel"], axis=(0, 1))
        return params

    def _change_native_kernel(self, kernel_changed):
        # squeeze kernel
        return np.squeeze(kernel_changed, axis=(0, 1))

    def get_equalization_handler_type(self, predecessor_index=None):
        if self.forced_output_scale_scalar_dof is not None:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        # TODO we dont handle yet dynamic weights
        return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def get_quarot_handler_type(self, predecessor_index=None):
        if self.forced_output_scale_scalar_dof is not None:
            return EquivClassification(LayerHandlerType.transparent, is_source=False)
        return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def enforce_internal_encoding(self, training=False, **kwargs):
        self._enforce_flatten_encoding()
        super().enforce_internal_encoding(training=training, **kwargs)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self._enforce_flatten_encoding()
        return super().create_hw_params(weights_clipping, optimization_target, hw_shifts=hw_shifts)

    def _enforce_flatten_encoding(self):
        self.flatten_op.enforce_encoding()
        self.conv_op.input_scales = [self.flatten_op.output_scale]
        self.conv_op.input_zero_points = [self.flatten_op.output_zero_point]

    def _enforce_input_encoding(self):
        self._enforce_flatten_encoding()

    def enforce_io_encoding(self, training=False, **kwargs):
        self._enforce_flatten_encoding()  # TODO , we use it here for eqalization not so nice
        super().enforce_io_encoding(training=training, **kwargs)

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return {
                PrecisionMode.a8_w4_a8,
                PrecisionMode.a8_w8_a8,
                PrecisionMode.a8_w4_a16,
                PrecisionMode.a8_w8_a16,
                PrecisionMode.a16_w16_a16,
                PrecisionMode.a8_w4,
                PrecisionMode.a8_w8,
                PrecisionMode.a16_w16,
            }
        else:
            return super()._get_precision_mode_supported_in_hw(arch)

    def _get_bias_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            return {
                BiasMode.single_scale_decomposition,
                BiasMode.double_scale_initialization,
            }  # by default same as emulator support
        else:
            return super()._get_bias_mode_supported_in_hw(arch)

    def verify_layer_inputs_shape(self, input_shapes):
        pass

    @staticmethod
    def _verify_hn_to_keras_input_shapes(keras_shapes, hn_shapes):
        if len(keras_shapes) != len(hn_shapes):
            return False
        for keras_shape, hn_shape in zip(keras_shapes, hn_shapes):
            is_flatten_input = hn_shape[1] == 1 and hn_shape[2] == 1
            if is_flatten_input:
                return np.prod(keras_shape[1:]) == np.prod(hn_shape[1:])
            if not (np.array(keras_shape) == np.array(hn_shape)).all():
                return False
        return True

    def _get_hn_output_shapes(self):
        output_shapes = super()._get_hn_output_shapes()
        return [[-1, 1, 1, np.prod(shape[1:])] for shape in output_shapes]
