from abc import ABC, abstractmethod

from hailo_model_optimization.acceleras.hailo_layers.base_acceleras_layer import BaseAccelerasLayer
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EquivClassification,
    LayerHandlerType,
    PrecisionMode,
    QuantizationAlgorithms,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasInitializationError
from hailo_model_optimization.acceleras.utils.flow_state_utils import LayerState


class DataEncoderDecoder:
    def __init__(self, scale, zero_point) -> None:
        self._scale = scale
        self._zero_point = zero_point

    def encode(self, data):
        return data / self._scale + self._zero_point

    def decode(self, data):
        return (data - self._zero_point) * self._scale


class BaseHailoNonNNCoreLayer(BaseAccelerasLayer, ABC):
    """
    Abstract class declaring the API that all our non-nn core layers will be exposed to.
    """

    def __init__(self, name, input_shapes, output_shapes, logger=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.input_shapes = input_shapes
        self.output_shapes = output_shapes
        self._inputs_decoders = None
        self._outputs_encoders = None
        self._precision_mode = PrecisionMode.a8_w8_a8

    def call(self, inputs, training=False, **kwargs):
        if self._inputs_decoders:
            if isinstance(inputs, list):
                inputs = [decoder.decode(inp) for decoder, inp in zip(self._inputs_decoders, inputs)]
            else:
                inputs = self._inputs_decoders[0].decode(inputs)
        outputs = self.call_core(inputs, training, **kwargs)
        if self._outputs_encoders:
            if isinstance(outputs, list):
                outputs = [encoder.encode(out) for encoder, out in zip(self._outputs_encoders, outputs)]
            else:
                outputs = self._outputs_encoders[0].encode(outputs)
        return outputs

    @abstractmethod
    def call_core(self, inputs, training=False, **kwargs):
        pass

    @classmethod
    @abstractmethod
    def from_hn(cls, lname, hn_element, logger=None):
        """
        OVERRIDE in subclasses with whatever is needed to create instance from an HN slice
        """

    def is_differentiable(self) -> bool:
        return False

    @property
    def atomic_ops(self):
        return []

    @property
    def input_shapes(self):
        return self._input_shapes

    @input_shapes.setter
    def input_shapes(self, input_shapes):
        self._input_shapes = input_shapes

    @property
    def output_shapes(self):
        return self._output_shapes

    @output_shapes.setter
    def output_shapes(self, output_shapes):
        self._output_shapes = output_shapes

    def resolve_output_index(self, output_index):
        """
        Returns the data output index of a layer based on the output edge index.
        If the model has a single output, the index will be 0, otherwise the index will be the actual index.
        """
        return output_index % self.num_outputs

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_algo_callback(self, algo):
        classifier_callbacks_by_algo = {QuantizationAlgorithms.equalization: self.get_equalization_handler_type}
        return classifier_callbacks_by_algo[algo]

    def enable_inputs_decoding(self, input_scales, input_zero_points):
        return
        decoders = []
        for inp_scale, inp_zp in zip(input_scales, input_zero_points):
            decoder = DataEncoderDecoder(inp_scale, inp_zp)
            decoders.append(decoder)
        self._inputs_decoders = decoders

    def enable_output_encoding(self, output_scales, output_zero_points):
        return
        encoders = []
        for out_scale, out_zp in zip(output_scales, output_zero_points):
            encoder = DataEncoderDecoder(out_scale, out_zp)
            encoders.append(encoder)
        self._outputs_encoders = encoders

    def disable_inputs_decoding(self):
        self._inputs_decoders = None

    def disable_outputs_encoding(self):
        self._outputs_encoders = None

    @property
    def num_inputs(self):
        return len(self.input_shapes)

    @property
    def num_outputs(self):
        return len(self.output_shapes)

    def enable_internal_encoding(self, **kargs):
        pass

    def disable_internal_encoding(self, **kargs):
        pass

    @property
    def is_native_input(self):
        return True

    @property
    def is_native_output(self):
        return True

    def create_io_encoding_candidates(self):
        pass

    def import_precision_config(self, precision_config, optimization_target):
        self._precision_mode = precision_config.precision_mode

    def export_flow_state(self) -> LayerState:
        """
        Export the flow parameters of the layer. None-nn-core-layeres have no atomic ops.
        """
        return LayerState(full_name=self.full_name, atomic_ops={})

    def import_flow_state(self, layer_state: LayerState) -> None:
        """
        Import the flow parameters of the layer. None-nn-core-layeres have no atomic ops.
        """
        if self.full_name != layer_state.full_name:
            raise AccelerasInitializationError(
                f"while importing flow states, names didn't match. current {self.full_name} and attempted import {layer_state.full_name}"
            )

    # implement the default behavior for the abstract method for supporting the test flow
    def import_weights(self, layer_params, **kwargs):
        pass

    def export_weights(self, include_shared_weights=True):
        return {}

    @property
    def activation_atomic_op(self):
        return None

    @property
    def bit_exact_supported(self) -> bool:
        return False

    def start_stats_collection(self, **kwargs):
        pass

    def stop_stats_collection(self):
        pass

    def add_supported_state(self, *states, **kwargs):
        pass

    def create_quant_element_by_data_path(self, data_path, bits):
        pass

    def create_quant_element_custom_behavior(self, precision_config, optimization_target):
        pass

    def enable_lossy(self):
        pass

    def enforce_io_encoding(self):
        pass

    def create_hw_params(self, *args, **kwargs):
        pass

    def enforce_internal_encoding(self):
        pass

    def get_bops(self) -> int:
        return 1

    def disable_lossy(self, *, native_act=False):
        pass

    @property
    def output_zero_point(self):
        return 0

    @property
    def output_scale(self):
        return 1

    @property
    def output_scales(self):
        return [1]

    @property
    def input_scales(self):
        return [1]

    def check_encoding_consistency(self):
        pass

    def export_acceleras(self, include_shared_weights=True):
        return {}

    def import_quant(self, quant_params):
        pass

    def export_quant(self, include_shared_weights=True):
        return {}

    def export_qnpz(self, convert=False):
        return {}

    def get_input_lossy_elements(self):
        return []

    def keep_original_output_stats(self):
        pass
