Varun Wadhwa commited on
Commit
aac732c
·
unverified ·
1 Parent(s): 19a3733
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -122,6 +122,8 @@ def evaluate_model(model, dataloader, device):
122
  model.eval() # Set model to evaluation mode
123
  all_preds = []
124
  all_labels = []
 
 
125
 
126
  # Disable gradient calculations
127
  with torch.no_grad():
@@ -149,12 +151,18 @@ def evaluate_model(model, dataloader, device):
149
  all_preds.extend(valid_preds.tolist())
150
  all_labels.extend(valid_labels.tolist())
151
 
 
 
 
 
 
 
 
 
152
  # Calculate evaluation metrics
153
  print("evaluate_model sizes")
154
  print(len(all_preds))
155
  print(len(all_labels))
156
- print(id2label[all_preds[0]])
157
- print(id2label[all_labels[0]])
158
  all_preds = np.asarray(all_preds, dtype=np.float32)
159
  all_labels = np.asarray(all_labels, dtype=np.float32)
160
  accuracy = accuracy_score(all_labels, all_preds)
 
122
  model.eval() # Set model to evaluation mode
123
  all_preds = []
124
  all_labels = []
125
+ sample_count = 0
126
+ num_samples=100
127
 
128
  # Disable gradient calculations
129
  with torch.no_grad():
 
151
  all_preds.extend(valid_preds.tolist())
152
  all_labels.extend(valid_labels.tolist())
153
 
154
+ if sample_count < num_samples:
155
+ print(f"Sample {sample_count + 1}:")
156
+ print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[i])}")
157
+ print(f"True Labels: {[id2label[label] for label in valid_labels]}")
158
+ print(f"Predicted Labels: {[id2label[pred] for pred in valid_preds]}")
159
+ print("-" * 50)
160
+ sample_count += 1
161
+
162
  # Calculate evaluation metrics
163
  print("evaluate_model sizes")
164
  print(len(all_preds))
165
  print(len(all_labels))
 
 
166
  all_preds = np.asarray(all_preds, dtype=np.float32)
167
  all_labels = np.asarray(all_labels, dtype=np.float32)
168
  accuracy = accuracy_score(all_labels, all_preds)