Spaces:
Runtime error
Runtime error
# Modified from: | |
# https://github.com/anibali/pytorch-stacked-hourglass | |
# https://github.com/bearpaw/pytorch-pose | |
import os | |
import shutil | |
import scipy.io | |
import torch | |
def to_numpy(tensor): | |
if torch.is_tensor(tensor): | |
return tensor.detach().cpu().numpy() | |
elif type(tensor).__module__ != 'numpy': | |
raise ValueError("Cannot convert {} to numpy array" | |
.format(type(tensor))) | |
return tensor | |
def to_torch(ndarray): | |
if type(ndarray).__module__ == 'numpy': | |
return torch.from_numpy(ndarray) | |
elif not torch.is_tensor(ndarray): | |
raise ValueError("Cannot convert {} to torch tensor" | |
.format(type(ndarray))) | |
return ndarray | |
def save_checkpoint(state, preds, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar', snapshot=None): | |
preds = to_numpy(preds) | |
filepath = os.path.join(checkpoint, filename) | |
torch.save(state, filepath) | |
scipy.io.savemat(os.path.join(checkpoint, 'preds.mat'), mdict={'preds' : preds}) | |
if snapshot and state['epoch'] % snapshot == 0: | |
shutil.copyfile(filepath, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch']))) | |
if is_best: | |
shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar')) | |
scipy.io.savemat(os.path.join(checkpoint, 'preds_best.mat'), mdict={'preds' : preds}) | |
def save_pred(preds, checkpoint='checkpoint', filename='preds_valid.mat'): | |
preds = to_numpy(preds) | |
filepath = os.path.join(checkpoint, filename) | |
scipy.io.savemat(filepath, mdict={'preds' : preds}) | |
def adjust_learning_rate(optimizer, epoch, lr, schedule, gamma): | |
"""Sets the learning rate to the initial LR decayed by schedule""" | |
if epoch in schedule: | |
lr *= gamma | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
return lr | |