Varun Wadhwa commited on
Commit
cb2cd7f
·
unverified ·
1 Parent(s): ee1a894
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -137,7 +137,7 @@ def evaluate_model(model, dataloader, device):
137
  for batch in dataloader:
138
  input_ids = batch['input_ids'].to(device)
139
  current_batch_size = input_ids.size(0)
140
- attention_mask = batch['attention_mask'].cpu().numpy().astype(bool)
141
  labels = batch['labels'].to(device).cpu().numpy()
142
 
143
  # Forward pass to get logits
@@ -147,10 +147,11 @@ def evaluate_model(model, dataloader, device):
147
 
148
  # Get predictions
149
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
 
150
 
151
  # Process each sequence in the batch
152
  for i in range(current_batch_size):
153
- valid_mask = (labels[i] != -100) & attention_mask[i]
154
  valid_preds = preds[i][valid_mask[i]].flatten()
155
  valid_labels = labels[i][valid_mask[i]].flatten()
156
  all_preds.extend(valid_preds.tolist())
 
137
  for batch in dataloader:
138
  input_ids = batch['input_ids'].to(device)
139
  current_batch_size = input_ids.size(0)
140
+ attention_mask = batch['attention_mask'].to(device)
141
  labels = batch['labels'].to(device).cpu().numpy()
142
 
143
  # Forward pass to get logits
 
147
 
148
  # Get predictions
149
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
150
+ mask = attention_mask.cpu().numpy().astype(bool)
151
 
152
  # Process each sequence in the batch
153
  for i in range(current_batch_size):
154
+ valid_mask = (labels[i] != -100) & mask[i]
155
  valid_preds = preds[i][valid_mask[i]].flatten()
156
  valid_labels = labels[i][valid_mask[i]].flatten()
157
  all_preds.extend(valid_preds.tolist())