import copy

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.cache_op import CacheOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.atomic_ops.spatial_transpose import SpatialTransposeOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.model.preprocess.conversion import ConvertionWeightsDataStruct
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    CacheOpMode,
    DataPath,
    EquivClassification,
    FormatConversionType,
    LayerHandlerType,
    LayerType,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import BadInputsShape
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import get_scalar_vector


class HailoInputLayer(BaseHailoSingleAtomic):
    _hn_type = LayerType.INPUT_LAYER
    OP_NAME = "output_op"

    def __init__(
        self,
        name: str,
        input_shapes,
        logger=None,
        transposed=False,
        conversion_type=None,
        emulate_conversion=False,
        op=None,
        **kwargs,
    ):
        if not op:
            op = (
                SpatialTransposeOp(f"{name}/{self.OP_NAME}", logger=logger)
                if transposed
                else PassthruOp(f"{name}/{self.OP_NAME}", logger=logger)
            )
        super().__init__(name=name, core_op=op, logger=logger, **kwargs)
        self.conversion_type = conversion_type
        self.emulate_conversion = emulate_conversion
        self.conversion_weights = ConvertionWeightsDataStruct()
        self.transposed = transposed
        self.set_input_spec(input_shapes)
        self.force_scalar_encoding = True

        self.encoding_const = False

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.atomic_op)

        layer_flow.add_edge(in1, self.atomic_op, DataPath.LAYER_OUT)
        layer_flow.add_edge(self.atomic_op, out1, DataPath.LAYER_OUT)
        return layer_flow

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        conversion_type = None
        emulate_conversion = False
        transposed = hn_element["transposed"]
        input_shapes = hn_element["input_shapes"]
        # Convert each shape in input_shapes to have rank 4
        input_shapes = [shape if len(shape) == 4 else [shape[0], 1, 1, shape[-1]] for shape in input_shapes]
        if "conversion_type" in hn_element:
            conversion_type = FormatConversionType(hn_element["conversion_type"])
            emulate_conversion = hn_element["emulate_conversion"]
        layer = cls(
            name=lname,
            input_shapes=input_shapes,
            logger=logger,
            transposed=transposed,
            conversion_type=conversion_type,
            emulate_conversion=emulate_conversion,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def set_input_spec(self, input_shapes):
        if self.transposed and len(input_shapes[0]) == 4:
            input_shapes_trans = input_shapes[0][1:]
            input_shapes_trans[0] = input_shapes[0][2]
            input_shapes_trans[1] = input_shapes[0][1]
            self.input_spec = tf.keras.layers.InputSpec(shape=[None, *input_shapes_trans])
        elif len(input_shapes[0]) == 2:
            features = input_shapes[0][-1]
            shape = [None, 1, 1, features]
            self.input_spec = tf.keras.layers.InputSpec(shape=shape)
        else:
            self.input_spec = tf.keras.layers.InputSpec(shape=[None, *input_shapes[0][1:]])

    def _input_stats_ops(self):
        return []

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=True)

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=True)

    def enforce_io_encoding(self, training=False, **kwargs):
        """
        a different behavior from regular has_same_io_scales -
        instead of updating the output_Scale from the input_scale we update the input_scales from the output.

        """
        self.atomic_op.set_input_lossy_element(self.atomic_op.output_lossy_element)
        self.set_input_scale(self.output_scale, 0)
        self.set_input_zero_point(self.output_zero_point, 0)

    def _export_weights(self):
        if self.conversion_type == FormatConversionType.rotation:
            return {"rotation": self.conversion_weights.rotation}
        elif self.conversion_type == FormatConversionType.mask:
            return {"tile": self.conversion_weights.tile}
        elif self.conversion_type in [FormatConversionType.cos, FormatConversionType.sin]:
            return {
                "factor": self.conversion_weights.factor,
                "tile": self.conversion_weights.tile,
                "theta": self.conversion_weights.theta,
            }
        elif self.conversion_type == FormatConversionType.embedding:
            return {"embed": self.conversion_weights.embed}
        return dict()

    def import_weights(self, layer_params: LayerParams):
        if self.conversion_type == FormatConversionType.rotation:
            self.conversion_weights.rotation = layer_params["rotation"]
        elif self.conversion_type == FormatConversionType.mask:
            self.conversion_weights.tile = layer_params["tile"]
        elif self.conversion_type in [FormatConversionType.cos, FormatConversionType.sin]:
            self.conversion_weights.factor = layer_params["factor"]
            self.conversion_weights.tile = layer_params["tile"]
            self.conversion_weights.theta = layer_params["theta"]
        elif self.conversion_type == FormatConversionType.embedding:
            self.conversion_weights.embed = layer_params["embed"]
        super().import_weights(layer_params)

    def import_layer_params(self, params):
        super().import_layer_params(params)
        self._update_encoding_scalar()

    def _export_layer_params(self):
        params = super()._export_layer_params()
        if self.conversion_type == FormatConversionType.rotation:
            pass
        elif self.conversion_type == FormatConversionType.mask:
            params["conversion/tile"] = np.array(self.conversion_weights.tile, dtype=np.int32)
        elif self.conversion_type in [FormatConversionType.cos, FormatConversionType.sin]:
            # params["conversion/factor"] = self.conversion_weights.factor
            params["conversion/tile"] = np.array(self.conversion_weights.tile, np.int32)
            params["conversion/theta"] = np.array(self.conversion_weights.theta, np.float32)
        elif self.conversion_type == FormatConversionType.embedding:
            embed_q = self.atomic_op.input_lossy_element(
                self.conversion_weights.embed / self.input_scale + self.input_zero_point
            )
            dtype = np.uint8 if self.atomic_op.input_lossy_element.bits == 8 else np.uint16
            params["conversion/embed"] = np.array(embed_q, dtype=dtype)
        return params

    def _update_encoding_scalar(self):
        # output_scales
        for i in range(len(self.output_scales)):
            output_scale = self.output_scales[i]
            if not np.all(output_scale == output_scale[0]):
                self.force_scalar_encoding = False
        # input_scales
        for i in range(len(self.input_scales)):
            input_scale = self.input_scales[i]
            if not np.all(input_scale == input_scale[0]):
                self.force_scalar_encoding = False

    def validate_shape(self, input_data):
        # TODO: this logic is a bit redundant, the shape is checked by the keras' build logic
        input_shape = tuple(self.input_spec.shape[1:])
        # TODO when we support transposed add a check

        data_shape = tuple(input_data.shape[1:])
        if len(data_shape) == 1:
            data_shape = (1, 1, data_shape[0])
        if not input_data.shape:
            self._logger.verbose("Dataset signature has missing information, skipping shape validation")
            return
        if None in data_shape and len(input_shape) == len(data_shape):
            self._logger.verbose("Dataset signature has missing information, skipping shape validation")
            return
        if self.transposed:
            if input_shape[0] == data_shape[1] and input_shape[1] == data_shape[0] and input_shape[2] == data_shape[2]:
                return
        if input_shape == data_shape:
            return
        elif self.emulate_conversion:
            if (
                self.conversion_type == FormatConversionType.yuy2_to_hailo_yuv
                and input_shape[0] * input_shape[1] * 2 == data_shape[0]
            ):
                return
        else:
            raise BadInputsShape(self.full_name, input_shape, data_shape)

    def disable_internal_encoding(self, encode_inputs=None, decode_outputs=None, quant_inputs=None, **kwgs):
        encode_inputs = True if encode_inputs is None else encode_inputs
        decode_outputs = False if decode_outputs is None else decode_outputs
        quant_inputs = True if quant_inputs is None else quant_inputs
        return super().disable_internal_encoding(encode_inputs, decode_outputs, quant_inputs, **kwgs)

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.nodes[f"{self.atomic_op.full_name}/input_scale:0"]["encoding"].scalar = self.force_scalar_encoding

    def _verify_and_set_hn_io_shapes(self):
        input_shapes = copy.deepcopy(self.hn_element.get("input_shapes", None))
        output_shapes = copy.deepcopy(self.hn_element.get("output_shapes", None))

        if self.transposed and input_shapes == output_shapes:
            # Should transpose the hn element shapes only in the first of multiple calls of this function.
            # Only in the first call the input and output shapes are equal.
            if input_shapes is None:
                return
            input_shapes = input_shapes[0]
            if input_shapes is not None and len(input_shapes) == 4:
                self._hn_element["input_shapes"][0][1] = input_shapes[2]
                self._hn_element["input_shapes"][0][2] = input_shapes[1]
        super()._verify_and_set_hn_io_shapes()

    def to_hn(self, out_degree=None):
        hn_dict = super().to_hn(out_degree=out_degree)
        if (
            self.transposed
            and self.atomic_op.output_shapes_is_valid
            and not hn_dict["input_shapes"][0] == hn_dict["output_shapes"][0]
        ):
            input_shapes = copy.deepcopy(hn_dict["input_shapes"][0])
            hn_dict["input_shapes"][0][1] = input_shapes[2]
            hn_dict["input_shapes"][0][2] = input_shapes[1]
        # TODO: when this feature is supported, uncomment this line.
        # hn_dict['input_scale_per_channel'] = not (self.force_scalar_encoding)
        return hn_dict

    @property
    def groups(self):
        # TODO: this is a temporary solution as equalization expect all producers to have groups
        return 1


