|
from transformers import Pipeline |
|
import torch |
|
|
|
class PairClassificationPipeline(Pipeline): |
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "text_pair" in kwargs: |
|
preprocess_kwargs["text_pair"] = kwargs["text_pair"] |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, text, text_pair=None): |
|
return self.tokenizer(text, text_pair=text_pair, return_tensors="pt") |
|
|
|
def _forward(self, model_inputs): |
|
return self.model(**model_inputs) |
|
|
|
def postprocess(self, model_outputs): |
|
logits = model_outputs.logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=-1) |
|
|
|
best_class = probabilities.argmax().item() |
|
label = self.model.config.id2label[best_class] |
|
score = probabilities.squeeze()[best_class].item() |
|
logits = logits.squeeze().tolist() |
|
return {"label": label, "score": score, "logits": logits} |
|
|