Spaces:
Runtime error
Runtime error
import torch | |
from models.resnet_simclr import ResNetSimCLR | |
from torch.utils.tensorboard import SummaryWriter | |
import torch.nn.functional as F | |
from loss.nt_xent import NTXentLoss | |
import os | |
import shutil | |
import sys | |
apex_support = False | |
try: | |
sys.path.append('./apex') | |
from apex import amp | |
apex_support = True | |
except: | |
print("Please install apex for mixed precision training from: https://github.com/NVIDIA/apex") | |
apex_support = False | |
import numpy as np | |
torch.manual_seed(0) | |
def _save_config_file(model_checkpoints_folder): | |
if not os.path.exists(model_checkpoints_folder): | |
os.makedirs(model_checkpoints_folder) | |
shutil.copy('./config.yaml', os.path.join(model_checkpoints_folder, 'config.yaml')) | |
class SimCLR(object): | |
def __init__(self, dataset, config, args=None): | |
self.config = config | |
self.device = self._get_device() | |
self.writer = SummaryWriter() | |
self.dataset = dataset | |
self.nt_xent_criterion = NTXentLoss(self.device, config['batch_size'], **config['loss']) | |
self.args = args | |
def _get_device(self): | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print("Running on:", device) | |
return device | |
def _step(self, model, xis, xjs, n_iter): | |
# get the representations and the projections | |
ris, zis = model(xis) # [N,C] | |
# get the representations and the projections | |
rjs, zjs = model(xjs) # [N,C] | |
# normalize projection feature vectors | |
zis = F.normalize(zis, dim=1) | |
zjs = F.normalize(zjs, dim=1) | |
loss = self.nt_xent_criterion(zis, zjs) | |
return loss | |
def train(self): | |
train_loader, valid_loader = self.dataset.get_data_loaders() | |
model = ResNetSimCLR(**self.config["model"])# .to(self.device) | |
if self.config['n_gpu'] > 1: | |
model = torch.nn.DataParallel(model, device_ids=eval(self.config['gpu_ids'])) | |
model = self._load_pre_trained_weights(model) | |
model = model.to(self.device) | |
optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=eval(self.config['weight_decay'])) | |
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, | |
# last_epoch=-1) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config['epochs'], eta_min=0, | |
last_epoch=-1) | |
if apex_support and self.config['fp16_precision']: | |
model, optimizer = amp.initialize(model, optimizer, | |
opt_level='O2', | |
keep_batchnorm_fp32=True) | |
if self.args is None: | |
model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints') | |
else: | |
model_checkpoints_folder = self.args.dest_weights#os.environ['FEATURE_EXTRACTOR_WEIGHT_PATH'] | |
model_checkpoints_folder = os.path.dirname(model_checkpoints_folder) | |
# save config file | |
_save_config_file(model_checkpoints_folder) | |
n_iter = 0 | |
valid_n_iter = 0 | |
best_valid_loss = np.inf | |
for epoch_counter in range(self.config['epochs']): | |
for (xis, xjs) in train_loader: | |
optimizer.zero_grad() | |
xis = xis.to(self.device) | |
xjs = xjs.to(self.device) | |
loss = self._step(model, xis, xjs, n_iter) | |
if n_iter % self.config['log_every_n_steps'] == 0: | |
self.writer.add_scalar('train_loss', loss, global_step=n_iter) | |
print("[%d/%d] step: %d train_loss: %.3f" % (epoch_counter, self.config['epochs'], n_iter, loss)) | |
if apex_support and self.config['fp16_precision']: | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
optimizer.step() | |
n_iter += 1 | |
# validate the model if requested | |
if epoch_counter % self.config['eval_every_n_epochs'] == 0: | |
valid_loss = self._validate(model, valid_loader) | |
print("[%d/%d] val_loss: %.3f" % (epoch_counter, self.config['epochs'], valid_loss)) | |
if valid_loss < best_valid_loss: | |
# save the model weights | |
best_valid_loss = valid_loss | |
torch.save(model.state_dict(), os.path.join(model_checkpoints_folder, 'model.pth')) | |
print('saved') | |
self.writer.add_scalar('validation_loss', valid_loss, global_step=valid_n_iter) | |
valid_n_iter += 1 | |
# warmup for the first 10 epochs | |
if epoch_counter >= 10: | |
scheduler.step() | |
self.writer.add_scalar('cosine_lr_decay', scheduler.get_lr()[0], global_step=n_iter) | |
def _load_pre_trained_weights(self, model): | |
try: | |
checkpoints_folder = os.path.join('./runs', self.config['fine_tune_from'], 'checkpoints') | |
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth')) | |
model.load_state_dict(state_dict) | |
print("Loaded pre-trained model with success.") | |
except FileNotFoundError: | |
print("Pre-trained weights not found. Training from scratch.") | |
return model | |
def _validate(self, model, valid_loader): | |
# validation steps | |
with torch.no_grad(): | |
model.eval() | |
valid_loss = 0.0 | |
counter = 0 | |
for (xis, xjs) in valid_loader: | |
xis = xis.to(self.device) | |
xjs = xjs.to(self.device) | |
loss = self._step(model, xis, xjs, counter) | |
valid_loss += loss.item() | |
counter += 1 | |
valid_loss /= counter | |
model.train() | |
return valid_loss | |