import copy
from abc import ABC
from collections import OrderedDict

from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer


class LayerWithParams(Layer, ABC):
    _REQUIRES_NATIVE_WEIGHTS = None
    _REQUIRES_QUANTIZED_WEIGHTS = None

    """
    LayerWithParams is a base class for layers which have custom parameters
    in their HN representation.
    """

    def __init__(self):
        super().__init__()
        self._transpose_output_width_features = False
        self._spatial_flatten_output = False
        self._zp_comp_added = False  # backward compatibility of _zp_comp_rank
        self._decompose_weights = False

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"] = OrderedDict()
        if self._dynamic_weights:
            result["params"]["dynamic_weights"] = self.dynamic_weights
        if self._transpose_output_width_features:
            result["params"]["transpose_output_width_features"] = self.transpose_output_width_features
        if self._spatial_flatten_output:
            result["params"]["spatial_flatten_output"] = self.spatial_flatten_output
        if self._zp_comp_added:
            result["params"]["zp_comp_added"] = self._zp_comp_added
        if self._decompose_weights:
            result["params"]["decompose_weights"] = self._decompose_weights

        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.dynamic_weights = self.dynamic_weights
        node.transpose_output_width_features = self.transpose_output_width_features
        node.spatial_flatten_output = self.spatial_flatten_output
        node.decompose_weights = self._decompose_weights
        return node

    @classmethod
    def from_hn(cls, hn, validate_params_exist=True):
        if validate_params_exist and "params" not in hn:
            raise UnsupportedModelError(
                f'layer {hn["name"]} of type {hn["type"]} requires params, but the HN does not contain them',
            )

        layer = super().from_hn(hn)
        if "params" in hn and "dynamic_weights" in hn["params"]:
            layer._dynamic_weights = hn["params"]["dynamic_weights"]
        if "params" in hn and "transpose_output_width_features" in hn["params"]:
            layer._transpose_output_width_features = hn["params"]["transpose_output_width_features"]
        if "params" in hn and "spatial_flatten_output" in hn["params"]:
            layer._spatial_flatten_output = hn["params"]["spatial_flatten_output"]
        if "params" in hn and "zp_comp_added" in hn["params"]:
            layer._zp_comp_added = hn["params"]["zp_comp_added"]
        if "params" in hn and "decompose_weights" in hn["params"]:
            layer._decompose_weights = hn["params"]["decompose_weights"]
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._dynamic_weights = pb.dynamic_weights
        layer._transpose_output_width_features = pb.transpose_output_width_features
        layer._spatial_flatten_output = pb.spatial_flatten_output
        layer._decompose_weights = pb.decompose_weights
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.dynamic_weights = old_layer.dynamic_weights
        if old_layer.op != LayerType.equal:
            layer.transpose_output_width_features = old_layer.transpose_output_width_features
            layer.spatial_flatten_output = old_layer.spatial_flatten_output
            layer.zp_comp_added = old_layer.zp_comp_added
        return layer

    @property
    def transpose_output_width_features(self):
        return self._transpose_output_width_features

    @transpose_output_width_features.setter
    def transpose_output_width_features(self, transpose_output_width_features):
        self._transpose_output_width_features = transpose_output_width_features

    @property
    def spatial_flatten_output(self):
        return self._spatial_flatten_output

    @spatial_flatten_output.setter
    def spatial_flatten_output(self, spatial_flatten_output):
        self._spatial_flatten_output = spatial_flatten_output

    def __str__(self):
        description = super().__str__()
        if self._transpose_output_width_features:
            description += " +T"

        return description

    @property
    def short_description(self):
        base = super().short_description
        transpose_output_width_features_description = " +T" if self._transpose_output_width_features else ""
        spatial_flatten_output = " +SFO" if self._spatial_flatten_output else ""

        return f"{base}{transpose_output_width_features_description}{spatial_flatten_output}"

    @property
    def zp_comp_added(self):
        return self._zp_comp_added

    @zp_comp_added.setter
    def zp_comp_added(self, zp_comp_added):
        self._zp_comp_added = zp_comp_added

    @property
    def decompose_weights(self):
        return self._decompose_weights

    @decompose_weights.setter
    def decompose_weights(self, decompose_weights):
        self._decompose_weights = decompose_weights
