|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel |
|
from .configuration_backpack_gpt2_nli import BackpackGPT2NLIConfig |
|
from .modeling_backpack_gpt2 import BackpackGPT2Model |
|
|
|
|
|
class BackpackGPT2NLIModel(GPT2PreTrainedModel): |
|
config_class = BackpackGPT2NLIConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.backpack = BackpackGPT2Model(config) |
|
|
|
self.n_embd = config.n_embd |
|
|
|
self.num_labels = config.num_labels |
|
|
|
self.nli_head = nn.Sequential( |
|
nn.Linear(self.n_embd, self.n_embd), |
|
nn.Dropout(0.1), |
|
nn.Linear(self.n_embd, self.num_labels) |
|
) |
|
|
|
|
|
self.backpack.requires_grad_(not config.freeze_backpack) |
|
|
|
self.loss_func = nn.CrossEntropyLoss() |
|
|
|
|
|
self.model_parallel = False |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None): |
|
|
|
backpack_outputs = self.backpack(input_ids=input_ids, position_ids=None) |
|
|
|
backpack_hidden_states, backpack_contextualization = backpack_outputs.hidden_states, backpack_outputs.contextualization |
|
last_toks_indices = attention_mask.shape[1] - 1 - attention_mask.flip((1,)).argmax(dim=1) |
|
last_backpack_hidden_states = backpack_hidden_states[torch.arange(backpack_hidden_states.shape[0]), last_toks_indices, :] |
|
|
|
logits = self.nli_head(last_backpack_hidden_states) |
|
|
|
if labels is not None: |
|
|
|
flat_logits = logits |
|
flat_labels = labels.view(-1) |
|
|
|
loss = self.loss_func(flat_logits, flat_labels) |
|
return {'logits': logits, 'loss': loss} |
|
else: |
|
return {'logits': logits} |
|
|
|
|
|
def predict(self, input_ids=None, attention_mask=None): |
|
logits = self.forward(input_ids, attention_mask, labels=None)['logits'] |
|
p = torch.argmax(logits, axis=1) |
|
labels = [self.config.id2label[index.item()] for index in p] |
|
return {'labels':labels, 'logits':logits} |
|
|
|
|