Varun Wadhwa commited on
Commit
ee1a894
·
unverified ·
1 Parent(s): f2d5c7a
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -137,7 +137,7 @@ def evaluate_model(model, dataloader, device):
137
  for batch in dataloader:
138
  input_ids = batch['input_ids'].to(device)
139
  current_batch_size = input_ids.size(0)
140
- attention_mask = batch['attention_mask'].to(device)
141
  labels = batch['labels'].to(device).cpu().numpy()
142
 
143
  # Forward pass to get logits
@@ -147,14 +147,12 @@ def evaluate_model(model, dataloader, device):
147
 
148
  # Get predictions
149
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
150
-
151
- # Use attention mask to get valid tokens
152
- mask = batch['attention_mask'].cpu().numpy().astype(bool)
153
 
154
  # Process each sequence in the batch
155
  for i in range(current_batch_size):
156
- valid_preds = preds[i][mask[i]].flatten()
157
- valid_labels = labels[i][mask[i]].flatten()
 
158
  all_preds.extend(valid_preds.tolist())
159
  all_labels.extend(valid_labels.tolist())
160
 
 
137
  for batch in dataloader:
138
  input_ids = batch['input_ids'].to(device)
139
  current_batch_size = input_ids.size(0)
140
+ attention_mask = batch['attention_mask'].cpu().numpy().astype(bool)
141
  labels = batch['labels'].to(device).cpu().numpy()
142
 
143
  # Forward pass to get logits
 
147
 
148
  # Get predictions
149
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
 
 
 
150
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
+ valid_mask = (labels[i] != -100) & attention_mask[i]
154
+ valid_preds = preds[i][valid_mask[i]].flatten()
155
+ valid_labels = labels[i][valid_mask[i]].flatten()
156
  all_preds.extend(valid_preds.tolist())
157
  all_labels.extend(valid_labels.tolist())
158