from past.utils import old_div

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import ConcatLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification
from hailo_sdk_common.numeric_utils import numeric_utils


class OutputMuxLayer(ConcatLayer):
    MEMORY_ALIGNMENT = 8

    def __init__(self):
        super().__init__()
        self._op = LayerType.output_mux

    @staticmethod
    def _calc_width(shape, h_ratios, index):
        return numeric_utils.align_to(shape[2], OutputMuxLayer.MEMORY_ALIGNMENT) * shape[3] * h_ratios[index]

    def _calc_output_shape(self):
        output_mux_b = self.pred_layer_output_shape(self._input_list[0])[0]

        # we take w0*f0 + w1*f1 + ...
        # FIXME this is not true. (consider padded_output_height)
        output_mux_h = numeric_utils.get_gcd(
            *[self.pred_layer_output_shape(in_item)[1] for in_item in self._input_list],
        )
        h_ratios = [old_div(self.pred_layer_output_shape(in_item)[1], output_mux_h) for in_item in self._input_list]
        output_mux_w = sum(
            [
                self._calc_width(self.pred_layer_output_shape(in_item), h_ratios, i)
                for i, in_item in enumerate(self._input_list)
            ],
        )
        output_mux_f = 1  # Flat

        return [output_mux_b, output_mux_h, output_mux_w, output_mux_f]

    def _get_output_shape(self, validate=True, layer_name=None, layer_index=None):
        return super()._get_output_shape(False)

    @property
    def input_height(self):
        if self.is_from_dense():
            return 1
        return self._get_shape_single_dim(self._input_shapes, 1, validate=False)

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_OUTPUT_MUX
        return node

    # TODO: set value instead undefined
    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def finetune_supported(self):
        return False
