Varun Wadhwa commited on
Commit
2578830
·
unverified ·
1 Parent(s): 1f408e5
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -151,14 +151,14 @@ def evaluate_model(model, dataloader, device):
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
  valid_mask = (labels[i] != -100) & attention_mask[i]
 
 
154
  print(valid_mask.dtype)
155
  print(labels[i].shape)
156
  print(attention_mask[i].shape)
157
  print(valid_mask.shape)
158
  print(valid_labels)
159
  print(valid_mask)
160
- valid_preds = preds[i][valid_mask[i]].flatten()
161
- valid_labels = labels[i][valid_mask[i]].flatten()
162
  all_preds.extend(valid_preds.tolist())
163
  all_labels.extend(valid_labels.tolist())
164
 
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
  valid_mask = (labels[i] != -100) & attention_mask[i]
154
+ valid_preds = preds[i][valid_mask[i]].flatten()
155
+ valid_labels = labels[i][valid_mask[i]].flatten()
156
  print(valid_mask.dtype)
157
  print(labels[i].shape)
158
  print(attention_mask[i].shape)
159
  print(valid_mask.shape)
160
  print(valid_labels)
161
  print(valid_mask)
 
 
162
  all_preds.extend(valid_preds.tolist())
163
  all_labels.extend(valid_labels.tolist())
164