import copy
from abc import ABC

import numpy as np

from hailo_sdk_common.hailo_nn.hn_definitions import DefuseType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams


class InnerLayer(LayerWithParams, ABC):
    """
    InnerLayer refers to a layer which has a bias and a kernel.
    """

    _REQUIRES_NATIVE_WEIGHTS = True
    _REQUIRES_QUANTIZED_WEIGHTS = True

    def __init__(self):
        super().__init__()
        self._pre_layer_bias = None
        self._bias = None
        self._kernel = None
        self._kernel_shape = None

    @property
    def defuse_features(self):
        # Defuse features can't be zero. If it is, we assume that there was a problem in the hn
        if "defuse_features" in self.defuse_params and self.defuse_type is not DefuseType.none:
            return self._defuse_params.get("defuse_features")
        self._logger.warning(
            f"Layer {self.name} has defuse_features=0. Assuming invalid hn, and using kernel_features={self.kernel_shape[-1]}",
        )
        return self.kernel_shape[-1]

    def compute_pre_layer_bias_approx(self):
        # this is only approximated due to image edges effect, bias needs to be in image shape
        # in order to reach zero difference
        fin_count, fout_count = self._kernel_shape[-2:]
        self._bias = np.zeros(fout_count)
        for fout in range(fout_count):
            for fin in range(fin_count):
                self._bias[fout] += np.sum(self._kernel[..., fin, fout]) * self._pre_layer_bias[fin]

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["kernel_shape"] = self._kernel_shape
        return result

    def move_params(self, layer):
        super().move_params(layer)
        if layer.op in [
            LayerType.dense,
            LayerType.base_dense,
            LayerType.conv,
            LayerType.base_conv,
            LayerType.dw,
            LayerType.base_dw,
            LayerType.deconv,
            LayerType.base_deconv,
        ]:
            # special case excluded in case of group convolutions.
            # added condition to identify dense layers that stem from old conv1x1 layers.
            self._kernel = layer.kernel
            if layer.bias is not None:
                self._bias = layer.bias
        elif layer.op == LayerType.bias_add:
            if self._bias is None:
                self._bias = layer.bias
            else:
                self._bias += layer.bias

    @property
    def kernel_shape(self):
        return self._kernel_shape

    @kernel_shape.setter
    def kernel_shape(self, kernel_shape):
        self._kernel_shape = list(kernel_shape) if kernel_shape else kernel_shape

    @property
    def bias_shape(self):
        return self._kernel_shape[-1]

    @property
    def kernel(self):
        return self._kernel

    @kernel.setter
    def kernel(self, kernel):
        self._kernel = kernel

    @property
    def pre_layer_bias(self):
        return self._pre_layer_bias

    @pre_layer_bias.setter
    def pre_layer_bias(self, pre_layer_bias):
        self._pre_layer_bias = pre_layer_bias

    @property
    def bias(self):
        return self._bias if self._bias is not None else np.zeros(self.output_shape[-1])

    @bias.setter
    def bias(self, bias):
        self._bias = bias

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.kernel_shape = hn["params"]["kernel_shape"]
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op not in [LayerType.base_ew_add, LayerType.base_ew_sub, LayerType.ew_add, LayerType.ew_sub]:
            layer.pre_layer_bias = old_layer.pre_layer_bias
            layer.bias = old_layer.bias
            layer.kernel = old_layer.kernel
            layer.kernel_shape = old_layer.kernel_shape.copy()
        return layer
