|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Generic utilities |
|
""" |
|
|
|
from collections import OrderedDict |
|
from dataclasses import fields |
|
from typing import Any, Tuple |
|
|
|
import numpy as np |
|
|
|
from .import_utils import is_torch_available |
|
|
|
|
|
def is_tensor(x): |
|
""" |
|
Tests if `x` is a `torch.Tensor` or `np.ndarray`. |
|
""" |
|
if is_torch_available(): |
|
import torch |
|
|
|
if isinstance(x, torch.Tensor): |
|
return True |
|
|
|
return isinstance(x, np.ndarray) |
|
|
|
|
|
class BaseOutput(OrderedDict): |
|
""" |
|
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a |
|
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular |
|
Python dictionary. |
|
|
|
<Tip warning={true}> |
|
|
|
You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple |
|
first. |
|
|
|
</Tip> |
|
""" |
|
|
|
def __post_init__(self): |
|
class_fields = fields(self) |
|
|
|
|
|
if not len(class_fields): |
|
raise ValueError(f"{self.__class__.__name__} has no fields.") |
|
|
|
first_field = getattr(self, class_fields[0].name) |
|
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) |
|
|
|
if other_fields_are_none and isinstance(first_field, dict): |
|
for key, value in first_field.items(): |
|
self[key] = value |
|
else: |
|
for field in class_fields: |
|
v = getattr(self, field.name) |
|
if v is not None: |
|
self[field.name] = v |
|
|
|
def __delitem__(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") |
|
|
|
def setdefault(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") |
|
|
|
def pop(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") |
|
|
|
def update(self, *args, **kwargs): |
|
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") |
|
|
|
def __getitem__(self, k): |
|
if isinstance(k, str): |
|
inner_dict = dict(self.items()) |
|
return inner_dict[k] |
|
else: |
|
return self.to_tuple()[k] |
|
|
|
def __setattr__(self, name, value): |
|
if name in self.keys() and value is not None: |
|
|
|
super().__setitem__(name, value) |
|
super().__setattr__(name, value) |
|
|
|
def __setitem__(self, key, value): |
|
|
|
super().__setitem__(key, value) |
|
|
|
super().__setattr__(key, value) |
|
|
|
def to_tuple(self) -> Tuple[Any]: |
|
""" |
|
Convert self to a tuple containing all the attributes/keys that are not `None`. |
|
""" |
|
return tuple(self[k] for k in self.keys()) |
|
|