from typing import List

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.atomic_ops.passthru_op import PassthruOp
from hailo_model_optimization.acceleras.model.preprocess.conversion import (
    i420_to_yuv_conversion,
    nv12_to_yuv_conversion,
    nv21_to_yuv_conversion,
    yuy2_to_yuv_conversion,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import FormatConversionType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.acceleras.utils.layer_utils import reshape_input_by_windows, reshape_output_by_windows


class BaseRGBConversionOp(BaseNonArithmeticAtomicOp):
    """
    Format change from Bayer Filter tiles to RGB
    meaning from double green tiles to 3 channels
    rgb

    Attributes
        pixel_map: Mapping for pixel tiles to RGB channels

    """

    pixel_map: dict
    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        height = inp.shape[1]
        width = inp.shape[2]

        c00 = inp[:, ::2, ::2, :]
        c01 = inp[:, ::2, 1::2, :]
        c10 = inp[:, 1::2, ::2, :]
        c11 = inp[:, 1::2, 1::2, :]

        r, g0, g1, b = self.pixel_routing([c00, c01, c10, c11])
        r = tf.image.resize(r, [height, width], method="nearest")
        b = tf.image.resize(b, [height, width], method="nearest")
        g0 = tf.image.resize(g0, [int(height / 2), width], method="nearest")
        g1 = tf.image.resize(g1, [int(height / 2), width], method="nearest")
        g = tf.reshape(tf.concat([g0, g1], 2), [-1, height, width, 1])
        conversion = tf.concat([r, g, b], 3, name="format_conversion")
        return conversion

    def enforce_encoding(self):
        if len(self.input_scales[0].shape) == 0:
            self.output_scale = self.input_scales[0]
        else:
            repeat_num = 3
            self.output_scale = tf.repeat(self.input_scales[0], repeat_num, axis=0)
        self.output_zero_point = self.input_zero_points[0]

    def pixel_routing(self, pixels: List[np.array]):
        """
        Maps tiles to colors
        """
        r = pixels[self.pixel_map["r"]]
        g0 = pixels[self.pixel_map["g0"]]
        g1 = pixels[self.pixel_map["g1"]]
        b = pixels[self.pixel_map["b"]]
        return r, g0, g1, b


class RggbToHailoRgbOp(BaseRGBConversionOp):
    """
    Transforms for pixels tiles to RGB
    [ r, g]
    [g , b] ==> [r,g,b]
    """

    pixel_map = {"r": 0, "g0": 1, "g1": 2, "b": 3}


class BggrToHailoRgbOp(BaseRGBConversionOp):
    """
    Transforms for pixels tiles to RGB
    [ b, g]
    [g , r] ==> [r,g,b]
    """

    pixel_map = {"b": 0, "g0": 1, "g1": 2, "r": 3}


class GrbgToHailoRgbOp(BaseRGBConversionOp):
    """
    Transforms for pixels tiles to RGB
    [ g, r]
    [b , g] ==> [r,g,b]
    """

    pixel_map = {"g0": 0, "r": 1, "b": 2, "g1": 3}


class GbrgToHailoRgbOp(BaseRGBConversionOp):
    """
    Transforms for pixels tiles to RGB
    [ g, b]
    [r , g] ==> [r,g,b]
    """

    pixel_map = {"g0": 0, "b": 1, "r": 2, "g1": 3}


###########################################


class TwelveToEightBitOp(BaseNonArithmeticAtomicOp):
    """
    Transform ...

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def enforce_encoding(self):
        self.output_scales = self.input_scales
        self.output_zero_points = self.input_zero_points

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        conversion = tf.floor(inp[:, :, ::2, :] / 16) + (inp[:, :, 1::2, :] % 16) * 16
        return conversion


class TwelveToSixteenBitOp(BaseNonArithmeticAtomicOp):
    """
    Transform ...

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def enforce_encoding(self):
        self.output_scales = self.input_scales
        self.output_zero_points = self.input_zero_points

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        inp_int = tf.cast(inp, tf.uint16)

        # Calculate the width in terms of 16-bit blocks as we are working with 12-bit values packed in 16 bits
        input_shape = inp.shape
        height = int(input_shape[1])
        width = int(input_shape[2])
        features = int(input_shape[3])
        width_16bits = (width * 2) // 3

        inp_int = tf.reshape(inp_int, [-1, height, width // 3, 3, features])

        # Extract the three bytes used to store the two 12-bit numbers
        byte1 = inp_int[:, :, :, 0, :]
        byte2 = inp_int[:, :, :, 1, :]
        byte3 = inp_int[:, :, :, 2, :]

        # Reconstruct the first 12-bit number using TensorFlow bitwise operations
        first_12bit = tf.bitwise.bitwise_or(
            tf.bitwise.bitwise_and(byte1, 0xFF), tf.bitwise.left_shift(tf.bitwise.bitwise_and(byte2, 0x0F), 8)
        )  # (b, h, 8, 1, f)

        # Reconstruct the second 12-bit number using TensorFlow bitwise operations
        second_12bit = tf.bitwise.bitwise_or(
            tf.bitwise.right_shift(tf.bitwise.bitwise_and(byte2, 0xF0), 4),
            tf.bitwise.left_shift(tf.bitwise.bitwise_and(byte3, 0xFF), 4),
        )

        conversion = tf.concat([first_12bit, second_12bit], 3, name="op")
        conversion = tf.reshape(conversion, [-1, height, width_16bits, features])
        return conversion


class SixteenToTwelveBitOp(BaseNonArithmeticAtomicOp):
    """
    Transform ...

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def enforce_encoding(self):
        self.output_scales = self.input_scales
        self.output_zero_points = self.input_zero_points

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]

        # Calculate the width in terms of 16-bit blocks as we are working with 12-bit values packed in 16 bits
        input_shape = inp.shape
        height = int(input_shape[1])
        width = int(input_shape[2])
        features = int(input_shape[3])
        width_8bits = (width * 3) // 2

        inp_int = tf.cast(inp, tf.uint64)
        inp_int = tf.reshape(inp_int, [-1, height, width // 4, 4, features])

        pixel1 = tf.bitwise.bitwise_and(inp_int[:, :, :, 0, :], 0xFFF)
        pixel2 = tf.bitwise.bitwise_and(inp_int[:, :, :, 1, :], 0xFFF)
        pixel3 = tf.bitwise.bitwise_and(inp_int[:, :, :, 2, :], 0xFFF)
        pixel4 = tf.bitwise.bitwise_and(inp_int[:, :, :, 3, :], 0xFFF)

        byte1 = tf.bitwise.bitwise_and(pixel1, 0xFF)
        byte2 = tf.bitwise.bitwise_or(
            tf.bitwise.right_shift(tf.bitwise.bitwise_and(pixel1, 0xF00), 8),
            tf.bitwise.left_shift(tf.bitwise.bitwise_and(pixel2, 0xF), 4),
        )
        byte3 = tf.bitwise.right_shift(tf.bitwise.bitwise_and(pixel2, 0xFF0), 4)
        byte4 = tf.bitwise.bitwise_and(pixel3, 0xFF)
        byte5 = tf.bitwise.bitwise_or(
            tf.bitwise.right_shift(tf.bitwise.bitwise_and(pixel3, 0xF00), 8),
            tf.bitwise.left_shift(tf.bitwise.bitwise_and(pixel4, 0xF), 4),
        )
        byte6 = tf.bitwise.right_shift(tf.bitwise.bitwise_and(pixel4, 0xFF0), 4)

        conversion = tf.concat([byte1, byte2, byte3, byte4, byte5, byte6], 3, name="op")
        conversion = tf.reshape(conversion, [-1, height, width_8bits, features])
        return conversion


class FlatToFramesOp(BaseNonArithmeticAtomicOp):
    """
    Format conversion
    Transform from flat vector to ..
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def enforce_encoding(self):
        raise AccelerasImplementationError(f"enforce_encoding is not implemented for {self.full_name}")

    def call_native(self, inputs, **kwargs):
        conversion = tf.reshape(inputs[0], self.output_shape)
        return conversion


class TransposeWidthFeaturesOp(BaseNonArithmeticAtomicOp):
    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, groups: int = 1, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._groups = groups

    @property
    def groups(self):
        return self._groups

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        input_shape = inp.shape
        output_width = int(input_shape[3] / self.groups)
        output_features = int(input_shape[2] * self.groups)

        r0 = tf.reshape(inp, [-1, input_shape[1], input_shape[2], self.groups, output_width])
        t0 = tf.transpose(r0, perm=[0, 1, 4, 3, 2])
        op = tf.reshape(t0, [-1, input_shape[1], output_width, output_features])
        return op

    def backward_encoding(self):
        if len(self.output_scale.shape) == 0:
            self.input_scale = self.output_scale
        else:
            self.input_scale = tf.repeat(self.output_scale[0], self.output_shape[-2], axis=0)
        self.input_zero_point = self.output_zero_point

    def enforce_encoding(self):
        if len(self.input_scales[0].shape) == 0:
            self.output_scale = self.input_scales[0]
        else:
            self.output_scale = tf.repeat(self.input_scales[0][0], self.output_shape[-1], axis=0)
        self.output_zero_point = self.input_zero_points[0]


class TransposeHeightWidthOp(TransposeWidthFeaturesOp):
    def call_native(self, inputs, **kwargs):
        return tf.transpose(inputs[0], (0, 2, 1, 3))


class FeaturesToWidthFeaturesOp(BaseNonArithmeticAtomicOp):
    num_inputs = 1
    num_outputs = 1


class SpatialReshapeOp(BaseNonArithmeticAtomicOp):
    """
    Spatial "un-flatten", expands shape from [N, 1, W, C] to [1, H', W', C], W=H'*W'
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        logger=None,
        fully_native=None,
        spatial_reshape_sizes=None,
        input_windows=None,
        output_windows=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._spatial_reshape_sizes = spatial_reshape_sizes
        self._input_windows = input_windows
        self._output_windows = output_windows

    @property
    def spatial_reshape_sizes(self):
        return self._spatial_reshape_sizes

    @property
    def input_windows(self):
        return self._input_windows if self._input_windows else [1, 1, 1]

    @property
    def output_windows(self):
        return self._output_windows if self._output_windows else [1, 1, 1]

    def call_native(self, inputs: tf.Tensor, **kwargs):
        inp = inputs[0]
        inp = reshape_input_by_windows(inp, self.input_windows)
        input_shape = inp.shape
        output_size = self.spatial_reshape_sizes[0] * self.spatial_reshape_sizes[1]
        input_size = inp.shape[1] * inp.shape[2]
        features = inp.shape[-1]

        inp = tf.reshape(inp, [-1, input_size, features])
        if output_size >= input_size:
            inp = tf.pad(inp, [[0, 0], [0, output_size - input_size], [0, 0]])
        else:
            inp = inp[:, :output_size, :]

        output_shape = [-1, self.spatial_reshape_sizes[0], self.spatial_reshape_sizes[1], input_shape[3]]
        return reshape_output_by_windows(tf.reshape(inp, output_shape), self.output_windows)

    def enforce_encoding(self):
        self.output_scale = self.input_scales[0]
        self.output_zero_point = self.input_zero_points[0]


class SpatialReshapeHeightFeaturesOp(BaseNonArithmeticAtomicOp):
    """
    Spatial reshape height features, change spatial shape from [N, H, W, C] to [1, H', W, C']
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(
        self,
        name: str,
        logger=None,
        fully_native=None,
        spatial_reshape_sizes=None,
        input_windows=None,
        output_windows=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._spatial_reshape_sizes = spatial_reshape_sizes
        self._input_windows = input_windows
        self._output_windows = output_windows

    @property
    def spatial_reshape_sizes(self):
        return self._spatial_reshape_sizes

    @property
    def input_windows(self):
        return self._input_windows if self._input_windows else [1, 1, 1]

    @property
    def output_windows(self):
        return self._output_windows if self._output_windows else [1, 1, 1]

    def call_native(self, inputs: tf.Tensor, **kwargs):
        inp = inputs[0]
        inp = reshape_input_by_windows(inp, self.input_windows)

        output_shape = [-1, self.spatial_reshape_sizes[0], self.spatial_reshape_sizes[2], inp.shape[2]]
        t0 = tf.transpose(inp, perm=[0, 1, 3, 2])
        r0 = tf.reshape(t0, output_shape)
        op = tf.transpose(r0, perm=[0, 1, 3, 2])
        return reshape_output_by_windows(op, self.output_windows)

    def enforce_encoding(self):
        self.output_scale = self.input_scales[0]
        self.output_zero_point = self.input_zero_points[0]


class RgbxToRgb(BaseNonArithmeticAtomicOp):
    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        conversion = inp[:, :, :, :3]
        return conversion

    def enforce_encoding(self):
        self.output_scale = self.input_scales[0][:3]
        self.output_zero_point = self.input_zero_points[0]


class Yuy2ToYuv(BaseNonArithmeticAtomicOp):
    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]

        output_shape = list(inp.shape)
        output_shape[-1] = 3
        return yuy2_to_yuv_conversion(inp, output_shape)

    def backward_encoding(self):
        if len(self.output_scale.shape) == 0:
            self.input_scale = self.output_scale
        else:
            self.input_scale = tf.repeat(self.output_scale[0], self.output_shape[-2], axis=0)
        self.input_zero_point = self.output_zero_point

    def enforce_encoding(self):
        if len(self.input_scales[0].shape) == 0:
            self.output_scale = self.input_scales[0]
        else:
            self.output_scale = tf.repeat(self.input_scales[0][0], self.output_shape[-1], axis=0)
        self.output_zero_point = self.input_zero_points[0]


class NV12ToYuv(BaseNonArithmeticAtomicOp):
    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]

        output_shape = list(inp.shape)
        output_shape[1] *= 2
        return nv12_to_yuv_conversion(inp, output_shape)

    def backward_encoding(self):
        if len(self.output_scale.shape) == 0:
            self.input_scale = self.output_scale
        else:
            self.input_scale = tf.repeat(self.output_scale[0], self.output_shape[-2], axis=0)
        self.input_zero_point = self.output_zero_point

    def enforce_encoding(self):
        if len(self.input_scales[0].shape) == 0:
            self.output_scale = self.input_scales[0]
        else:
            self.output_scale = tf.repeat(self.input_scales[0][0], self.output_shape[-1], axis=0)
        self.output_zero_point = self.input_zero_points[0]


class NV21ToYuv(NV12ToYuv):
    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        output_shape = list(inp.shape)
        output_shape[1] *= 2
        return nv21_to_yuv_conversion(inp, output_shape)


class I420ToYuv(NV12ToYuv):
    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        output_shape = list(inp.shape)
        output_shape[1] *= 2
        return i420_to_yuv_conversion(inp, output_shape)


def format_conversion_factory(
    conversion_type: FormatConversionType,
    in_emulation_graph: bool,
) -> BaseNonArithmeticAtomicOp:
    """
    Factory for conversion ops

    Args:
        conversion_type: Supported conversion

    """
    if not in_emulation_graph:
        return PassthruOp

    conversions = {
        FormatConversionType.mipi_bayer_rggb_to_hailo_rgb: RggbToHailoRgbOp,
        FormatConversionType.mipi_bayer_bggr_to_hailo_rgb: BggrToHailoRgbOp,
        FormatConversionType.mipi_bayer_grbg_to_hailo_rgb: GrbgToHailoRgbOp,
        FormatConversionType.mipi_bayer_gbrg_to_hailo_rgb: GbrgToHailoRgbOp,
        FormatConversionType.twelve_to_eight_bit: TwelveToEightBitOp,
        FormatConversionType.twelve_to_sixteen_bit: TwelveToSixteenBitOp,
        FormatConversionType.sixteen_to_twelve_bit: SixteenToTwelveBitOp,
        FormatConversionType.spatial_reshape: SpatialReshapeOp,
        FormatConversionType.tf_rgb_to_hailo_rgb: PassthruOp,
        FormatConversionType.tf_rgbx_to_hailo_rgb: RgbxToRgb,
        FormatConversionType.mipi_rgb888_to_hailo_rgb: PassthruOp,
        FormatConversionType.hailo_rgb_to_tf_rgb: PassthruOp,
        FormatConversionType.hailo_rgb_to_ppu: PassthruOp,
        FormatConversionType.ppu_to_hailo_rgb: PassthruOp,
        FormatConversionType.hailo_rgb_to_f8cr: PassthruOp,
        FormatConversionType.f8cr_to_hailo_rgb: PassthruOp,
        FormatConversionType.yuy2_to_hailo_yuv: Yuy2ToYuv,
        FormatConversionType.hxf_to_w_transposed: PassthruOp,
        FormatConversionType.f_to_hxw_transposed: PassthruOp,
        FormatConversionType.fcr_to_c8fr: PassthruOp,
        FormatConversionType.f8cr_to_fcr: PassthruOp,
        FormatConversionType.c8fr_to_frames: PassthruOp,
        FormatConversionType.reshape_1xw0_to_hxw: PassthruOp,
        FormatConversionType.hailo_rgb_to_lcu: PassthruOp,
        FormatConversionType.transpose_height_width: TransposeHeightWidthOp,
        FormatConversionType.nv12_to_hailo_yuv: NV12ToYuv,
        FormatConversionType.nv21_to_hailo_yuv: NV21ToYuv,
        FormatConversionType.i420_to_hailo_yuv: I420ToYuv,
        FormatConversionType.reshape_height_features: SpatialReshapeHeightFeaturesOp,
        FormatConversionType.reshape_post_ew_mult: PassthruOp,
    }

    return conversions[conversion_type]
