Varun Wadhwa commited on
Commit
157d28c
·
unverified ·
1 Parent(s): 4fb10c8
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -131,6 +131,7 @@ def evaluate_model(model, dataloader, device):
131
 
132
  # Forward pass to get logits
133
  outputs = model(input_ids, attention_mask=attention_mask)
 
134
  logits = outputs.logits
135
 
136
  # Get predictions
@@ -150,6 +151,10 @@ def evaluate_model(model, dataloader, device):
150
  print("evaluate_model sizes")
151
  print(len(all_preds))
152
  print(len(all_labels))
 
 
 
 
153
  all_preds = np.asarray(all_preds, dtype=np.float32)
154
  all_labels = np.asarray(all_labels, dtype=np.float32)
155
  accuracy = accuracy_score(all_labels, all_preds)
 
131
 
132
  # Forward pass to get logits
133
  outputs = model(input_ids, attention_mask=attention_mask)
134
+
135
  logits = outputs.logits
136
 
137
  # Get predictions
 
151
  print("evaluate_model sizes")
152
  print(len(all_preds))
153
  print(len(all_labels))
154
+ print(id2label(all_preds[0]))
155
+ print(id2label(all_labels[0]))
156
+ print(id2label(all_preds[1]))
157
+ print(id2label(all_labels[1]))
158
  all_preds = np.asarray(all_preds, dtype=np.float32)
159
  all_labels = np.asarray(all_labels, dtype=np.float32)
160
  accuracy = accuracy_score(all_labels, all_preds)