Varun Wadhwa commited on
Commit
f1c8407
·
unverified ·
1 Parent(s): 2c409e9
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -123,13 +123,9 @@ def evaluate_model(model, dataloader, device):
123
 
124
  # Disable gradient calculations
125
  with torch.no_grad():
126
- for batch in dataloader:
127
- print("Sample sequence labels:", batch['labels'][0].tolist()[:20])
128
- print("Corresponding predictions:", torch.argmax(model(batch['input_ids'].to(device),
129
- attention_mask=batch['attention_mask'].to(device)).logits, dim=-1)[0].tolist()[:20])
130
- break
131
  for batch in dataloader:
132
  input_ids = batch['input_ids'].to(device)
 
133
  attention_mask = batch['attention_mask'].to(device)
134
  labels = batch['labels'].to(device).cpu().numpy()
135
 
@@ -140,11 +136,15 @@ def evaluate_model(model, dataloader, device):
140
  # Get predictions
141
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
142
 
143
- print("Shape of preds:", preds.shape)
144
- print("Shape of labels:", labels.shape)
145
 
146
- all_preds.extend(preds)
147
- all_labels.extend(labels)
 
 
 
 
148
 
149
  # Calculate evaluation metrics
150
  print("evaluate_model sizes")
 
123
 
124
  # Disable gradient calculations
125
  with torch.no_grad():
 
 
 
 
 
126
  for batch in dataloader:
127
  input_ids = batch['input_ids'].to(device)
128
+ current_batch_size = input_ids.size(0)
129
  attention_mask = batch['attention_mask'].to(device)
130
  labels = batch['labels'].to(device).cpu().numpy()
131
 
 
136
  # Get predictions
137
  preds = torch.argmax(logits, dim=-1).cpu().numpy()
138
 
139
+ # Use attention mask to get valid tokens
140
+ mask = batch['attention_mask'].cpu().numpy().astype(bool)
141
 
142
+ # Process each sequence in the batch
143
+ for i in range(current_batch_size):
144
+ valid_preds = preds[i][mask[i]].flatten()
145
+ valid_labels = labels[i][mask[i]].flatten()
146
+ all_preds.extend(valid_preds.tolist())
147
+ all_labels.extend(valid_labels.tolist())
148
 
149
  # Calculate evaluation metrics
150
  print("evaluate_model sizes")