import os
from abc import abstractmethod
from typing import Dict, Iterator, List, Optional, Tuple

import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_decompose_flow import LayerDecomposeFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DataPath,
    OpStates,
    OptimizationTarget,
    PrecisionMode,
    SplittedPrecisionMode,
)


class BaseHailoLayerDecompose(BaseHailoLayer):
    """
    This Base layer is a flow of layers instead of ops
    """

    # Layer block intermediate data path
    INTERMEDIATE_DATA_PATH = {
        DataPath.INTER_BLOCK_16: 16,
        DataPath.INTER_BLOCK_8: 8,
    }

    def _init_flow(self) -> LayerDecomposeFlow:
        # overrides base hailo layer method
        layer_flow = LayerDecomposeFlow()
        ops = [op for op in self.__dict__.values() if isinstance(op, BaseHailoLayer)]
        for op in ops:
            layer_flow.add_node(op)
        return layer_flow

    # region abstract methods
    @abstractmethod
    def _create_weight_constraints(self, layer_precision_mode: Dict[str, int]) -> Dict[str, int]:
        """
        This method should be overridden by each decompose layer to provide hardcoded precision modes for each layer.
        By default, the precision mode is copied to each class. The in/out precision of layers is overridden with data
        in DecomposeLayerFlow generated in the add_edge() method. In this method, the weights precision of any layer
        can be overridden. If the decompose layer does not have precision constraints, simply return the input
        layer_precision_mode.
        """

    @abstractmethod
    def _apply_precision_config_constraints(
        self, layer_precision_config: Dict[str, LayerPrecisionConfig]
    ) -> Dict[str, LayerPrecisionConfig]:
        """Same as _create_weight_constraints but used to apply constraints on the precision config, such as bias_mode or quantization groups"""

    @abstractmethod
    def _get_activation_layer(self) -> BaseHailoLayer:
        """Base Hailo Layer is hardcoded to have only one activation op. Thus, any decomposde layer have to define it's own activation layer, if exists"""

    @abstractmethod
    def _get_kernel_bits(self) -> Optional[int]:
        """each decomposed layer needs a main kernel, this kernel that the layer will import and export the layer precision config from"""
        return None

    @property
    @abstractmethod
    def _training_layer(self) -> BaseHailoLayer:
        """Can be left empty if training is disabled"""
        pass

    # region precision config

    def import_precision_config(self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget):
        """This method is responsible not only to apply precision_config and create the necessary quant_elements, but also generate a separate precision config per layer, each with it's own modifications"""
        # Initialize layer_precision_mode with split precision strings
        layer_precision_config = self._generate_layer_precision_config(precision_config)
        # Apply layerwise import_precision_config
        for layer in self._iterate_layers():
            layer.import_precision_config(
                precision_config=layer_precision_config[layer.name], optimization_target=optimization_target
            )
        self.enforce_io_encoding()

    def import_translation_config(
        self,
        translation_config: LayerTranslationConfig,
    ):
        for layer in self._iterate_layers():
            layer_translation_config = LayerTranslationConfig()
            if layer in self._iterate_input_layers():
                layer_translation_config.force_range_in = translation_config.force_range_in
                layer_translation_config.input_normalization = None
                layer_translation_config.force_range_index = translation_config.force_range_index
            if layer in self._iterate_output_layers():
                layer_translation_config.weak_force_range_out = translation_config.weak_force_range_out
                layer_translation_config.force_range_out = translation_config.force_range_out
                layer_translation_config.force_range_index = translation_config.force_range_index
            if layer in self._iterate_conv_layer():
                layer_translation_config.null_channels_cutoff_factor = translation_config.null_channels_cutoff_factor
                layer_translation_config.max_elementwise_feed_repeat = translation_config.max_elementwise_feed_repeat
                layer_translation_config.max_bias_feed_repeat = translation_config.max_bias_feed_repeat
                layer_translation_config.force_shift = translation_config.force_shift
            if layer in [self._get_activation_layer()]:
                layer_translation_config.activation_fit = translation_config.activation_fit
                layer_translation_config.force_range_preact = translation_config.force_range_preact
            layer_translation_config.ignore_hw_limitation_assertion = (
                layer_translation_config.ignore_hw_limitation_assertion
            )
            layer_translation_config.activation_symmetric_range = translation_config.activation_symmetric_range

            layer.import_translation_config(translation_config=layer_translation_config)

    def _generate_layer_precision_config(self, precision_config):
        layer_precision_mode_splitted = self._split_precision_config(precision_config)

        layer_precision_mode_splitted = self._apply_weight_constraints(layer_precision_mode_splitted)
        layer_precision_mode_splitted = self._create_intermediate_quant_element(layer_precision_mode_splitted)

        layer_precision_config = self._join_precision_config(precision_config, layer_precision_mode_splitted)
        layer_precision_config = self._apply_precision_config_constraints(layer_precision_config)
        return layer_precision_config

    def _join_precision_config(
        self, precision_config: LayerPrecisionConfig, layer_precision_mode_splitted: Dict[str, SplittedPrecisionMode]
    ) -> Dict[str, PrecisionMode]:
        """Build layer precision config for each layer from the split precision mode"""
        layer_precision_config = {layer_name: precision_config.copy() for layer_name in self._iterate_layer_names()}
        for layer_name in self._iterate_layer_names():
            layer_precision_config[layer_name].precision_mode = layer_precision_mode_splitted[
                layer_name
            ].to_precision_mode()
        return layer_precision_config

    def _split_precision_config(self, precision_config: LayerPrecisionConfig) -> Dict[str, SplittedPrecisionMode]:
        """Copy the same split precision between all the layers"""
        return {
            layer_name: SplittedPrecisionMode.from_precision_mode(precision_config.precision_mode)
            for layer_name in self._iterate_layer_names()
        }

    def _apply_weight_constraints(
        self, layer_precision_mode: Dict[str, SplittedPrecisionMode]
    ) -> Dict[str, SplittedPrecisionMode]:
        # Extract weights
        weights_to_modify: Dict[str, int] = {
            layer_name: splitted_precision.weights for layer_name, splitted_precision in layer_precision_mode.items()
        }

        # Apply constraints (now only on weights)
        constrained_weights = self._create_weight_constraints(weights_to_modify)

        # Update layer_precision_mode with constrained weights
        for layer_name, constrained_weight in constrained_weights.items():
            if layer_name in layer_precision_mode:
                layer_precision_mode[layer_name].weights = constrained_weight
        return layer_precision_mode

    def _create_intermediate_quant_element(
        self, layer_precision_config: Dict[str, SplittedPrecisionMode]
    ) -> Dict[str, SplittedPrecisionMode]:
        """Enforce precision based on intermediate DataPath edges"""
        # First Part: iterate all (intermediate) edges
        for u, v in self._layer_flow.toposort_edges():
            curr_data_path = self._layer_flow.get_edge_data_path(u, v)
            if curr_data_path not in BaseHailoLayerDecompose.INTERMEDIATE_DATA_PATH.keys():
                continue
            bits = BaseHailoLayerDecompose.INTERMEDIATE_DATA_PATH[curr_data_path]
            u_bn = os.path.basename(u)
            v_bn = os.path.basename(v)

            # For each edge, translate intermediate precision config to input edge config
            layer_precision_config[u_bn].output = bits
            layer_precision_config[v_bn].input = bits
        return layer_precision_config

    # region encoding
    def _enforce_io_encoding_forward(self, training=False, **kwargs):
        """This is currently not in use, to be used by layer with no degree of freedom"""
        source_nodes = set()  # TODO has to understand if this is really necessary, currently being overwritten
        op_dict = self._layer_flow._get_op_attribute()
        for u, v in self._layer_flow.toposort_edges():
            if u not in op_dict or v not in op_dict:
                continue
            layer: BaseHailoLayer = self._layer_flow.get_op(u)
            successor = self._layer_flow.get_op(v)
            if u not in source_nodes:
                layer.enforce_io_encoding(training=training, **kwargs)
                source_nodes.add(u)
            self._set_successors_inputs_encodings(layer, successor)

    def _set_successors_inputs_encodings(self, layer: BaseHailoLayer, successor: BaseHailoLayer):
        """
        set inputs scales/zp of successor based on output_scales and output_zp of layer
        Args:
            layer: the layer name
            successor: the successor name
        """
        out_ind = self._layer_flow.get_edge_output_index(layer.full_name, successor.full_name)
        out_ind = layer.resolve_output_index(out_ind)
        inp_ind = self._layer_flow.get_edge_input_index(layer.full_name, successor.full_name)

        output_scale = layer.output_scales[out_ind]
        output_zero_point = layer.output_zero_points[out_ind]

        successor.set_input_scale(output_scale, inp_ind)
        successor.set_input_zero_point(output_zero_point, inp_ind)

    def _set_predecessor_output_encodings(self, layer: BaseHailoLayer, successor: BaseHailoLayer):
        """
        set inputs scales/zp of predecessor based on input scale and input zp of layer successor
        Args:
            layer: the layer name
            successor: the successor name
        """
        out_ind = self._layer_flow.get_edge_output_index(layer.full_name, successor.full_name)
        out_ind = layer.resolve_output_index(out_ind)
        inp_ind = self._layer_flow.get_edge_input_index(layer.full_name, successor.full_name)

        input_scale = successor.input_scales[inp_ind]
        input_zero_point = successor.input_zero_points[inp_ind]

        layer.set_output_scale(input_scale, out_ind)
        layer.set_output_zero_point(input_zero_point, out_ind)

    def check_encoding_consistency(self):
        super().check_encoding_consistency()
        for layer in self._iterate_layers():
            layer.check_encoding_consistency()

    def _validate_consistency_of_lossy_elements(self, u_layer, output_index: int, v_layer, input_index: int):
        return u_layer.get_output_lossy_elements()[output_index] == v_layer.get_input_lossy_elements()[input_index]

    def create_io_encoding_candidates(self, **kwargs):
        for layer in self._iterate_layers():
            layer.create_io_encoding_candidates(**kwargs)

    def create_input_encoding_candidates(self, translation_config=None):
        # we iterate all layers for ouptput/ input encoding candidates
        for layer in self._iterate_layers():
            layer.create_input_encoding_candidates(translation_config)

    def create_output_encoding_candidates(self, forced_range=None, translation_config=None):
        for layer in self._iterate_layers():
            layer.create_output_encoding_candidates(forced_range, translation_config)

    def disable_internal_encoding(
        self, encode_inputs=None, decode_outputs=None, quant_inputs=None, *, export_model_state=False
    ):
        # TODO test we don't need to add keys nesting
        model_state_dict = {}
        for layer in self._iterate_layers():
            layer_state = layer.disable_internal_encoding(
                encode_inputs, decode_outputs, quant_inputs, export_model_state=export_model_state
            )
            if export_model_state:
                model_state_dict.update(layer_state)
        self.enforce_internal_encoding_in_call = False
        if export_model_state:
            return model_state_dict

    # region Iterators

    def _iterate_layers(self) -> Iterator[BaseHailoLayer]:
        """Note that this method return toposort layers by relies on LayerFlow.get_ops() to return toposort layers."""
        yield from (layer for layer in self._layer_flow.get_ops() if isinstance(layer, BaseHailoLayer))

    def _iterate_layer_names(self) -> Iterator[BaseHailoLayer]:
        yield from (layer.name for layer in self._layer_flow.get_ops() if isinstance(layer, BaseHailoLayer))

    def iterate_input_ops(self) -> Iterator[Tuple[BaseAtomicOp, int]]:
        """it is not a good idea to iterate directly from decompose layer on it's atomic ops, but a lot of legacy code does itm including hailo model and various validation tools"""
        index = 0  # account for different indexes at layer flow
        for layer, _ in super().iterate_input_ops():
            if isinstance(layer, BaseHailoLayer):
                for op, _ in layer.iterate_input_ops():
                    yield op, index
                    index += 1
            else:
                yield op, index
                index += 1

    def iterate_output_ops(self) -> Iterator[Tuple[BaseAtomicOp, int]]:
        index = 0  # account for different indexes at layer flow
        for layer, _ in super().iterate_output_ops():
            if isinstance(layer, BaseHailoLayer):
                for op, _ in layer.iterate_output_ops():
                    yield op, index
                    index += 1
            else:
                yield op, index
                index += 1

    def _iterate_output_layers(self) -> Iterator[Tuple[BaseHailoLayer, int]]:
        yield from super().iterate_output_ops()  # using base hailo layer for iteration

    def _iterate_input_layers(self) -> Iterator[Tuple[BaseHailoLayer, int]]:
        yield from super().iterate_input_ops()

    def get_bias_ops(self):
        """
        Base Hailo layer uses the bias ops for purposes such as returning the bias decomposition of the layer.
        """
        bias_ops = []
        for layer in self._iterate_layers():
            bias_ops += list(layer.get_bias_ops())
        return bias_ops

    def _iterate_act_ops(self):
        yield from self._get_activation_layer()._iterate_act_ops()

    def _iterate_conv_layer(self):
        for layer in self._iterate_layers():
            if isinstance(layer, BaseHailoConv):
                yield layer

    def get_input_lossy_elements(self) -> List[QuantElement]:
        """
        returns a list of the lossy_elements.
        """
        lossy_elements = []
        for layer, _ in self._iterate_input_layers():
            lossy_elements += layer.get_input_lossy_elements()
        return lossy_elements

    def get_weight_lossy_elements(self) -> List[QuantElement]:
        """
        returns a list of the lossy_elements.
        """
        lossy_elements = []
        for layer in self._iterate_layers():
            lossy_elements += layer.get_weight_lossy_elements()
        return lossy_elements

    def get_output_lossy_elements(self) -> List[QuantElement]:
        """
        returns a list of the lossy_elements.
        """
        lossy_elements = []
        for layer, _ in self._iterate_output_layers():
            lossy_elements += layer.get_output_lossy_elements()
        return lossy_elements

    # region miscellaneous

    def start_stats_collection(self, **kwargs):
        for layer in self._iterate_layers():
            layer.start_stats_collection(**kwargs)

    def add_supported_state(self, *states: OpStates):
        for state in states:
            for layer in self._iterate_layers():
                layer.add_supported_state(state)
            self.supported_states.add(state)

    def vectorize_scales(self):
        for layer in self._iterate_layers():
            layer.vectorize_scales()
        self._create_out_in_scale_ratio()

    def _handle_negative_exponent(self, layer: BaseHailoLayer) -> bool:
        """return true if a change was made"""
        act_op = layer.activation_atomic_op
        if act_op is None:
            return False
        assigned_exp = act_op.get_assigned_exponent()
        output_shift_fix_offset = act_op.get_offset_needed_shift()

        if np.all(assigned_exp >= 0) and output_shift_fix_offset <= 0:
            # No negative exp to fix
            return False

        output_shift_fix = np.max([output_shift_fix_offset, np.max(-assigned_exp)])
        layer.update_negative_slope_exponent_shift(output_shift_fix)
        self._update_out_in_scale_ratio(output_shift_fix)
        return True

    def _update_out_in_scale_ratio(self, output_shift_fix):
        pass

    def enable_lossy(self, **kwargs):
        for layer in self._iterate_layers():
            layer.enable_lossy(**kwargs)

    def disable_lossy(self, **kwargs):
        for layer in self._iterate_layers():
            layer.disable_lossy(**kwargs)

    def lossy_status_summary(self):
        """
        handy utility for debug.
        """
        ret = "\n"
        for aop in self.atomic_ops:
            ret += "---" + aop.full_name + " ops:"
            ret += aop.lossy_status_summary()
        return ret + "\n"

    # TODO go to HailoModel and make sure HailoModel update layers and not aops
    @property
    def debug_mode(self):
        return self._debug_mode

    @debug_mode.setter
    def debug_mode(self, value):
        for layer in self._iterate_layers():
            layer._debug_mode = value
            for aop in layer.atomic_ops:
                aop.debug_mode = True
        self._debug_mode = value

    # region import/export

    def import_acceleras(self, acceleras_params: Dict[str, dict]):
        """Import acceleras params to each layer, The dictionary is nested by layer n"""
        for layer in self._iterate_layers():
            layer.import_acceleras(self._import_sub_layer_params(layer, acceleras_params))
        self._create_out_in_scale_ratio()
        self._has_hw_params = True

    def export_acceleras(self, include_shared_weights=True) -> Dict[str, dict]:
        """Export acceleras params from each layer, The dictionary is nested by layer name"""
        npz = dict()
        for layer in self._iterate_layers():
            npz.update(
                self._export_sub_layer_params(
                    layer, layer.export_acceleras(include_shared_weights=include_shared_weights)
                )
            )
        return npz

    def _export_sub_layer_params(self, layer: BaseHailoLayer, export_params: Dict[str, dict]) -> Dict[str, dict]:
        return {f"{layer.name}/{k}": v for k, v in export_params.items()}

    def _import_sub_layer_params(self, layer: BaseHailoLayer, params: Dict[str, dict]):
        return {k[len(layer.name) + 1 :]: v for k, v in params.items() if k.startswith(layer.name)}

    # def _export_unfolded_hw_params(self, layer: BaseHailoLayer, export_params: Dict[str, dict]) -> Dict[str, dict]:
    #     """This is a one way (lossy) export of the hw params, Each decompose layer is exported as multiple layers"""
    #     return {f"{layer.name}/{k}": v for k, v in export_params.items()}

    def export_hw_params(self, include_shared_weights=True, unfold_nesting=False) -> dict:
        # TODO this is WIP, will be determined by Adi's team
        if unfold_nesting:
            return  # TODO this is still not working
        params = {}
        for layer in self._iterate_layers():
            params.update(
                self._export_sub_layer_params(
                    layer, layer.export_hw_params(include_shared_weights=include_shared_weights)
                )
            )
        return params

    # region properties for scale matching

    @property
    def vector_zp_in(self):
        """
        when there is a kernel that is trainable and (1,1) we can allow a vector of zero_point input
        """
        return np.any(
            [layer.vector_zp_in for layer, _ in self._iterate_input_layers()],
        )

    @property
    def consumer_input_scale(self):
        """
        when there is a kernel that is trainable we have a vector degree of freedom for the input scale
        """
        return np.any(
            [layer.consumer_input_scale for layer, _ in self._iterate_input_layers()],
        )

    @property
    def has_activation(self):
        """
        when there is activation we have a potential scalar degree of freedom for the output scale
        Returns:
        """
        return np.any(
            [layer.has_activation for layer, _ in self._iterate_output_layers()],
        )

    @property
    def homogeneous(self):
        """
        when there is an homogenous activation we have a scalar degree of freedom for the output scale
        """
        return np.any(
            [layer.homogeneous for layer, _ in self._iterate_output_layers()],
        )

    @property
    def shift_delta(self):
        """
        Return the shift delta needed by this layer. If the layer does not support shift delta or
        the shift delta was not calculated yet, return None
        """
        for layer in self._iterate_layers():
            shift_delta = layer.shift_delta
            if shift_delta is not None:
                return shift_delta
        return None

    @property
    def strong_force_range(self):
        return self._strong_force_range

    @strong_force_range.setter
    def strong_force_range(self, value):
        self._strong_force_range = value
        for layer, _ in self._iterate_output_layers():
            layer.strong_force_range = value

    # endregion

    # region QFT

    def get_quant_element_train_mode(self):
        return self._training_layer.get_quant_element_train_mode()

    def update_quant_elements_train_mode(self, train_mode):
        return self._training_layer.update_quant_elements_train_mode(train_mode)

    def set_wraparound_loss(self, state):
        return self._training_layer.set_wraparound_loss(state)
