import copy

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EWMultType,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, FeatureMultiplierType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.hn_layers.nms import NMSLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification
from hailo_sdk_common.hailo_nn.nms_postprocess_defaults import DEFAULT_YOLO_ANCHORS, DEFAULT_YOLO_CLASSES


class FeatureMultiplierLayer(LayerWithActivation):
    _REQUIRES_NATIVE_WEIGHTS = {
        FeatureMultiplierType.user_specified: True,
        FeatureMultiplierType.square: False,
        FeatureMultiplierType.yolov5: True,
    }
    _REQUIRES_QUANTIZED_WEIGHTS = {
        FeatureMultiplierType.user_specified: True,
        FeatureMultiplierType.square: True,
        FeatureMultiplierType.yolov5: True,
    }

    def __init__(self, recipe=None):
        super().__init__()
        self._op = LayerType.feature_multiplier
        self._number_of_inputs_supported = 1
        self._feature_multiplier_type = FeatureMultiplierType.user_specified
        self._input_list = []
        self._power_table = None
        self._ew_mult_type = EWMultType.on_apu
        self._reduce_sum_groups = None
        self._dynamic_weights = False

    @staticmethod
    def convert_to_array(recipe, input_features):
        """
        This func converts recipe to 2d array of powers, the recipe is list contains
        lists, one for each output. the inner list contains tuples, while each tuple
        represent indexes of features to multiply and store in the tuple's index.
            for example, the recipe [[(1, 2), (3, 3)], [(1, 1), (3,)]]
            will store feature_1 x feature_2 in feature_0  in first output,
            store feature_3 x feature_3 in feature_1  in first output,
            store feature_1 x feature_1 in feature_0  in second output,
            and store feature_3 in feature_1 in second output.

        The output of the function is np.array that contains the powers to apply for
        each output feature. For the example below, the output will be:
        [
            # first output:
            [1, 1, 0], # (1, 2)
            [0, 0, 2], # (3, 3)
            # second output:
            [2, 0, 0], # (1, 1)
            [0, 0, 1], # (3, )
        ]
        """
        output_features = sum([len(output) for output in recipe])
        tables = []

        for output in recipe:
            index = 0
            output_features = len(output)
            power_table = np.full([output_features, input_features], fill_value=0, dtype="float32")

            for features in output:
                for feature in features:
                    power_table[index, feature] = power_table[index, feature] + 1
                index = index + 1

            tables.append(power_table)

        return np.concatenate(tables)

    @property
    def output_features(self):
        return sum([shape[3] for shape in self._output_shapes])

    @property
    def feature_multiplier_type(self):
        return self._feature_multiplier_type

    @feature_multiplier_type.setter
    def feature_multiplier_type(self, new_type):
        self._feature_multiplier_type = new_type

    @property
    def power_table(self):
        return self._power_table

    @power_table.setter
    def power_table(self, new_power_table):
        self._power_table = new_power_table

    def validate_recipe(self, recipe):
        if len(recipe) != len(self.output_shapes):
            raise Exception(
                f"Recipe must contain same number of elements as number of output shapes at {self.full_name_msg}",
            )
        for output, output_shape in zip(recipe, self.output_shapes):
            if len(output) != output_shape[-1]:
                raise Exception(
                    f"Each output in recipe should contains same number of tuples as output shape features "
                    f"at {self.full_name_msg}",
                )

    @staticmethod
    def validate_table(table):
        total_pass_count = 0
        for row in table:
            square_count = np.count_nonzero(row == 2)
            pass_count = np.count_nonzero(row == 1)
            zero_count = np.count_nonzero(row == 0)

            total_pass_count = total_pass_count + 1 if pass_count == 1 else 0

            if square_count + pass_count > 2:
                raise Exception("Maximum two features can be multiplied at same output feature")

            if square_count + pass_count + zero_count != row.shape[0]:
                raise Exception("Invalid power in power table, no one of [0, 1, 2]")

        if total_pass_count == table.shape[0]:
            raise Exception("Table Cannot contain only output features constructed from single feature multiplication")

    def init_power_table(self, recipe):
        if recipe:
            self.validate_recipe(recipe)
            shape = (
                self.output_shape[-1]
                if self._feature_multiplier_type == FeatureMultiplierType.square
                else self.input_shape[-1]
            )
            power_table = FeatureMultiplierLayer.convert_to_array(recipe, shape)
            self.validate_table(power_table)
            self._power_table = power_table
        return power_table

    def sort_outputs(self):
        return lambda layer1, layer2: 1 if self.outputs.index(layer1.name) > self.outputs.index(layer2.name) else -1

    @classmethod
    def create(cls, original_name, input_vertex_order, feature_multiplier_type=None, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        if feature_multiplier_type:
            layer.feature_multiplier_type = feature_multiplier_type
        return layer

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate)

    @property
    def input_list(self):
        return self._input_list

    @input_list.setter
    def input_list(self, input_list):
        self._input_list = input_list

    @property
    def dynamic_weights(self):
        return self._dynamic_weights

    @dynamic_weights.setter
    def dynamic_weights(self, dynamic_weights):
        self._dynamic_weights = dynamic_weights

    def append_to_input_list(self, inp):
        self._input_list.append(inp)

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["activation"] = self._activation.value
        result["params"]["feature_multiplier_type"] = self._feature_multiplier_type.value
        result["params"]["ew_mult_type"] = self._ew_mult_type.value
        if self._reduce_sum_groups is not None:
            result["params"]["reduce_sum_groups"] = self._reduce_sum_groups

        return result

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn, validate_params_exist=False)

        if "params" in hn:
            if "activation" in hn["params"]:
                layer.activation = ActivationType(hn["params"]["activation"])
            if "feature_multiplier_type" in hn["params"]:
                layer.feature_multiplier_type = FeatureMultiplierType(hn["params"]["feature_multiplier_type"])
            layer._ew_mult_type = EWMultType(hn["params"].get("ew_mult_type", EWMultType.on_apu))
            layer._reduce_sum_groups = hn["params"].get("reduce_sum_groups", None)
            layer.dynamic_weights = False if layer._ew_mult_type == EWMultType.on_apu else True
            if layer._ew_mult_type != EWMultType.on_mac and layer._reduce_sum_groups is not None:
                raise ValueError("reduce sum groups is only supported if ew_mult_type is on_mac")
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_FEATURE_MULTIPLIER
        node.activation = pb_wrapper.ACTIVATION_TYPE_TO_PB[self._activation]
        node.feature_multiplier_type = pb_wrapper.FEATURE_MULTIPLIER_TYPE_TYPE_TO_PB[self._feature_multiplier_type]
        node.reduce_sum_groups = (
            self._reduce_sum_groups if self._reduce_sum_groups is not None else self._input_shapes[0][-1]
        )
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._activation = pb_wrapper.ACTIVATION_PB_TO_TYPE[pb.activation]
        layer.feature_multiplier_type = pb_wrapper.FEATURE_MULTIPLIER_TYPE_PB_TO_TYPE[pb.feature_multiplier_type]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.activation = old_layer.activation
        layer.feature_multiplier_type = old_layer.feature_multiplier_type
        layer.power_table = old_layer.power_table
        return layer

    def update_output_shapes(self, **kwargs):
        self.output_shapes = self._calc_output_shape()

    def _get_output_shape(self, validate=True, layer_name=None, layer_index=None):
        if len(self._output_shapes) == 1:
            return self._output_shapes[0]
        if layer_name is None:
            raise UnsupportedModelError(f"{self.full_name_msg} successor name is missing, output shape is ambiguous")
        return self._output_shapes[self.outputs.index(layer_name)]

    @property
    def macs(self):
        # The /2 is because we don't do accumulate
        return self.ops / 2

    @property
    def ops(self):
        if self._feature_multiplier_type == FeatureMultiplierType.square:
            # A little trick that simplifies it all
            return float(np.abs(np.prod(np.array(self.input_shape))))
            # we square each value once
        else:
            return 0

    def _calc_output_shape(self):
        if self._feature_multiplier_type == FeatureMultiplierType.square:
            output_shape = self.input_shapes
            if self._reduce_sum_groups is not None:
                output_shape = [[*shape[:3], self._reduce_sum_groups] for shape in output_shape]
        else:
            output_shape = self.output_shapes

        if "defuse_input_width" in self.defuse_params and self.defuse_input_width != 0:
            output_shape[2] = self.defuse_input_width
        return output_shape

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    def user_specified(self):
        return []

    def yolov5(self, number_of_anchors=DEFAULT_YOLO_ANCHORS, number_of_classes=DEFAULT_YOLO_CLASSES):
        def _build_feature_power_table_scores(number_of_anchors, number_of_probs):
            scores = [()] * number_of_anchors * number_of_probs
            skip_count = number_of_probs + NMSLayer.BBOX_PER_CHUNK + 1  # Obj not count in BBOX_PER_CHUNK

            for i in range(number_of_anchors):
                for j in range(number_of_probs):
                    scores[(i * number_of_probs) + j] = (
                        skip_count * i + NMSLayer.BBOX_PER_CHUNK,
                        skip_count * i + NMSLayer.BBOX_PER_CHUNK + 1 + j,
                    )

            return scores

        def _build_feature_power_table_centers_and_scales(number_of_anchors, number_of_probs):
            centers_and_scales = [()] * NMSLayer.BBOX_PER_CHUNK * number_of_anchors
            skip_count = number_of_probs + NMSLayer.BBOX_PER_CHUNK + 1  # Obj not count in BBOX_PER_CHUNK

            for i in range(number_of_anchors):
                centers_and_scales[0 + NMSLayer.BBOX_PER_CHUNK * i] = (1 + skip_count * i,)
                centers_and_scales[1 + NMSLayer.BBOX_PER_CHUNK * i] = (0 + skip_count * i,)
                centers_and_scales[2 + NMSLayer.BBOX_PER_CHUNK * i] = (3 + skip_count * i, 3 + skip_count * i)
                centers_and_scales[3 + NMSLayer.BBOX_PER_CHUNK * i] = (2 + skip_count * i, 2 + skip_count * i)

            return centers_and_scales

        return [
            _build_feature_power_table_centers_and_scales(number_of_anchors, number_of_classes),
            _build_feature_power_table_scores(number_of_anchors, number_of_classes),
        ]

    @property
    def hn_name(self):
        if self.feature_multiplier_type == FeatureMultiplierType.square:
            return self.feature_multiplier_type.value
        return super().hn_name

    @property
    def finetune_supported(self):
        return False

    @property
    def requires_native_weights(self):
        if self.feature_multiplier_type not in self._REQUIRES_NATIVE_WEIGHTS:
            self._logger.warning(
                f"Layer {self.name} of feature multiplier type {self.feature_multiplier_type.value} does not specify"
                f" whether native weights are required. Assuming True.",
            )
            return True

        return self._REQUIRES_NATIVE_WEIGHTS[self.feature_multiplier_type]

    @property
    def requires_quantized_weights(self):
        if self.feature_multiplier_type not in self._REQUIRES_QUANTIZED_WEIGHTS:
            self._logger.warning(
                f"Layer {self.name} of feature multiplier type {self.feature_multiplier_type.value} does not specify"
                f" whether quantized weights are required. Assuming True.",
            )
            return True

        return self._REQUIRES_QUANTIZED_WEIGHTS[self.feature_multiplier_type]
