|
from transformers import pipeline, Pipeline, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification |
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
import torch |
|
|
|
|
|
class SpanClassificationPipeline(Pipeline): |
|
def __init__(self, model, tokenizer, device="cpu", **kwargs): |
|
super().__init__(model=model, tokenizer=tokenizer, device=device, **kwargs) |
|
self.model.to(self.device) |
|
self.model.eval() |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
return {}, kwargs, {} |
|
|
|
def preprocess(self, inputs): |
|
return self.tokenizer(inputs, return_tensors="pt").to(self.device) |
|
|
|
def _forward(self, model_inputs): |
|
with torch.no_grad(): |
|
outputs = self.model(**model_inputs) |
|
return outputs |
|
|
|
def postprocess(self, model_outputs): |
|
logits = model_outputs.logits |
|
return int(torch.argmax(logits, dim=1).item()) |
|
|
|
|
|
PIPELINE_REGISTRY.register_pipeline( |
|
task="spancnn-classification", |
|
pipeline_class=SpanClassificationPipeline, |
|
pt_model=AutoModelForSequenceClassification, |
|
) |
|
|