Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
""" | |
This file stores functions for conversion between numpy and torch, torch, list, etc. | |
Also deal with general operations such as to(dev), detach, etc. | |
""" | |
def thing2list(thing): | |
if isinstance(thing, torch.Tensor): | |
return thing.tolist() | |
if isinstance(thing, np.ndarray): | |
return thing.tolist() | |
if isinstance(thing, dict): | |
return {k: thing2list(v) for k, v in md.items()} | |
if isinstance(thing, list): | |
return [thing2list(ten) for ten in thing] | |
return thing | |
def thing2dev(thing, dev): | |
if hasattr(thing, "to"): | |
thing = thing.to(dev) | |
return thing | |
if isinstance(thing, list): | |
return [thing2dev(ten, dev) for ten in thing] | |
if isinstance(thing, tuple): | |
return tuple(thing2dev(list(thing), dev)) | |
if isinstance(thing, dict): | |
return {k: thing2dev(v, dev) for k, v in thing.items()} | |
if isinstance(thing, torch.Tensor): | |
return thing.to(dev) | |
return thing | |
def thing2np(thing): | |
if isinstance(thing, list): | |
return np.array(thing) | |
if isinstance(thing, torch.Tensor): | |
return thing.cpu().detach().numpy() | |
if isinstance(thing, dict): | |
return {k: thing2np(v) for k, v in thing.items()} | |
return thing | |
def thing2torch(thing): | |
if isinstance(thing, list): | |
return torch.tensor(np.array(thing)) | |
if isinstance(thing, np.ndarray): | |
return torch.from_numpy(thing) | |
if isinstance(thing, dict): | |
return {k: thing2torch(v) for k, v in thing.items()} | |
return thing | |
def detach_thing(thing): | |
if isinstance(thing, torch.Tensor): | |
return thing.cpu().detach() | |
if isinstance(thing, list): | |
return [detach_thing(ten) for ten in thing] | |
if isinstance(thing, tuple): | |
return tuple(detach_thing(list(thing))) | |
if isinstance(thing, dict): | |
return {k: detach_thing(v) for k, v in thing.items()} | |
return thing | |