import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationTypes, DefuseType, LayerType, PaddingType
from hailo_sdk_common.hailo_nn.hn_layers.conv2d import Conv2DLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FusedConv2DLayer(Conv2DLayer, LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._bn_info = None
        self._pre_layer_bn = False
        self._bn_enabled = False
        self._ew_connections = []
        self._op = LayerType.conv
        self._ew_add_factor = 1

    def move_params(self, old_layer):
        super().move_params(old_layer)
        if (
            old_layer.op == LayerType.base_conv
            and self.kernel_height == self.kernel_width == 1
            and old_layer.kernel_height == old_layer.kernel_width == 16
        ):
            self.reshape_kernel_conv_octxoct()
        if old_layer.op in [LayerType.base_batch_norm, LayerType.deconv]:
            self._bn_info = old_layer.bn_info
        if old_layer.op == LayerType.normalization:
            self.move_normalization_params(old_layer)
        if (
            old_layer.op in [LayerType.base_ew_add, LayerType.base_ew_sub, LayerType.ew_add, LayerType.ew_sub]
            and self.index == old_layer.index
        ):
            self._strides = [1, 1, 1, 1]
            self._dilations = [1, 1, 1, 1]
            self._kernel_shape = [1, 1, self.output_features, self.output_features]
            self._padding = PaddingType.valid
            self._bias = np.zeros(self.output_features, dtype=np.float32)
            self._kernel = np.reshape(np.identity(self.output_features, dtype=np.float32), self._kernel_shape)
            if old_layer.op in [LayerType.base_ew_sub, LayerType.ew_sub]:
                self._kernel = -1.0 * self._kernel
            self._groups = 1

    def move_normalization_params(self, old_layer):
        std, mean = np.array(old_layer.std, dtype=np.float128), np.array(old_layer.mean, dtype=np.float128)
        std[std == 0] += 1e-7
        new_kernel = np.copy(self.kernel).astype(np.float128)
        new_bias = np.copy(self.bias).astype(np.float128)

        # Forward folding normalization is only supported for 1x1 kernels, due to the high time consumption
        # of the inefficient calculation of the new bias, and problem with supporting paddings
        if self.pre_layer_bn:
            if not (self.kernel_height == self.kernel_width == 1):
                raise UnsupportedModelError(
                    f"Cannot forward fold {old_layer.full_name_msg} to {self.full_name_msg} that has kernel != 1x1.",
                )
            self.original_names = old_layer.original_names + [
                x for x in self.original_names if x not in old_layer.original_names
            ]
            fin, fout = np.shape(new_kernel)[-2:]
            new_kernel /= std.reshape([1, 1, fin, 1])
            new_mean = np.matmul(mean, new_kernel.reshape([fin, fout])).reshape([fout])
            new_bias = new_bias - new_mean
        else:
            new_bias = (new_bias - mean) / std
            if self.op == LayerType.dw:
                std = std.reshape([1, 1, std.shape[0], 1])
            new_kernel /= std

        self.kernel = new_kernel.astype(np.float32)
        self.bias = new_bias.astype(np.float32)

    @property
    def bn_enabled(self):
        return self._bn_enabled

    @bn_enabled.setter
    def bn_enabled(self, bn_enabled):
        self._bn_enabled = bn_enabled

    @property
    def ew_add_enabled(self):
        return len(self._ew_connections) > 0

    @property
    def ew_add_connections(self):
        return self._ew_connections

    @property
    def input_shape(self):
        if self.ew_add_enabled or self.dynamic_weights:
            # In case of conv with ew add, the input shapes can be different,
            # and we will fail the deafult input_shape property
            return self._input_shapes[0]

        return super().input_shape

    @input_shape.setter
    def input_shape(self, input_shape):
        # This does nothing, but otherwise, Conv2DLayers won't have an input_shape setter
        super(FusedConv2DLayer, self.__class__).input_shape.fset(self, input_shape)

    def add_ew_connection(self, other_layer):
        if self.op not in [LayerType.conv, LayerType.dw]:
            raise UnsupportedModelError(
                f"Cannot add elementwise add connections for {self.full_name_msg}. "
                f"It is only supported for conv layers, but the layer is of type {self.op}",
            )
        self._ew_connections.append(other_layer)

    def clear_ew_connections(self):
        self._ew_connections = []

    def _is_ew_connection(self, other_layer):
        return other_layer in self._ew_connections

    def sort_inputs(self):
        def sort_function(layer1, layer2):
            ew1 = self._is_ew_connection(layer1)
            ew2 = self._is_ew_connection(layer2)
            if ew1 and (not ew2):
                return 1
            if (not ew1) and ew2:
                return -1
            return 0

        return sort_function

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["batch_norm"] = self.bn_enabled
        result["params"]["elementwise_add"] = self.ew_add_enabled
        if self.ew_add_enabled:
            result["params"]["elementwise_add_factor"] = self.ew_add_factor
        result["params"]["activation"] = self._activation.value
        if self.pre_layer_bn:
            result["params"]["pre_layer_batch_norm"] = True
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        if self._op == LayerType.conv:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_CONV
        elif self._op == LayerType.dw:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_DW
        elif self._op == LayerType.deconv:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_DECONV

        node.batch_norm = self.bn_enabled
        node.ew_add = self.ew_add_enabled
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        node.pre_layer_batch_norm = self.pre_layer_bn
        return node

    @property
    def input_width(self):
        if self.defuse_type == DefuseType.spatial_w:
            return self.defuse_input_width
        return super().input_width

    @property
    def bn_info(self):
        return self._bn_info

    @property
    def pre_layer_bn(self):
        return self._pre_layer_bn

    @pre_layer_bn.setter
    def pre_layer_bn(self, pre_layer_bn):
        self._pre_layer_bn = pre_layer_bn

    @property
    def ew_add_factor(self):
        return self._ew_add_factor

    @ew_add_factor.setter
    def ew_add_factor(self, ew_add_factor):
        self._ew_add_factor = ew_add_factor

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if hn["params"]["batch_norm"]:
            layer.bn_enabled = True
        if "pre_layer_batch_norm" in hn["params"]:
            layer.pre_layer_bn = hn["params"]["pre_layer_batch_norm"]
        if hn["params"].get("elementwise_add"):
            layer.ew_add_factor = hn["params"].get("elementwise_add_factor", 1)

        if hn["type"] == LayerType.conv.value:
            layer.op = LayerType.conv
        elif hn["type"] == LayerType.dw.value:
            layer.op = LayerType.dw
        elif hn["type"] == LayerType.deconv.value:
            layer.op = LayerType.deconv
        layer.activation = ActivationTypes[hn["params"]["activation"]]

        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        if pb.type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_CONV:
            layer._op = LayerType.conv
        elif pb.type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_DW:
            layer._op = LayerType.dw
        elif pb.type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_DECONV:
            layer._op = LayerType.deconv

        # Handling edge cases:
        # 1. A fused layer is a stand-alone ew-add layer
        if pb.type != pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_EW_ADD:
            layer.bn_enabled = pb.batch_norm
            layer.pre_layer_bn = pb.pre_layer_batch_norm
            layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]

        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.bn_enabled = old_layer.bn_enabled

        if old_layer.op == LayerType.base_dw:
            layer.op = LayerType.dw
        elif old_layer.op == LayerType.base_deconv:
            layer.op = LayerType.deconv
        elif old_layer.op == LayerType.base_conv:
            layer.op = LayerType.conv

        if old_layer.op not in [LayerType.base_dw, LayerType.base_deconv, LayerType.base_conv]:
            layer.activation = old_layer.activation

        return layer

    @property
    def macs(self):
        number_of_elements = self.output_features * self.output_height * self.output_width
        if self.op in [LayerType.conv, LayerType.deconv]:
            # Adding plus for the bias
            macs_per_element = (
                self.input_features
                / self.groups
                / self.input_disparity
                * self.kernel_height
                * self.kernel_width
                * self.kernel_disparity
                + 1
            )
        elif self.op is LayerType.dw:
            # Adding plus for the bias
            macs_per_element = self.kernel_height * self.kernel_width * self.kernel_disparity + 1
        else:
            raise UnsupportedModelError(f"Cannot get macs for layer {self.name}. Invalid op: {self.op}")
        return macs_per_element * number_of_elements

    @property
    def ops(self):
        """
        Return the number of multiplications and additions.
        Multiplications are input_features / input_disparity / groups * kernel_height * kernel_width * kernel_disparity
        Additions are (input_features - 1) / input_disparity / groups * kernel_height * kernel_width * kernel_disparity
        and bias addition is kernel_height * kernel_width * kernel_disparity / groups / input_disparity.
        """
        number_of_elements = self.output_features * self.output_height * self.output_width
        if self.op in [LayerType.conv, LayerType.deconv]:
            ops_per_element = (
                self.input_features
                / self.input_disparity
                / self.groups
                * self.kernel_disparity
                * self.kernel_height
                * self.kernel_width
                * 2
            )
        elif self.op is LayerType.dw:
            ops_per_element = self.kernel_height * self.kernel_width * self.kernel_disparity * 2
        else:
            raise UnsupportedModelError(f"Cannot get ops for layer {self.name}. Invalid op: {self.op}")
        return ops_per_element * number_of_elements

    @property
    def weights(self):
        common_weights = (
            (self.output_features / self.layer_disparity)
            * self.kernel_height
            * self.kernel_width
            * self.kernel_disparity
        )
        if self.op in [LayerType.conv, LayerType.deconv]:
            return (
                common_weights * self.input_features / self.groups / self.input_disparity
                + self.output_features / self.layer_disparity
            )
        elif self.op is LayerType.dw:
            return common_weights + self.output_features / self.layer_disparity
        else:
            raise UnsupportedModelError(f"Cannot get weights for layer {self.name}. Invalid op: {self.op}")

    def get_equalization_handler_type(self, predecessor=None):
        if self.ew_add_enabled:
            is_source = False
            if (predecessor is not None) and predecessor in self.ew_add_connections:
                handler_type = LayerHandlerType.ew_bouncer
            else:
                handler_type = LayerHandlerType.consumer
        elif self.dynamic_weights or self._transpose_output_width_features:
            is_source = False
            handler_type = LayerHandlerType.unsupported
        else:
            is_source = True
            handler_type = LayerHandlerType.consumer
        return EquivClassification(handler_type, is_source=is_source)

    def get_params_sorter_handler_type(self, predecessor=None):
        if self.groups > 1:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        if (predecessor is not None) and self.ew_add_enabled and predecessor in self.ew_add_connections:
            return EquivClassification(LayerHandlerType.ew_bouncer, is_source=False)
        handler = {
            LayerType.dw: LayerHandlerType.featurewise if not self.dynamic_weights else LayerHandlerType.unsupported,
            LayerType.conv: LayerHandlerType.consumer,
            LayerType.deconv: LayerHandlerType.consumer,
        }
        is_source = {
            LayerType.dw: False,
            LayerType.conv: not self.ew_add_enabled,
            LayerType.deconv: False,
        }
        return EquivClassification(handler[self.op], is_source=is_source[self.op])

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        if (predecessor is not None) and self.ew_add_enabled and predecessor in self.ew_add_connections:
            return EquivClassification(LayerHandlerType.ew_bouncer, is_source=False)
        if self.groups > 1 or self.transpose_output_width_features or self.spatial_flatten_output:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        handler = {
            LayerType.dw: LayerHandlerType.featurewise,
            LayerType.conv: LayerHandlerType.consumer,
            LayerType.deconv: LayerHandlerType.consumer,
        }
        is_source = {
            LayerType.dw: False,
            LayerType.conv: not self.ew_add_enabled,
            LayerType.deconv: not self.ew_add_enabled,
        }
        return EquivClassification(handler[self.op], is_source=is_source[self.op])

    def ibc_supported(self):
        if (self.op == LayerType.dw) and self.dynamic_weights:
            return LayerSupportStatus.unsupported
        return LayerSupportStatus.supported

    def is_zippable(self, other):
        """Allow zipping two fused_conv2d layers as long as they share the same activation and ew_add connections"""
        if self.ew_add_enabled != other.ew_add_enabled:
            return False
        if self.activation != other.activation:
            return False
        return super().is_zippable(other)
