#!/usr/bin/env python
import copy

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    BiasModes,
    PrecisionMode,
    PrecisionModes,
)
from hailo_sdk_common.hailo_nn.exceptions import HailoNNLayerParamsException
from hailo_sdk_common.hailo_nn.hn_definitions import (
    BalanceOutputMultisplitPolicies,
    BalanceOutputMultisplitPolicy,
    CompressL3WeightsPolicies,
    CompressL3WeightsPolicy,
    DefuseType,
    EnableEwAddOptimizationPolicies,
    EnableEwAddOptimizationPolicy,
    HWLayerType,
    LayerBuffersFormat,
    ParallelActivationPolicies,
    ParallelActivationPolicy,
    ResizeBilinearStreamingPolicies,
    ResizeBilinearStreamingPolicy,
    ResourcesAllocationStrategy,
    Subclusters16x4Policies,
    Subclusters16x4Policy,
    SubclustersNoContextsPolicies,
    SubclustersNoContextsPolicy,
    TwoLineBufferModePolicies,
    TwoLineBufferModePolicy,
    UseL2WeightsPolicies,
    UseL2WeightsPolicy,
)
from hailo_sdk_common.hailo_nn.tools_params import (
    AutoInt,
    AutoVariablePolicy,
    ToolsParams,
    allowed_str_for_bool,
    allowed_str_for_enumerate,
    convert_str_to_bool,
    convert_str_to_enumerate,
    convert_to_auto_int,
    convert_to_float,
    convert_to_int,
)
from hailo_sdk_common.logger.logger import default_logger


class HNLayersParams(ToolsParams):
    DEFAULT_PARAMS = {}

    def __init__(self):
        self._layer_name = None
        super().__init__()

    def change_default_param(self, param, param_value, should_set=False):
        if param not in self._default_params:
            return
        self._default_params[param] = param_value
        if should_set:
            self._params[param] = param_value

    def __deepcopy__(self, memo):
        new_layer_params = type(self)()
        self._logger.debug("Copying HNLayerParams")
        for attr, value in vars(self).items():
            if attr == "_logger":
                new_layer_params._logger = default_logger()
            else:
                setattr(new_layer_params, attr, copy.deepcopy(value))
        self._logger.debug("HNLayerParams was copied.")
        return new_layer_params

    @property
    def non_default_params(self):
        params = {}
        for k, v in self._params.items():
            if self._default_params[k] != v:
                params[k] = v
        return params

    def to_hn(self, hn_item, should_get_default_params=False):
        raise NotImplementedError

    def set_layer_name(self, layer_name):
        self._layer_name = layer_name


