from torch import nn from transformers import XLMRobertaModel class XLMRobertaForSequenceClassification(XLMRobertaModel): def __init__(self, config): super().__init__(config) self.classifier = nn.Linear(768, 2) def forward(self, input_ids, attention_mask): outputs = super(XLMRobertaForSequenceClassification, self).forward(input_ids=input_ids, attention_mask=attention_mask) return self.classifier(outputs[1])