Varun Wadhwa commited on
Commit
a509844
·
unverified ·
1 Parent(s): d04aeed
Files changed (1) hide show
  1. app.py +1 -9
app.py CHANGED
@@ -130,7 +130,7 @@ def evaluate_model(model, dataloader, device):
130
  all_preds = []
131
  all_labels = []
132
  sample_count = 0
133
- num_samples=100
134
 
135
  # Disable gradient calculations
136
  with torch.no_grad():
@@ -153,14 +153,6 @@ def evaluate_model(model, dataloader, device):
153
  valid_mask = (labels[i] != -100) & (attention_mask[i] != 0)
154
  valid_preds = preds[i][valid_mask].flatten()
155
  valid_labels = labels[i][valid_mask].flatten()
156
- print(valid_mask.dtype)
157
- print(labels[i].shape)
158
- print(labels[i])
159
- print(attention_mask[i].shape)
160
- print(valid_mask.shape)
161
- print(valid_labels)
162
- print(valid_mask)
163
- assert not torch.any(valid_labels == -100), f"Found -100 in valid_labels for batch {i}"
164
  if sample_count < num_samples:
165
  print(f"Sample {sample_count + 1}:")
166
  print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")
 
130
  all_preds = []
131
  all_labels = []
132
  sample_count = 0
133
+ num_samples=10
134
 
135
  # Disable gradient calculations
136
  with torch.no_grad():
 
153
  valid_mask = (labels[i] != -100) & (attention_mask[i] != 0)
154
  valid_preds = preds[i][valid_mask].flatten()
155
  valid_labels = labels[i][valid_mask].flatten()
 
 
 
 
 
 
 
 
156
  if sample_count < num_samples:
157
  print(f"Sample {sample_count + 1}:")
158
  print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")