Varun Wadhwa commited on
Commit
a23d6ef
·
unverified ·
1 Parent(s): 2578830
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -150,7 +150,7 @@ def evaluate_model(model, dataloader, device):
150
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
- valid_mask = (labels[i] != -100) & attention_mask[i]
154
  valid_preds = preds[i][valid_mask[i]].flatten()
155
  valid_labels = labels[i][valid_mask[i]].flatten()
156
  print(valid_mask.dtype)
 
150
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
+ valid_mask = (labels[i] != -100) & (attention_mask[i] != 0)
154
  valid_preds = preds[i][valid_mask[i]].flatten()
155
  valid_labels = labels[i][valid_mask[i]].flatten()
156
  print(valid_mask.dtype)