BecomeAllan commited on
Commit
6755d15
1 Parent(s): a701d2a

init_comit

Browse files
Files changed (3) hide show
  1. app.py +233 -0
  2. requeriments.txt +3 -0
  3. utils.py +479 -0
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import unicodedata
7
+ import re
8
+
9
+ # Undesirable patterns within texts
10
+ patterns = {
11
+ 'CONCLUSIONS AND IMPLICATIONS':'',
12
+ 'BACKGROUND AND PURPOSE':'',
13
+ 'EXPERIMENTAL APPROACH':'',
14
+ 'KEY RESULTS AEA':'',
15
+ '©':'',
16
+ '®':'',
17
+ 'μ':'',
18
+ '(C)':'',
19
+ 'OBJECTIVE:':'',
20
+ 'MATERIALS AND METHODS:':'',
21
+ 'SIGNIFICANCE:':'',
22
+ 'BACKGROUND:':'',
23
+ 'RESULTS:':'',
24
+ 'METHODS:':'',
25
+ 'CONCLUSIONS:':'',
26
+ 'AIM:':'',
27
+ 'STUDY DESIGN:':'',
28
+ 'CLINICAL RELEVANCE:':'',
29
+ 'CONCLUSION:':'',
30
+ 'HYPOTHESIS:':'',
31
+ 'CLINICAL RELEVANCE:':'',
32
+ 'Questions/Purposes:':'',
33
+ 'Introduction:':'',
34
+ 'PURPOSE:':'',
35
+ 'PATIENTS AND METHODS:':'',
36
+ 'FINDINGS:':'',
37
+ 'INTERPRETATIONS:':'',
38
+ 'FUNDING:':'',
39
+ 'PROGRESS:':'',
40
+ 'CONTEXT:':'',
41
+ 'MEASURES:':'',
42
+ 'DESIGN:':'',
43
+ 'BACKGROUND AND OBJECTIVES:':'',
44
+ '<p>':'',
45
+ '</p>':'',
46
+ '<<ETX>>':'',
47
+ '+/-':'',
48
+ }
49
+
50
+ patterns = {x.lower():y for x,y in patterns.items()}
51
+
52
+ class treat_text:
53
+ def __init__(self, patterns):
54
+ self.patterns = patterns
55
+
56
+ def __call__(self,text):
57
+ text = unicodedata.normalize("NFKD",str(text))
58
+ text = multiple_replace(self.patterns,text.lower())
59
+ text = re.sub('(\(.+\))|(\[.+\])|( \d )|(<)|(>)|(- )','', text)
60
+ text = re.sub('( +)',' ', text)
61
+ text = re.sub('(, ,)|(,,)',',', text)
62
+ text = re.sub('(%)|(per cent)',' percent', text)
63
+ return text
64
+
65
+ # Regex multiple replace function
66
+ def multiple_replace(dict, text):
67
+
68
+ # Building regex from dict keys
69
+ regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys())))
70
+
71
+ # Substitution
72
+ return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text)
73
+
74
+ treat_text_fun = treat_text(patterns)
75
+
76
+ import sys
77
+ sys.path.append('ML-SLRC/')
78
+
79
+ path = 'ML-SLRC/'
80
+
81
+ model_path = path + 'model.pt'
82
+ info_path = path + 'Info.json'
83
+
84
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
+
86
+ # # carrega o modelo
87
+ model = torch.load(model_path)
88
+
89
+
90
+ # # carrega as meta informações do modelo treinado
91
+ with open(info_path, 'r') as f:
92
+ Info = json.load(f)
93
+
94
+ import random
95
+ from datetime import datetime
96
+
97
+
98
+ rand_seed = 2003
99
+
100
+ # datetime object containing current date and time
101
+ now = datetime.now()
102
+
103
+ time_stamp = now.strftime("%d_%m_%Y_HR_%H_%M_%S")
104
+
105
+
106
+ config = {
107
+ "shots_per_class":8,
108
+ "batch_size":4,
109
+ "epochs":8,
110
+ "learning_rate":5e-05,
111
+ "weight_decay": 0.85,
112
+ "rand_seed":rand_seed,
113
+ 'pos_weight':3.5,
114
+ 'p_incld': 0.2,
115
+ 'p_excld': 0.01,
116
+ }
117
+
118
+
119
+ NAME = str(config['shots_per_class'])+'-shots-Learner' +'_'+ time_stamp
120
+ num_workers = 0
121
+ val_batch = 100
122
+
123
+ p_included = 0.7
124
+ p_notincluded = 0.3
125
+ sample_valid = 300
126
+
127
+
128
+
129
+
130
+ gen_seed = torch.Generator().manual_seed(rand_seed)
131
+ np.random.seed(rand_seed)
132
+ torch.manual_seed(rand_seed)
133
+ random.seed(rand_seed)
134
+
135
+
136
+
137
+
138
+ def treat_data_input(data, etailment_txt):
139
+
140
+ data_train = data.groupby('test').sample(frac=1)
141
+ dataload_all = data.copy()
142
+
143
+ dataload_all.test = dataload_all.test.replace({np.nan: 'NANN'})
144
+
145
+
146
+ dataset_train = SLR_DataSet(data=data_train,
147
+ input= 'text',
148
+ output='test',
149
+ tokenizer= initializer_model_scibert.tokenizer,
150
+ LABEL_MAP=LABEL_MAP,
151
+ treat_text=treat_text_fun,
152
+ etailment_txt=etailment_txt)
153
+
154
+ dataset_remain = SLR_DataSet(data=dataload_all,
155
+ input= 'text',
156
+ output='test',
157
+ tokenizer= initializer_model_scibert.tokenizer,
158
+ LABEL_MAP=LABEL_MAP,
159
+ treat_text=treat_text_fun,
160
+ etailment_txt=etailment_txt)
161
+
162
+
163
+
164
+ dataload_train = DataLoader(dataset_train,
165
+ batch_size=config['batch_size'],drop_last=False,
166
+ num_workers=num_workers)
167
+
168
+ dataload_remain = DataLoader(dataset_remain,
169
+ batch_size=200,drop_last=False,
170
+ num_workers=num_workers)
171
+
172
+ return dataload_train, dataload_remain
173
+
174
+
175
+ import gc
176
+ from torch.optim import Adam
177
+ from scipy.stats import entropy
178
+
179
+ def treat_train_evaluate(dataload_train, dataload_remain):
180
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
181
+
182
+ gc.collect()
183
+ torch.cuda.empty_cache()
184
+
185
+
186
+ model_few = deepcopy(model)
187
+ model_few.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
188
+ pos_weight=torch.FloatTensor([config['pos_weight']]))
189
+
190
+
191
+ optimizer = Adam(model_few.parameters(), lr = config['learning_rate'],
192
+ weight_decay = config['weight_decay'])
193
+
194
+
195
+ model_few.to('cuda')
196
+ model_few.train()
197
+
198
+
199
+ trainlog = model_few.fit(optimizer=optimizer,
200
+ scheduler = None,
201
+ data_train_loader=dataload_train,
202
+ epochs = config['epochs'], print_info = 1, metrics= False,
203
+ log = None, metrics_print = False)
204
+
205
+
206
+
207
+ (loss, features_out, (logits, outputs)) = model_few.evaluate(dataload_remain)
208
+ return logits
209
+
210
+ def treat_sort(dataload_all,logits):
211
+ dataload_all['prediction'] = torch.sigmoid(logits)
212
+ dataload_all = dataload_all.sort_values(by=['prediction'], ascending=False).reset_index(drop=True)
213
+ dataload_all.to_excel("output.xlsx")
214
+
215
+ def pipeline(data):
216
+ # data = pd.read_csv(fil.name)
217
+ data = pd.read_excel(data)
218
+ dataload_train, dataload_remain = treat_data_input(data,"its a great text")
219
+ logits = treat_train_evaluate(dataload_train, dataload_remain)
220
+ treat_sort(dataload_all,logits)
221
+ return "output.xlsx"
222
+
223
+
224
+ import gradio as gr
225
+
226
+
227
+ with gr.Blocks() as demo:
228
+ fil = gr.File(label="input data")
229
+ output = gr.File(label="output data")
230
+ greet_btn = gr.Button("Greet")
231
+ greet_btn.click(fn=pipeline, inputs=fil, outputs=output)
232
+
233
+ demo.launch()
requeriments.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers==4.16.2
2
+ torchmetrics==0.8.0
3
+ matplotlib==3.5.1
utils.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset, DataLoader
4
+
5
+ LABEL_MAP = {'negative': 0,
6
+ 'not included':0,
7
+ '0':0,
8
+ 0:0,
9
+ 'excluded':0,
10
+ 'positive': 1,
11
+ 'included':1,
12
+ '1':1,
13
+ 1:1,
14
+ }
15
+
16
+ class SLR_DataSet(Dataset):
17
+ def __init__(self,
18
+ treat_text =None,
19
+ etailment_txt =None,
20
+ LABEL_MAP= None,
21
+ NA = None,
22
+ **args):
23
+ self.tokenizer = args.get('tokenizer')
24
+ self.data = args.get('data').reset_index()
25
+ self.max_seq_length = args.get("max_seq_length", 512)
26
+ self.INPUT_NAME = args.get("input", 'x')
27
+ self.LABEL_NAME = args.get("output", None)
28
+ self.treat_text = treat_text
29
+ self.etailment_txt = etailment_txt
30
+ self.LABEL_MAP=LABEL_MAP
31
+ self.NA=NA
32
+
33
+ if not self.INPUT_NAME in self.data.columns:
34
+ self.data[self.INPUT_NAME] = np.nan
35
+
36
+
37
+ # Tokenizing and processing text
38
+ def encode_text(self, example):
39
+ comment_text = example[self.INPUT_NAME]
40
+ if not isinstance(self.treat_text,type(None)):
41
+ comment_text = self.treat_text(comment_text)
42
+
43
+ if example[self.LABEL_NAME] is np.NaN and self.NA != None:
44
+ labels = self.NA
45
+
46
+ elif self.LABEL_NAME != None:
47
+ try:
48
+ labels = self.LABEL_MAP[example[self.LABEL_NAME]]
49
+ except:
50
+ labels = -1
51
+ # raise TypeError(f"Label passed {example[self.LABEL_NAME]}, is not be in LABEL_MAP")
52
+ # print('Not handle LABEL_MAP')
53
+ else:
54
+ labels = None
55
+
56
+ if self.etailment_txt:
57
+ tensor_data = self.tokenize((comment_text, self.etailment_txt), labels )
58
+ else:
59
+ tensor_data = self.tokenize((comment_text), labels)
60
+
61
+ return tensor_data
62
+
63
+ def tokenize(self, comment_text, labels):
64
+ encoding = self.tokenizer.encode_plus(
65
+ (comment_text),
66
+ add_special_tokens=True,
67
+ max_length=self.max_seq_length,
68
+ return_token_type_ids=True,
69
+ padding="max_length",
70
+ truncation=True,
71
+ return_attention_mask=True,
72
+ return_tensors='pt',
73
+ )
74
+
75
+
76
+
77
+ if labels != None:
78
+ return tuple(((
79
+ encoding["input_ids"].flatten(),
80
+ encoding["attention_mask"].flatten(),
81
+ encoding["token_type_ids"].flatten()
82
+ ),
83
+ torch.tensor([torch.tensor(labels).to(int)])
84
+ ))
85
+ else:
86
+ return tuple(((
87
+ encoding["input_ids"].flatten(),
88
+ encoding["attention_mask"].flatten(),
89
+ encoding["token_type_ids"].flatten()
90
+ ),
91
+ torch.empty(0)
92
+ ))
93
+
94
+
95
+ def __len__(self):
96
+ return len(self.data)
97
+
98
+ # Returning data
99
+ def __getitem__(self, index: int):
100
+ # print(index)
101
+ data_row = self.data.iloc[index]
102
+ tensor_data = self.encode_text(data_row)
103
+ return tensor_data
104
+
105
+
106
+ from tqdm import tqdm
107
+ import gc
108
+ from IPython.display import clear_output
109
+ from collections import namedtuple
110
+
111
+ features = namedtuple('features', ['bert', 'feature_map'])
112
+ Output = namedtuple('Output', ['loss', 'features', 'logit'])
113
+
114
+ bert_tuple = namedtuple('bert',['hidden_states', 'attentions'])
115
+
116
+
117
+
118
+ class loop():
119
+
120
+ @classmethod
121
+ def train_loop(self, model,device, optimizer, data_train_loader, scheduler = None, data_valid_loader = None,
122
+ epochs = 4, print_info = 1000000000, metrics = True, log = None, metrics_print = True):
123
+ # Start the model's parameters
124
+
125
+ table.reset()
126
+ model.to(device)
127
+ model.train()
128
+
129
+ # Task epochs (Inner epochs)
130
+ for epoch in range(0, epochs):
131
+ train_loss, _, out = self.batch_loop(data_train_loader, model, optimizer, device)
132
+
133
+ if scheduler is not None:
134
+ for sched in scheduler:
135
+ sched.step()
136
+
137
+ if (epoch % print_info == 0):
138
+ if metrics:
139
+ labels = self.map_batch(out[1]).to(int).squeeze()
140
+ logits = self.map_batch(out[0]).squeeze()
141
+
142
+ train_metrics, _ = plot(logits, labels, 0.9)
143
+
144
+ del labels, logits
145
+
146
+ train_metrics['Loss'] = torch.Tensor(train_loss).mean().item()
147
+
148
+ if not isinstance(log,type(None)):
149
+ log({"train_"+ x :y for x,y in train_metrics.items()})
150
+
151
+ table(train_metrics, epoch, "Train")
152
+
153
+ else:
154
+ print("Loss: ", torch.Tensor(train_loss).mean().item())
155
+
156
+ if data_valid_loader:
157
+ valid_loss, _, out = self.eval_loop(data_valid_loader, model, device=device)
158
+ if metrics:
159
+ global out2
160
+ out2 = out
161
+ labels = self.map_batch(out[1]).to(int).squeeze()
162
+ logits = self.map_batch(out[0]).squeeze()
163
+
164
+ valid_metrics, _ = plot(logits, labels, 0.9)
165
+ valid_metrics['Loss'] = torch.Tensor(valid_loss).mean().item()
166
+
167
+ del labels, logits
168
+
169
+ if not isinstance(log,type(None)):
170
+ log({"valid_"+ x :y for x,y in train_metrics.items()})
171
+
172
+ table(valid_metrics, epoch, "Valid")
173
+
174
+ if metrics_print:
175
+ print(table.data_frame().round(4))
176
+
177
+ else:
178
+ print("Valid Loss: ", torch.Tensor(valid_loss).mean().item())
179
+
180
+ return table.data_frame()
181
+
182
+ @classmethod
183
+ def batch_loop(self, loader, model, optimizer, device):
184
+ all_loss = []
185
+ features_lst = []
186
+ attention_lst = []
187
+ logits = []
188
+ outputs = []
189
+
190
+ # Test's Batch loop
191
+ for inner_step, batch in enumerate(tqdm(loader,
192
+ desc="Train validation | ",
193
+ ncols=80)) :
194
+ input, output =batch
195
+ input = tuple(t.to(device) for t in input)
196
+
197
+ if isinstance(output, torch.Tensor):
198
+ output = output.to(device)
199
+
200
+
201
+ optimizer.zero_grad()
202
+
203
+ # Predictions
204
+ loss, feature, logit = model(input, output)
205
+
206
+ # compute grads
207
+ loss.backward()
208
+
209
+ # update parameters
210
+ optimizer.step()
211
+
212
+
213
+ input = tuple(t.to("cpu") for t in input)
214
+
215
+ if isinstance(output, torch.Tensor):
216
+ output = output.to("cpu")
217
+
218
+ if isinstance(loss, torch.Tensor):
219
+ all_loss.append(loss.to('cpu').detach().clone())
220
+
221
+ if isinstance(logit, torch.Tensor):
222
+ logits.append(logit.to('cpu').detach().clone())
223
+
224
+
225
+ if isinstance(output, torch.Tensor):
226
+ outputs.append(output.to('cpu').detach().clone())
227
+
228
+ if len(feature.feature_map)!=0:
229
+ features_lst.append([x.to('cpu').detach().clone() for x in feature.feature_map])
230
+
231
+
232
+ del batch, input, output, loss, feature, logit
233
+
234
+ # model.to('cpu')
235
+ gc.collect()
236
+ torch.cuda.empty_cache()
237
+
238
+ # del model, optimizer
239
+
240
+ return Output(all_loss, features(None,features_lst), (logits, outputs))
241
+
242
+ @classmethod
243
+ def eval_loop(self, loader, model, device, attention= False, hidden_states=False):
244
+ all_loss = []
245
+ features_lst = []
246
+ attention_lst = []
247
+ hidden_states_lst = []
248
+ logits = []
249
+ outputs = []
250
+ model.eval()
251
+
252
+ with torch.no_grad():
253
+ # Test's Batch loop
254
+ for inner_step, batch in enumerate(tqdm(loader,
255
+ desc="Test validation | ",
256
+ ncols=80)) :
257
+ input, output =batch
258
+ input = tuple(t.to(device) for t in input)
259
+
260
+
261
+ if output.numel()!=0:
262
+ # Predictions
263
+ loss, feature, logit = model(input, output.to(device),
264
+ attention= attention, hidden_states=hidden_states)
265
+ else:
266
+ # Predictions
267
+ loss, feature, logit = model(input,
268
+ attention= attention, hidden_states=hidden_states)
269
+
270
+
271
+ input = tuple(t.to("cpu") for t in input)
272
+
273
+ if isinstance(output, torch.Tensor):
274
+ output = output.to("cpu")
275
+
276
+ if isinstance(loss, torch.Tensor):
277
+ all_loss.append(loss.to('cpu').detach().clone())
278
+
279
+ if isinstance(logit, torch.Tensor):
280
+ logits.append(logit.to('cpu').detach().clone())
281
+
282
+ try:
283
+ if not isinstance(feature.bert.attentions, type(None)):
284
+ attention_lst.append([x.to('cpu').detach().clone() for x in feature.bert.attentions])
285
+ except:
286
+ attention_lst = None
287
+
288
+ try:
289
+ if not isinstance(feature.bert.hidden_states, type(None)):
290
+ hidden_states_lst.append([x.to('cpu').detach().clone() for x in feature.bert.hidden_states])
291
+ except:
292
+ hidden_states_lst = None
293
+
294
+ if isinstance(output, torch.Tensor):
295
+ outputs.append(output.to('cpu').detach().clone())
296
+
297
+ if len(feature.feature_map)!=0:
298
+ features_lst.append([x.to('cpu').detach().clone() for x in feature.feature_map])
299
+
300
+
301
+ del batch, input, output, loss, feature, logit
302
+
303
+ # model.to('cpu')
304
+ gc.collect()
305
+ torch.cuda.empty_cache()
306
+
307
+ # del model, optimizer
308
+
309
+ return Output(all_loss, features(bert_tuple(hidden_states_lst,attention_lst),features_lst), (logits, outputs))
310
+
311
+ # Process predictions and map the feature_map in tsne
312
+ @staticmethod
313
+ def map_batch(features):
314
+ features = torch.cat(features, dim =0)
315
+ # features = np.concatenate(np.array(features,dtype=object)).astype(np.float32)
316
+ # features = torch.tensor(features)
317
+ return features.detach().clone()
318
+
319
+
320
+ class table:
321
+ data = []
322
+ index = []
323
+
324
+ @torch.no_grad()
325
+ def __init__(self, data, epochs, name):
326
+ self.index.append((epochs, name))
327
+ self.data.append(data)
328
+
329
+
330
+ @classmethod
331
+ @torch.no_grad()
332
+ def data_frame(cls):
333
+ clear_output()
334
+ index = pd.MultiIndex.from_tuples(cls.index, names=["Epochs", "Data"])
335
+ data = pd.DataFrame(cls.data, index=index)
336
+ return data
337
+
338
+ @classmethod
339
+ @torch.no_grad()
340
+ def reset(cls):
341
+ cls.data = []
342
+ cls.index = []
343
+
344
+ from collections import namedtuple
345
+
346
+ # Declaring namedtuple()
347
+
348
+
349
+ # Pre-trained model
350
+ class Encoder(nn.Module):
351
+ def __init__(self, layers, freeze_bert, model):
352
+ super(Encoder, self).__init__()
353
+
354
+ # Dummy Parameter
355
+ self.dummy_param = nn.Parameter(torch.empty(0))
356
+
357
+ # Pre-trained model
358
+ self.model = deepcopy(model)
359
+
360
+ # Freezing bert parameters
361
+ if freeze_bert:
362
+ for param in self.model.parameters():
363
+ param.requires_grad = freeze_bert
364
+
365
+ # Selecting hidden layers of the pre-trained model
366
+ old_model_encoder = self.model.encoder.layer
367
+ new_model_encoder = nn.ModuleList()
368
+
369
+ for i in layers:
370
+ new_model_encoder.append(old_model_encoder[i])
371
+
372
+ self.model.encoder.layer = new_model_encoder
373
+
374
+ # Feed forward
375
+ def forward(self, output_attentions=False,output_hidden_states=False, **x):
376
+
377
+ return self.model(output_attentions=output_attentions,
378
+ output_hidden_states=output_hidden_states,
379
+ return_dict=True,
380
+ **x)
381
+
382
+ # Complete model
383
+ class SLR_Classifier(nn.Module):
384
+ def __init__(self, **data):
385
+ super(SLR_Classifier, self).__init__()
386
+
387
+ # Dummy Parameter
388
+ self.dummy_param = nn.Parameter(torch.empty(0))
389
+
390
+ # Loss function
391
+ # Binary Cross Entropy with logits reduced to mean
392
+ self.loss_fn = nn.BCEWithLogitsLoss(reduction = 'mean',
393
+ pos_weight=torch.FloatTensor([data.get("pos_weight", 2.5)]))
394
+
395
+ # Pre-trained model
396
+ self.Encoder = Encoder(layers = data.get("bert_layers", range(12)),
397
+ freeze_bert = data.get("freeze_bert", False),
398
+ model = data.get("model"),
399
+ )
400
+
401
+ # Feature Map Layer
402
+ self.feature_map = nn.Sequential(
403
+ # nn.LayerNorm(self.Encoder.model.config.hidden_size),
404
+ nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
405
+ # nn.Dropout(data.get("drop", 0.5)),
406
+ nn.Linear(self.Encoder.model.config.hidden_size, 200),
407
+ nn.Dropout(data.get("drop", 0.5)),
408
+ )
409
+
410
+ # Classifier Layer
411
+ self.classifier = nn.Sequential(
412
+ # nn.LayerNorm(self.Encoder.model.config.hidden_size),
413
+ # nn.Dropout(data.get("drop", 0.5)),
414
+ # nn.BatchNorm1d(self.Encoder.model.config.hidden_size),
415
+ # nn.Dropout(data.get("drop", 0.5)),
416
+ nn.Tanh(),
417
+ nn.Linear(200, 1)
418
+ )
419
+
420
+ # Initializing layer parameters
421
+ nn.init.normal_(self.feature_map[1].weight, mean=0, std=0.00001)
422
+ nn.init.zeros_(self.feature_map[1].bias)
423
+
424
+ # Feed forward
425
+ def forward(self, input, output=None, attention= False, hidden_states=False):
426
+ # input, output = batch
427
+ input_ids, attention_mask, token_type_ids = input
428
+
429
+ predict = self.Encoder(output_attentions=attention,
430
+ output_hidden_states=hidden_states,
431
+ **{"input_ids":input_ids,
432
+ "attention_mask":attention_mask,
433
+ "token_type_ids":token_type_ids
434
+ })
435
+
436
+ feature_maped = self.feature_map(predict['pooler_output'])
437
+ # print(feature_maped)
438
+ logit = self.classifier(feature_maped)
439
+
440
+ # predict = torch.sigmoid(logit)
441
+
442
+ if not isinstance(output, type(None)):
443
+ # Loss function
444
+ loss = self.loss_fn(logit.to(torch.float), output.to(torch.float))
445
+
446
+ return Output(loss, features(predict, feature_maped), logit)
447
+ else:
448
+ return Output(None, features(predict, feature_maped), logit)
449
+
450
+
451
+
452
+ def fit(self, optimizer, data_train_loader, scheduler = None, data_valid_loader = None,
453
+ epochs = 4, print_info = 1000000000, metrics = True, log = None, metrics_print = True):
454
+
455
+
456
+ return loop.train_loop(self,
457
+ device = self.dummy_param.device,
458
+ optimizer=optimizer,
459
+ scheduler= scheduler,
460
+ data_train_loader=data_train_loader,
461
+ data_valid_loader= data_valid_loader,
462
+ epochs = epochs,
463
+ print_info = print_info,
464
+ metrics = metrics,
465
+ log= log,
466
+ metrics_print=metrics_print)
467
+
468
+ def evaluate(self, loader, attention= False, hidden_states=False):
469
+ # global feature
470
+ all_loss, feature, (logits, outputs) = loop.eval_loop(loader, self, self.dummy_param.device,
471
+ attention= attention, hidden_states=hidden_states)
472
+
473
+
474
+ logits = loop.map_batch(logits)
475
+
476
+ if len(outputs) != 0:
477
+ outputs = loop.map_batch(outputs)
478
+
479
+ return Output(np.mean(all_loss), feature, (logits, outputs))