File size: 454 Bytes
54e216c
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
from torch import nn

from transformers import XLMRobertaModel


class XLMRobertaForSequenceClassification(XLMRobertaModel):
    def __init__(self, config):
        super().__init__(config)
        self.classifier = nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask):
        outputs = super(XLMRobertaForSequenceClassification, self).forward(input_ids=input_ids, attention_mask=attention_mask)
        return self.classifier(outputs[1])