import copy
import math

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import HailoNNException, UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import (
    DefuseType,
    FormatConversionType,
    LayerType,
    PaddingType,
)
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification

LCU_CONVERSIONS = {
    FormatConversionType.transpose_width_features,
    FormatConversionType.transpose_matmul,
    FormatConversionType.hxf_to_w_transposed,
    FormatConversionType.f_to_hxw_transposed,
    FormatConversionType.spatial_reshape,
    FormatConversionType.reshape_1xw0_to_hxw,
    FormatConversionType.transpose_height_width,
    FormatConversionType.reshape_height_features,
    FormatConversionType.reshape_post_ew_mult,
}


class FormatConversionLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.format_conversion
        self._external_output_height = None
        self._external_output_width = None
        self._conversion_type = FormatConversionType.tf_rgb_to_hailo_rgb
        self._groups = 1
        self._kernel_shape = None
        self._strides = None
        self._dilations = None
        self._padding = None
        self._block_sizes = None
        self._spatial_reshape_sizes = None
        self._width_slice = None
        self._height_slice = None
        self._attention_params = None
        self._input_windows = None
        self._output_windows = None

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        conversion_type,
        output_shapes=None,
        groups=1,
        spatial_reshape_sizes=None,
        attention_params=None,
        input_windows=None,
        output_windows=None,
    ):
        external_output_height = None
        external_output_width = None
        if conversion_type == FormatConversionType.flat_to_frames and output_shapes:
            output_shape = output_shapes[0] if isinstance(output_shapes[0], list) else output_shapes
            if output_shape[1] != -1 and output_shape[2] != -1 and output_shape[3] == -1:
                external_output_height = output_shapes[1]
                external_output_width = output_shapes[2]
                output_shapes = None

        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.external_output_height = external_output_height
        layer.external_output_width = external_output_width
        layer.conversion_type = conversion_type
        layer.groups = groups
        layer.spatial_reshape_sizes = spatial_reshape_sizes
        layer.attention_params = attention_params
        layer.input_windows = input_windows
        layer.output_windows = output_windows

        if conversion_type in LCU_CONVERSIONS:
            layer.set_compilation_params(hw_layer_type_list=["lcu"])

        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_FORMAT_CONVERSION
        node.conversion_type = pb_wrapper.CONVERSION_TYPE_TO_PB[self._conversion_type]
        node.groups = self._groups
        node.spatial_reshape_sizes.extend(self._spatial_reshape_sizes)
        if self._width_slice:
            node.width_slice.start, node.width_slice.end, node.width_slice.stride = (
                self._width_slice[0],
                self._width_slice[1],
                1,
            )

        if self._height_slice:
            node.height_slice.start, node.height_slice.end, node.height_slice.stride = (
                self._height_slice[0],
                self._height_slice[1],
                1,
            )

        if self.input_windows is not None:
            node.input_windows.extend(self.input_windows)

        if self.output_windows is not None:
            node.output_windows.extend(self.output_windows)

        return node

    @property
    def conversion_type(self):
        return self._conversion_type

    @conversion_type.setter
    def conversion_type(self, conversion_type):
        self._conversion_type = conversion_type

    @property
    def input_windows(self):
        return self._input_windows

    @input_windows.setter
    def input_windows(self, input_windows):
        self._input_windows = input_windows

    @property
    def output_windows(self):
        return self._output_windows

    @output_windows.setter
    def output_windows(self, output_windows):
        self._output_windows = output_windows

    @property
    def input_height(self):
        if self._conversion_type == FormatConversionType.flat_to_frames:
            return 1
        return super().input_height

    @property
    def input_width(self):
        if self._conversion_type == FormatConversionType.flat_to_frames:
            return 1
        return super().input_width

    @property
    def height_slice(self):
        return self._height_slice

    @height_slice.setter
    def height_slice(self, height_slice):
        self._height_slice = height_slice

    @property
    def width_slice(self):
        return self._width_slice

    @width_slice.setter
    def width_slice(self, width_slice):
        self._width_slice = width_slice

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.conversion_type = pb_wrapper.CONVERSION_PB_TO_TYPE[pb.conversion_type]
        layer.groups = max(1, pb.groups)
        if layer.conversion_type in [
            FormatConversionType.hxf_to_w_transposed,
            FormatConversionType.f_to_hxw_transposed,
        ]:
            layer.block_sizes = pb.block_sizes[:]

        if layer.conversion_type in (
            FormatConversionType.hxf_to_w_transposed,
            FormatConversionType.transpose_width_features,
        ):
            layer.kernel_shape = [
                pb.kernel_shape.height,
                pb.kernel_shape.width,
                pb.input_shapes[0].width,
                pb.kernel_shape.features,
            ]

            layer.strides = [1, pb.strides.height, pb.strides.width, 1]
            layer.dilations = [1, pb.dilations.height, pb.dilations.width, 1]
        layer.padding = pb_wrapper.PADDING_PB_TO_TYPE[pb.padding]
        layer.spatial_reshape_sizes = pb.spatial_reshape_sizes[:]
        layer.width_slice = [pb.width_slice.start, pb.width_slice.end, pb.width_slice.stride]
        layer.height_slice = [pb.height_slice.start, pb.height_slice.end, pb.height_slice.stride]
        layer.input_windows = pb.input_windows[:]
        layer.output_windows = pb.output_windows[:]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.conversion_type = old_layer.conversion_type
        layer.groups = old_layer.groups
        layer.spatial_reshape_sizes = old_layer.spatial_reshape_sizes
        layer.input_windows = old_layer.input_windows
        layer.output_windows = old_layer.output_windows

        if layer.conversion_type in [
            FormatConversionType.hxf_to_w_transposed,
            FormatConversionType.f_to_hxw_transposed,
        ]:
            layer.block_sizes = old_layer.block_sizes.copy() if old_layer.block_sizes else None

        if layer.conversion_type in [
            FormatConversionType.hxf_to_w_transposed,
            FormatConversionType.transpose_width_features,
        ]:
            layer.kernel_shape = old_layer.kernel_shape.copy() if old_layer.kernel_shape else None
            layer.strides = old_layer.strides.copy() if old_layer.strides else None
            layer.dilations = old_layer.dilations.copy() if old_layer.dilations else None
            layer.padding = old_layer.padding if old_layer.padding else None

        return layer

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "params" in hn:
            if "conversion_type" in hn["params"]:
                if hn["params"]["conversion_type"] not in [
                    conversion_type.value for conversion_type in iter(FormatConversionType)
                ]:
                    raise HailoNNException(f"None-existent conversion_type at {layer.full_name_msg}")
                layer.conversion_type = FormatConversionType(hn["params"]["conversion_type"])
                if layer.conversion_type in LCU_CONVERSIONS:
                    layer.set_compilation_params(hw_layer_type_list=["lcu"])
                if layer.conversion_type in [
                    FormatConversionType.spatial_flatten,
                    FormatConversionType.spatial_expand,
                ]:
                    layer.conversion_type = FormatConversionType.spatial_reshape

            if "block_sizes" in hn["params"] and layer.conversion_type in [
                FormatConversionType.hxf_to_w_transposed,
                FormatConversionType.f_to_hxw_transposed,
            ]:
                layer.block_sizes = hn["params"]["block_sizes"]

            if "groups" in hn["params"]:
                layer.groups = hn["params"]["groups"]
                if layer.input_shape[-1] % layer.groups != 0:
                    raise UnsupportedModelError(f"{layer.full_name_msg} input features must be a multiply of groups")
            else:
                layer.groups = 1

            if "expand_spatial_sizes" in hn["params"]:  # for backward compatibility
                layer.spatial_reshape_sizes = hn["params"]["expand_spatial_sizes"]
            elif "spatial_reshape_sizes" in hn["params"]:
                layer.spatial_reshape_sizes = hn["params"]["spatial_reshape_sizes"]
            if layer.conversion_type == FormatConversionType.spatial_reshape and not layer.spatial_reshape_sizes:
                # for backward compatibility
                if len(layer.input_shape) == 4:
                    layer.spatial_reshape_sizes = [1, layer.input_shape[1] * layer.input_shape[2]]
                elif len(layer.input_shape) == 2:
                    layer.spatial_reshape_sizes = [1, 1]

            if "width_slice" in hn["params"] and any(hn["params"]["width_slice"]):
                layer.width_slice = hn["params"]["width_slice"]
                if len(layer.width_slice) == 2:
                    layer.width_slice = [*layer.width_slice, 1]

            if "height_slice" in hn["params"] and any(hn["params"]["height_slice"]):
                layer.height_slice = hn["params"]["height_slice"]
                if len(layer.height_slice) == 2:
                    layer.height_slice = [*layer.height_slice, 1]

            layer.input_windows = hn["params"].get("input_windows", [1, 1, 1])
            layer.output_windows = hn["params"].get("output_windows", [1, 1, 1])
            if (
                layer.conversion_type
                in [FormatConversionType.reshape_height_features, FormatConversionType.spatial_reshape]
                and layer.spatial_reshape_sizes
            ):
                if (
                    layer.conversion_type == FormatConversionType.spatial_reshape
                    and len(layer.spatial_reshape_sizes) == 2
                ):
                    layer.spatial_reshape_sizes.append(layer.input_features)
                elif (
                    layer.conversion_type == FormatConversionType.reshape_height_features
                    and len(layer.spatial_reshape_sizes) == 2
                ):
                    layer.spatial_reshape_sizes = [
                        layer.spatial_reshape_sizes[0],
                        layer.input_width,
                        layer.spatial_reshape_sizes[1],
                    ]
                elif len(layer.spatial_reshape_sizes) != 3:
                    raise UnsupportedModelError(
                        f"{layer.full_name_msg} spatial_reshape_sizes must have 2 or 3 values",
                    )

                window_size = [
                    layer.input_shape[1] // layer.input_windows[0],
                    layer.input_shape[2] // layer.input_windows[1],
                    layer.input_shape[3] // layer.input_windows[2],
                ]

                if layer.input_windows[1] != 1 or layer.input_windows[2] != 1:
                    if math.prod(layer.spatial_reshape_sizes) != math.prod(window_size):
                        raise UnsupportedModelError(
                            f"{layer.full_name_msg} Number of pixels in spatial_reshape_sizes must be equal to the number of pixels in the input window",
                        )

                    for i in range(len(window_size)):
                        if (
                            window_size[i] % layer.spatial_reshape_sizes[i] != 0
                            and layer.spatial_reshape_sizes[i] % window_size[i] != 0
                        ):
                            raise UnsupportedModelError(
                                f"{layer.full_name_msg} reshape_height_features requires spatial_reshape_sizes to be a multiple of the input window size on same axis",
                            )

        return layer

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate)
        if self.output_shapes and self._conversion_type == FormatConversionType.features_to_width_features:
            self._validate_features_to_width_features_conversion_shape()

    def _validate_features_to_width_features_conversion_shape(self):
        if not self.output_features * self.output_width == self.input_features:
            raise UnsupportedModelError(
                f"{self.full_name_msg} features_to_width_features not support "
                f"output_features * output_width != input_features ",
            )
        if not self.input_width == 1:
            raise UnsupportedModelError(f"{self.full_name_msg} features_to_width_features not support input_width != 1")

    def _validate_flat_to_frames_conversion_shape(self):
        if len(self.output_shapes) > 0:
            if self.output_features * self.output_width != self.input_features:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} dense to frames conversion not supported - "
                    f"output_features * output_width != input_features ",
                )
            if self.output_height != 1:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} dense to frames conversion not supported - output_height must be 1",
                )

    def update_output_shapes(self, **kwargs):
        if self.is_nv_converter():
            self.output_shapes = self._calc_output_shape()
        elif self._conversion_type not in [
            FormatConversionType.general_reshape,
            FormatConversionType.split_windowed_attention,
            FormatConversionType.merge_windowed_attention,
            FormatConversionType.groups_to_spatial_flatten,
            FormatConversionType.spatial_flatten_to_groups,
            FormatConversionType.partial_groups_to_spatial_flatten,
        ]:
            super().update_output_shapes()

    def _calc_output_shape(self):
        output_shape = copy.deepcopy(self.input_shape)
        if self._conversion_type in [
            FormatConversionType.mipi_bayer_rggb_to_hailo_rgb,
            FormatConversionType.mipi_bayer_bggr_to_hailo_rgb,
            FormatConversionType.mipi_bayer_gbrg_to_hailo_rgb,
            FormatConversionType.mipi_bayer_grbg_to_hailo_rgb,
        ]:
            output_shape[3] *= 3
        elif self._conversion_type == FormatConversionType.twelve_to_eight_bit:
            output_shape[2] = math.floor(output_shape[2] / 2)
        elif self._conversion_type == FormatConversionType.twelve_to_sixteen_bit:
            output_shape[2] = (output_shape[2] * 2) // 3
        elif self._conversion_type == FormatConversionType.sixteen_to_twelve_bit:
            output_shape[2] = (output_shape[2] * 3) // 2
        elif self._conversion_type == FormatConversionType.reshape_post_ew_mult:
            output_shape[2] = output_shape[2] // 2  # in b0 should be 4
        elif self._conversion_type == FormatConversionType.features_to_width_features:
            self._validate_features_to_width_features_conversion_shape()
            output_shape[2:] = self.output_shape[2:]
        elif self._conversion_type == FormatConversionType.flat_to_frames:
            if self._external_output_height and self._external_output_height and not self.output_shape:
                out_features = int(self.input_features // (self._external_output_height * self._external_output_height))
                self.output_shapes = [
                    [-1, int(self._external_output_height), int(self._external_output_height), out_features],
                ] * self._output_copies
            self._validate_flat_to_frames_conversion_shape()
            output_shape[1] = 1
            if self.output_shape:
                output_shape[2:] = self.output_shape[2:]
        elif self._conversion_type == FormatConversionType.yuy2_to_hailo_yuv:
            output_shape[3] = 3
        elif self._conversion_type in [
            FormatConversionType.transpose_width_features,
            FormatConversionType.transpose_matmul,
        ]:
            input_features = self.input_shape[3]
            defuse_width = self._defuse_params.get("defuse_input_width")
            output_f = defuse_width if defuse_width > 0 else self.input_shape[2]
            output_f = output_f * self._groups
            if self.defuse_type is DefuseType.super_dw and "defuse_original_features" in self._defuse_params:
                # this condition should check defuse_types with INPUT_FEATURES after
                # https://hailotech.atlassian.net/browse/SDK-27066 is ready
                input_features = self._defuse_params.get("defuse_original_features")

            output_w = int(input_features / self._groups)
            if self._kernel_shape is None or self._kernel_shape[0] == 0:
                output_shape[2] = output_w
                output_shape[3] = output_f
            else:
                dilation_kernel_h = self._kernel_shape[0] + (self._kernel_shape[0] - 1) * (self._dilations[1] - 1)
                dilation_kernel_w = self._kernel_shape[1] + (self._kernel_shape[1] - 1) * (self._dilations[2] - 1)
                if self._padding != PaddingType.valid:
                    if self.input_shape[1] % self._strides[1] == 0:
                        pad_total_h = max(dilation_kernel_h - self._strides[1], 0)
                    else:
                        pad_total_h = max(dilation_kernel_h - (self.input_shape[1] % self._strides[1]), 0)
                    if self.input_shape[2] % self._strides[2] == 0:
                        pad_total_f = max(dilation_kernel_w - self._strides[2], 0)
                    else:
                        pad_total_f = max(dilation_kernel_w - (self.input_shape[2] % self._strides[2]), 0)
                else:
                    pad_total_h = 0
                    pad_total_f = 0

                output_shape[1] = self.input_shape[1] + pad_total_h
                output_shape[2] = output_w
                output_shape[3] = output_f + pad_total_f
        elif self._conversion_type == FormatConversionType.hxf_to_w_transposed:
            if self.defuse_type == DefuseType.spatial_reshape:
                output_shape = (
                    self.output_shape
                )  # HACK: We want to take the original shape, since we might have padding in the last input_row
            elif self._kernel_shape and self._strides:
                if self._padding == PaddingType.valid:
                    dilation_kernel_h = self._kernel_shape[0] + (self._kernel_shape[0] - 1) * (self._dilations[1] - 1)
                    dilation_kernel_w = self._kernel_shape[1] + (self._kernel_shape[1] - 1) * (self._dilations[2] - 1)
                    h = int((self.input_height - dilation_kernel_h) / self._strides[1]) + 1
                    w = int((self.input_features - dilation_kernel_w) / self._strides[2]) + 1
                else:
                    h = math.ceil(self.input_height / self._strides[1])
                    w = math.ceil(self.input_features / self._strides[2])
                f = int(self.input_width * self._kernel_shape[0] * self._kernel_shape[1])
                output_shape[1:] = [h, w, f]
            else:
                output_shape[2] = self.input_shape[3]
                output_shape[3] = self.input_shape[2]
        elif self._conversion_type == FormatConversionType.f_to_hxw_transposed:
            block_sizes = [1, 1] if not self._block_sizes else self._block_sizes

            h = int(self.input_height * block_sizes[0])
            w = int(self.input_features / (block_sizes[0] * block_sizes[1]))
            f = int(self.input_width * block_sizes[1])
            output_shape[1:] = [h, w, f]
            if self._height_slice and self._height_slice[:-1] != [0, 0]:
                output_shape[1] = self._height_slice[1] - self._height_slice[0]

            if self._width_slice and self._width_slice[:-1] != [0, 0]:
                output_shape[3] = self._width_slice[1] - self._width_slice[0]
        elif self.is_nv_converter():
            if self.input_features == 3:
                output_shape[1] = self.input_height * 2
            else:
                output_shape = [self.input_shape[0], self.input_shape[1], self.input_shape[2] // 2, 2]

        elif self._conversion_type == FormatConversionType.tf_rgbx_to_hailo_rgb:
            output_shape[3] = self.input_shape[3] - 1
        elif self._conversion_type == FormatConversionType.spatial_reshape:
            output_shape = [-1, self._spatial_reshape_sizes[0], self._spatial_reshape_sizes[1], self.input_features]
            if self.output_windows:
                output_shape[1] *= self.output_windows[0]
                output_shape[2] *= self.output_windows[1]
        elif self._conversion_type == FormatConversionType.transpose_height_width:
            output_shape[1] = self.input_width
            output_shape[2] = self.input_height
        elif self._conversion_type == FormatConversionType.reshape_1xw0_to_hxw:
            # currently support only width = 8, will change after https://hailotech.atlassian.net/browse/SDK-39391 is solved.
            width_f = 8
            output_shape = [-1, math.ceil(self.input_width / width_f), width_f, self.input_features]
        elif self._conversion_type == FormatConversionType.reshape_height_features:
            output_windows = self.output_windows if self.output_windows else [1, 1, 1]
            assert output_windows[1] == 1
            output_shape[1] = self._spatial_reshape_sizes[0] * output_windows[0]
            output_shape[3] = self._spatial_reshape_sizes[2] * output_windows[2]
        return output_shape

    def _get_output_shape(self, validate=False, layer_name=None, layer_index=None):
        if self.is_nv_converter():
            if layer_name is None:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} successor name is missing, output shape is ambiguous",
                )
            if len(self.output_indices) > 0:
                if layer_index is None:
                    raise UnsupportedModelError(
                        f"{self.full_name_msg} successor index is missing, output shape is ambiguous",
                    )
                return self.output_shapes[self.output_indices.index(layer_index)]
            return self.output_shapes[self.outputs.index(layer_name)]
        return super()._get_output_shape()

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["conversion_type"] = self._conversion_type.value
        result["params"]["groups"] = self._groups

        if self._spatial_reshape_sizes is not None:
            result["params"]["spatial_reshape_sizes"] = self._spatial_reshape_sizes
        if self._width_slice:
            result["params"]["width_slice"] = [int(x) for x in self._width_slice]
        if self._height_slice:
            result["params"]["height_slice"] = [int(x) for x in self._height_slice]
        if self.input_windows is not None:
            result["params"]["input_windows"] = self.input_windows
        if self.output_windows is not None:
            result["params"]["output_windows"] = self.output_windows

        return result

    def is_from_dense(self, validate=True):
        # SDK-9087- remove after in_format_type and out_format_type implemented
        return len(self.input_shapes[0]) == 2

    @property
    def input_features(self):
        if self.is_from_dense():
            return self._get_shape_single_dim(self.input_shapes, 1)
        return self._get_shape_single_dim(self.input_shapes, -1)

    @property
    def output_features(self):
        if self.is_nv_converter():
            return sum([shape[3] for shape in self.output_shapes])
        else:
            return self._get_shape_single_dim(self.output_shapes, -1)

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @property
    def spatial_reshape_sizes(self):
        return self._spatial_reshape_sizes

    @spatial_reshape_sizes.setter
    def spatial_reshape_sizes(self, spatial_reshape_sizes):
        self._spatial_reshape_sizes = spatial_reshape_sizes

    @property
    def attention_params(self):
        return self._attention_params

    @attention_params.setter
    def attention_params(self, attention_params):
        self._attention_params = attention_params

    @property
    def external_output_height(self):
        return self._external_output_height

    @external_output_height.setter
    def external_output_height(self, external_output_height):
        self._external_output_height = external_output_height

    @property
    def external_output_width(self):
        return self._external_output_width

    @external_output_width.setter
    def external_output_width(self, external_output_width):
        self._external_output_width = external_output_width

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def sort_outputs(self):
        return lambda layer1, layer2: 1 if self.outputs.index(layer1.name) > self.outputs.index(layer2.name) else -1

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    def is_nv_converter(self):
        return self._conversion_type in [
            FormatConversionType.nv12_to_hailo_yuv,
            FormatConversionType.nv21_to_hailo_yuv,
            FormatConversionType.i420_to_hailo_yuv,
        ]

    @property
    def kernel_shape(self):
        return self._kernel_shape

    @kernel_shape.setter
    def kernel_shape(self, kernel_shape):
        self._kernel_shape = list(kernel_shape) if kernel_shape else kernel_shape

    @property
    def strides(self):
        return self._strides

    @strides.setter
    def strides(self, strides):
        self._strides = strides

    @property
    def dilations(self):
        return self._dilations

    @dilations.setter
    def dilations(self, dilations):
        self._dilations = dilations

    @property
    def padding(self):
        return self._padding

    @padding.setter
    def padding(self, padding):
        self._padding = padding

    @property
    def block_sizes(self):
        return self._block_sizes

    @block_sizes.setter
    def block_sizes(self, block_sizes):
        self._block_sizes = block_sizes

    @property
    def is_flatten_reshape(self):
        return self._conversion_type == FormatConversionType.spatial_reshape and self._spatial_reshape_sizes[0] == 1

    @property
    def is_expand_reshape(self):
        return self._conversion_type == FormatConversionType.spatial_reshape and self._spatial_reshape_sizes[0] != 1
