PereLluis13 commited on
Commit
7d9090e
·
verified ·
1 Parent(s): a5f606b

Upload model

Browse files
Files changed (4) hide show
  1. config.json +24 -0
  2. configuration_relik.py +45 -0
  3. modeling_relik_dev.py +1130 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "experiments/relik-reader-large-lr-0.0001-half/2024-03-02/12-14-56/wandb/latest-run/files/hf_model/relik-large-nyt/",
3
+ "activation": "gelu",
4
+ "add_entity_embedding": null,
5
+ "additional_special_symbols": 24,
6
+ "additional_special_symbols_types": 0,
7
+ "architectures": [
8
+ "RelikReaderREModel"
9
+ ],
10
+ "auto_map": {
11
+ "AutoModel": "modeling_relik_dev.RelikReaderREModel"
12
+ },
13
+ "default_reader_class": null,
14
+ "entity_type_loss": null,
15
+ "linears_hidden_size": 512,
16
+ "model_type": "relik-reader",
17
+ "num_layers": null,
18
+ "threshold": 0.912111759185791,
19
+ "torch_dtype": "float32",
20
+ "training": true,
21
+ "transformer_model": "microsoft/deberta-v3-large",
22
+ "transformers_version": "4.33.3",
23
+ "use_last_k_layers": 1
24
+ }
configuration_relik.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from transformers import AutoConfig
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class RelikReaderConfig(PretrainedConfig):
8
+ model_type = "relik-reader"
9
+
10
+ def __init__(
11
+ self,
12
+ transformer_model: str = "microsoft/deberta-v3-base",
13
+ additional_special_symbols: int = 101,
14
+ additional_special_symbols_types: Optional[int] = 0,
15
+ num_layers: Optional[int] = None,
16
+ activation: str = "gelu",
17
+ linears_hidden_size: Optional[int] = 512,
18
+ use_last_k_layers: int = 1,
19
+ threshold: Optional[float] = 0.5,
20
+ entity_type_loss: bool = False,
21
+ add_entity_embedding: bool = None,
22
+ training: bool = False,
23
+ default_reader_class: Optional[str] = None,
24
+ **kwargs
25
+ ) -> None:
26
+ self.transformer_model = transformer_model
27
+ self.additional_special_symbols = additional_special_symbols
28
+ self.additional_special_symbols_types = additional_special_symbols_types
29
+ self.num_layers = num_layers
30
+ self.activation = activation
31
+ self.linears_hidden_size = linears_hidden_size
32
+ self.use_last_k_layers = use_last_k_layers
33
+ self.entity_type_loss = entity_type_loss
34
+ self.add_entity_embedding = (
35
+ True
36
+ if add_entity_embedding is None and entity_type_loss
37
+ else add_entity_embedding
38
+ )
39
+ self.training = training
40
+ self.threshold = threshold
41
+ self.default_reader_class = default_reader_class
42
+ super().__init__(**kwargs)
43
+
44
+
45
+ AutoConfig.register("relik-reader", RelikReaderConfig)
modeling_relik_dev.py ADDED
@@ -0,0 +1,1130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ from transformers import AutoModel, PreTrainedModel
5
+ from transformers.activations import ClippedGELUActivation, GELUActivation
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_utils import PoolerEndLogits
8
+
9
+ from .configuration_relik import RelikReaderConfig
10
+
11
+ torch.set_float32_matmul_precision('medium')
12
+
13
+ def get_mention_features(
14
+ features: torch.Tensor,
15
+ starts: torch.Tensor,
16
+ ends: torch.Tensor,
17
+ batch_size: int,
18
+ ) -> torch.Tensor:
19
+ # Step 1: Create the mask for the spans
20
+ start_positions = starts.nonzero(as_tuple=True)
21
+ end_positions = ends.nonzero(as_tuple=True)
22
+ averaged_features = []
23
+ for batch_idx, (start, end) in enumerate(zip(start_positions[1], end_positions[1])):
24
+ # Select the batch where the span is located
25
+ batch_id = start_positions[0][batch_idx]
26
+ # Extract features for the span
27
+ span_features = features[batch_id, start:end+1]
28
+ # Compute the average of the features
29
+ span_avg = span_features.mean(dim=0)
30
+ averaged_features.append(span_avg)
31
+
32
+ averaged_features = torch.stack(averaged_features, dim=0)
33
+
34
+ # use torch.nn.utils.rnn.pad_sequence and split to pad and split the features to batch_size
35
+ averaged_features = torch.nn.utils.rnn.pad_sequence(
36
+ torch.split(averaged_features, torch.sum(starts, dim=1), dim=0),
37
+ )
38
+ return averaged_features
39
+
40
+ def custom_margin_ranking_loss(scores, labels, margin=1.0):
41
+ """
42
+ Custom implementation of margin ranking loss for imbalanced positive/negative scores.
43
+
44
+ Parameters:
45
+ - scores: Tensor containing the scores for each example.
46
+ - labels: Tensor containing labels (1 for positive, 0 for negative examples).
47
+ - margin: Desired margin between positive and negative scores.
48
+
49
+ Returns:
50
+ - loss: Computed loss value.
51
+ """
52
+ # Separate scores into positive and negative based on labels
53
+ positive_scores = scores[labels == 1]
54
+ negative_scores = scores[labels == 0]
55
+
56
+ # Compute all pairs of positive-negative scores
57
+ pos_neg_diff = positive_scores.unsqueeze(1) - negative_scores.unsqueeze(0)
58
+
59
+ # Calculate loss for all positive-negative pairs
60
+ loss_components = torch.clamp(margin - pos_neg_diff, min=0)
61
+
62
+ # Average loss across all comparisons
63
+ loss = torch.mean(loss_components)
64
+
65
+ return loss
66
+
67
+ def split_and_process(tensor, projector_start, projector_end):
68
+ # Split the tensor along the last dimension
69
+ half = tensor.shape[-1] // 2
70
+ tensor_start, tensor_end = tensor[..., :half], tensor[..., half:]
71
+
72
+ # Apply the linear layers
73
+ tensor_start_processed = projector_start(tensor_start)
74
+ tensor_end_processed = projector_end(tensor_end)
75
+
76
+ return tensor_start_processed, tensor_end_processed
77
+
78
+ def get_mention_features_vectorized(features, starts, ends, batch_size):
79
+ # Create a range tensor that matches the size of the longest span
80
+ max_length = (ends - starts).max() + 1
81
+ range_tensor = torch.arange(max_length).to(features.device)
82
+
83
+ # Expand starts and range_tensor to compute a mask for each position in each span
84
+ expanded_starts = starts.unsqueeze(-1) # Adding an extra dimension for broadcasting
85
+ expanded_ends = ends.unsqueeze(-1)
86
+ range_mask = (range_tensor < (expanded_ends - expanded_starts + 1))
87
+
88
+ # Use the mask to select features, handling variable-length sequences with padding
89
+ span_lengths = (expanded_ends - expanded_starts).squeeze(-1) + 1
90
+ max_span_length = span_lengths.max()
91
+ padded_features = torch.zeros((batch_size, max_span_length, features.size(-1)), device=features.device)
92
+
93
+ for i in range(batch_size):
94
+ span = features[i, starts[i]:ends[i]+1]
95
+ padded_features[i, :span.size(0)] = span
96
+
97
+ # Compute the mean of features for each span, using the mask for correct averaging
98
+ span_means = (padded_features * range_mask.unsqueeze(-1)).sum(dim=1) / range_mask.sum(dim=1, keepdim=True)
99
+
100
+ return span_means
101
+
102
+ def random_half_tensor_dropout(tensor, dropout_prob=0.5, is_training=True):
103
+ """
104
+ Applies dropout to either the first half or the second half of the tensor with a specified probability.
105
+ Dropout is only applied during training.
106
+
107
+ Args:
108
+ tensor (torch.Tensor): The input tensor.
109
+ dropout_prob (float): The probability of dropping out half of the tensor.
110
+ is_training (bool): If True, apply dropout; if False, do not apply dropout.
111
+
112
+ Returns:
113
+ torch.Tensor: The tensor after applying dropout.
114
+ """
115
+ assert 0 <= dropout_prob <= 1, "Dropout probability must be in the range [0, 1]"
116
+
117
+ if is_training:
118
+ # Size of the last dimension
119
+ last_dim_size = tensor.size(-1)
120
+
121
+ # Calculate the index for splitting the tensor into two halves
122
+ split_index = last_dim_size // 2
123
+
124
+ # Generate a random number and compare it with the dropout probability
125
+ if torch.rand(1).item() < dropout_prob:
126
+ # Randomly choose to drop the first half or the second half
127
+ if torch.rand(1).item() < 0.5:
128
+ # Set the first half to zero
129
+ tensor[..., :split_index] = 0
130
+ else:
131
+ # Set the second half to zero
132
+ tensor[..., split_index:] = 0
133
+
134
+ return tensor
135
+
136
+ class RelikReaderSample:
137
+ def __init__(self, **kwargs):
138
+ super().__setattr__("_d", {})
139
+ self._d = kwargs
140
+
141
+ def __getattribute__(self, item):
142
+ return super(RelikReaderSample, self).__getattribute__(item)
143
+
144
+ def __getattr__(self, item):
145
+ if item.startswith("__") and item.endswith("__"):
146
+ # this is likely some python library-specific variable (such as __deepcopy__ for copy)
147
+ # better follow standard behavior here
148
+ raise AttributeError(item)
149
+ elif item in self._d:
150
+ return self._d[item]
151
+ else:
152
+ return None
153
+
154
+ def __setattr__(self, key, value):
155
+ if key in self._d:
156
+ self._d[key] = value
157
+ else:
158
+ super().__setattr__(key, value)
159
+
160
+
161
+ activation2functions = {
162
+ "relu": torch.nn.ReLU(),
163
+ "gelu": GELUActivation(),
164
+ "gelu_10": ClippedGELUActivation(-10, 10),
165
+ }
166
+
167
+
168
+ class PoolerEndLogitsBi(PoolerEndLogits):
169
+ def __init__(self, config: PretrainedConfig):
170
+ super().__init__(config)
171
+ self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
172
+
173
+ def forward(
174
+ self,
175
+ hidden_states: torch.FloatTensor,
176
+ start_states: Optional[torch.FloatTensor] = None,
177
+ start_positions: Optional[torch.LongTensor] = None,
178
+ p_mask: Optional[torch.FloatTensor] = None,
179
+ ) -> torch.FloatTensor:
180
+ if p_mask is not None:
181
+ p_mask = p_mask.unsqueeze(-1)
182
+ logits = super().forward(
183
+ hidden_states,
184
+ start_states,
185
+ start_positions,
186
+ p_mask,
187
+ )
188
+ return logits
189
+
190
+
191
+ class RelikReaderSpanModel(PreTrainedModel):
192
+ config_class = RelikReaderConfig
193
+
194
+ def __init__(self, config: RelikReaderConfig, *args, **kwargs):
195
+ super().__init__(config)
196
+ # Transformer model declaration
197
+ self.config = config
198
+ self.transformer_model = (
199
+ AutoModel.from_pretrained(self.config.transformer_model)
200
+ if self.config.num_layers is None
201
+ else AutoModel.from_pretrained(
202
+ self.config.transformer_model, num_hidden_layers=self.config.num_layers
203
+ )
204
+ )
205
+ self.transformer_model.resize_token_embeddings(
206
+ self.transformer_model.config.vocab_size
207
+ + self.config.additional_special_symbols,
208
+ pad_to_multiple_of=8,
209
+ )
210
+
211
+ self.activation = self.config.activation
212
+ self.linears_hidden_size = self.config.linears_hidden_size
213
+ self.use_last_k_layers = self.config.use_last_k_layers
214
+
215
+ # named entity detection layers
216
+ self.ned_start_classifier = self._get_projection_layer(
217
+ self.activation, last_hidden=2, layer_norm=False
218
+ )
219
+ self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
220
+
221
+ # END entity disambiguation layer
222
+ self.ed_projector = self._get_projection_layer(self.activation, last_hidden = 2*self.linears_hidden_size, hidden=2*self.linears_hidden_size)
223
+
224
+ self.training = self.config.training
225
+
226
+ # criterion
227
+ self.criterion = torch.nn.CrossEntropyLoss()
228
+
229
+ def _get_projection_layer(
230
+ self,
231
+ activation: str,
232
+ last_hidden: Optional[int] = None,
233
+ hidden: Optional[int] = None,
234
+ input_hidden=None,
235
+ layer_norm: bool = True,
236
+ ) -> torch.nn.Sequential:
237
+ head_components = [
238
+ torch.nn.Dropout(0.1),
239
+ torch.nn.Linear(
240
+ (
241
+ self.transformer_model.config.hidden_size * self.use_last_k_layers
242
+ if input_hidden is None
243
+ else input_hidden
244
+ ),
245
+ self.linears_hidden_size if hidden is None else hidden,
246
+ ),
247
+ activation2functions[activation],
248
+ torch.nn.Dropout(0.1),
249
+ torch.nn.Linear(
250
+ self.linears_hidden_size if hidden is None else hidden,
251
+ self.linears_hidden_size if last_hidden is None else last_hidden,
252
+ ),
253
+ ]
254
+
255
+ if layer_norm:
256
+ head_components.append(
257
+ torch.nn.LayerNorm(
258
+ self.linears_hidden_size if last_hidden is None else last_hidden,
259
+ self.transformer_model.config.layer_norm_eps,
260
+ )
261
+ )
262
+
263
+ return torch.nn.Sequential(*head_components)
264
+
265
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
266
+ mask = mask.unsqueeze(-1)
267
+ if next(self.parameters()).dtype == torch.float16:
268
+ logits = logits * (1 - mask) - 65500 * mask
269
+ else:
270
+ logits = logits * (1 - mask) - 1e30 * mask
271
+ return logits
272
+
273
+ def _get_model_features(
274
+ self,
275
+ input_ids: torch.Tensor,
276
+ attention_mask: torch.Tensor,
277
+ token_type_ids: Optional[torch.Tensor],
278
+ ):
279
+ model_input = {
280
+ "input_ids": input_ids,
281
+ "attention_mask": attention_mask,
282
+ "output_hidden_states": self.use_last_k_layers > 1,
283
+ }
284
+
285
+ if token_type_ids is not None:
286
+ model_input["token_type_ids"] = token_type_ids
287
+
288
+ model_output = self.transformer_model(**model_input)
289
+
290
+ if self.use_last_k_layers > 1:
291
+ model_features = torch.cat(
292
+ model_output[1][-self.use_last_k_layers :], dim=-1
293
+ )
294
+ else:
295
+ model_features = model_output[0]
296
+
297
+ return model_features
298
+
299
+ def compute_ned_end_logits(
300
+ self,
301
+ start_predictions,
302
+ start_labels,
303
+ model_features,
304
+ prediction_mask,
305
+ batch_size,
306
+ ) -> Optional[torch.Tensor]:
307
+ # todo: maybe when constraining on the spans,
308
+ # we should not use a prediction_mask for the end tokens.
309
+ # at least we should not during training imo
310
+ start_positions = start_labels if self.training else start_predictions
311
+ start_positions_indices = (
312
+ torch.arange(start_positions.size(1), device=start_positions.device)
313
+ .unsqueeze(0)
314
+ .expand(batch_size, -1)[start_positions > 0]
315
+ ).to(start_positions.device)
316
+
317
+ if len(start_positions_indices) > 0:
318
+ expanded_features = model_features.repeat_interleave(
319
+ torch.sum(start_positions > 0, dim=-1), dim=0
320
+ )
321
+ expanded_prediction_mask = prediction_mask.repeat_interleave(
322
+ torch.sum(start_positions > 0, dim=-1), dim=0
323
+ )
324
+ end_logits = self.ned_end_classifier(
325
+ hidden_states=expanded_features,
326
+ start_positions=start_positions_indices,
327
+ p_mask=expanded_prediction_mask,
328
+ )
329
+
330
+ return end_logits
331
+
332
+ return None
333
+
334
+ def compute_classification_logits(
335
+ self,
336
+ model_features,
337
+ special_symbols_mask,
338
+ prediction_mask,
339
+ batch_size,
340
+ start_positions=None,
341
+ end_positions=None,
342
+ attention_mask=None,
343
+ ) -> torch.Tensor:
344
+ if start_positions is None or end_positions is None:
345
+ start_positions = torch.zeros_like(prediction_mask)
346
+ end_positions = torch.zeros_like(prediction_mask)
347
+
348
+ model_ed_features = self.ed_projector(model_features)
349
+
350
+ model_ed_features[start_positions > 0][:, model_ed_features.shape[-1] // 2:] = model_ed_features[end_positions > 0][
351
+ :, :model_ed_features.shape[-1] // 2
352
+ ]
353
+
354
+ # computing ed features
355
+ classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
356
+ special_symbols_mask_start = special_symbols_mask.roll(1, 1)
357
+ special_symbols_mask_start[:, :2] = torch.tensor([True, False], device=special_symbols_mask.device).expand_as(
358
+ special_symbols_mask_start[:, :2]
359
+ )
360
+
361
+ special_symbols_mask_end = special_symbols_mask.roll(-1, 1)
362
+ cumsum = special_symbols_mask_end.cumsum(dim=1)
363
+ # Identify the second True in each row (where cumulative sum equals 2)
364
+ special_symbols_mask_end[cumsum == 2] = False
365
+ special_symbols_mask_end[:, [0, -1]] = torch.tensor([True, False], device=special_symbols_mask.device).expand_as(
366
+ special_symbols_mask_end[:, [0, -1]]
367
+ )
368
+ # first padding token in
369
+ last_token_ent = attention_mask.sum(1) - 2
370
+ special_symbols_mask_end[torch.arange(special_symbols_mask_end.shape[0], device=special_symbols_mask_end.device), last_token_ent] = True
371
+
372
+
373
+ special_symbols_representation_start = model_ed_features[special_symbols_mask_start][:,:model_ed_features.shape[-1] // 2].view(
374
+ batch_size, classes_representations, -1
375
+ )
376
+ special_symbols_representation_end = model_ed_features[special_symbols_mask_end][:,model_ed_features.shape[-1] // 2:].view(
377
+ batch_size, classes_representations, -1
378
+ )
379
+ # special_symbols_representation_start = self.ed_special_tokens_projector_start(special_symbols_representation_start)
380
+ # special_symbols_representation_end = self.ed_special_tokens_projector_end(special_symbols_representation_end)
381
+
382
+ special_symbols_representation = torch.cat(
383
+ [special_symbols_representation_start, special_symbols_representation_end, special_symbols_representation_end, special_symbols_representation_start], dim=-1
384
+ )
385
+ model_ed_features = torch.cat(
386
+ [model_ed_features, model_ed_features], dim=-1
387
+ )
388
+
389
+ logits = torch.bmm(
390
+ model_ed_features,
391
+ torch.permute(special_symbols_representation, (0, 2, 1)),
392
+ )
393
+
394
+ logits = self._mask_logits(logits, prediction_mask)
395
+
396
+ return logits
397
+
398
+ def forward(
399
+ self,
400
+ input_ids: torch.Tensor,
401
+ attention_mask: torch.Tensor,
402
+ token_type_ids: Optional[torch.Tensor] = None,
403
+ prediction_mask: Optional[torch.Tensor] = None,
404
+ special_symbols_mask: Optional[torch.Tensor] = None,
405
+ start_labels: Optional[torch.Tensor] = None,
406
+ end_labels: Optional[torch.Tensor] = None,
407
+ use_predefined_spans: bool = False,
408
+ *args,
409
+ **kwargs,
410
+ ) -> Dict[str, Any]:
411
+ batch_size, seq_len = input_ids.shape
412
+
413
+ model_features = self._get_model_features(
414
+ input_ids, attention_mask, token_type_ids
415
+ )
416
+
417
+ ned_start_labels = None
418
+
419
+ # named entity detection if required
420
+ if use_predefined_spans: # no need to compute spans
421
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
422
+ None,
423
+ None,
424
+ torch.clone(start_labels)
425
+ if start_labels is not None
426
+ else torch.zeros_like(input_ids),
427
+ )
428
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
429
+ None,
430
+ None,
431
+ torch.clone(end_labels)
432
+ if end_labels is not None
433
+ else torch.zeros_like(input_ids),
434
+ )
435
+
436
+ ned_start_predictions[ned_start_predictions > 0] = 1
437
+ ned_end_predictions[ned_end_predictions > 0] = 1
438
+
439
+ else: # compute spans
440
+ # start boundary prediction
441
+ ned_start_logits = self.ned_start_classifier(model_features)
442
+ ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
443
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
444
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
445
+
446
+ # end boundary prediction
447
+ ned_start_labels = (
448
+ torch.zeros_like(start_labels) if start_labels is not None else None
449
+ )
450
+
451
+ if ned_start_labels is not None:
452
+ ned_start_labels[start_labels == -100] = -100
453
+ ned_start_labels[start_labels > 0] = 1
454
+
455
+ ned_end_logits = self.compute_ned_end_logits(
456
+ ned_start_predictions,
457
+ ned_start_labels,
458
+ model_features,
459
+ prediction_mask,
460
+ batch_size,
461
+ )
462
+
463
+ if ned_end_logits is not None:
464
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
465
+ ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
466
+ else:
467
+ ned_end_logits, ned_end_probabilities = None, None
468
+ ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
469
+
470
+ # flattening end predictions
471
+ # (flattening can happen only if the
472
+ # end boundaries were not predicted using the gold labels)
473
+ if not self.training and ned_end_logits is not None:
474
+ flattened_end_predictions = torch.zeros_like(ned_start_predictions)
475
+
476
+ row_indices, start_positions = torch.where(ned_start_predictions > 0)
477
+ ned_end_predictions[ned_end_predictions<start_positions] = start_positions[ned_end_predictions<start_positions]
478
+
479
+ end_spans_repeated = (row_indices + 1)* seq_len + ned_end_predictions
480
+ cummax_values, _ = end_spans_repeated.cummax(dim=0)
481
+
482
+ end_spans_repeated = (end_spans_repeated > torch.cat((end_spans_repeated[:1], cummax_values[:-1])))
483
+ end_spans_repeated[0] = True
484
+
485
+ ned_start_predictions[row_indices[~end_spans_repeated], start_positions[~end_spans_repeated]] = 0
486
+
487
+ row_indices, start_positions, ned_end_predictions = row_indices[end_spans_repeated], start_positions[end_spans_repeated], ned_end_predictions[end_spans_repeated]
488
+
489
+ flattened_end_predictions[row_indices, ned_end_predictions] = 1
490
+
491
+ total_start_predictions, total_end_predictions = ned_start_predictions.sum(), flattened_end_predictions.sum()
492
+
493
+ assert (
494
+ total_start_predictions == 0
495
+ or total_start_predictions == total_end_predictions
496
+ ), (
497
+ f"Total number of start predictions = {total_start_predictions}. "
498
+ f"Total number of end predictions = {total_end_predictions}"
499
+ )
500
+ ned_end_predictions = flattened_end_predictions
501
+ else:
502
+ ned_end_predictions = torch.zeros_like(ned_start_predictions)
503
+
504
+ start_position, end_position = (
505
+ (start_labels, end_labels)
506
+ if self.training
507
+ else (ned_start_predictions, ned_end_predictions)
508
+ )
509
+
510
+ # Entity disambiguation
511
+ ed_logits = self.compute_classification_logits(
512
+ model_features,
513
+ special_symbols_mask,
514
+ prediction_mask,
515
+ batch_size,
516
+ start_position,
517
+ end_position,
518
+ attention_mask,
519
+ )
520
+ ed_probabilities = torch.softmax(ed_logits, dim=-1)
521
+ ed_predictions = torch.argmax(ed_probabilities, dim=-1)
522
+
523
+ # output build
524
+ output_dict = dict(
525
+ batch_size=batch_size,
526
+ ned_start_logits=ned_start_logits,
527
+ ned_start_probabilities=ned_start_probabilities,
528
+ ned_start_predictions=ned_start_predictions,
529
+ ned_end_logits=ned_end_logits,
530
+ ned_end_probabilities=ned_end_probabilities,
531
+ ned_end_predictions=ned_end_predictions,
532
+ ed_logits=ed_logits,
533
+ ed_probabilities=ed_probabilities,
534
+ ed_predictions=ed_predictions,
535
+ )
536
+
537
+ # compute loss if labels
538
+ if start_labels is not None and end_labels is not None and self.training:
539
+ # named entity detection loss
540
+
541
+ # start
542
+ if ned_start_logits is not None:
543
+ ned_start_loss = self.criterion(
544
+ ned_start_logits.view(-1, ned_start_logits.shape[-1]),
545
+ ned_start_labels.view(-1),
546
+ )
547
+ else:
548
+ ned_start_loss = 0
549
+
550
+ # end
551
+ if ned_end_logits is not None:
552
+ ned_end_labels = torch.zeros_like(end_labels)
553
+ ned_end_labels[end_labels == -100] = -100
554
+ ned_end_labels[end_labels > 0] = 1
555
+
556
+ ned_end_loss = self.criterion(
557
+ ned_end_logits,
558
+ (
559
+ torch.arange(
560
+ ned_end_labels.size(1), device=ned_end_labels.device
561
+ )
562
+ .unsqueeze(0)
563
+ .expand(batch_size, -1)[ned_end_labels > 0]
564
+ ).to(ned_end_labels.device),
565
+ )
566
+
567
+ else:
568
+ ned_end_loss = 0
569
+
570
+ # entity disambiguation loss
571
+ start_labels[ned_start_labels != 1] = -100
572
+ ed_labels = torch.clone(start_labels)
573
+ ed_labels[end_labels > 0] = end_labels[end_labels > 0]
574
+ ed_loss = self.criterion(
575
+ ed_logits.view(-1, ed_logits.shape[-1]),
576
+ ed_labels.view(-1),
577
+ )
578
+
579
+ output_dict["ned_start_loss"] = ned_start_loss
580
+ output_dict["ned_end_loss"] = ned_end_loss
581
+ output_dict["ed_loss"] = ed_loss
582
+
583
+ output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
584
+
585
+ return output_dict
586
+
587
+
588
+ class RelikReaderREModel(PreTrainedModel):
589
+ config_class = RelikReaderConfig
590
+
591
+ def __init__(self, config, *args, **kwargs):
592
+ super().__init__(config)
593
+ # Transformer model declaration
594
+ # self.transformer_model_name = transformer_model
595
+ self.config = config
596
+ self.transformer_model = (
597
+ AutoModel.from_pretrained(config.transformer_model)
598
+ if config.num_layers is None
599
+ else AutoModel.from_pretrained(
600
+ config.transformer_model, num_hidden_layers=config.num_layers
601
+ )
602
+ )
603
+ self.transformer_model.resize_token_embeddings(
604
+ self.transformer_model.config.vocab_size
605
+ + config.additional_special_symbols
606
+ + config.additional_special_symbols_types,
607
+ pad_to_multiple_of=8,
608
+ )
609
+
610
+ # named entity detection layers
611
+ self.ned_start_classifier = self._get_projection_layer(
612
+ config.activation, last_hidden=2, layer_norm=False
613
+ )
614
+
615
+ self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
616
+
617
+ self.relation_disambiguation_loss = (
618
+ config.relation_disambiguation_loss
619
+ if hasattr(config, "relation_disambiguation_loss")
620
+ else False
621
+ )
622
+
623
+ if self.config.entity_type_loss and self.config.add_entity_embedding:
624
+ input_hidden_ents = 3 * self.config.linears_hidden_size
625
+ else:
626
+ input_hidden_ents = 2 * self.config.linears_hidden_size
627
+
628
+ self.re_projector = self._get_projection_layer(
629
+ config.activation, input_hidden=2*self.transformer_model.config.hidden_size, hidden=input_hidden_ents, last_hidden=2*self.config.linears_hidden_size
630
+ )
631
+
632
+ self.re_relation_projector = self._get_projection_layer(
633
+ config.activation, input_hidden=self.transformer_model.config.hidden_size,
634
+ )
635
+
636
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
637
+ self.re_entities_projector = self._get_projection_layer(
638
+ config.activation,
639
+ input_hidden=2 * self.transformer_model.config.hidden_size,
640
+ )
641
+ self.re_definition_projector = self._get_projection_layer(
642
+ config.activation,
643
+ )
644
+
645
+ self.re_classifier = self._get_projection_layer(
646
+ config.activation,
647
+ input_hidden=config.linears_hidden_size,
648
+ last_hidden=2,
649
+ layer_norm=False,
650
+ )
651
+
652
+ self.training = config.training
653
+
654
+ # criterion
655
+ self.criterion = torch.nn.CrossEntropyLoss()
656
+ self.criterion_type = torch.nn.BCEWithLogitsLoss()
657
+
658
+ def _get_projection_layer(
659
+ self,
660
+ activation: str,
661
+ last_hidden: Optional[int] = None,
662
+ hidden: Optional[int] = None,
663
+ input_hidden=None,
664
+ layer_norm: bool = True,
665
+ ) -> torch.nn.Sequential:
666
+ head_components = [
667
+ torch.nn.Dropout(0.1),
668
+ torch.nn.Linear(
669
+ (
670
+ self.transformer_model.config.hidden_size * self.config.use_last_k_layers
671
+ if input_hidden is None
672
+ else input_hidden
673
+ ),
674
+ self.config.linears_hidden_size if hidden is None else hidden,
675
+ ),
676
+ activation2functions[activation],
677
+ torch.nn.Dropout(0.1),
678
+ torch.nn.Linear(
679
+ self.config.linears_hidden_size if hidden is None else hidden,
680
+ self.config.linears_hidden_size if last_hidden is None else last_hidden,
681
+ ),
682
+ ]
683
+
684
+ if layer_norm:
685
+ head_components.append(
686
+ torch.nn.LayerNorm(
687
+ self.config.linears_hidden_size if last_hidden is None else last_hidden,
688
+ self.transformer_model.config.layer_norm_eps,
689
+ )
690
+ )
691
+
692
+ return torch.nn.Sequential(*head_components)
693
+
694
+ def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
695
+ mask = mask.unsqueeze(-1)
696
+ if next(self.parameters()).dtype == torch.float16:
697
+ logits = logits * (1 - mask) - 65500 * mask
698
+ else:
699
+ logits = logits * (1 - mask) - 1e30 * mask
700
+ return logits
701
+
702
+ def _get_model_features(
703
+ self,
704
+ input_ids: torch.Tensor,
705
+ attention_mask: torch.Tensor,
706
+ token_type_ids: Optional[torch.Tensor],
707
+ ):
708
+ model_input = {
709
+ "input_ids": input_ids,
710
+ "attention_mask": attention_mask,
711
+ "output_hidden_states": self.config.use_last_k_layers > 1,
712
+ }
713
+
714
+ if token_type_ids is not None:
715
+ model_input["token_type_ids"] = token_type_ids
716
+
717
+ model_output = self.transformer_model(**model_input)
718
+
719
+ if self.config.use_last_k_layers > 1:
720
+ model_features = torch.cat(
721
+ model_output[1][-self.config.use_last_k_layers :], dim=-1
722
+ )
723
+ else:
724
+ model_features = model_output[0]
725
+
726
+ return model_features
727
+
728
+ def compute_ned_end_logits(
729
+ self,
730
+ start_predictions,
731
+ start_labels,
732
+ model_features,
733
+ prediction_mask,
734
+ batch_size,
735
+ mask_preceding: bool = False,
736
+ ) -> Optional[torch.Tensor]:
737
+ # todo: maybe when constraining on the spans,
738
+ # we should not use a prediction_mask for the end tokens.
739
+ # at least we should not during training imo
740
+ start_positions = start_labels if self.training else start_predictions
741
+ start_positions_indices = (
742
+ torch.arange(start_positions.size(1), device=start_positions.device)
743
+ .unsqueeze(0)
744
+ .expand(batch_size, -1)[start_positions > 0]
745
+ ).to(start_positions.device)
746
+
747
+ if len(start_positions_indices) > 0:
748
+ expanded_features = model_features.repeat_interleave(
749
+ torch.sum(start_positions > 0, dim=-1), dim=0
750
+ )
751
+ expanded_prediction_mask = prediction_mask.repeat_interleave(
752
+ torch.sum(start_positions > 0, dim=-1), dim=0
753
+ )
754
+ if mask_preceding:
755
+ expanded_prediction_mask[
756
+ torch.arange(
757
+ expanded_prediction_mask.shape[1],
758
+ device=expanded_prediction_mask.device,
759
+ )
760
+ < start_positions_indices.unsqueeze(1)
761
+ ] = 1
762
+ end_logits = self.ned_end_classifier(
763
+ hidden_states=expanded_features,
764
+ start_positions=start_positions_indices,
765
+ p_mask=expanded_prediction_mask,
766
+ )
767
+
768
+ return end_logits
769
+
770
+ return None
771
+
772
+ def compute_relation_logits(
773
+ self,
774
+ model_entity_features,
775
+ special_symbols_features,
776
+ ) -> torch.Tensor:
777
+ model_subject_object_features = self.re_projector(model_entity_features)
778
+ model_subject_features = model_subject_object_features[
779
+ :, :, : model_subject_object_features.shape[-1] // 2
780
+ ]
781
+ model_object_features = model_subject_object_features[
782
+ :, :, model_subject_object_features.shape[-1] // 2 :
783
+ ]
784
+ special_symbols_start_representation = self.re_relation_projector(
785
+ special_symbols_features
786
+ )
787
+ re_logits = torch.einsum(
788
+ "bse,bde,bfe->bsdfe",
789
+ model_subject_features,
790
+ model_object_features,
791
+ special_symbols_start_representation,
792
+ )
793
+ re_logits = self.re_classifier(re_logits)
794
+
795
+ return re_logits
796
+
797
+ def compute_entity_logits(
798
+ self,
799
+ model_entity_features,
800
+ special_symbols_features,
801
+ ) -> torch.Tensor:
802
+ model_ed_features = self.re_entities_projector(model_entity_features)
803
+ special_symbols_ed_representation = self.re_definition_projector(
804
+ special_symbols_features
805
+ )
806
+
807
+ logits = torch.bmm(
808
+ model_ed_features,
809
+ torch.permute(special_symbols_ed_representation, (0, 2, 1)),
810
+ )
811
+ logits = self._mask_logits(
812
+ logits, (model_entity_features == -100).all(2).long()
813
+ )
814
+ return logits
815
+
816
+ def compute_loss(self, logits, labels, mask=None):
817
+ logits = logits.reshape(-1, logits.shape[-1])
818
+ labels = labels.reshape(-1).long()
819
+ if mask is not None:
820
+ return self.criterion(logits[mask], labels[mask])
821
+ return self.criterion(logits, labels)
822
+
823
+ def compute_ned_type_loss(
824
+ self,
825
+ disambiguation_labels,
826
+ re_ned_entities_logits,
827
+ ned_type_logits,
828
+ re_entities_logits,
829
+ entity_types,
830
+ mask,
831
+ ):
832
+ if self.config.entity_type_loss and self.relation_disambiguation_loss:
833
+ return self.criterion_type(
834
+ re_ned_entities_logits[disambiguation_labels != -100],
835
+ disambiguation_labels[disambiguation_labels != -100],
836
+ )
837
+ if self.config.entity_type_loss:
838
+ return self.criterion_type(
839
+ ned_type_logits[mask],
840
+ disambiguation_labels[:, :, :entity_types][mask],
841
+ )
842
+
843
+ if self.relation_disambiguation_loss:
844
+ return self.criterion_type(
845
+ re_entities_logits[disambiguation_labels != -100],
846
+ disambiguation_labels[disambiguation_labels != -100],
847
+ )
848
+ return 0
849
+
850
+ def compute_relation_loss(self, relation_labels, re_logits):
851
+ return self.compute_loss(
852
+ re_logits, relation_labels, relation_labels.view(-1) != -100
853
+ )
854
+
855
+ def forward(
856
+ self,
857
+ input_ids: torch.Tensor,
858
+ attention_mask: torch.Tensor,
859
+ token_type_ids: torch.Tensor,
860
+ prediction_mask: Optional[torch.Tensor] = None,
861
+ special_symbols_mask: Optional[torch.Tensor] = None,
862
+ special_symbols_mask_entities: Optional[torch.Tensor] = None,
863
+ start_labels: Optional[torch.Tensor] = None,
864
+ end_labels: Optional[torch.Tensor] = None,
865
+ disambiguation_labels: Optional[torch.Tensor] = None,
866
+ relation_labels: Optional[torch.Tensor] = None,
867
+ relation_threshold: float = None,
868
+ is_validation: bool = False,
869
+ is_prediction: bool = False,
870
+ use_predefined_spans: bool = False,
871
+ *args,
872
+ **kwargs,
873
+ ) -> Dict[str, Any]:
874
+
875
+ thresshold = self.config.threshold if relation_threshold is None else relation_threshold
876
+
877
+ batch_size = input_ids.shape[0]
878
+
879
+ model_features = self._get_model_features(
880
+ input_ids, attention_mask, token_type_ids
881
+ )
882
+
883
+ # named entity detection
884
+ if use_predefined_spans:
885
+ ned_start_logits, ned_start_probabilities, ned_start_predictions = (
886
+ None,
887
+ None,
888
+ torch.zeros_like(start_labels),
889
+ )
890
+ ned_end_logits, ned_end_probabilities, ned_end_predictions = (
891
+ None,
892
+ None,
893
+ torch.zeros_like(end_labels),
894
+ )
895
+
896
+ ned_start_predictions[start_labels > 0] = 1
897
+ ned_end_predictions[end_labels > 0] = 1
898
+ ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
899
+ ned_start_labels = start_labels
900
+ ned_start_labels[start_labels > 0] = 1
901
+ else:
902
+ # start boundary prediction
903
+ ned_start_logits = self.ned_start_classifier(model_features)
904
+ if is_validation or is_prediction:
905
+ ned_start_logits = self._mask_logits(
906
+ ned_start_logits, prediction_mask
907
+ ) # why?
908
+ ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
909
+ ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
910
+
911
+ # end boundary prediction
912
+ ned_start_labels = (
913
+ torch.zeros_like(start_labels) if start_labels is not None else None
914
+ )
915
+
916
+ # start_labels contain entity id at their position, we just need 1 for start of entity
917
+ if ned_start_labels is not None:
918
+ ned_start_labels[start_labels == -100] = -100
919
+ ned_start_labels[start_labels > 0] = 1
920
+
921
+ # compute end logits only if there are any start predictions.
922
+ # For each start prediction, n end predictions are made
923
+ ned_end_logits = self.compute_ned_end_logits(
924
+ ned_start_predictions,
925
+ ned_start_labels,
926
+ model_features,
927
+ prediction_mask,
928
+ batch_size,
929
+ True,
930
+ )
931
+
932
+ if ned_end_logits is not None:
933
+ # For each start prediction, n end predictions are made based on
934
+ # binary classification ie. argmax at each position.
935
+ ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
936
+ ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
937
+ else:
938
+ ned_end_logits, ned_end_probabilities = None, None
939
+ ned_end_predictions = torch.zeros_like(ned_start_predictions)
940
+
941
+ if is_prediction or is_validation:
942
+ end_preds_count = ned_end_predictions.sum(1)
943
+ # If there are no end predictions for a start prediction, remove the start prediction
944
+ if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
945
+ ned_start_predictions[ned_start_predictions == 1] = (
946
+ end_preds_count != 0
947
+ ).long()
948
+ ned_end_predictions = ned_end_predictions[end_preds_count != 0]
949
+
950
+ if end_labels is not None:
951
+ end_labels = end_labels[~(end_labels == -100).all(2)]
952
+
953
+ start_position, end_position = (
954
+ (start_labels, end_labels)
955
+ if (not is_prediction and not is_validation)
956
+ else (ned_start_predictions, ned_end_predictions)
957
+ )
958
+
959
+ start_counts = (start_position > 0).sum(1)
960
+ if (start_counts > 0).any():
961
+ ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
962
+ # limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
963
+ # if is_validation or is_prediction:
964
+ # ned_start_predictions[ned_start_predictions == 1] = start_counts
965
+ # We can only predict relations if we have start and end predictions
966
+ if (end_position > 0).sum() > 0:
967
+ ends_count = (end_position > 0).sum(1)
968
+ model_subject_features = torch.cat(
969
+ [
970
+ torch.repeat_interleave(
971
+ model_features[start_position > 0], ends_count, dim=0
972
+ ), # start position features
973
+ torch.repeat_interleave(model_features, start_counts, dim=0)[
974
+ end_position > 0
975
+ ], # end position features
976
+ ],
977
+ dim=-1,
978
+ )
979
+ ents_count = torch.nn.utils.rnn.pad_sequence(
980
+ torch.split(ends_count, start_counts.tolist()),
981
+ batch_first=True,
982
+ padding_value=0,
983
+ ).sum(1)
984
+ model_subject_features = torch.nn.utils.rnn.pad_sequence(
985
+ torch.split(model_subject_features, ents_count.tolist()),
986
+ batch_first=True,
987
+ padding_value=-100,
988
+ )
989
+
990
+ # if is_validation or is_prediction:
991
+ # model_subject_features = model_subject_features[:, :30, :]
992
+
993
+ # entity disambiguation. Here relation_disambiguation_loss would only be useful to
994
+ # reduce the number of candidate relations for the next step, but currently unused.
995
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
996
+ (re_ned_entities_logits) = self.compute_entity_logits(
997
+ model_subject_features,
998
+ model_features[
999
+ special_symbols_mask | special_symbols_mask_entities
1000
+ ].view(batch_size, -1, model_features.shape[-1]),
1001
+ )
1002
+ entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
1003
+ ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
1004
+ re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
1005
+
1006
+ if self.config.entity_type_loss:
1007
+ ned_type_probabilities = torch.sigmoid(ned_type_logits)
1008
+ ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
1009
+
1010
+ if self.config.add_entity_embedding:
1011
+ special_symbols_representation = model_features[
1012
+ special_symbols_mask_entities
1013
+ ].view(batch_size, entity_types, -1)
1014
+
1015
+ entities_representation = torch.einsum(
1016
+ "bsp,bpe->bse",
1017
+ ned_type_probabilities,
1018
+ special_symbols_representation,
1019
+ )
1020
+ model_subject_features = torch.cat(
1021
+ [model_subject_features, entities_representation], dim=-1
1022
+ )
1023
+ re_entities_probabilities = torch.sigmoid(re_entities_logits)
1024
+ re_entities_predictions = re_entities_probabilities.round()
1025
+ else:
1026
+ (
1027
+ ned_type_logits,
1028
+ ned_type_probabilities,
1029
+ re_entities_logits,
1030
+ re_entities_probabilities,
1031
+ ) = (None, None, None, None)
1032
+ ned_type_predictions, re_entities_predictions = (
1033
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
1034
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
1035
+ )
1036
+
1037
+ # Compute relation logits
1038
+ re_logits = self.compute_relation_logits(
1039
+ model_subject_features,
1040
+ model_features[special_symbols_mask].view(
1041
+ batch_size, -1, model_features.shape[-1]
1042
+ ),
1043
+ )
1044
+
1045
+ re_probabilities = torch.softmax(re_logits, dim=-1)
1046
+ # we set a thresshold instead of argmax in cause it needs to be tweaked
1047
+ re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
1048
+ re_probabilities = re_probabilities[:, :, :, :, 1]
1049
+
1050
+ else:
1051
+ (
1052
+ ned_type_logits,
1053
+ ned_type_probabilities,
1054
+ re_entities_logits,
1055
+ re_entities_probabilities,
1056
+ ) = (None, None, None, None)
1057
+ ned_type_predictions, re_entities_predictions = (
1058
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
1059
+ torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
1060
+ )
1061
+ re_logits, re_probabilities, re_predictions = (
1062
+ torch.zeros(
1063
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
1064
+ ).to(input_ids.device),
1065
+ torch.zeros(
1066
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
1067
+ ).to(input_ids.device),
1068
+ torch.zeros(
1069
+ [batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
1070
+ ).to(input_ids.device),
1071
+ )
1072
+
1073
+ # output build
1074
+ output_dict = dict(
1075
+ batch_size=batch_size,
1076
+ ned_start_logits=ned_start_logits,
1077
+ ned_start_probabilities=ned_start_probabilities,
1078
+ ned_start_predictions=ned_start_predictions,
1079
+ ned_end_logits=ned_end_logits,
1080
+ ned_end_probabilities=ned_end_probabilities,
1081
+ ned_end_predictions=ned_end_predictions,
1082
+ ned_type_logits=ned_type_logits,
1083
+ ned_type_probabilities=ned_type_probabilities,
1084
+ ned_type_predictions=ned_type_predictions,
1085
+ re_entities_logits=re_entities_logits,
1086
+ re_entities_probabilities=re_entities_probabilities,
1087
+ re_entities_predictions=re_entities_predictions,
1088
+ re_logits=re_logits,
1089
+ re_probabilities=re_probabilities,
1090
+ re_predictions=re_predictions,
1091
+ )
1092
+
1093
+ if (
1094
+ start_labels is not None
1095
+ and end_labels is not None
1096
+ and relation_labels is not None
1097
+ and is_prediction is False
1098
+ ):
1099
+ ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
1100
+ end_labels[end_labels > 0] = 1
1101
+ ned_end_loss = self.compute_loss(ned_end_logits, end_labels)
1102
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
1103
+ ned_type_loss = self.compute_ned_type_loss(
1104
+ disambiguation_labels,
1105
+ re_ned_entities_logits,
1106
+ ned_type_logits,
1107
+ re_entities_logits,
1108
+ entity_types,
1109
+ (model_subject_features != -100).all(2),
1110
+ )
1111
+ relation_loss = self.compute_relation_loss(relation_labels, re_logits)
1112
+ # compute loss. We can skip the relation loss if we are in the first epochs (optional)
1113
+ if self.config.entity_type_loss or self.relation_disambiguation_loss:
1114
+ output_dict["loss"] = (
1115
+ ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
1116
+ ) / 4
1117
+ output_dict["ned_type_loss"] = ned_type_loss
1118
+ else:
1119
+ # output_dict["loss"] = ((1 / 4) * (ned_start_loss + ned_end_loss)) + (
1120
+ # (1 / 2) * relation_loss
1121
+ # )
1122
+ output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + (
1123
+ (9 / 10) * relation_loss
1124
+ )
1125
+
1126
+ output_dict["ned_start_loss"] = ned_start_loss
1127
+ output_dict["ned_end_loss"] = ned_end_loss
1128
+ output_dict["re_loss"] = relation_loss
1129
+
1130
+ return output_dict
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dceb2b0ba198d2bb49e448cd48906dae7a34b00d6fea56c914820bdc895b03bd
3
+ size 1763625485