class CompilationParams(HNLayersParams):
    DEFAULT_PARAMS = {
        "mixed_mem": UseL2WeightsPolicy.allowed,
        "resources_allocation_strategy": ResourcesAllocationStrategy.automatic_scs_selection,
        "use_16x4_sc": Subclusters16x4Policy.allowed,
        "no_contexts": SubclustersNoContextsPolicy.allowed,
        "defuse_num_layers": 1,
        "defuse_spatial_w": False,
        "balance_output_multisplit": BalanceOutputMultisplitPolicy.allowed,
        "number_of_subclusters": 1,
        "buffer_in_l4": False,
        "enable_exhaustive_merge": False,
        "hw_layer_type_list": [],
        "number_of_apus": 1,
        "number_of_input_aligners": 1,
        "resize_bilinear_streaming": ResizeBilinearStreamingPolicy.allowed,
        "two_line_buffer_mode": TwoLineBufferModePolicy.allowed,
        "fps": 0,
        "layer_output_buffers_format": LayerBuffersFormat.automatic,
        "enable_ew_add_optimization": EnableEwAddOptimizationPolicy.allowed,
        "parallel_activation": ParallelActivationPolicy.allowed,
        "microcoder_without_halts": False,
        "prepost_haltless": False,
        "nms_burst_size": AutoInt("automatic"),
        "enable_mjitc": False,
        "compress_l3_weights": CompressL3WeightsPolicy.allowed,
        "enable_nested_defuse": False,
    }
    ALLOWED_VALUES = {
        "mixed_mem": allowed_str_for_enumerate(UseL2WeightsPolicy),
        "resources_allocation_strategy": allowed_str_for_enumerate(ResourcesAllocationStrategy),
        "use_16x4_sc": allowed_str_for_enumerate(Subclusters16x4Policy),
        "no_contexts": allowed_str_for_enumerate(SubclustersNoContextsPolicy),
        "defuse_spatial_w": allowed_str_for_bool(),
        "balance_output_multisplit": allowed_str_for_enumerate(BalanceOutputMultisplitPolicy),
        "buffer_in_l4": allowed_str_for_bool(),
        "enable_exhaustive_merge": allowed_str_for_bool(),
        "resize_bilinear_streaming": allowed_str_for_enumerate(ResizeBilinearStreamingPolicy),
        "two_line_buffer_mode": allowed_str_for_enumerate(TwoLineBufferModePolicy),
        "layer_output_buffers_format": allowed_str_for_enumerate(LayerBuffersFormat),
        "enable_ew_add_optimization": allowed_str_for_enumerate(EnableEwAddOptimizationPolicy),
        "parallel_activation": allowed_str_for_enumerate(ParallelActivationPolicy),
        "microcoder_without_halts": allowed_str_for_bool(),
        "prepost_haltless": allowed_str_for_bool(),
        "enable_mjitc": allowed_str_for_bool(),
        "compress_l3_weights": allowed_str_for_enumerate(CompressL3WeightsPolicy),
        "enable_nested_defuse": allowed_str_for_bool(),
    }
    VALUE_CONVERTERS = {
        "mixed_mem": convert_str_to_enumerate(UseL2WeightsPolicy),
        "resources_allocation_strategy": convert_str_to_enumerate(ResourcesAllocationStrategy),
        "use_16x4_sc": convert_str_to_enumerate(Subclusters16x4Policy),
        "no_contexts": convert_str_to_enumerate(SubclustersNoContextsPolicy),
        "defuse_num_layers": convert_to_int,
        "defuse_spatial_w": convert_str_to_bool,
        "balance_output_multisplit": convert_str_to_enumerate(BalanceOutputMultisplitPolicy),
        "number_of_subclusters": convert_to_int,
        "buffer_in_l4": convert_str_to_bool,
        "enable_exhaustive_merge": convert_str_to_bool,
        "hw_layer_type_list": lambda list_val: [HWLayerType[layer_type] for layer_type in list_val],
        "number_of_apus": convert_to_int,
        "number_of_input_aligners": convert_to_int,
        "resize_bilinear_streaming": convert_str_to_enumerate(ResizeBilinearStreamingPolicy),
        "two_line_buffer_mode": convert_str_to_enumerate(TwoLineBufferModePolicy),
        "fps": convert_to_float,
        "layer_output_buffers_format": convert_str_to_enumerate(LayerBuffersFormat),
        "enable_ew_add_optimization": convert_str_to_enumerate(EnableEwAddOptimizationPolicy),
        "parallel_activation": convert_str_to_enumerate(ParallelActivationPolicy),
        "microcoder_without_halts": convert_str_to_bool,
        "prepost_haltless": convert_str_to_bool,
        "nms_burst_size": convert_to_auto_int,
        "enable_mjitc": convert_str_to_bool,
        "compress_l3_weights": convert_str_to_enumerate(CompressL3WeightsPolicy),
        "enable_nested_defuse": convert_str_to_bool,
    }
    POLICY_TYPES = {
        "mixed_mem": UseL2WeightsPolicies,
        "use_16x4_sc": Subclusters16x4Policies,
        "no_contexts": SubclustersNoContextsPolicies,
        "resize_bilinear_streaming": ResizeBilinearStreamingPolicies,
        "two_line_buffer_mode": TwoLineBufferModePolicies,
        "balance_output_multisplit": BalanceOutputMultisplitPolicies,
        "enable_ew_add_optimization": EnableEwAddOptimizationPolicies,
        "parallel_activation": ParallelActivationPolicies,
        "compress_l3_weights": CompressL3WeightsPolicies,
    }
    DEPRECATED_PARAMS = ["number_of_lanes"]

    def set_default_params(self):
        params = copy.deepcopy(self.DEFAULT_PARAMS)

        # Exclude FPS, microcoder_without_halts from default parameters
        del params["fps"]
        self._params = params

    def convert_compilation_param_to_value(self, hn_item, param_key):
        if param_key in hn_item["compilation_params"]:
            hn_item["compilation_params"][param_key] = hn_item["compilation_params"][param_key].value

    def to_hn(self, hn_item, should_get_default_params=False):
        if should_get_default_params:
            hn_item["compilation_params"] = self._params
        elif len(self.non_default_params) != 0:
            hn_item["compilation_params"] = self.non_default_params
        else:
            return

        for policy_type in CompilationParams().POLICY_TYPES:
            self.convert_compilation_param_to_value(hn_item, policy_type)

        if "resources_allocation_strategy" in hn_item["compilation_params"]:
            hn_item["compilation_params"]["resources_allocation_strategy"] = hn_item["compilation_params"][
                "resources_allocation_strategy"
            ].value
        if "hw_layer_type_list" in hn_item["compilation_params"]:
            hn_item["compilation_params"]["hw_layer_type_list"] = [
                item.value for item in hn_item["compilation_params"]["hw_layer_type_list"]
            ]
        if "layer_output_buffers_format" in hn_item["compilation_params"]:
            hn_item["compilation_params"]["layer_output_buffers_format"] = hn_item["compilation_params"][
                "layer_output_buffers_format"
            ].value
        if "nms_burst_size" in hn_item["compilation_params"]:
            if hn_item["compilation_params"]["nms_burst_size"].policy() == AutoVariablePolicy.AUTOMATIC:
                del hn_item["compilation_params"]["nms_burst_size"]
            else:
                hn_item["compilation_params"]["nms_burst_size"] = hn_item["compilation_params"]["nms_burst_size"].val()

    def to_pb(self, pb, pb_wrapper):
        params = self.get()
        if not params:
            params = self.get_default_params()
        if "mixed_mem" in params:
            pb.allocation_params.use_l2_weights = pb_wrapper.L2_WEIGHTS_POLICY_TYPE_TO_PB[params["mixed_mem"]]
        if "hw_layer_type_list" in params:
            pb.allocation_params.hw_layer_type_list.extend(
                [pb_wrapper.HW_LAYER_TYPE_TO_PB[layer] for layer in params["hw_layer_type_list"]],
            )
        self.set_pb_field(
            pb.allocation_params,
            "resources_allocation_strategy",
            pb_wrapper.RESOURCES_ALLOCATION_STRATEGY_TYPE_TO_PB,
        )
        self.set_pb_field(pb.allocation_params, "use_16x4_sc", pb_wrapper.SUBCLUSTERS_16x4_POLICY_TYPE_TO_PB)
        self.set_pb_field(pb.allocation_params, "no_contexts", pb_wrapper.SUBCLUSTERS_NO_CONTEXTS_POLICY_TYPE_TO_PB)
        self.set_pb_field(pb.allocation_params, "defuse_num_layers")
        self.set_pb_field(pb.allocation_params, "defuse_spatial_w", pb_wrapper.STR_TO_BOOL_TO_PB)
        self.set_pb_field(
            pb.allocation_params,
            "balance_output_multisplit",
            pb_wrapper.BALANCE_OUTPUT_MULTISPLIT_POLICY_TYPE_TO_PB,
        )
        self.set_pb_field(pb.allocation_params, "buffer_in_l4", pb_wrapper.STR_TO_BOOL_TO_PB)
        self.set_pb_field(pb.allocation_params, "enable_exhaustive_merge", pb_wrapper.STR_TO_BOOL_TO_PB)
        self.set_pb_field(pb.allocation_params, "number_of_subclusters")
        self.set_pb_field(pb.allocation_params, "number_of_apus")
        self.set_pb_field(pb.allocation_params, "number_of_input_aligners")
        self.set_pb_field(
            pb.allocation_params,
            "resize_bilinear_streaming",
            pb_wrapper.RESIZE_BILINEAR_STREAMING_POLICY_TYPE_TO_PB,
        )
        self.set_pb_field(
            pb.allocation_params,
            "two_line_buffer_mode",
            pb_wrapper.TWO_LINE_BUFFER_MODE_POLICY_TYPE_TO_PB,
        )
        self.set_pb_field(pb.allocation_params, "fps")
        self.set_pb_field(pb.allocation_params, "microcoder_without_halts")
        self.set_pb_field(pb.allocation_params, "prepost_haltless")
        self.set_pb_field(pb.allocation_params, "layer_output_buffers_format", pb_wrapper.BUFFERS_FORMAT_TYPE_TO_PB)
        self.set_pb_field(
            pb.allocation_params,
            "enable_ew_add_optimization",
            pb_wrapper.ENABLE_EW_ADD_OPTIMZATION_POLICY_TYPE_TO_PB,
        )
        self.set_pb_field(pb.allocation_params, "parallel_activation", pb_wrapper.PARALLEL_ACTIVATION_POLICY_TYPE_TO_PB)
        self.set_pb_field(pb.allocation_params, "compress_l3_weights", pb_wrapper.COMPRESS_L3_WEIGHTS_POLICY_TYPE_TO_PB)
        (self.set_auto_pb_field(pb.allocation_params, pb_wrapper, "nms_burst_size"),)
        self.set_pb_field(pb.allocation_params, "enable_mjitc")
        self.set_pb_field(pb.allocation_params, "enable_nested_defuse", pb_wrapper.STR_TO_BOOL_TO_PB)

    def from_pb(self, pb, pb_wrapper):
        pb_alloc = pb.allocation_params
        if pb_alloc.HasField("use_l2_weights"):
            self._params["mixed_mem"] = pb_wrapper.L2_WEIGHTS_POLICY_PB_TO_TYPE[pb_alloc.use_l2_weights]
        self.get_pb_field(
            pb_alloc,
            "resources_allocation_strategy",
            pb_wrapper.RESOURCES_ALLOCATION_STRATEGY_PB_TO_TYPE,
        )
        self.get_pb_field(pb_alloc, "use_16x4_sc", pb_wrapper.SUBCLUSTERS_16x4_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb_alloc, "no_contexts", pb_wrapper.SUBCLUSTERS_NO_CONTEXTS_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb_alloc, "defuse_num_layers")
        self.get_pb_field(pb_alloc, "defuse_spatial_w")
        self.get_pb_field(pb_alloc, "balance_output_multisplit", pb_wrapper.BALANCE_OUTPUT_MULTISPLIT_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb_alloc, "buffer_in_l4")
        self.get_pb_field(pb_alloc, "enable_exhaustive_merge")
        self.get_pb_field(pb_alloc, "enable_nested_defuse")
        self.get_pb_field(pb_alloc, "number_of_subclusters")
        self.get_pb_field(pb_alloc, "number_of_apus")
        self.get_pb_field(pb_alloc, "number_of_input_aligners")
        self.get_pb_field(
            pb_alloc,
            "enable_ew_add_optimization",
            pb_wrapper.ENABLE_EW_ADD_OPTIMZATION_POLICY_PB_TO_TYPE,
        )
        self.get_pb_field(pb_alloc, "parallel_activation", pb_wrapper.PARALLEL_ACTIVATION_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb_alloc, "compress_l3_weights", pb_wrapper.COMPRESS_L3_WEIGHTS_POLICY_PB_TO_TYPE)
        self.get_auto_pb_field(pb_alloc, "nms_burst_size", AutoInt)
        if len(pb_alloc.hw_layer_type_list) > 0:
            pb_to_type = pb_wrapper.HW_LAYER_PB_TO_TYPE
            self._params["hw_layer_type_list"] = [pb_to_type[layer_type] for layer_type in pb_alloc.hw_layer_type_list]
        if pb_alloc.HasField("resize_bilinear_streaming"):
            self.get_pb_field(
                pb_alloc, "resize_bilinear_streaming", pb_wrapper.RESIZE_BILINEAR_STREAMING_POLICY_PB_TO_TYPE
            )
        if pb_alloc.HasField("two_line_buffer_mode"):
            self.get_pb_field(pb_alloc, "two_line_buffer_mode", pb_wrapper.TWO_LINE_BUFFER_MODE_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb_alloc, "layer_output_buffers_format", pb_wrapper.BUFFERS_FORMAT_PB_TO_TYPE)