class HailoOutputLayer(BaseHailoSingleAtomic):
    _hn_type = LayerType.OUTPUT_LAYER
    OP_NAME = "input_op"

    def __init__(self, name: str, input_shapes, logger=None, transposed=False, op=None, **kwargs):
        if not op:
            op = (
                SpatialTransposeOp(f"{name}/{self.OP_NAME}", logger=logger, fully_native=True)
                if transposed
                else PassthruOp(f"{name}/{self.OP_NAME}", logger=logger, fully_native=True)
            )
        super().__init__(name=name, core_op=op, logger=logger, **kwargs)
        self._input_shapes = input_shapes

        self.force_scalar_encoding = True
        self.encoding_const = False

    def _build_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        in1 = layer_flow.add_input()
        out1 = layer_flow.add_output()

        layer_flow.add_node(self.atomic_op)

        layer_flow.add_edge(in1, self.atomic_op, DataPath.LAYER_IN)
        layer_flow.add_edge(self.atomic_op, out1, DataPath.LAYER_IN)
        return layer_flow

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        transposed = hn_element["transposed"]
        input_shapes = hn_element["input_shapes"]
        input_shapes = [shape if len(shape) == 4 else [shape[0], 1, 1, shape[-1]] for shape in input_shapes]
        layer = cls(name=lname, input_shapes=input_shapes, logger=logger, transposed=transposed)
        layer.finalize_from_hn(hn_element)
        return layer

    def _layer_dependent_hw_params_modifications(self, params):
        if self.force_scalar_encoding:
            get_scalar_vector(self.output_scale)
        return super()._layer_dependent_hw_params_modifications(params)

    def _input_stats_ops(self):
        return []

    def enforce_io_encoding(self, training=False, **kwargs):
        """
        Non-arithmetic layer, encoding should be the same across...
        """
        self.atomic_op.set_output_lossy_element(self.atomic_op.input_lossy_element)
        self.set_output_scale(self.input_scale, 0)
        self.set_output_zero_point(self.input_zero_point, 0)

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.output, is_source=False)

    def _export_weights(self):
        return dict()

    def import_layer_params(self, params):
        super().import_layer_params(params)
        self._update_encoding_scalar()

    def _update_encoding_scalar(self):
        # output_scales
        for i in range(len(self.output_scales)):
            output_scale = self.output_scales[i]
            if not np.all(output_scale == output_scale[0]):
                self.force_scalar_encoding = False
        # input_scales
        for i in range(len(self.input_scales)):
            input_scale = self.input_scales[i]
            if not np.all(input_scale == input_scale[0]):
                self.force_scalar_encoding = False

    def disable_internal_encoding(self, encode_inputs=None, decode_outputs=None, quant_inputs=None, **kwgs):
        encode_inputs = False if encode_inputs is None else encode_inputs
        decode_outputs = True if decode_outputs is None else decode_outputs
        quant_inputs = True if quant_inputs is None else quant_inputs
        return super().disable_internal_encoding(encode_inputs, decode_outputs, quant_inputs, **kwgs)

    def is_differentiable(self):
        return False

    def define_encodings(self, flow):
        super().define_encodings(flow)
        flow.nodes[f"{self.atomic_op.full_name}/output_scale:0"]["encoding"].scalar = self.force_scalar_encoding

    def to_hn(self, out_degree=None):
        hn_dict = super().to_hn(out_degree=out_degree)
        hn_dict["output_scale_per_channel"] = not (self.force_scalar_encoding)
        return hn_dict

    def _get_hn_input_shapes(self):
        return self._input_shapes

    def _get_hn_output_shapes(self):
        return self._input_shapes


