from transformers import pipeline, Pipeline, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification from transformers.pipelines import PIPELINE_REGISTRY import torch class SpanClassificationPipeline(Pipeline): def __init__(self, model, tokenizer, device="cpu", **kwargs): super().__init__(model=model, tokenizer=tokenizer, device=device, **kwargs) self.model.to(self.device) self.model.eval() def _sanitize_parameters(self, **kwargs): return {}, kwargs, {} def preprocess(self, inputs): return self.tokenizer(inputs, return_tensors="pt").to(self.device) def _forward(self, model_inputs): with torch.no_grad(): outputs = self.model(**model_inputs) return outputs def postprocess(self, model_outputs): logits = model_outputs.logits return int(torch.argmax(logits, dim=1).item())