class QuantizationParams(HNLayersParams):
    DEFAULT_PARAMS = {
        "precision_mode": None,
        "bias_mode": None,
        "quantization_groups": None,
        "quantization_weight_groups": None,
        "signed_output": None,
    }
    ALLOWED_VALUES = {
        "precision_mode": allowed_str_for_enumerate(PrecisionMode),
        "bias_mode": allowed_str_for_enumerate(BiasMode),
    }
    POLICY_TYPES = {
        "precision_mode": PrecisionModes,
        "bias_mode": BiasModes,
    }
    VALUE_CONVERTERS = {
        "precision_mode": convert_str_to_enumerate(PrecisionMode),
        "bias_mode": convert_str_to_enumerate(BiasMode),
        "quantization_groups": convert_to_int,
        "quantization_weight_groups": convert_to_int,
        "signed_output": convert_str_to_bool,
    }

    def _check_for_param_duplication(self, key1, key2, input_dict):
        if key1 in input_dict and key2 in input_dict:
            raise HailoNNLayerParamsException(
                f'Cannot set "{key1}={input_dict[key1]}" param and "{key2}={input_dict[key2]}" at the same time',
            )

    def to_hn(self, hn_item, should_get_default_params=False):
        hn_item["quantization_params"] = copy.copy(self._params)

        if "precision_mode" in hn_item["quantization_params"]:
            hn_item["quantization_params"]["precision_mode"] = hn_item["quantization_params"]["precision_mode"].value
        if "bias_mode" in hn_item["quantization_params"]:
            hn_item["quantization_params"]["bias_mode"] = hn_item["quantization_params"]["bias_mode"].value
        if "null_channels_cutoff_factor" in hn_item["quantization_params"]:
            del hn_item["quantization_params"]["null_channels_cutoff_factor"]
        if "max_elementwise_feed_repeat" in hn_item["quantization_params"]:
            del hn_item["quantization_params"]["max_elementwise_feed_repeat"]
        if "activation_fit" in hn_item["quantization_params"]:
            del hn_item["quantization_params"]["activation_fit"]
        if "equalization" in hn_item["quantization_params"]:
            del hn_item["quantization_params"]["equalization"]
        if "bias_correction" in hn_item["quantization_params"]:
            del hn_item["quantization_params"]["bias_correction"]
        if "max_bias_feed_repeat" in hn_item["quantization_params"]:
            del hn_item["quantization_params"]["max_bias_feed_repeat"]
        if "output_min_max_strategy" in hn_item["quantization_params"]:
            del hn_item["quantization_params"]["output_min_max_strategy"]

    def to_pb(self, pb, pb_wrapper):
        params = self.get()
        pb.quantization_params.precision_mode.CopyFrom(pb_wrapper.DumpPrecisionMode(params.get("precision_mode", None)))

        self.set_pb_field(pb.quantization_params, "quantization_groups")
        self.set_pb_field(pb.quantization_params, "signed_output")
        self.set_pb_field(pb.quantization_params, "quantization_weight_groups")
        self.set_pb_field(pb.quantization_params, "bias_mode", pb_wrapper.BIAS_MODE_TYPE_TO_PB)

        # This fields are irrelevant, serializing them just for legacy reasons.
        pb.quantization_params.null_channels_cutoff_factor = 0
        pb.quantization_params.max_elementwise_feed_repeat = 0

    def from_pb(self, pb, pb_wrapper):
        pb_alloc = pb.quantization_params
        self.get_pb_field(pb_alloc, "quantization_groups")
        self.get_pb_field(pb_alloc, "bias_mode", pb_wrapper.BIAS_MODE_PB_TO_TYPE)
        self.get_pb_field(pb_alloc, "signed_output")
        self.get_pb_field(pb_alloc, "quantization_weight_groups")
        self._params["precision_mode"] = pb_wrapper.LoadPrecisionMode(pb_alloc.precision_mode)

    def keys(self):
        return self._params.keys()

    def to_precision_config(self):
        return LayerPrecisionConfig(**self._params)


