#!/usr/bin/env python

from hailo_sdk_common.hailo_nn.hn_definitions import (
    EnableLcuFromSequencerPolicy,
    KoPolicy,
    ParamsLoadTimeCompressionPolicy,
    ShouldAlignCcwsSection,
    UseSequencerPolicy,
    policies_dict,
)
from hailo_sdk_common.hailo_nn.tools_params import (
    AutoInt,
    ToolsParams,
    allowed_str_for_bool,
    allowed_str_for_enumerate,
    convert_str_to_bool,
    convert_str_to_enumerate,
    convert_to_auto_int,
)


class HefParams(ToolsParams):
    DEFAULT_PARAMS = {
        "should_use_ccw": True,
        "should_use_sequencer": UseSequencerPolicy.allowed,
        "enable_lcu_from_sequencer": EnableLcuFromSequencerPolicy.enabled,
        "should_use_sequencer_interleave": True,
        "should_use_sequencer_l2_interleave": True,
        "should_compress_sequencer_data": True,
        "should_prioritize_sequencer_over_l3": True,
        "should_run_abstract_validation": True,
        "should_run_full_validation": True,
        "posted_writes": False,
        "params_load_time_compression": ParamsLoadTimeCompressionPolicy.allowed,
        "cfg_channel_count": AutoInt("automatic"),
        "enable_axis_upsize_workaround": True,
        "dump_debug_params": False,
        "dump_debug_hef_per_context": False,
        "delay_context_switch": False,
        "should_use_confifo": True,
        "enable_ko": KoPolicy.ALLOWED,
        "num_preliminary_groups": AutoInt("automatic"),
        "dma_engine_count": AutoInt("automatic"),
        "should_align_ccws_section": ShouldAlignCcwsSection.allowed,
        "should_padd_channels": True,
        "strip_mapping_info": False,
    }
    VALUE_CONVERTERS = {
        "should_use_ccw": convert_str_to_bool,
        "should_use_sequencer_interleave": convert_str_to_bool,
        "should_use_sequencer_l2_interleave": convert_str_to_bool,
        "should_compress_sequencer_data": convert_str_to_bool,
        "should_prioritize_sequencer_over_l3": convert_str_to_bool,
        "should_use_sequencer": convert_str_to_enumerate(UseSequencerPolicy),
        "enable_lcu_from_sequencer": convert_str_to_enumerate(EnableLcuFromSequencerPolicy),
        "should_run_abstract_validation": convert_str_to_bool,
        "should_run_full_validation": convert_str_to_bool,
        "posted_writes": convert_str_to_bool,
        "params_load_time_compression": convert_str_to_enumerate(ParamsLoadTimeCompressionPolicy),
        "cfg_channel_count": convert_to_auto_int,
        "enable_axis_upsize_workaround": convert_str_to_bool,
        "dump_debug_params": convert_str_to_bool,
        "dump_debug_hef_per_context": convert_str_to_bool,
        "delay_context_switch": convert_str_to_bool,
        "should_use_confifo": convert_str_to_bool,
        "enable_ko": convert_str_to_enumerate(KoPolicy),
        "num_preliminary_groups": convert_to_auto_int,
        "dma_engine_count": convert_to_auto_int,
        "should_align_ccws_section": convert_str_to_enumerate(ShouldAlignCcwsSection),
        "should_padd_channels": convert_str_to_bool,
        "strip_mapping_info": convert_str_to_bool,
    }
    POLICY_TYPES = {
        "should_use_sequencer": policies_dict(UseSequencerPolicy),
        "enable_lcu_from_sequencer": policies_dict(EnableLcuFromSequencerPolicy),
        "params_load_time_compression": policies_dict(ParamsLoadTimeCompressionPolicy),
        "enable_ko": policies_dict(KoPolicy),
    }
    ALLOWED_VALUES = {
        "should_use_ccw": allowed_str_for_bool(),
        "should_use_sequencer_interleave": allowed_str_for_bool(),
        "should_use_sequencer_l2_interleave": allowed_str_for_bool(),
        "should_compress_sequencer_data": allowed_str_for_bool(),
        "should_prioritize_sequencer_over_l3": allowed_str_for_bool(),
        "should_use_sequencer": allowed_str_for_enumerate(UseSequencerPolicy),
        "enable_lcu_from_sequencer": allowed_str_for_enumerate(EnableLcuFromSequencerPolicy),
        "should_run_abstract_validation": allowed_str_for_bool(),
        "should_run_full_validation": allowed_str_for_bool(),
        "posted_writes": allowed_str_for_bool(),
        "params_load_time_compression": allowed_str_for_enumerate(ParamsLoadTimeCompressionPolicy),
        "enable_axis_upsize_workaround": allowed_str_for_bool(),
        "dump_debug_params": allowed_str_for_bool(),
        "dump_debug_hef_per_context": allowed_str_for_bool(),
        "delay_context_switch": allowed_str_for_bool(),
        "should_use_confifo": allowed_str_for_bool(),
        "enable_ko": allowed_str_for_enumerate(KoPolicy),
        "should_align_ccws_section": allowed_str_for_enumerate(ShouldAlignCcwsSection),
        "should_padd_channels": allowed_str_for_bool(),
        "strip_mapping_info": allowed_str_for_bool(),
    }
    BACKWARD_COMP_PARAMS = {"params_compression": "params_load_time_compression"}
    DEPRECATED_PARAMS = {"should_force_bias_interleave"}

    def __init__(self):
        super().__init__()
        self.set_default_params()

    def to_pb(self, pb, pb_wrapper):
        self.set_pb_field(pb, "should_use_ccw")
        self.set_pb_field(pb, "should_use_sequencer", pb_wrapper.USE_SEQUENCER_POLICY_TYPE_TO_PB)
        self.set_pb_field(pb, "enable_lcu_from_sequencer", pb_wrapper.ENABLE_LCU_FROM_SEQUENCER_POLICY_TYPE_TO_PB)
        self.set_pb_field(pb, "should_run_abstract_validation")
        self.set_pb_field(pb, "should_run_full_validation")
        self.set_pb_field(pb, "posted_writes")
        self.set_pb_field(pb, "params_load_time_compression", pb_wrapper.PARAMS_LOAD_TIME_COMPRESSION_POLICY_TYPE_TO_PB)
        self.set_auto_pb_field(pb, pb_wrapper, "cfg_channel_count")

        self.set_pb_field(pb, "should_use_sequencer_interleave")
        self.set_pb_field(pb, "should_use_sequencer_l2_interleave")
        self.set_pb_field(pb, "should_compress_sequencer_data")
        self.set_pb_field(pb, "should_prioritize_sequencer_over_l3")
        self.set_pb_field(pb, "enable_axis_upsize_workaround")
        self.set_pb_field(pb, "dump_debug_params")
        self.set_pb_field(pb, "dump_debug_hef_per_context")
        self.set_pb_field(pb, "delay_context_switch")
        self.set_pb_field(pb, "should_use_confifo")
        self.set_pb_field(pb, "enable_ko", pb_wrapper.KO_POLICY_TYPE_TO_PB)
        self.set_auto_pb_field(pb, pb_wrapper, "num_preliminary_groups")
        self.set_auto_pb_field(pb, pb_wrapper, "dma_engine_count")
        self.set_pb_field(pb, "should_align_ccws_section", pb_wrapper.ALIGN_CCWS_SECTION_POLICY_TYPE_TO_PB)
        self.set_pb_field(pb, "should_padd_channels")
        self.set_pb_field(pb, "strip_mapping_info")

    def from_pb(self, pb, pb_wrapper):
        self.get_pb_field(pb, "should_use_ccw")
        self.get_pb_field(pb, "should_use_sequencer", pb_wrapper.USE_SEQUENCER_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb, "enable_lcu_from_sequencer", pb_wrapper.ENABLE_LCU_FROM_SEQUENCER_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb, "should_run_abstract_validation")
        self.get_pb_field(pb, "should_run_full_validation")
        self.get_pb_field(pb, "posted_writes")
        self.get_pb_field(pb, "params_load_time_compression", pb_wrapper.PARAMS_LOAD_TIME_COMPRESSION_POLICY_PB_TO_TYPE)
        self.get_auto_pb_field(pb, "cfg_channel_count", AutoInt)
        self.get_pb_field(pb, "should_use_sequencer_interleave")
        self.get_pb_field(pb, "should_use_sequencer_l2_interleave")
        self.get_pb_field(pb, "should_compress_sequencer_data")
        self.get_pb_field(pb, "should_prioritize_sequencer_over_l3")
        self.get_pb_field(pb, "enable_axis_upsize_workaround")
        self.get_pb_field(pb, "dump_debug_params")
        self.get_pb_field(pb, "dump_debug_hef_per_context")
        self.get_pb_field(pb, "delay_context_switch")
        self.get_pb_field(pb, "should_use_confifo")
        self.get_pb_field(pb, "enable_ko", pb_wrapper.KO_POLICY_PB_TO_TYPE)
        self.get_auto_pb_field(pb, "num_preliminary_groups", AutoInt)
        self.get_auto_pb_field(pb, "dma_engine_count", AutoInt)
        self.get_pb_field(pb, "should_align_ccws_section", pb_wrapper.ALIGN_CCWS_SECTION_POLICY_PB_TO_TYPE)
        self.get_pb_field(pb, "should_padd_channels")
        self.get_pb_field(pb, "strip_mapping_info")
