|
|
|
""" |
|
Wrappers around on some nn functions, mainly to support empty tensors. |
|
|
|
Ideally, add support directly in PyTorch to empty tensors in those functions. |
|
|
|
These can be removed once https://github.com/pytorch/pytorch/issues/12013 |
|
is implemented |
|
""" |
|
|
|
import functools |
|
import warnings |
|
from typing import List, Optional |
|
import torch |
|
from torch.nn import functional as F |
|
|
|
from detectron2.utils.env import TORCH_VERSION |
|
|
|
|
|
def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor: |
|
""" |
|
Turn a list of integer scalars or integer Tensor scalars into a vector, |
|
in a way that's both traceable and scriptable. |
|
|
|
In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs. |
|
In scripting or eager, `x` should be a list of int. |
|
""" |
|
if torch.jit.is_scripting(): |
|
return torch.as_tensor(x, device=device) |
|
if torch.jit.is_tracing(): |
|
assert all( |
|
[isinstance(t, torch.Tensor) for t in x] |
|
), "Shape should be tensor during tracing!" |
|
|
|
ret = torch.stack(x) |
|
if ret.device != device: |
|
ret = ret.to(device=device) |
|
return ret |
|
return torch.as_tensor(x, device=device) |
|
|
|
|
|
def check_if_dynamo_compiling(): |
|
if TORCH_VERSION >= (2, 1): |
|
from torch._dynamo import is_compiling |
|
|
|
return is_compiling() |
|
else: |
|
return False |
|
|
|
|
|
def disable_torch_compiler(func): |
|
if TORCH_VERSION >= (2, 1): |
|
|
|
@torch.compiler.disable |
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
return func(*args, **kwargs) |
|
|
|
return wrapper |
|
else: |
|
|
|
return func |
|
|
|
|
|
def cat(tensors: List[torch.Tensor], dim: int = 0): |
|
""" |
|
Efficient version of torch.cat that avoids a copy if there is only a single element in a list |
|
""" |
|
assert isinstance(tensors, (list, tuple)) |
|
if len(tensors) == 1: |
|
return tensors[0] |
|
return torch.cat(tensors, dim) |
|
|
|
|
|
def empty_input_loss_func_wrapper(loss_func): |
|
def wrapped_loss_func(input, target, *, reduction="mean", **kwargs): |
|
""" |
|
Same as `loss_func`, but returns 0 (instead of nan) for empty inputs. |
|
""" |
|
if target.numel() == 0 and reduction == "mean": |
|
return input.sum() * 0.0 |
|
return loss_func(input, target, reduction=reduction, **kwargs) |
|
|
|
return wrapped_loss_func |
|
|
|
|
|
cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy) |
|
|
|
|
|
class _NewEmptyTensorOp(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, x, new_shape): |
|
ctx.shape = x.shape |
|
return x.new_empty(new_shape) |
|
|
|
@staticmethod |
|
def backward(ctx, grad): |
|
shape = ctx.shape |
|
return _NewEmptyTensorOp.apply(grad, shape), None |
|
|
|
|
|
class Conv2d(torch.nn.Conv2d): |
|
""" |
|
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
""" |
|
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: |
|
|
|
Args: |
|
norm (nn.Module, optional): a normalization layer |
|
activation (callable(Tensor) -> Tensor): a callable activation function |
|
|
|
It assumes that norm layer is used before activation. |
|
""" |
|
norm = kwargs.pop("norm", None) |
|
activation = kwargs.pop("activation", None) |
|
super().__init__(*args, **kwargs) |
|
|
|
self.norm = norm |
|
self.activation = activation |
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not torch.jit.is_scripting(): |
|
|
|
is_dynamo_compiling = check_if_dynamo_compiling() |
|
if not is_dynamo_compiling: |
|
with warnings.catch_warnings(record=True): |
|
if x.numel() == 0 and self.training: |
|
|
|
assert not isinstance( |
|
self.norm, torch.nn.SyncBatchNorm |
|
), "SyncBatchNorm does not support empty inputs!" |
|
|
|
x = F.conv2d( |
|
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups |
|
) |
|
if self.norm is not None: |
|
x = self.norm(x) |
|
if self.activation is not None: |
|
x = self.activation(x) |
|
return x |
|
|
|
|
|
ConvTranspose2d = torch.nn.ConvTranspose2d |
|
BatchNorm2d = torch.nn.BatchNorm2d |
|
interpolate = F.interpolate |
|
Linear = torch.nn.Linear |
|
|
|
|
|
def nonzero_tuple(x): |
|
""" |
|
A 'as_tuple=True' version of torch.nonzero to support torchscript. |
|
because of https://github.com/pytorch/pytorch/issues/38718 |
|
""" |
|
if torch.jit.is_scripting(): |
|
if x.dim() == 0: |
|
return x.unsqueeze(0).nonzero().unbind(1) |
|
return x.nonzero().unbind(1) |
|
else: |
|
return x.nonzero(as_tuple=True) |
|
|
|
|
|
@torch.jit.script_if_tracing |
|
def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Tracing friendly way to cast tensor to another tensor's device. Device will be treated |
|
as constant during tracing, scripting the casting process as whole can workaround this issue. |
|
""" |
|
return src.to(dst.device) |
|
|