class DefuseParams(HNLayersParams):
    DEFAULT_PARAMS = {
        "defuse_type": DefuseType.none,
        "defuse_features": 0,
        "defuse_input_width": 0,
        "defuse_original_features": 0,
        "defuse_features_offset": 0,
        "defuse_width_offset": 0,
        "defuse_name": "",
        "defuse_input_shapes": None,
        "defuse_output_shapes": None,
        "feature_split": False,
        "defuse_types": None,
        "defuse_ew_add_input_width": 0,
        "defuse_ew_add_width_offset": 0,
        "defuse_input_features": 0,
        "defuse_groups": 0,
    }
    ALLOWED_VALUES = {"defuse_type": allowed_str_for_enumerate(DefuseType)}
    VALUE_CONVERTERS = {
        "defuse_type": convert_str_to_enumerate(DefuseType),
        "defuse_features": convert_to_int,
        "defuse_input_width": convert_to_int,
        "defuse_original_features": convert_to_int,
        "defuse_features_offset": convert_to_int,
        "defuse_width_offset": convert_to_int,
        "defuse_ew_add_input_width": convert_to_int,
        "defuse_ew_add_width_offset": convert_to_int,
        "defuse_input_features": convert_to_int,
        "defuse_groups": convert_to_int,
        "defuse_name": lambda val: val,
    }

    def set(self, input_dict):
        defuse_params = copy.deepcopy(input_dict)
        if "defuse_type" in input_dict:
            defuse_params["defuse_type"] = DefuseType(input_dict["defuse_type"])
        super().set(defuse_params)

    def set_defuse_name(self, name):
        self._params["defuse_name"] = name

    def to_hn(self, hn_item, should_get_default_params=False):
        if should_get_default_params:
            hn_item["defuse_params"] = self._params
        elif len(self.non_default_params) != 0:
            hn_item["defuse_params"] = self.non_default_params
        else:
            return

        # TODO: SDK-9156: add validation on defuse params
        if "defuse_type" in hn_item["defuse_params"]:
            hn_item["defuse_params"]["defuse_type"] = hn_item["defuse_params"]["defuse_type"].value

        if hn_item["defuse_params"].get("defuse_types") is not None:
            hn_item["defuse_params"]["defuse_types"] = [
                defuse_type.value for defuse_type in hn_item["defuse_params"]["defuse_types"]
            ]

    def to_pb(self, pb, pb_wrapper):
        self.set_pb_field(pb.defuse_params, "defuse_type", pb_wrapper.DEFUSE_TYPE_TO_PB)
        self.set_pb_field(pb.defuse_params, "defuse_features")
        self.set_pb_field(pb.defuse_params, "defuse_input_width")
        self.set_pb_field(pb.defuse_params, "defuse_original_features")
        self.set_pb_field(pb.defuse_params, "defuse_features_offset")
        self.set_pb_field(pb.defuse_params, "defuse_width_offset")
        self.set_pb_field(pb.defuse_params, "defuse_ew_add_input_width")
        self.set_pb_field(pb.defuse_params, "defuse_ew_add_width_offset")
        self.set_pb_field(pb.defuse_params, "defuse_name")
        self.set_pb_field(pb.defuse_params, "defuse_input_features")
        self.set_pb_field(pb.defuse_params, "defuse_groups")

    def from_pb(self, pb, pb_wrapper):
        if not pb.HasField("defuse_params"):
            return
        defuse_params = pb.defuse_params
        self.get_pb_field(defuse_params, "defuse_name")
        self.get_pb_field(defuse_params, "defuse_type", pb_wrapper.DEFUSE_PB_TO_TYPE)
        self.get_pb_field(defuse_params, "defuse_features_offset")
        self.get_pb_field(defuse_params, "defuse_width_offset")
        self.get_pb_field(defuse_params, "defuse_features")
        self.get_pb_field(defuse_params, "defuse_input_width")
        self.get_pb_field(defuse_params, "defuse_original_features")
        self.get_pb_field(defuse_params, "defuse_ew_add_input_width")
        self.get_pb_field(defuse_params, "defuse_ew_add_width_offset")
        self.get_pb_field(defuse_params, "defuse_input_features")
        self.get_pb_field(defuse_params, "defuse_groups")
        defuse_input_shapes = getattr(defuse_params, "defuse_input_shapes", None)
        if defuse_input_shapes:
            self._params["defuse_input_shapes"] = [[dis.height, dis.width, dis.features] for dis in defuse_input_shapes]
        defuse_output_shapes = getattr(defuse_params, "defuse_output_shapes", None)
        if defuse_output_shapes:
            self._params["defuse_output_shapes"] = [
                [dos.height, dos.width, dos.features] for dos in defuse_output_shapes
            ]
        defuse_types = getattr(defuse_params, "defuse_types", None)
        if defuse_types:
            self._params["defuse_types"] = [pb_wrapper.DEFUSE_PB_TO_TYPE[i] for i in defuse_types]
        self.get_pb_field(defuse_params, "feature_split")
