|
import torch |
|
from torch.autograd import Variable |
|
import numpy as np |
|
import collections |
|
|
|
__all__ = ['as_variable', 'as_numpy', 'mark_volatile'] |
|
|
|
def as_variable(obj): |
|
if isinstance(obj, Variable): |
|
return obj |
|
if isinstance(obj, collections.Sequence): |
|
return [as_variable(v) for v in obj] |
|
elif isinstance(obj, collections.Mapping): |
|
return {k: as_variable(v) for k, v in obj.items()} |
|
else: |
|
return Variable(obj) |
|
|
|
def as_numpy(obj): |
|
if isinstance(obj, collections.Sequence): |
|
return [as_numpy(v) for v in obj] |
|
elif isinstance(obj, collections.Mapping): |
|
return {k: as_numpy(v) for k, v in obj.items()} |
|
elif isinstance(obj, Variable): |
|
return obj.data.cpu().numpy() |
|
elif torch.is_tensor(obj): |
|
return obj.cpu().numpy() |
|
else: |
|
return np.array(obj) |
|
|
|
def mark_volatile(obj): |
|
if torch.is_tensor(obj): |
|
obj = Variable(obj) |
|
if isinstance(obj, Variable): |
|
obj.no_grad = True |
|
return obj |
|
elif isinstance(obj, collections.Mapping): |
|
return {k: mark_volatile(o) for k, o in obj.items()} |
|
elif isinstance(obj, collections.Sequence): |
|
return [mark_volatile(o) for o in obj] |
|
else: |
|
return obj |
|
|