Spaces:
Runtime error
Runtime error
File size: 2,331 Bytes
4db4d66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from tqdm import tqdm
import torch
import torch.nn.functional as F
def train(
model,
device,
train_loader,
optimizer,
criterion,
scheduler,
L1=False,
l1_lambda=0.01,
):
model.train()
pbar = tqdm(train_loader)
train_losses = []
train_acc = []
lrs = []
correct = 0
processed = 0
train_loss = 0
for batch_idx, (data, target) in enumerate(pbar):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
y_pred = model(data)
# Calculate loss
loss = criterion(y_pred, target)
if L1:
l1_loss = 0
for p in model.parameters():
l1_loss = l1_loss + p.abs().sum()
loss = loss + l1_lambda * l1_loss
else:
loss = loss
train_loss += loss.item()
train_losses.append(loss.item())
# Backpropagation
loss.backward()
optimizer.step()
scheduler.step()
# Update pbar-tqdm
pred = y_pred.argmax(
dim=1, keepdim=True
) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
processed += len(data)
pbar.set_description(
desc=f"Loss={loss.item():0.2f} Accuracy={100*correct/processed:0.2f}"
)
train_acc.append(100 * correct / processed)
lrs.append(scheduler.get_last_lr())
return train_losses, train_acc, lrs
def test(model, device, criterion, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction="sum").item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n".format(
test_loss,
correct,
len(test_loader.dataset),
100.0 * correct / len(test_loader.dataset),
)
)
test_acc = 100.0 * correct / len(test_loader.dataset)
return test_loss, test_acc
|