seanbenhur commited on
Commit
b6283c9
·
1 Parent(s): 4ccbdf9
tamilatis DELETED
@@ -1 +0,0 @@
1
- Subproject commit b1022a9187d9d47c18b360fc45b7f55d3b40824f
 
 
tamilatis/configs/config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: default
3
+ - dataset: default
4
+ - training: default
5
+ - wandb: default
6
+ - override hydra/job_logging: colorlog
7
+ - override hydra/hydra_logging: colorlog
tamilatis/configs/dataset/default.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ train_path : "/content/train_intent.pkl"
2
+ valid_path : "/content/val_intent.pkl"
3
+ test_path : "/content/test_intent.pkl"
4
+ output_dir: "/content/saved_models"
5
+ num_labels: 78
6
+ num_intents: 23
tamilatis/configs/model/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tokenizer_name: "xlm-roberta-base"
2
+ model_name: "xlm-roberta-base"
3
+ num_labels: 78
4
+ num_intents: 23
5
+ test_model :
tamilatis/configs/training/default.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 32
2
+ weight_decay: 0.01
3
+ lr: 1e-4
4
+ max_epochs: 20
5
+ patience: 5
6
+ scheduler: "cosine"
7
+ warmup_steps: 0
8
+ do_train: True
9
+ do_predict: False
10
+ ner_cls_path: /content/ner_cls_rlw.csv
11
+ intent_cls_path: /content/intent_cls_rlw.csv
tamilatis/configs/wandb/default.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ project_name: "tamilatis"
2
+ group_name: "hard-parameter-sharing-rlw"
3
+ run_name:
tamilatis/dataset.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer
7
+
8
+
9
+ class BuildDataset:
10
+ def __init__(self):
11
+ pass
12
+
13
+ def tokenize(self, text):
14
+ """Splits the text and get offsets"""
15
+ text = text.strip()
16
+ tokens = text.split()
17
+ offsets = []
18
+ for token in tokens:
19
+ start_idx = text.find(token)
20
+ end_idx = start_idx + len(token)
21
+ offsets.append([start_idx, end_idx])
22
+ return tokens, offsets
23
+
24
+ def convert_to_boi(self, text, annotations):
25
+ """Convert Intent Tags to BOI Tags"""
26
+ tokens, offsets = self.tokenize(text)
27
+ boi_tags = ["O"] * len(tokens)
28
+
29
+ for name, value, [start_idx, end_idx] in annotations:
30
+ value = value.strip()
31
+ try:
32
+ token_span = len(value.split())
33
+
34
+ start_token_idx = [
35
+ token_idx
36
+ for token_idx, (s, e) in enumerate(offsets)
37
+ if s == start_idx
38
+ ][0]
39
+ end_token_idx = start_token_idx + token_span
40
+ annotation = [name] + ["I" + name[1:]] * (token_span - 1)
41
+ boi_tags[start_token_idx:end_token_idx] = annotation
42
+ except Exception as error:
43
+ pass
44
+
45
+ return list(zip(tokens, boi_tags))
46
+
47
+ def build_dataset(self, path):
48
+ """Build a TOD dataset"""
49
+ with open(path, "rb") as f:
50
+ data = pickle.load(f)
51
+
52
+ boi_data = []
53
+ for text, annotation, intent in tqdm(data):
54
+ boi_item = self.convert_to_boi(text, annotation)
55
+ is_valid = any([True for token, tag in boi_item if tag != "O"])
56
+ wrong_intent = intent[0] == "B" or intent[0] == "I"
57
+
58
+ if is_valid and not wrong_intent:
59
+ boi_data.append((boi_item, intent))
60
+ return boi_data
61
+
62
+
63
+ class ATISDataset(Dataset):
64
+ def __init__(self, data, tokenizer, label_encoder, intent_encoder):
65
+ self.data = data
66
+ self.label_encoder = label_encoder
67
+ self.intent_encoder = intent_encoder
68
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
69
+
70
+ def __len__(self):
71
+ return len(self.data)
72
+
73
+ def __getitem__(self, idx):
74
+ tokens = [token for token, annotation in self.data[idx][0]]
75
+ tags = [tag for token, tag in self.data[idx][0]]
76
+
77
+ intent_name = self.data[idx][1]
78
+ intent_label = self.intent_encoder.transform([intent_name])
79
+ text = "#".join(tokens)
80
+
81
+ encoding = self.tokenizer(
82
+ tokens,
83
+ max_length=60,
84
+ padding="max_length",
85
+ truncation=True,
86
+ is_split_into_words=True,
87
+ return_tensors="pt",
88
+ )
89
+
90
+ input_ids = encoding.input_ids.squeeze(0)
91
+ attention_mask = encoding.attention_mask.squeeze(0)
92
+ word_ids = encoding.word_ids()
93
+
94
+ tags = self.label_encoder.transform(tags)
95
+
96
+ labels = []
97
+ label_all_tokens = None
98
+ previous_word_idx = None
99
+
100
+ for word_idx in word_ids:
101
+ if word_idx is None:
102
+ labels.append(-100)
103
+ elif word_idx != previous_word_idx:
104
+ labels.append(tags[word_idx])
105
+ else:
106
+ labels.append(tags[word_idx] if label_all_tokens else -100)
107
+ previous_word_idx = word_idx
108
+
109
+ labels = torch.tensor(labels)
110
+ tags = tags.tolist()
111
+ tags.extend([-100] * (50 - len(tags)))
112
+
113
+ return {
114
+ "text": text,
115
+ "input_ids": input_ids,
116
+ "attention_mask": attention_mask,
117
+ "labels": labels,
118
+ "intent": intent_label.item(),
119
+ "tags": tags,
120
+ }
tamilatis/main.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pickle
4
+
5
+ import wandb
6
+ import hydra
7
+ import pandas as pd
8
+ import torch.nn as nn
9
+ import torch.optim as optim
10
+ from accelerate import Accelerator
11
+ from omegaconf.omegaconf import OmegaConf
12
+ from sklearn.preprocessing import LabelEncoder
13
+ from torch.utils.data import DataLoader
14
+ from transformers import AutoTokenizer, get_scheduler
15
+
16
+ from dataset import ATISDataset, BuildDataset
17
+ from model import JointATISModel
18
+ from predict import TamilATISPredictor
19
+ from trainer import ATISTrainer
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @hydra.main(config_path="./configs", config_name="config")
25
+ def main(cfg):
26
+
27
+ os.environ['WANDB_PROJECT'] = cfg.wandb.project_name
28
+ os.environ['WANDB_RUN_GROUP'] = cfg.wandb.group_name
29
+
30
+ logger.info(OmegaConf.to_yaml(cfg, resolve=True))
31
+ accelerator = Accelerator()
32
+ # Get all tags
33
+ annotations = set()
34
+ intents = set()
35
+ count = 0
36
+
37
+ logger.info("Building Dataset")
38
+ data_utils = BuildDataset()
39
+ train_data = data_utils.build_dataset(cfg.dataset.train_path)
40
+ valid_data = data_utils.build_dataset(cfg.dataset.valid_path)
41
+ test_data = data_utils.build_dataset(cfg.dataset.test_path)
42
+
43
+ annotations, intents, count = set(), set(), 0
44
+ for boi_data, intent in train_data:
45
+ if intent[0] == "B" or intent[0] == "I":
46
+ count += 1
47
+ intents.add(intent)
48
+ for token, annotation in boi_data:
49
+ annotations.add(annotation)
50
+
51
+ for boi_data, intent in valid_data:
52
+ if intent[0] == "B" or intent[0] == "I":
53
+ count += 1
54
+ intents.add(intent)
55
+ for token, annotation in boi_data:
56
+ annotations.add(annotation)
57
+
58
+ for boi_data, intent in test_data:
59
+ if intent[0] == "B" or intent[0] == "I":
60
+ count += 1
61
+ intents.add(intent)
62
+ for token, annotation in boi_data:
63
+ annotations.add(annotation)
64
+
65
+ annotations = list(annotations)
66
+ intents = list(intents)
67
+
68
+ # convert string labels to int
69
+ label_encoder = LabelEncoder()
70
+ label_encoder.fit(annotations)
71
+
72
+ intent_encoder = LabelEncoder()
73
+ intent_encoder.fit(intents)
74
+
75
+ train_ds = ATISDataset(
76
+ train_data, cfg.model.tokenizer_name, label_encoder, intent_encoder
77
+ )
78
+ val_ds = ATISDataset(
79
+ valid_data, cfg.model.tokenizer_name, label_encoder, intent_encoder
80
+ )
81
+ test_ds = ATISDataset(
82
+ test_data, cfg.model.tokenizer_name, label_encoder, intent_encoder
83
+ )
84
+
85
+ train_dl = DataLoader(train_ds, batch_size=cfg.training.batch_size, pin_memory=True)
86
+ val_dl = DataLoader(val_ds, batch_size=cfg.training.batch_size * 2, pin_memory=True)
87
+ test_dl = DataLoader(
88
+ test_ds, batch_size=cfg.training.batch_size * 2, pin_memory=True
89
+ )
90
+ logging.info("DataLoaders are created!")
91
+
92
+ model = JointATISModel(
93
+ cfg.model.model_name, cfg.model.num_labels, cfg.model.num_intents
94
+ )
95
+ criterion = nn.CrossEntropyLoss()
96
+ # Optimizer
97
+ # Split weights in two groups, one with weight decay and the other not.
98
+ no_decay = ["bias", "LayerNorm.weight"]
99
+ optimizer_grouped_parameters = [
100
+ {
101
+ "params": [
102
+ p
103
+ for n, p in model.named_parameters()
104
+ if not any(nd in n for nd in no_decay)
105
+ ],
106
+ "weight_decay": cfg.training.weight_decay,
107
+ },
108
+ {
109
+ "params": [
110
+ p
111
+ for n, p in model.named_parameters()
112
+ if any(nd in n for nd in no_decay)
113
+ ],
114
+ "weight_decay": 0.0,
115
+ },
116
+ ]
117
+ optimizer = optim.AdamW(optimizer_grouped_parameters, lr=cfg.training.lr)
118
+ nb_train_steps = int(
119
+ len(train_dl) / cfg.training.batch_size * cfg.training.max_epochs
120
+ )
121
+
122
+ if cfg.training.scheduler is not None:
123
+ scheduler = get_scheduler(
124
+ cfg.training.scheduler,
125
+ optimizer,
126
+ num_warmup_steps=cfg.training.warmup_steps,
127
+ num_training_steps=nb_train_steps)
128
+ # Register the LR scheduler
129
+ accelerator.register_for_checkpointing(scheduler)
130
+
131
+ scheduler = None
132
+ model, optimizer, train_dl, val_dl = accelerator.prepare(
133
+ model, optimizer, train_dl, val_dl
134
+ )
135
+
136
+ run = wandb.init(cfg.wandb.project_name,cfg.wandb.group_name,cfg.wandb.run_name)
137
+ if cfg.training.do_train:
138
+ trainer = ATISTrainer(
139
+ model,
140
+ optimizer,
141
+ scheduler,
142
+ criterion,
143
+ accelerator,
144
+ cfg.dataset.output_dir,
145
+ cfg.dataset.num_labels,
146
+ cfg.dataset.num_intents,
147
+ run
148
+ )
149
+ best_model, best_loss = trainer.fit(
150
+ cfg.training.max_epochs, train_dl, val_dl, cfg.training.patience
151
+ )
152
+ model_dir = f"{cfg.dataset.output_dir}/model_{best_loss}"
153
+ if not os.path.exists(model_dir):
154
+ os.makedirs(model_dir)
155
+ best_model.save_pretrained(model_dir, push_to_hub=False)
156
+ logging.info(
157
+ f"The Best model with validation loss {best_loss} is saved in {model_dir}"
158
+ )
159
+ if cfg.training.do_predict:
160
+ predictor = TamilATISPredictor(
161
+ model,
162
+ cfg.model.test_model,
163
+ cfg.model.tokenizer_name,
164
+ label_encoder,
165
+ intent_encoder,
166
+ cfg.model.num_labels,
167
+ )
168
+ outputs, intents = predictor.predict_test_data(test_data)
169
+ ner_cls_rep, intent_cls_rep = predictor.evaluate(outputs, intents)
170
+ ner_cls_df = pd.DataFrame(ner_cls_rep).transpose()
171
+ intent_cls_df = pd.DataFrame(intent_cls_rep).transpose()
172
+ ner_cls_df.to_csv(cfg.training.ner_cls_path)
173
+ intent_cls_df.to_csv(cfg.training.intent_cls_path)
174
+ logging.info(
175
+ f"Classification reports of intents and slots are saved in {cfg.training.ner_cls_path} and {cfg.training.intent_cls_path}"
176
+ )
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()
tamilatis/model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from huggingface_hub import PyTorchModelHubMixin
3
+ from transformers import AutoConfig, AutoModelForTokenClassification
4
+
5
+
6
+ class JointATISModel(nn.Module, PyTorchModelHubMixin):
7
+ def __init__(self, model_name, num_labels, num_intents):
8
+ super().__init__()
9
+ self.model = AutoModelForTokenClassification.from_pretrained(
10
+ model_name, num_labels=num_labels
11
+ )
12
+ self.model_config = AutoConfig.from_pretrained(model_name)
13
+ self.intent_head = nn.Linear(self.model_config.hidden_size, num_intents)
14
+
15
+ def forward(self, input_ids, attention_mask, labels):
16
+ outputs = self.model(
17
+ input_ids, attention_mask, labels=labels, output_hidden_states=True
18
+ )
19
+ pooled_output = outputs["hidden_states"][-1][:, 0, :]
20
+ intent_logits = self.intent_head(pooled_output)
21
+ return {
22
+ "dst_logits": outputs.logits,
23
+ "intent_loss": intent_logits,
24
+ "dst_loss": outputs.loss,
25
+ }
tamilatis/predict.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+ from transformers import AutoTokenizer
5
+
6
+
7
+
8
+ class TamilATISPredictor:
9
+ def __init__(
10
+ self,
11
+ model,
12
+ checkpoint_path,
13
+ tokenizer,
14
+ label_encoder,
15
+ intent_encoder,
16
+ num_labels,
17
+ ):
18
+ self.model = model
19
+ self.model.load_state_dict(torch.load(checkpoint_path))
20
+ self.model.eval()
21
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.num_labels = num_labels
24
+ self.label_encoder = label_encoder
25
+ self.intent_encoder = intent_encoder
26
+
27
+ def get_predictions(self, text):
28
+
29
+ inputs = self.tokenizer(
30
+ text.split(),
31
+ is_split_into_words=True,
32
+ return_offsets_mapping=True,
33
+ padding="max_length",
34
+ truncation=True,
35
+ max_length=60,
36
+ return_tensors="pt",
37
+ )
38
+ ids = inputs["input_ids"].to(self.device)
39
+ mask = inputs["attention_mask"].to(self.device)
40
+
41
+ # forward pass
42
+ loss_dict = self.model(input_ids=ids, attention_mask=mask, labels=None)
43
+ slot_logits, intent_logits, slot_loss = (
44
+ loss_dict["dst_logits"],
45
+ loss_dict["intent_loss"],
46
+ loss_dict["dst_loss"],
47
+ )
48
+
49
+ active_logits = slot_logits.view(
50
+ -1, self.num_labels
51
+ ) # shape (batch_size * seq_len, num_labels)
52
+ flattened_predictions = torch.argmax(
53
+ active_logits, axis=1
54
+ ) # shape (batch_size*seq_len,) - predictions at the token level
55
+ tokens = self.tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
56
+ token_predictions = self.label_encoder.inverse_transform(
57
+ [i for i in flattened_predictions.cpu().numpy()]
58
+ )
59
+ wp_preds = list(
60
+ zip(tokens, token_predictions)
61
+ ) # list of tuples. Each tuple = (wordpiece, prediction)
62
+
63
+ slot_prediction = []
64
+ for token_pred, mapping in zip(
65
+ wp_preds, inputs["offset_mapping"].squeeze().tolist()
66
+ ):
67
+ # only predictions on first word pieces are important
68
+ if mapping[0] == 0 and mapping[1] != 0 and token_pred[0] != "▁":
69
+ slot_prediction.append(token_pred[1])
70
+ else:
71
+ continue
72
+ intent_preds = torch.argmax(intent_logits, axis=1)
73
+ intent_preds = self.intent_encoder.inverse_transform(intent_preds.cpu().numpy())
74
+
75
+ return intent_preds, slot_prediction
76
+
77
+ def predict_test_data(self, test_data):
78
+ outputs = []
79
+ intents = []
80
+
81
+ for item, intent in tqdm(test_data):
82
+ try:
83
+ tokens = [token for token, tag in item]
84
+ tags = [tag for token, tag in item]
85
+ text = " ".join(tokens)
86
+ intent_preds, slot_preds = self.get_predictions(text)
87
+ outputs.append((tags, slot_preds))
88
+ intents.append((intent, intent_preds.item()))
89
+ except Exception as error:
90
+ print(error)
91
+ return outputs, intents
92
+
93
+ def evaluate(self, outputs, intents):
94
+ for output in tqdm(outputs):
95
+ assert len(output[0]) == len(output[1])
96
+ y_true = [output[0] for output in outputs]
97
+ y_pred = [output[1] for output in outputs]
98
+ from seqeval.metrics import classification_report
99
+
100
+ ner_cls_rep = classification_report(y_true, y_pred, output_dict=True)
101
+ from sklearn.metrics import classification_report
102
+
103
+ # Compute metrics for intent
104
+ y_true = self.intent_encoder.transform(
105
+ [output[0] for output in intents]
106
+ ).tolist()
107
+ y_pred = self.intent_encoder.transform(
108
+ [output[1] for output in intents]
109
+ ).tolist()
110
+
111
+ target_names = self.intent_encoder.classes_.tolist()
112
+ target_names = [target_names[idx] for idx in np.unique(y_true + y_pred)]
113
+ intent_cls_rep = classification_report(
114
+ y_true, y_pred, target_names=target_names, output_dict=True
115
+ )
116
+
117
+ return ner_cls_rep, intent_cls_rep
tamilatis/trainer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import wandb
8
+ from torchmetrics.functional import accuracy, f1_score, precision, recall
9
+ from tqdm import tqdm, trange
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class ATISTrainer:
15
+ """A Trainer class consists of utitlity functions for training the model"""
16
+ def __init__(
17
+ self,
18
+ model,
19
+ optimizer,
20
+ scheduler,
21
+ criterion,
22
+ accelerate,
23
+ output_dir,
24
+ num_labels,
25
+ num_intents,
26
+ run
27
+ ):
28
+ self.model = model
29
+ self.criterion = criterion
30
+ self.optimizer = optimizer
31
+ self.scheduler = scheduler
32
+ self.accelerator = accelerate
33
+ self.output_dir = output_dir
34
+ self.num_labels = num_labels
35
+ self.num_intents = num_intents
36
+
37
+ if not os.path.exists(self.output_dir):
38
+ os.makedirs(self.output_dir)
39
+
40
+ self.run = run
41
+ logging.info(f"Strating Training, outputs are saved in {self.output_dir}")
42
+
43
+ def train_step(self, iterator):
44
+ training_progress_bar = tqdm(iterator, desc="training")
45
+ for batch in training_progress_bar:
46
+ input_ids, attention_mask, labels, intents = (
47
+ batch["input_ids"],
48
+ batch["attention_mask"],
49
+ batch["labels"],
50
+ batch["intent"],
51
+ )
52
+ self.optimizer.zero_grad()
53
+ loss_dict = self.model(input_ids, attention_mask, labels)
54
+ slot_logits, intent_logits, slot_loss = (
55
+ loss_dict["dst_logits"],
56
+ loss_dict["intent_loss"],
57
+ loss_dict["dst_loss"],
58
+ )
59
+
60
+ # compute training accuracy for slots
61
+ flattened_target_labels = batch["labels"].view(
62
+ -1
63
+ ) # [batch_size * seq_len, ]
64
+ active_logits = slot_logits.view(
65
+ -1, self.num_labels
66
+ ) # [batch_size* seq_len, num_labels]
67
+ flattened_preds = torch.argmax(
68
+ active_logits, axis=-1
69
+ ) # [batch_size * seq_len,]
70
+
71
+ # compute accuracy at active labels
72
+ active_accuracy = (
73
+ batch["labels"].view(-1) != -100
74
+ ) # [batch_size * seq_len, ]
75
+
76
+ slot_labels = torch.masked_select(flattened_target_labels, active_accuracy)
77
+ slot_preds = torch.masked_select(flattened_preds, active_accuracy)
78
+
79
+ # compute loss for intents
80
+ #use rlw
81
+ intent_loss = self.criterion(intent_logits, batch["intent"])
82
+ weight = F.softmax(torch.randn(1), dim=-1) # RLW is only this!
83
+ intent_loss = torch.sum(intent_loss*weight.cuda())
84
+ intent_preds = torch.argmax(intent_logits, axis=1)
85
+ train_loss = slot_loss + intent_loss
86
+ self.accelerator.backward(train_loss)
87
+ self.optimizer.step()
88
+
89
+ if self.scheduler is not None:
90
+ if not self.accelerator.optimizer_step_was_skipped:
91
+ self.scheduler.step()
92
+
93
+ if self.scheduler is not None:
94
+ self.scheduler.step()
95
+
96
+ intent_acc = accuracy(
97
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
98
+ )
99
+ intent_f1 = f1_score(
100
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
101
+ )
102
+ intent_rec = recall(
103
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
104
+ )
105
+ intent_prec = precision(
106
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
107
+ )
108
+
109
+ slot_acc = accuracy(
110
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
111
+ )
112
+ slot_f1 = f1_score(
113
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
114
+ )
115
+ slot_rec = recall(
116
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
117
+ )
118
+ slot_prec = precision(
119
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
120
+ )
121
+
122
+ self.run.log(
123
+ {
124
+ "train_loss_step": train_loss.cpu().detach().numpy(),
125
+ "train_intent_acc_step": intent_acc,
126
+ "train_intent_f1_step": intent_f1,
127
+ "train_slot_acc_step": slot_acc,
128
+ "train_slot_f1_step": slot_f1,
129
+ }
130
+ )
131
+ # logging.info({"train_loss_step": train_loss, "train_intent_acc_step": intent_acc, "train_intent_f1_step": intent_f1, "train_slot_acc_step": slot_acc, "train_slot_f1_step": slot_f1 })
132
+
133
+ return {
134
+ "train_loss_epoch": train_loss / len(iterator),
135
+ "train_intent_f1_epoch": intent_f1 / len(iterator),
136
+ "train_intent_acc_epoch": intent_acc / len(iterator),
137
+ "train_slot_f1_epoch": slot_f1 / len(iterator),
138
+ "train_slot_acc_epoch": slot_acc / len(iterator),
139
+ }
140
+
141
+ @torch.no_grad()
142
+ def eval_step(self, iterator):
143
+ eval_progress_bar = tqdm(iterator, desc="Evaluating")
144
+ for batch in eval_progress_bar:
145
+ input_ids, attention_mask, labels, intents = (
146
+ batch["input_ids"],
147
+ batch["attention_mask"],
148
+ batch["labels"],
149
+ batch["intent"],
150
+ )
151
+ loss_dict = self.model(input_ids, attention_mask, labels)
152
+ slot_logits, intent_logits, slot_loss = (
153
+ loss_dict["dst_logits"],
154
+ loss_dict["intent_loss"],
155
+ loss_dict["dst_loss"],
156
+ )
157
+ # compute training accuracy for slots
158
+ flattened_target_labels = batch["labels"].view(
159
+ -1
160
+ ) # [batch_size * seq_len, ]
161
+ active_logits = slot_logits.view(
162
+ -1, self.num_labels
163
+ ) # [batch_size* seq_len, num_labels]
164
+ flattened_preds = torch.argmax(
165
+ active_logits, axis=-1
166
+ ) # [batch_size * seq_len,]
167
+
168
+ # compute accuracy at active labels
169
+ active_accuracy = (
170
+ batch["labels"].view(-1) != -100
171
+ ) # [batch_size * seq_len, ]
172
+
173
+ slot_labels = torch.masked_select(flattened_target_labels, active_accuracy)
174
+ slot_preds = torch.masked_select(flattened_preds, active_accuracy)
175
+
176
+ # compute loss for intents
177
+ intent_loss = self.criterion(intent_logits, batch["intent"])
178
+ weight = F.softmax(torch.randn(1), dim=-1) # RLW is only this!
179
+ intent_loss = torch.sum(intent_loss*weight.cuda())
180
+
181
+ intent_preds = torch.argmax(intent_logits, axis=1)
182
+ eval_loss = slot_loss + intent_loss
183
+
184
+ intent_acc = accuracy(
185
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
186
+ )
187
+ intent_f1 = f1_score(
188
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
189
+ )
190
+ intent_rec = recall(
191
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
192
+ )
193
+ intent_prec = precision(
194
+ intent_preds, intents, num_classes=self.num_intents, average="weighted"
195
+ )
196
+
197
+ slot_acc = accuracy(
198
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
199
+ )
200
+ slot_f1 = f1_score(
201
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
202
+ )
203
+ slot_rec = recall(
204
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
205
+ )
206
+ slot_prec = precision(
207
+ slot_preds, slot_labels, num_classes=self.num_labels, average="weighted"
208
+ )
209
+
210
+ self.run.log(
211
+ {
212
+ "eval_loss_step": eval_loss,
213
+ "eval_intent_acc_step": intent_acc,
214
+ "eval_intent_f1_step": intent_f1,
215
+ "eval_slot_acc_step": slot_acc,
216
+ "eval_slot_f1_step": slot_f1,
217
+ }
218
+ )
219
+
220
+ return {
221
+ "eval_loss_epoch": eval_loss / len(iterator),
222
+ "eval_intent_f1_epoch": intent_f1 / len(iterator),
223
+ "eval_intent_acc_epoch": intent_acc / len(iterator),
224
+ "eval_slot_f1_epoch": slot_f1 / len(iterator),
225
+ "eval_slot_acc_epoch": slot_acc / len(iterator),
226
+ }
227
+
228
+ def fit(self, n_epochs, train_dataloader, eval_dataloader, patience):
229
+ best_eval_loss = float("inf")
230
+ pbar = trange(n_epochs)
231
+
232
+ for epoch in pbar:
233
+ train_metrics_dict = self.train_step(train_dataloader)
234
+ eval_metrics_dict = self.eval_step(eval_dataloader)
235
+ # access all the values from the dicts
236
+ train_loss, eval_loss = (
237
+ train_metrics_dict["train_loss_epoch"],
238
+ eval_metrics_dict["eval_loss_epoch"],
239
+ )
240
+ train_intent_f1, eval_intent_f1 = (
241
+ train_metrics_dict["train_intent_f1_epoch"],
242
+ eval_metrics_dict["eval_intent_f1_epoch"],
243
+ )
244
+ train_intent_acc, eval_intent_acc = (
245
+ train_metrics_dict["train_intent_acc_epoch"],
246
+ eval_metrics_dict["eval_intent_acc_epoch"],
247
+ )
248
+ train_slot_f1, eval_slot_f1 = (
249
+ train_metrics_dict["train_intent_acc_epoch"],
250
+ eval_metrics_dict["eval_intent_acc_epoch"],
251
+ )
252
+ train_slot_acc, eval_slot_acc = (
253
+ train_metrics_dict["train_slot_acc_epoch"],
254
+ eval_metrics_dict["eval_slot_acc_epoch"],
255
+ )
256
+
257
+
258
+ if eval_loss < best_eval_loss:
259
+ best_model = self.model
260
+ best_eval_loss = eval_loss
261
+
262
+ train_logs = {
263
+ "epoch": epoch,
264
+ "train_loss": train_loss,
265
+ "eval_loss": eval_loss,
266
+ "train_intent_acc": train_intent_acc,
267
+ "train_intent_f1": train_intent_f1,
268
+ "eval_intent_f1": eval_intent_f1,
269
+ "eval_intent_acc": eval_intent_acc,
270
+ "train_slot_f1": train_slot_f1,
271
+ "train_slot_acc": train_slot_acc,
272
+ "lr": {self.optimizer.param_groups[0]["lr"]: 0.2},
273
+ }
274
+
275
+ train_logs["patience"] = patience
276
+ logging.info(train_logs)
277
+ logging.info(eval_metrics_dict)
278
+
279
+ self.accelerator.wait_for_everyone()
280
+ model = self.accelerator.unwrap_model(self.model)
281
+ self.accelerator.save_state(self.output_dir)
282
+ logging.info(f"Checkpoint is saved in {self.output_dir}")
283
+
284
+ return best_model, best_eval_loss