from typing import List, Union

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.reorder_op import ReorderOp
from hailo_model_optimization.acceleras.atomic_ops.slice_op import SliceOp
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    DataPath,
    EquivClassification,
    FeatureMultiplierType,
    LayerHandlerType,
    LayerType,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams

DEFAULT_YOLO_ANCHORS = 3
DEFAULT_YOLO_CLASSES = 80
BBOX_PER_CHUNK = 4

OUT_SLICE_PREFIX = "out_slice_op"


class HailoFeatureMultiplierOnMac(HailoElementwiseMultOnMac):
    """
    This layer can multiply between input features or output them unaffected.
    If the "feature_multiplier_type" is square, the layer will just multiply each feature with itself,
    If the type is "user_specified" so the layer will multiply the weights depending on the power_table.
    The layer uses element wise layer, the difference is that both inputs are the same input.

    Args:
        features_in - number of input features
        features_out - number of output features
        feature_multiplier_type - The type of multiplication to do, hence, defining the power_table.

    """

    _hn_type = LayerType.FEATURE_MULTIPLIER

    def __init__(
        self,
        name: str,
        features_in: int,
        features_out: Union[int, List[int]],
        activation: Union[str, callable, ActivationType] = ActivationType.LINEAR,
        feature_multiplier_type=FeatureMultiplierType.user_specified,
        logger=None,
        **kwargs,
    ):
        if isinstance(features_out, int):
            features_out = [features_out]
        self.input_pass_op = PassthruOp(name=f"{name}/in_passthru_op", logger=logger)
        self.feature_multiplier_type = feature_multiplier_type
        self._features_in = features_in
        if feature_multiplier_type == FeatureMultiplierType.square:
            self._features_out = features_in
        else:
            self._features_out = sum(features_out)

        self._reorder1 = ReorderOp(name=f"{name}/reorder1_op", logger=logger)
        self._reorder2 = ReorderOp(name=f"{name}/reorder2_op", logger=logger)
        self._output_slices: List[SliceOp] = []
        slice_start = 0
        slice_end = 0
        for index, features_count in enumerate(features_out):
            slice_end += features_count
            slice_op = SliceOp(
                name=f"{name}/{OUT_SLICE_PREFIX}_{index}",
                features_slice=(slice_start, slice_end, 1),
                logger=logger,
            )
            self._output_slices.append(slice_op)
            slice_start += features_count
        super().__init__(name=name, activation=activation, logger=logger, **kwargs)

        self.encoding_const = True

    @classmethod
    def get_default_params(cls):
        # TODO: this is temporary solution until we have pydantic scheme
        defaults = {
            "activation": "linear",
        }
        return dict(defaults)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = cls.get_default_params()
        features_in = hn_element["input_shapes"][0][-1]
        params.update(hn_element.get("params", dict()))
        reduce_sum_groups = params.get("reduce_sum_groups", None)
        feature_multiplier_type = FeatureMultiplierType(params["feature_multiplier_type"])

        output_features = [sh[-1] for sh in hn_element["output_shapes"]]
        layer = cls(
            name=lname,
            features_in=features_in,
            features_out=output_features,
            activation=params["activation"],
            feature_multiplier_type=feature_multiplier_type,
            reduce_sum_groups=reduce_sum_groups,
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def _build_flow(self) -> LayerFlow:
        num_outputs = len(self._output_slices)
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        outputs = [layer_flow.add_output() for _ in range(num_outputs)]

        layer_flow.add_node(self.input_pass_op)
        layer_flow.add_node(self._reorder1)
        layer_flow.add_node(self._reorder2)

        layer_flow.add_node(self.ew_mult_op)
        layer_flow.add_node(self.bias_add_op)
        layer_flow.add_node(self.optional_reduce_sum_op)
        layer_flow.add_node(self.act_op)
        layer_flow.add_node(self.output_op)
        for out_index in range(num_outputs):
            layer_flow.add_node(self._output_slices[out_index])

        layer_flow.add_edge(in1, self.input_pass_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_pass_op, self._reorder1, DataPath.LAYER_IN)
        layer_flow.add_edge(self.input_pass_op, self._reorder2, DataPath.LAYER_IN)
        layer_flow.add_edge(self._reorder1, self.ew_mult_op, DataPath.LAYER_IN, input_index=0)
        layer_flow.add_edge(self._reorder2, self.ew_mult_op, DataPath.LAYER_IN, input_index=1)
        layer_flow.add_edge(self.ew_mult_op, self.optional_reduce_sum_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.optional_reduce_sum_op, self.bias_add_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.bias_add_op, self.act_op, DataPath.ACCUMULATOR)
        layer_flow.add_edge(self.act_op, self.output_op, DataPath.LAYER_OUT)
        for out_index in range(num_outputs):
            layer_flow.add_edge(self.output_op, self._output_slices[out_index], DataPath.LAYER_OUT)
            layer_flow.add_edge(self._output_slices[out_index], outputs[out_index], DataPath.LAYER_OUT)

        return layer_flow

    def _layer_dependent_hw_params_modifications(self, params):
        if self.feature_multiplier_type != FeatureMultiplierType.square:
            params["power_table"] = np.array(self._recreate_power_table(), np.int8)
        else:
            params["power_table"] = np.array([[0]], np.int8)

        quantized_one = self._reorder2.get_quantized_one()
        params["quantized_one"] = np.array(quantized_one, dtype=np.uint16)
        return params

    def import_weights(self, layer_params: LayerParams):
        power_table = self.get_power_table(layer_params)
        recipe1, recipe2 = self._prepare_reorder_matrix(power_table)
        self._reorder1.import_weights(recipe1)
        self._reorder2.import_weights(recipe2)
        self.act_op.import_weights(layer_params)

    def _get_post_mult_scale(self):
        """
        utility for dof scale calculation which will be overriden in subclasses
        """
        in1, in2 = self._get_repeated_input_scales()
        quanted_one_scale = self.input_scales[0][0]
        in1_extended = np.append(in1, quanted_one_scale)
        in2_extended = np.append(in2, quanted_one_scale)
        return tf.multiply(in1_extended[self._reorder1.recipe], in2_extended[self._reorder2.recipe])

    def _export_weights(self):
        weights = super()._export_weights()
        weights.update(self.act_op.export_weights())
        if self.feature_multiplier_type != FeatureMultiplierType.square:
            weights["power_table"] = self._recreate_power_table()
        return weights

    def _recreate_power_table(self):
        recipe1 = self._reorder1.recipe
        recipe2 = self._reorder2.recipe
        power_table = np.zeros(shape=(self._features_out, self._features_in), dtype=int)
        for out_channel in range(self._features_out):
            channel_in1 = recipe1[out_channel]
            channel_in2 = recipe2[out_channel]
            # Check if the desired value wasn't 1
            if channel_in1 < self._features_in:
                power_table[out_channel][channel_in1] += 1
            if channel_in2 < self._features_in:
                power_table[out_channel][channel_in2] += 1
        return power_table

    def get_power_table(self, layer_params):
        if self.feature_multiplier_type == FeatureMultiplierType.square:
            return np.identity(self._features_in, dtype="int") * 2
        elif self.feature_multiplier_type in [FeatureMultiplierType.user_specified, FeatureMultiplierType.yolov5]:
            return layer_params["power_table"]

    def _prepare_reorder_matrix(self, power_table):
        reorder1 = np.zeros(shape=(self._features_out, self._features_in), dtype=int)
        reorder2 = np.zeros(shape=(self._features_out, self._features_in), dtype=int)
        # Get all the squere positions
        indexes = np.where(power_table == 2)
        reorder1[indexes] = 1
        reorder2[indexes] = 1

        # Get all the ones
        indexes = np.where(power_table == 1)

        entered_mat1 = np.zeros(self._features_out)
        for row, col in zip(indexes[0], indexes[1]):
            if entered_mat1[row] == 1:
                reorder2[row, col] = 1
            else:
                reorder1[row, col] = 1
                entered_mat1[row] = 1

        recipe1, recipe2 = [], []
        for row1, row2 in zip(reorder1, reorder2):
            ind1 = np.where(row1 == 1)
            if ind1[0].shape[0] == 0:
                # if row is empty, use the index for 1 (which is the last index in the array)
                recipe1.append(self._features_in)
            else:
                recipe1.append(ind1[0][0])

            ind2 = np.where(row2 == 1)
            if ind2[0].shape[0] == 0:
                # if row is empty, use the index for 1 (which is the last index in the array)
                recipe2.append(self._features_in)
            else:
                recipe2.append(ind2[0][0])

        return recipe1, recipe2

    def create_output_encoding_candidates(self, forced_range=None, translation_config=None):
        # We want to calculate the output range based on both outputs of the op.
        # We take the range from both slice ops and take the largest range.
        if forced_range is None:
            lim_min = np.inf
            lim_max = -np.inf
            for op, ind in self._output_stats_ops():
                curr_min, curr_max = op.get_output_limvals(ind)
                lim_max = max(lim_max, curr_max)
                lim_min = min(lim_min, curr_min)
            forced_range = (lim_min, lim_max)
        super().create_output_encoding_candidates(forced_range, translation_config=translation_config)

    def _force_output_scale(self):
        if self.forced_output_scale_scalar_dof is not None and self.output_scale.shape != 0:
            self.set_output_scale(self.input_scale**2 * self.forced_output_scale_scalar_dof, 0)

    def _get_output_scale_for_scalar_dof(self):
        return self.output_op.output_scale

    def enforce_io_encoding(self, training=False, **kwargs):
        if self.input_scales[0].shape == ():
            input_scales_extended = np.append(self.input_scales[0], 1).astype("float32")
        else:
            input_scales_extended = np.append(self.input_scales[0], self.input_scales[0][0]).astype("float32")
        self.output_op.output_scale = (
            input_scales_extended[self._reorder1.recipe]
            * input_scales_extended[self._reorder2.recipe]
            * self.output_scale_scalar_dof
        )
        for slice_op in self._output_slices:
            slice_op.input_zero_point = self.output_op.output_zero_point
            slice_op.input_scale = self.output_op.output_scale
            slice_op.enforce_encoding()

    @staticmethod
    def get_length(input_data):
        if isinstance(input_data, tf.Tensor):
            # Use tf.shape to get a dynamic length (this returns a tensor)
            static_shape = input_data.shape.as_list()
            if static_shape and static_shape[0] is not None:
                return static_shape[0]
            else:
                return tf.shape(input_data)[0]
        else:
            return len(input_data)

    def _enforce_output_encoding(self):
        slice_output_scales = []
        slice_zp = []
        for slice_op in self._output_slices:
            slice_output_scales.append(slice_op.output_scale)
            slice_op.input_zero_point = slice_op.output_zero_point
            slice_zp.append(slice_op.input_zero_point)

        # The output_op here isn't really the output.
        # We resuse output_op because this layer inherits from elementwise mult
        # The real output here are the slice ops.
        if self.get_length(tf.shape(slice_output_scales[0])) == 0:
            # scale should be 1 at this point
            self.output_op.output_scale = slice_output_scales[0]
        else:
            self.output_op.output_scale = tf.concat(slice_output_scales, 0)
        self.output_op.output_zero_point = slice_zp[0]
        for slice_op in self._output_slices:
            slice_op.input_scale = self.output_op.output_scale
        return super()._enforce_output_encoding()

    def enforce_internal_encoding(self, training=False, **kwargs):
        # Passthrow op
        self.input_pass_op.forward_encoding()

        # Reoredr ops
        self._reorder1.input_scale = self.input_pass_op.output_scale
        self._reorder1.input_zero_point = self.input_pass_op.output_zero_point
        self._reorder1.enforce_encoding()
        self._reorder2.input_scale = self.input_pass_op.output_scale
        self._reorder2.input_zero_point = self.input_pass_op.output_zero_point
        self._reorder2.enforce_encoding()

        self.ew_mult_op.input_scales = [self._reorder1.output_scale, self._reorder2.output_scale]
        self.ew_mult_op.input_zero_points = [self._reorder1.output_zero_point, self._reorder2.output_zero_point]

        super().enforce_internal_encoding(training=training, **kwargs)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)
