import numpy as np

from hailo_model_optimization.saitama.translators.hailo_translator.mappings.common_unit_mapping import CommonMappings
from hailo_model_optimization.saitama.translators.translator_utils import KeyHandler


def get_ew_mult_on_apu_mapping(tags=None):
    return [
        *CommonMappings.get_bias_mapping(
            "mac.bias.0",
            tags=tags,
            weight=None,
            shape_key="act_op/input_scale:0:0",
            op_name="bias_add_op_a",
        ),
        *CommonMappings.get_bias_mapping(
            "mac.bias.1",
            tags=tags,
            weight=None,
            shape_key="act_op/input_scale:0:0",
            op_name="bias_add_op_b",
        ),
        KeyHandler("mac.apu_shift", "elementwise_mult_op/mult_shift:0", tags=tags),
        KeyHandler("mac.accumulator_quantizer.scale", "act_op/input_scale:0:0", tags=tags),
        KeyHandler("mac.accumulator_quantizer.zero_point", "act_op/input_zero_point:0:0", tags=tags),
    ]


def get_ew_mult_on_mac_mapping(tags=None):
    return [
        KeyHandler(
            "mac.accumulator_quantizer.scale",
            ["bias_add_op/output_scale:0:0", "bias:0"],
            lambda x, y, **kw: x + np.zeros_like(y),
            tags=tags,
        ),
        KeyHandler("mac.accumulator_quantizer.zero_point", "bias_add_op/output_zero_point:0:0", tags=tags),
        *CommonMappings.get_bias_mapping("mac.bias", tags=tags),  # adds bias keys
    ]
