import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationTypes, DefuseType, LayerType, PaddingType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class ReduceSumLayer(LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = True
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.reduce_sum
        self._reduce_axes = [3]
        self._groups = 1
        self._height_groups = 1
        self._interleaved_groups = False

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        reduce_axes,
        height_groups=1,
        groups=1,
        interleaved_groups=False,
        output_shapes=None,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer._reduce_axes = reduce_axes
        layer._height_groups = height_groups
        layer.groups = groups
        layer.interleaved_groups = interleaved_groups
        return layer

    @property
    def interleaved_groups(self):
        return self._interleaved_groups

    @interleaved_groups.setter
    def interleaved_groups(self, interleaved_groups):
        self._interleaved_groups = interleaved_groups

    @property
    def macs(self):
        # The /2 is because we don't do multiply
        return self.ops / 2

    @property
    def ops(self):
        # A little trick that simplifies it all
        # Each input is being summed one time (:
        return float(np.abs(np.prod(np.array(self.input_shape))))

    def get_width_kernel(self):
        width_dim = len(self.input_shape) - 2
        return self.input_shape[2] if width_dim in self.reduce_axes else 1

    def get_height_kernel(self):
        height_dim = 1
        kernel_height = int(self.input_shape[height_dim] / self._height_groups)
        return kernel_height if height_dim in self.reduce_axes and not self.is_reduce_rank2() else 1

    def is_reduce_height_or_width(self):
        features_dim = 3
        return not self.is_reduce_rank2() and features_dim not in self._reduce_axes and len(self._reduce_axes) > 0

    def is_reduce_rank2(self):
        height_dim = 1
        return self._reduce_axes == [height_dim] and len(self.input_shape) == 2

    @property
    def strides(self):
        return [1, self.get_height_kernel(), self.get_width_kernel(), 1]

    @property
    def kernel_shape(self):
        features_dim = len(self.input_shape) - 1
        kernel_input_features = (
            int(self.input_shape[features_dim] / self.groups) if features_dim in self.reduce_axes else 1
        )
        kernel_output_features = self.groups if features_dim in self.reduce_axes else self.input_shape[features_dim]
        return [self.get_height_kernel(), self.get_width_kernel(), kernel_input_features, kernel_output_features]

    @property
    def padding(self):
        return PaddingType.same

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @property
    def height_groups(self):
        return self._height_groups

    @height_groups.setter
    def height_groups(self, height_groups):
        self._height_groups = height_groups

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_REDUCE_SUM
        node.reduce_axes.extend(self.reduce_axes)
        node.groups = self._groups
        node.height_groups = self.height_groups
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        _, node.strides.height, node.strides.width, _ = self.strides
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._reduce_axes = pb.reduce_axes
        layer._groups = pb.groups
        layer._height_groups = pb.height_groups
        layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer._reduce_axes = old_layer.reduce_axes
        layer._groups = old_layer.groups
        layer._height_groups = old_layer.height_groups
        layer._activation = old_layer.activation
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["reduce_axes"] = list(self.reduce_axes)
        result["params"]["groups"] = self._groups
        result["params"]["height_groups"] = self._height_groups
        result["params"]["activation"] = self._activation.value
        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "params" in hn:
            if "reduce_axes" in hn["params"]:
                layer.reduce_axes = hn["params"]["reduce_axes"]
                # backward compatibility for rank2 reduce sum
                if layer.input_shape[1:-1] == [1, 1] and layer.reduce_axes == [1]:
                    layer.reduce_axes[0] = 3
            if "groups" in hn["params"]:
                layer.groups = hn["params"]["groups"]
                if layer.input_shape[-1] % layer.groups != 0:
                    raise UnsupportedModelError(
                        f"input features must be a multiply of groups for {layer.full_name_msg}",
                    )
            if "height_groups" in hn["params"]:
                layer.height_groups = hn["params"]["height_groups"]
                if layer.input_shape[1] % layer.height_groups != 0:
                    raise UnsupportedModelError(
                        f"input height must be a multiply of height_groups for {layer.full_name_msg}",
                    )
                if layer.height_groups > 1 and not layer.is_reduce_height_or_width():
                    raise UnsupportedModelError(
                        f"height_groups is not supported with given reduce axes for {layer.full_name_msg}",
                    )
            if "activation" in hn["params"]:
                layer.activation = ActivationTypes[hn["params"]["activation"]]
        return layer

    def move_params(self, layer):
        super().move_params(layer)
        if hasattr(layer, "reduce_axes"):
            self.reduce_axes = layer.reduce_axes

    def _calc_output_shape(self):
        if len(self.input_shape) == 2:
            return [self.input_shape[0], 1, 1, self._groups]
        if "defuse_input_width" in self.defuse_params and self.defuse_input_width != 0:
            width = self.defuse_input_width
        else:
            width = self.input_shape[2]
        output_shape = [self.input_shape[0], self.input_shape[1], width, self.input_shape[3]]
        for axis in self.reduce_axes:
            output_shape[axis] = 1
        if output_shape[1] == 1:
            output_shape[1] = self._height_groups
        if output_shape[-1] == 1:
            output_shape[-1] = self._groups
        return output_shape

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate)
        if len(self.input_shape) == 2:
            self._reduce_axes = [1]

    @property
    def input_width(self):
        if self.defuse_type == DefuseType.spatial_w:
            return self.defuse_input_width
        return super().input_width

    @property
    def reduce_axes(self):
        return self._reduce_axes

    @reduce_axes.setter
    def reduce_axes(self, reduce_axes):
        self._reduce_axes = reduce_axes

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
