Varun Wadhwa commited on
Commit
5d92510
·
unverified ·
1 Parent(s): a23d6ef
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -151,10 +151,11 @@ def evaluate_model(model, dataloader, device):
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
  valid_mask = (labels[i] != -100) & (attention_mask[i] != 0)
154
- valid_preds = preds[i][valid_mask[i]].flatten()
155
- valid_labels = labels[i][valid_mask[i]].flatten()
156
  print(valid_mask.dtype)
157
  print(labels[i].shape)
 
158
  print(attention_mask[i].shape)
159
  print(valid_mask.shape)
160
  print(valid_labels)
@@ -162,6 +163,8 @@ def evaluate_model(model, dataloader, device):
162
  all_preds.extend(valid_preds.tolist())
163
  all_labels.extend(valid_labels.tolist())
164
 
 
 
165
  if sample_count < num_samples:
166
  print(f"Sample {sample_count + 1}:")
167
  print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
  valid_mask = (labels[i] != -100) & (attention_mask[i] != 0)
154
+ valid_preds = preds[i][valid_mask].flatten()
155
+ valid_labels = labels[i][valid_mask].flatten()
156
  print(valid_mask.dtype)
157
  print(labels[i].shape)
158
+ print(labels[i])
159
  print(attention_mask[i].shape)
160
  print(valid_mask.shape)
161
  print(valid_labels)
 
163
  all_preds.extend(valid_preds.tolist())
164
  all_labels.extend(valid_labels.tolist())
165
 
166
+ assert not torch.any(valid_labels == -100), f"Found -100 in valid_labels for batch {i}"
167
+
168
  if sample_count < num_samples:
169
  print(f"Sample {sample_count + 1}:")
170
  print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")