Varun Wadhwa commited on
Commit
2c409e9
·
unverified ·
1 Parent(s): f3c2885
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -123,6 +123,11 @@ def evaluate_model(model, dataloader, device):
123
 
124
  # Disable gradient calculations
125
  with torch.no_grad():
 
 
 
 
 
126
  for batch in dataloader:
127
  input_ids = batch['input_ids'].to(device)
128
  attention_mask = batch['attention_mask'].to(device)
 
123
 
124
  # Disable gradient calculations
125
  with torch.no_grad():
126
+ for batch in dataloader:
127
+ print("Sample sequence labels:", batch['labels'][0].tolist()[:20])
128
+ print("Corresponding predictions:", torch.argmax(model(batch['input_ids'].to(device),
129
+ attention_mask=batch['attention_mask'].to(device)).logits, dim=-1)[0].tolist()[:20])
130
+ break
131
  for batch in dataloader:
132
  input_ids = batch['input_ids'].to(device)
133
  attention_mask = batch['attention_mask'].to(device)