import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LayerEquivType,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import (
    DefuseType,
    LayerType,
    PaddingType,
    PaddingTypes,
    SubclustersNoContextsPolicy,
    TemporaryPaddingType,
)
from hailo_sdk_common.hailo_nn.hn_layers.inner_layer import InnerLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import input_to_output_height_width
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class Conv2DLayer(InnerLayer):
    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._padding = None
        self._padding_const_value = 0
        self._external_padding_value = None
        self._strides = None
        self._op = LayerType.base_conv
        self._dilations = None
        self._is_dilated_s2b = False
        self._groups = 1
        self._number_of_inputs_supported = 2
        self._layer_disparity = None
        self._input_disparity = 1
        self._next_layer_groups_num = 1

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        op,
        kernel,
        bias,
        padding,
        padding_vals,
        strides,
        dilations,
        groups=1,
        output_shapes=None,
        dynamic_kernel_shape=None,
        transpose_output_width_features=False,
        input_disparity=1,
        next_layer_groups_num=1,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.op = op
        layer.kernel = kernel
        if kernel is not None:
            layer.kernel_shape = [kernel.shape[0], kernel.shape[1], kernel.shape[2] * groups, kernel.shape[3]]
        else:
            layer.kernel_shape = dynamic_kernel_shape
        layer.bias = bias
        layer.padding = padding
        if layer.padding in [TemporaryPaddingType.external_undecided, TemporaryPaddingType.conv3d]:
            layer.external_padding_value = padding_vals
        layer.strides = strides
        layer.dilations = dilations
        layer.groups = groups
        layer.dynamic_weights = dynamic_kernel_shape is not None
        layer.transpose_output_width_features = transpose_output_width_features
        layer.input_disparity = input_disparity
        layer.next_layer_groups_num = next_layer_groups_num
        return layer

    @property
    def kernel_short_description(self):
        description = (
            f" ({self.kernel_shape[0]}x{self.kernel_shape[1]}/{self.strides[1]}) "
            f"({self.input_shapes[0][-1]}->{self.output_shapes[0][-1]})"
        )
        if self._dilations and len(self._dilations) > 1 and self._dilations[1] > 1:
            description += f" (dilation={self._dilations[1]})"
        return description

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["strides"] = [int(x) for x in self._strides]
        result["params"]["dilations"] = [int(x) for x in self._dilations]
        result["params"]["padding"] = self._padding.value
        result["params"]["groups"] = self._groups
        result["params"]["layer_disparity"] = self.layer_disparity
        result["params"]["input_disparity"] = self.input_disparity
        if self._group_sizes and not all(group_size == 1 for group_size in self._group_sizes):
            result["params"]["group_sizes"] = self._group_sizes
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)

        if self._op == LayerType.base_conv:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_CONV
        elif self._op == LayerType.base_dw:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_DW
        elif self._op == LayerType.base_deconv:
            node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_DECONV

        node.kernel_shape.height, node.kernel_shape.width, _, node.kernel_shape.features = self.kernel_shape
        _, node.strides.height, node.strides.width, node.strides.features = self.strides
        _, node.dilations.height, node.dilations.width, _ = self.dilations
        node.padding = pb_wrapper.PADDING_TYPE_TO_PB[self.padding]
        node.groups = self.groups
        if self.group_sizes:
            node.group_sizes.extend(self.group_sizes)
        node.input_disparity = self.input_disparity
        node.layer_disparity = self.layer_disparity
        node.kernel_disparity = self.kernel_disparity

        return node

    def _calc_deconv_output_shape(self, input_shape):
        output_h = input_shape[1] * self._strides[1]
        output_w = input_shape[2] * self._strides[2]
        output_f = self._kernel_shape[3]
        return [-1, output_h, output_w, output_f]

    def _calc_conv_ew_add_output_shape(self):
        if self._op in [LayerType.dw, LayerType.base_dw]:
            second_conv_output_shape = self._calc_dw_conv_output_shape(self._input_shapes[0])
        else:
            second_conv_output_shape = self._calc_conv_output_shape(self.input_shapes[0])

        add_output_shape = self.input_shapes[1]
        if self.is_defused():
            defuse_types = []
            if self.defuse_params and self.defuse_params.get("defuse_types"):
                defuse_types = self.defuse_params.get("defuse_types")

            if DefuseType.super_conv in defuse_types or DefuseType.super_dw in defuse_types:
                if self.defuse_ew_add_input_width != 0:
                    batch, height, _, features = add_output_shape
                    return [batch, height, self.defuse_ew_add_input_width, features]
                if self._defuse_params.get("defuse_original_features") != 0:
                    add_output_shape = add_output_shape[:-1] + [self._defuse_params.get("defuse_original_features")]
            if self._defuse_params.get("defuse_input_width") != 0:
                if self._op in [LayerType.dw, LayerType.base_dw]:
                    final_output_shape = self._calc_dw_conv_output_shape(
                        self.input_shapes[0],
                        get_spatial_defused_output_shape=False,
                    )
                else:
                    final_output_shape = self._calc_conv_output_shape(
                        self.input_shapes[0],
                        get_spatial_defused_output_shape=False,
                    )
            else:
                final_output_shape = second_conv_output_shape[:-1] + [
                    self._defuse_params.get("defuse_original_features"),
                ]
            if add_output_shape != final_output_shape:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} with element-wise addition requires the "
                    f"output_shape of conv and of the add to be equal\n"
                    f"add_output_shape='{add_output_shape}', "
                    f"conv_output_shape={final_output_shape}",
                )
        elif add_output_shape != second_conv_output_shape:
            raise UnsupportedModelError(
                f"{self.full_name_msg} with element-wise addition requires the output_shape of "
                f"conv and of the add to be equal\nadd_output_shape='{add_output_shape}', "
                f"conv_output_shape={second_conv_output_shape} ",
            )

        return second_conv_output_shape

    def _calc_dw_conv_output_shape(self, input_shape, get_spatial_defused_output_shape=True):
        output_h, output_w = input_to_output_height_width(
            input_shape,
            self._kernel_shape,
            self._strides,
            self._padding,
            self._dilations,
        )
        # in depthwise conv layers, kernel shape is [h, w, input_channels, depth_multiplier],
        # so output shape is calculated from kernel shape by multiplying
        defuse_types = []
        if self.defuse_params and self.defuse_params.get("defuse_types"):
            defuse_types = self.defuse_params.get("defuse_types")

        if DefuseType.super_dw in defuse_types:
            output_f = self.kernel_shape[3]
        elif "defuse_features" in self.defuse_params and self.defuse_type not in [
            DefuseType.none,
            DefuseType.compute_lanes,
            DefuseType.spatial_w,
            DefuseType.input_features,
        ]:
            output_f = self.defuse_features
        else:
            output_f = self._kernel_shape[2] * self._kernel_shape[3]
        if (
            self.defuse_type is DefuseType.spatial_w
            and "defuse_input_width" in self.defuse_params
            and self.defuse_input_width != 0
        ) and get_spatial_defused_output_shape:
            output_w = self.defuse_input_width

        return [-1, output_h, output_w, output_f]

    def _calc_conv_output_shape(self, input_shape, get_spatial_defused_output_shape=True):
        output_f = self._calc_conv3d_output_f() if self.input_disparity > 1 else self._kernel_shape[3]
        shape = copy.deepcopy(input_shape)
        if "defuse_input_width" in self.defuse_params and self.defuse_input_width != 0:
            if get_spatial_defused_output_shape:
                shape[2] = self.defuse_input_width
                if (
                    self.defuse_params["defuse_types"] is not None
                    and DefuseType.double_precision_conv in self.defuse_params["defuse_types"]
                    and self.stride_width == 2
                ):
                    shape[2] = shape[2] * 2
                # In the case of spatial defuse conv, which is part of a deconv layer, only the last spatial-defused
                # conv should output width, which is input width+1.
                # To adjust for the calculation in input_to_output_height_width for all other spatial defused convs,
                # the input width is decreased by one.
                if self.output_shape[2] == self.defuse_input_width and self.padding == PaddingType.deconv:
                    shape[2] -= 1
        output_h, output_w = input_to_output_height_width(
            shape,
            self._kernel_shape,
            self._strides,
            self._padding,
            self._dilations,
        )
        return [-1, output_h, output_w, output_f]

    @property
    def layer_disparity(self):
        if not hasattr(self, "input_disparity"):
            raise UnsupportedModelError(f"{self.full_name_msg}: input disparity is required.")
        if self.input_disparity == 1:
            return self.input_disparity  # conv2d

        # compute conv3d output disparity
        input_f = self.input_features // self.input_disparity
        kernel_d = self._kernel_shape[2] // input_f
        strides_d = self._strides[3] // input_f
        dilation_d = 1
        if self._padding in [PaddingType.valid, TemporaryPaddingType.external_undecided]:
            pad_d = 0
            round_mode = np.floor
        elif self._padding == TemporaryPaddingType.conv3d:
            pad_d = sum(self._external_padding_value[4:])
            round_mode = np.floor
        else:
            pad_d = 2 * (dilation_d * (kernel_d - 1) + 1 - strides_d) // 2
            round_mode = np.ceil

        return int(round_mode((self.input_disparity + pad_d - dilation_d * (kernel_d - 1) - 1) / strides_d + 1))

    def _calc_conv3d_output_f(self):
        return self._kernel_shape[3] * self.layer_disparity

    def _calc_output_shape(self):
        # TODO: Split Conv2DLayer into subclasses. Each class will have it's own _calc_output_shape (SDK-8648)
        if len(self.input_shapes[0]) == 4:
            _, in_h, in_w, in_f = self.input_shapes[0]
        else:
            _, in_f = self.input_shapes[0]
            in_h, in_w = 1, 1

        if self._external_padding_value:
            top, bottom, left, right, front, back = self._external_padding_value
            in_h = in_h + top + bottom
            in_w = in_w + right + left
            in_f = in_f + front + back

        if (
            self.defuse_params["defuse_types"] is not None
            and DefuseType.double_precision_conv in self.defuse_params["defuse_types"]
            and self.stride_width == 2
        ):
            in_w = in_w * 2
        input_shape = [-1, in_h, in_w, in_f]

        if self._op in [LayerType.deconv, LayerType.base_deconv]:
            output_shape = self._calc_deconv_output_shape(input_shape)
        elif self.ew_add_enabled:
            output_shape = self._calc_conv_ew_add_output_shape()
        elif self._op in [LayerType.dw, LayerType.base_dw]:
            output_shape = self._calc_dw_conv_output_shape(input_shape)
        else:
            output_shape = self._calc_conv_output_shape(input_shape)

        # handle special cases of output shapes transformations
        if self._transpose_output_width_features:
            width = output_shape[3] // self._next_layer_groups_num
            features = output_shape[2] * self._next_layer_groups_num
            output_shape = [output_shape[0], output_shape[1], width, features]
        if self._spatial_flatten_output:
            output_shape = [output_shape[0], 1, output_shape[1] * output_shape[2], output_shape[3]]

        return output_shape

    @property
    def padding(self):
        return self._padding

    @padding.setter
    def padding(self, padding):
        self._padding = padding

    @property
    def next_layer_groups_num(self):
        return self._next_layer_groups_num

    @next_layer_groups_num.setter
    def next_layer_groups_num(self, next_layer_groups_num):
        self._next_layer_groups_num = next_layer_groups_num

    @property
    def external_padding_value(self):
        return self._external_padding_value

    @external_padding_value.setter
    def external_padding_value(self, external_padding_value):
        self._external_padding_value = external_padding_value

    @property
    def strides(self):
        return self._strides

    @strides.setter
    def strides(self, strides):
        self._strides = strides

    @property
    def stride_height(self):
        return self.strides[1]

    @property
    def stride_width(self):
        return self.strides[2]

    @property
    def is_dilated_s2b(self):
        return self._is_dilated_s2b

    @is_dilated_s2b.setter
    def is_dilated_s2b(self, is_dilated_s2b):
        self._is_dilated_s2b = is_dilated_s2b

    @property
    def dilations(self):
        return self._dilations

    @dilations.setter
    def dilations(self, dilations):
        self._dilations = dilations

    @property
    def kernel_height(self):
        return self.kernel_shape[0]

    @property
    def kernel_width(self):
        return self.kernel_shape[1]

    @property
    def kernel_disparity(self):
        return self.kernel_shape[2] // (self.input_features // self.input_disparity)

    @property
    def groups(self):
        return self._groups

    def get_axes_mask(self, type_of_layer=None):
        producer_axes = {
            LayerType.dw: [True, True, False, True],
            LayerType.conv: [True, True, True, False],
            LayerType.deconv: [True, True, True, False],
        }
        consumer_axes = {
            LayerType.dw: [True, True, False, True],
            LayerType.conv: [True, True, False, True],
            LayerType.deconv: [True, True, False, True],
        }
        if type_of_layer == LayerEquivType.producer:
            return producer_axes[self.op]
        else:
            return consumer_axes[self.op]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if hn["type"] == LayerType.base_conv.value:
            layer.op = LayerType.base_conv
        elif hn["type"] == LayerType.base_dw.value:
            layer.op = LayerType.base_dw
        elif hn["type"] == LayerType.dw.value:
            layer.op = LayerType.dw
            if "dynamic_weights" in hn["params"]:
                layer.dynamic_weights = hn["params"].get("dynamic_weights")
            if layer.dynamic_weights and (
                "compilation_params" not in hn or "no_contexts" not in hn["compilation_params"]
            ):
                layer._no_cotexts = SubclustersNoContextsPolicy.enabled
        elif hn["type"] == LayerType.base_deconv.value:
            layer.op = LayerType.base_deconv

        layer.dilations = hn["params"].get("dilations", [1, 1, 1, 1])
        layer.groups = hn["params"].get("groups", 1)
        layer.group_sizes = hn["params"].get("group_sizes", [1] * layer.groups)

        if hn["type"] in [LayerType.conv.value, LayerType.deconv.value]:
            group_sizes_sum = sum(layer.group_sizes)
            if layer.kernel_shape[2] % group_sizes_sum != 0 or layer.kernel_shape[3] % group_sizes_sum != 0:
                raise UnsupportedModelError(
                    f"Input features and output features must be a multiply of groups for {layer.full_name_msg}",
                )

        if layer.groups != len(layer.group_sizes):
            raise UnsupportedModelError(
                f"{layer.full_name_msg} group sizes length {layer.group_sizes} must be equal to groups {layer.groups}",
            )

        layer.padding = PaddingTypes[hn["params"]["padding"]]
        layer.strides = hn["params"]["strides"]
        layer.input_disparity = hn["params"].get("input_disparity", hn["params"].get("disparity", 1))

        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.input_disparity = pb.input_disparity
        layer.next_layer_groups_num = pb.next_layer_groups_num
        if pb.type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_CONV:
            layer.op = LayerType.base_conv
        elif pb.type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_DW:
            layer.op = LayerType.base_dw
        elif pb.type == pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_DECONV:
            layer.op = LayerType.base_deconv

        # Handling an edge case where fused layer is a stand-alone ew-add layer
        if pb.type != pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_EW_ADD:
            input_features = (
                layer.defuse_params.get("defuse_input_features")
                if layer.is_defused() and layer.defuse_params.get("defuse_input_features") > 0
                else pb.input_shapes[0].features
            )
            layer.groups = (
                layer.defuse_params.get("defuse_groups")
                if layer.is_defused() and layer.defuse_params.get("defuse_groups") > 0
                else max(1, pb.groups)
            )
            layer.kernel_shape = [
                pb.kernel_shape.height,
                pb.kernel_shape.width,
                input_features * pb.kernel_disparity // pb.input_disparity,
                pb.kernel_shape.features,
            ]

            layer.strides = [1, pb.strides.height, pb.strides.width, pb.strides.features]
            layer.padding = pb_wrapper.PADDING_PB_TO_TYPE[pb.padding]
            if pb.group_sizes:
                layer.group_sizes = pb.group_sizes

            if pb.dilations.HasField("height") and pb.dilations.HasField("width"):
                layer.dilations = [1, pb.dilations.height, pb.dilations.width, 1]
            else:
                layer.dilations = [1, 1, 1, 1]

        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.op = old_layer.op
        if old_layer.op not in [
            LayerType.base_ew_add,
            LayerType.base_ew_sub,
            LayerType.ew_add,
            LayerType.ew_sub,
            LayerType.rnn,
            LayerType.lstm,
        ]:
            layer.groups = old_layer.groups
            layer.kernel_shape = old_layer.kernel_shape.copy()
            layer.strides = old_layer.strides.copy()
            layer.dilations = old_layer.dilations
            layer.is_dilated_s2b = old_layer.is_dilated_s2b
            layer.padding = old_layer.padding
            layer.external_padding_value = old_layer.external_padding_value
            layer.input_disparity = old_layer.input_disparity
            layer.group_sizes = old_layer.group_sizes
            layer.padding_const_value = old_layer.padding_const_value
        return layer

    @property
    def input_features(self):
        return self._get_shape_single_dim(
            self._input_shapes,
            -1,
            validate=(not (self.ew_add_enabled or self.dynamic_weights)),
        )

    @property
    def input_width(self):
        return self._get_shape_single_dim(
            self._input_shapes,
            2,
            validate=(not (self.ew_add_enabled or self.dynamic_weights)),
        )

    @property
    def input_height(self):
        return self._get_shape_single_dim(self._input_shapes, 1, validate=(not self.dynamic_weights))

    @property
    def input_disparity(self):
        return self._input_disparity

    @input_disparity.setter
    def input_disparity(self, input_disparity):
        self._input_disparity = input_disparity

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate=validate)
        if self.input_disparity == 1 and self.kernel_shape and self.groups == 1:
            input_features = (
                self.defuse_params.get("defuse_input_features")
                if self.is_defused() and self.defuse_params.get("defuse_input_features") > 0
                else self.input_features
            )
            calculated_groups = int(input_features / self._kernel_shape[2])

            # We don't validate that the feature dimensions are equal in case of conv with ew_add.
            # In that case, there may be two different input_features (one for each input).
            # The input features are validated in _calc_conv_ew_add_output_shape.
            if input_features != self.kernel_shape[2] * calculated_groups:
                raise UnsupportedModelError(
                    f"Invalid kernel shape for {self.full_name_msg}.\nEither the input shape "
                    f"doesn't match the kernel shape, or the calculated groups number doesn't "
                    f"match the expected ratio between kernel shape and input shape.\n"
                    f"Kernel features: {self.kernel_shape[2]} Input features: {input_features} "
                    f"Groups: {calculated_groups}",
                )

            self.groups = calculated_groups

    def move_group_conv_params(self, old_nodes):
        fused_conv_layers = [x for x in old_nodes if x.op in [LayerType.base_conv, LayerType.base_deconv]]
        fused_bias_add_layers = [x for x in old_nodes if x.op == LayerType.bias_add]
        handled_bias_adds = []
        group_kernels = []
        group_biases = []
        for layer in fused_conv_layers:
            self._update_original_names(layer.original_names)
            group_kernels.append(layer.kernel)

            if fused_bias_add_layers:
                bias = np.zeros(layer.output_shape[-1])
                output_name = layer.outputs[0]
                for bias_add_layer in fused_bias_add_layers:
                    self._update_original_names(bias_add_layer.original_names)
                    if output_name == bias_add_layer.name:
                        bias = bias_add_layer.bias
                        handled_bias_adds.append(bias_add_layer)
                        break
                group_biases.append(bias)
            elif np.any(layer.bias):
                group_biases.append(layer.bias)

        self.kernel = np.concatenate(group_kernels, axis=3)
        self.bias = np.concatenate(group_biases, axis=0) if group_biases else np.zeros(shape=self.output_features)

        fused_other_nodes = sorted([x for x in old_nodes if x not in fused_conv_layers and x not in handled_bias_adds])
        for old_node in fused_other_nodes:
            self.move_params(old_node)

    def reshape_kernel_conv_octxoct(self):
        kh, kw, kic, kof = self._kernel.shape
        splits = np.split(self._kernel, kic, axis=2)
        splits_stack = []
        for split in splits:
            split_slice = np.reshape(split, (1, 1, kw * kh, kof))
            splits_stack.append(split_slice)

        self._kernel = np.concatenate(splits_stack, axis=2)

    @property
    def bias(self):
        if self._bias is not None:
            return self._bias
        if self.transpose_output_width_features:
            return np.zeros(self.output_shape[2] // self.layer_disparity)
        return np.zeros(self.output_shape[-1] // self.layer_disparity)

    @bias.setter
    def bias(self, bias):
        self._bias = bias

    def is_zippable(self, other):
        """Allow zipping two conv layers as long as they share the same parameters except the output features"""
        if not (self.groups == other.groups == 1):
            return False
        if self.kernel_height != other.kernel_height:
            return False
        if self.kernel_width != other.kernel_width:
            return False
        if self.strides != other.strides:
            return False
        if self.dilations != other.dilations:
            return False
        if self.padding != other.padding:
            return False
        if self.padding_const_value != other.padding_const_value:
            return False
        if not (self.layer_disparity == other.layer_disparity == 1):
            return False
        return super().is_zippable(other)

    @property
    def padding_const_value(self):
        return self._padding_const_value

    @padding_const_value.setter
    def padding_const_value(self, padding_const_value):
        self._padding_const_value = padding_const_value
