Varun Wadhwa commited on
Commit
0444fba
·
unverified ·
1 Parent(s): 8e7d1ea
Files changed (1) hide show
  1. app.py +41 -39
app.py CHANGED
@@ -78,32 +78,31 @@ print(raw_dataset.column_names)
78
  # function to align labels with tokens
79
  # --> special tokens: -100 label id (ignored by cross entropy),
80
  # --> if tokens are inside a word, replace 'B-' with 'I-'
81
- def align_labels_with_tokens(labels, word_ids, max_length):
82
  aligned_label_ids = []
83
- for word_id in word_ids:
84
- if word_id is None:
85
- aligned_label_ids.append(-100)
86
- else:
87
- aligned_label_ids.append(label2id[labels[word_id]].replace("B-", "I-"))
88
-
89
- # Pad to max length
90
- aligned_label_ids += [-100] * (max_length - len(aligned_label_ids))
91
  return aligned_label_ids
92
 
93
  # create tokenize function
94
  def tokenize_function(examples):
 
 
 
95
  inputs = tokenizer(
96
  examples['mbert_tokens'],
97
  is_split_into_words=True,
 
98
  truncation=True,
99
- max_length=512,
100
- padding="max_length"
101
- )
102
- word_ids = inputs.word_ids()
103
- inputs["labels"] = [
104
- align_labels_with_tokens(labels, word_ids, tokenizer.model_max_length)
105
- for labels in examples['mbert_token_classes']
106
- ]
107
  return inputs
108
 
109
  # tokenize training and validation datasets
@@ -112,43 +111,46 @@ tokenized_data = raw_dataset.map(
112
  batched=True)
113
  tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
114
  # data collator
115
- data_collator = DataCollatorForTokenClassification(
116
- tokenizer, padding=True, truncation=True, max_length=512
117
- )
118
 
119
  st.write(tokenized_data["train"][:2]["labels"])
120
 
121
  # Function to evaluate model performance
122
  def evaluate_model(model, dataloader, device):
123
- model.eval()
124
- all_preds, all_labels = [], []
 
125
 
 
126
  with torch.no_grad():
127
  for batch in dataloader:
128
  input_ids = batch['input_ids'].to(device)
129
  attention_mask = batch['attention_mask'].to(device)
130
- labels = batch['labels'].to(device)
131
 
 
132
  outputs = model(input_ids, attention_mask=attention_mask)
133
  logits = outputs.logits
134
- preds = torch.argmax(logits, dim=-1)
135
-
136
- # Mask out padding tokens (-100 in labels)
137
- mask = labels != -100
138
- valid_preds = preds[mask]
139
- valid_labels = labels[mask]
140
-
141
- all_preds.extend(valid_preds.cpu().numpy())
142
- all_labels.extend(valid_labels.cpu().numpy())
143
-
144
- # Convert to numpy arrays for metrics calculation
145
- all_preds = np.array(all_preds)
146
- all_labels = np.array(all_labels)
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  accuracy = accuracy_score(all_labels, all_preds)
149
- precision, recall, f1, _ = precision_recall_fscore_support(
150
- all_labels, all_preds, average='micro'
151
- )
152
 
153
  return accuracy, precision, recall, f1
154
 
 
78
  # function to align labels with tokens
79
  # --> special tokens: -100 label id (ignored by cross entropy),
80
  # --> if tokens are inside a word, replace 'B-' with 'I-'
81
+ def align_labels_with_tokens(labels):
82
  aligned_label_ids = []
83
+ aligned_label_ids.append(-100)
84
+ for i, label in enumerate(labels):
85
+ if label.startswith("B-"):
86
+ label = label.replace("B-", "I-")
87
+ aligned_label_ids.append(label2id[label])
88
+ aligned_label_ids.append(-100)
 
 
89
  return aligned_label_ids
90
 
91
  # create tokenize function
92
  def tokenize_function(examples):
93
+ # tokenize and truncate text. The examples argument would have already stripped
94
+ # the train or test label.
95
+ new_labels = []
96
  inputs = tokenizer(
97
  examples['mbert_tokens'],
98
  is_split_into_words=True,
99
+ padding=True,
100
  truncation=True,
101
+ max_length=512)
102
+ for _, labels in enumerate(examples['mbert_token_classes']):
103
+ new_labels.append(align_labels_with_tokens(labels))
104
+
105
+ inputs["labels"] = new_labels
 
 
 
106
  return inputs
107
 
108
  # tokenize training and validation datasets
 
111
  batched=True)
112
  tokenized_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
113
  # data collator
114
+ data_collator = DataCollatorForTokenClassification(tokenizer)
 
 
115
 
116
  st.write(tokenized_data["train"][:2]["labels"])
117
 
118
  # Function to evaluate model performance
119
  def evaluate_model(model, dataloader, device):
120
+ model.eval() # Set model to evaluation mode
121
+ all_preds = []
122
+ all_labels = []
123
 
124
+ # Disable gradient calculations
125
  with torch.no_grad():
126
  for batch in dataloader:
127
  input_ids = batch['input_ids'].to(device)
128
  attention_mask = batch['attention_mask'].to(device)
129
+ labels = batch['labels'].to(device).cpu().numpy()
130
 
131
+ # Forward pass to get logits
132
  outputs = model(input_ids, attention_mask=attention_mask)
133
  logits = outputs.logits
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
+ # Get predictions
136
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
137
+
138
+ all_preds.extend(preds)
139
+ all_labels.extend(labels)
140
+
141
+ # Calculate evaluation metrics
142
+ print("evaluate_model sizes")
143
+ print("Shape of preds:", all_preds.shape)
144
+ print("Shape of labels:", all_labels.shape)
145
+ all_preds = np.asarray(all_preds, dtype=np.float32)
146
+ all_labels = np.asarray(all_labels, dtype=np.float32)
147
+ print("Flattened sizes")
148
+ print(all_preds.size)
149
+ print(all_labels.size)
150
+ all_preds = all_preds.flatten()
151
+ all_labels = all_labels.flatten()
152
  accuracy = accuracy_score(all_labels, all_preds)
153
+ precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='micro')
 
 
154
 
155
  return accuracy, precision, recall, f1
156