class HailoCacheInputLayer(HailoInputLayer):
    _hn_type = LayerType.INPUT_LAYER
    OP_NAME = "cache_op"

    def __init__(
        self,
        name: str,
        cache_id: str,
        input_shapes: list,
        logger=None,
        transposed=False,
        conversion_type=None,
        emulate_conversion=False,
        **kwargs,
    ):
        op = CacheOp(f"{name}/{self.OP_NAME}", cache_id, CacheOpMode.READ, logger=logger)
        super().__init__(
            name,
            input_shapes,
            logger,
            transposed,
            conversion_type,
            emulate_conversion,
            op,
            **kwargs,
        )
        self._cache_id = cache_id

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        conversion_type = None
        emulate_conversion = False
        if "conversion_type" in hn_element:
            conversion_type = FormatConversionType(hn_element["conversion_type"])
            emulate_conversion = hn_element["emulate_conversion"]

        layer = cls(
            name=lname,
            cache_id=hn_element["cache_id"],
            input_shapes=hn_element["input_shapes"],
            logger=logger,
            transposed=hn_element["transposed"],
            conversion_type=conversion_type,
            emulate_conversion=emulate_conversion,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    @property
    def cache_id(self):
        return self._cache_id


class HailoCacheOutputLayer(HailoOutputLayer):
    _hn_type = LayerType.OUTPUT_LAYER
    OP_NAME = "cache_op"

    def __init__(
        self,
        name: str,
        cache_id: str,
        input_shapes: list,
        logger=None,
        transposed=False,
        **kwargs,
    ):
        op = CacheOp(f"{name}/{self.OP_NAME}", cache_id, CacheOpMode.WRITE, logger=logger)
        super().__init__(
            name,
            input_shapes,
            logger,
            transposed,
            op,
            **kwargs,
        )

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        layer = cls(
            name=lname,
            cache_id=hn_element["cache_id"],
            input_shapes=hn_element["input_shapes"],
            logger=logger,
            transposed=hn_element["transposed"],
        )
        layer.finalize_from_hn(hn_element)
        return layer
