from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    DataPath,
    OptimizationTarget,
    PrecisionMode,
)


class BaseHailoSingleAtomic(BaseHailoLayer):
    """Base class for CompositeOp that encapsulates a single operation"""

    # The real options are either a8 or a16, the weights shouldn't matter (I think?)
    SUPPORTED_PRECISION_MODE = {
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a16_w16_a16,
    }
    # Single atomic shouldn't have bias, so any value is fine
    SUPPORTED_BIAS_MODE = {
        BiasMode.single_scale_decomposition,
    }
    SUPPORTED_QUANTIZATION_GROUPS = False

    # TODO: this layer is not fully ready yet. It should have proper lossy SDK-23646
    def __init__(
        self,
        name: str,
        core_op: BaseAtomicOp,
        logger=None,
        **kwargs,
    ):
        self.atomic_op = core_op
        super().__init__(name=name, logger=logger, **kwargs)

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.atomic_op)

        layer_flow.add_edge(in1, self.atomic_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.atomic_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    def import_weights(self, *args, **kwargs):
        self.atomic_op.import_weights(*args, **kwargs)

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        self.atomic_op.create_hw_params(optimization_target=optimization_target)

    def enforce_io_encoding(self, training=False, **kwargs):
        self.atomic_op.enforce_encoding()

    def _export_weights(self):
        return self.atomic_op.export_weights()

    def enforce_internal_encoding(self, training=False, **kwargs):
        # for explicitness
        pass

    def fast_enforce_internal_encoding(self, **kwargs):
        pass

    def _get_precision_mode_supported_in_hw(self, arch):
        if arch in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            supported_precision_mode = {
                PrecisionMode.a8_w8,
                PrecisionMode.a16_w16,
                PrecisionMode.a8_w8_a8,
                PrecisionMode.a16_w16_a16,
            }
        else:
            supported_precision_mode = super()._get_precision_mode_supported_in_hw(arch)
        return supported_precision_mode
