import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LayerEquivType,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    ActivationTypes,
    DefuseType,
    HnStage,
    LayerType,
    NormalizationType,
    PaddingType,
)
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification
from hailo_sdk_common.numeric_utils.normalization_params import calc_normalization_params


class NormalizationLayer(LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.normalization
        self._padding = PaddingType.same
        self._mean = None
        self._std = None
        self._number_of_inputs_supported = 2
        self._ew_connections = []
        self._normalization_type = NormalizationType.normalization

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        mean,
        std,
        output_shapes=None,
        padding=PaddingType.same,
        activation=ActivationType.linear,
        normalization_type=NormalizationType.normalization,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.mean = mean
        layer.std = std
        layer.padding = padding
        layer.activation = activation
        layer.normalization_type = normalization_type
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_NORMALIZATION
        _, node.strides.height, node.strides.width, node.strides.features = self.strides
        node.kernel_shape.height, node.kernel_shape.width, _, node.kernel_shape.features = self.kernel_shape
        node.ew_add = self.ew_add_enabled
        node.normalization_type = pb_wrapper.NORMALIZATION_TYPE_TO_PB[self.normalization_type]
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.normalization_type = pb_wrapper.NORMALIZATION_TYPE_PB_TO_TYPE[pb.normalization_type]
        layer.activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op not in [LayerType.base_ew_mean, LayerType.rnn, LayerType.lstm, LayerType.equal]:
            layer.normalization_type = old_layer.normalization_type
            layer.activation = old_layer.activation
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["elementwise_add"] = self.ew_add_enabled
        result["params"]["activation"] = self._activation.value
        return result

    @property
    def padding(self):
        return self._padding

    @padding.setter
    def padding(self, padding):
        self._padding = padding

    @property
    def mean(self):
        return self._mean

    @mean.setter
    def mean(self, mean):
        self._mean = mean

    @property
    def normalization_type(self):
        return self._normalization_type

    @normalization_type.setter
    def normalization_type(self, norm_type):
        self._normalization_type = norm_type

    @property
    def std(self):
        return self._std

    @std.setter
    def std(self, std):
        self._std = std

    @property
    def strides(self):
        return [1, 1, 1, 1]

    @property
    def dilations(self):
        return [1, 1, 1, 1]

    def _convert_normalization_params_to_dw_params(self):
        kernel, bias = calc_normalization_params(self._mean, self._std, self.kernel_shape)
        return kernel, bias

    @property
    def kernel(self):
        kernel, _ = self._convert_normalization_params_to_dw_params()
        return kernel

    @property
    def kernel_height(self):
        return self.kernel_shape[0]

    @property
    def kernel_width(self):
        return self.kernel_shape[1]

    @property
    def bias(self):
        _, bias = self._convert_normalization_params_to_dw_params()
        return bias

    @property
    def output_features(self):
        return self.kernel_shape[2]

    @property
    def macs(self):
        # The /2 is because we don't do multiply
        return self.ops / 2

    @property
    def ops(self):
        # A little trick that simplifies it all
        total_input_size = float(np.abs(np.prod(np.array(self.input_shape))))
        # For each value, we subtract a value and div by a value
        return total_input_size * 2

    @property
    def kernel_shape(self):
        return [1, 1, self.input_features, 1]

    @property
    def ew_add_enabled(self):
        return len(self._ew_connections) > 0

    @property
    def ew_add_connections(self):
        return self._ew_connections

    @property
    def input_shape(self):
        if self.ew_add_enabled:
            return self._input_shapes[0]

        return super().input_shape

    @input_shape.setter
    def input_shape(self, input_shape):
        super(NormalizationLayer, self.__class__).input_shape.fset(self, input_shape)

    def get_equalization_handler_type(self, predecessor=None):
        if self._transpose_output_width_features:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        else:
            return EquivClassification(LayerHandlerType.consumer, is_source=True)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.featurewise, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        if self._transpose_output_width_features or self._spatial_flatten_output:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        else:
            return EquivClassification(LayerHandlerType.featurewise, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.supported

    def move_params(self, layer):
        super().move_params(layer)
        if layer.op == LayerType.normalization:
            self._mean = layer._mean
            self._std = layer._std
        if layer.op == LayerType.base_activation and layer.activation == ActivationType.prelu:
            if self.activation == ActivationType.relu:
                self._mean = [0.0] * layer.output_features
                self._std = [-1.0] * layer.output_features
            else:
                self._mean = [0.0] * layer.output_features
                self._std = [-1 / x for x in layer.prelu_slope]

    def update_output_shapes(self, **kwargs):
        hn_stage = kwargs["hn_stage"]
        if hn_stage == HnStage.PRE_FUSED:
            self.update_mean_and_std()
        super().update_output_shapes()

    def _calc_output_shape(self):
        output_f = self.input_shape[-1]
        if "defuse_features" in self.defuse_params and self.defuse_type not in [
            DefuseType.none,
            DefuseType.compute_lanes,
            DefuseType.spatial_w,
        ]:
            output_f = self.defuse_features

        output_shape = list(self.input_shape)
        if (
            self.defuse_type is DefuseType.spatial_w
            and "defuse_input_width" in self.defuse_params
            and self.defuse_input_width != 0
        ):
            output_shape[2] = self.defuse_input_width

        if self._transpose_output_width_features:
            _, output_h, output_w, output_f = output_shape
            output_shape = [-1, output_h, output_f, output_w]
        else:
            output_shape = [-1, *output_shape[1:-1], output_f]

        return output_shape

    def update_mean_and_std(self):
        if isinstance(self.mean, np.ndarray):
            self.mean = self.mean.tolist()
        if not isinstance(self.mean, list):
            self.mean = [self.mean] * self.input_features
        elif len(self.mean) == 1 and self.input_features != 1:
            self.mean = self.mean * self.input_features
        if isinstance(self.std, np.ndarray):
            self.std = self.std.tolist()
        if not isinstance(self.std, list):
            self.std = [self.std] * self.input_features
        elif len(self.std) == 1 and self.input_features != 1:
            self.std = self.std * self.input_features

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn, validate_params_exist=False)
        if "params" in hn and "activation" in hn["params"]:
            layer._activation = ActivationTypes[hn["params"]["activation"]]

        return layer

    @staticmethod
    def get_axes_mask(type_of_layer=None):
        if type_of_layer == LayerEquivType.producer:
            return [True, True, False, True]
        else:
            return [True, True, False, True]

    def add_ew_connection(self, other_layer):
        self._ew_connections.append(other_layer)

    def clear_ew_connections(self):
        self._ew_connections = []

    def _is_ew_connection(self, other_layer):
        return other_layer in self._ew_connections

    def sort_inputs(self):
        def sort_function(layer1, layer2):
            ew1 = self._is_ew_connection(layer1)
            ew2 = self._is_ew_connection(layer2)
            if ew1 and (not ew2):
                return 1
            if (not ew1) and ew2:
                return -1
            return 0

        return sort_function

    @property
    def hn_name(self):
        return self._normalization_type.value
