Varun Wadhwa commited on
Commit
2d0c6e2
·
unverified ·
1 Parent(s): aeb8f5a
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -146,7 +146,7 @@ def evaluate_model(model, dataloader, device):
146
  logits = outputs.logits
147
 
148
  # Get predictions
149
- preds = torch.argmax(logits, dim=-1).cpu().numpy()
150
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
 
146
  logits = outputs.logits
147
 
148
  # Get predictions
149
+ preds = torch.argmax(logits, dim=-1)
150
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):