Varun Wadhwa commited on
Commit
b62f161
·
unverified ·
1 Parent(s): 5d92510
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -151,8 +151,8 @@ def evaluate_model(model, dataloader, device):
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
  valid_mask = (labels[i] != -100) & (attention_mask[i] != 0)
154
- valid_preds = preds[i][valid_mask].flatten()
155
- valid_labels = labels[i][valid_mask].flatten()
156
  print(valid_mask.dtype)
157
  print(labels[i].shape)
158
  print(labels[i])
@@ -160,11 +160,7 @@ def evaluate_model(model, dataloader, device):
160
  print(valid_mask.shape)
161
  print(valid_labels)
162
  print(valid_mask)
163
- all_preds.extend(valid_preds.tolist())
164
- all_labels.extend(valid_labels.tolist())
165
-
166
  assert not torch.any(valid_labels == -100), f"Found -100 in valid_labels for batch {i}"
167
-
168
  if sample_count < num_samples:
169
  print(f"Sample {sample_count + 1}:")
170
  print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")
@@ -172,6 +168,8 @@ def evaluate_model(model, dataloader, device):
172
  print(f"Predicted Labels: {[id2label[pred] for pred in valid_preds]}")
173
  print("-" * 50)
174
  sample_count += 1
 
 
175
 
176
  # Calculate evaluation metrics
177
  print("evaluate_model sizes")
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
  valid_mask = (labels[i] != -100) & (attention_mask[i] != 0)
154
+ valid_preds = preds[i][valid_mask[i]].flatten()
155
+ valid_labels = labels[i][valid_mask[i]].flatten()
156
  print(valid_mask.dtype)
157
  print(labels[i].shape)
158
  print(labels[i])
 
160
  print(valid_mask.shape)
161
  print(valid_labels)
162
  print(valid_mask)
 
 
 
163
  assert not torch.any(valid_labels == -100), f"Found -100 in valid_labels for batch {i}"
 
164
  if sample_count < num_samples:
165
  print(f"Sample {sample_count + 1}:")
166
  print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")
 
168
  print(f"Predicted Labels: {[id2label[pred] for pred in valid_preds]}")
169
  print("-" * 50)
170
  sample_count += 1
171
+ all_preds.extend(valid_preds.tolist())
172
+ all_labels.extend(valid_labels.tolist())
173
 
174
  # Calculate evaluation metrics
175
  print("evaluate_model sizes")