ZetangForward commited on
Commit
2503e87
·
verified ·
1 Parent(s): eb143e0

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. config.json +12 -1
  2. spancnn_pipeline.py +24 -0
config.json CHANGED
@@ -31,5 +31,16 @@
31
  "transformers_version": "4.28.1",
32
  "type_vocab_size": 2,
33
  "use_cache": true,
34
- "vocab_size": 30522
 
 
 
 
 
 
 
 
 
 
 
35
  }
 
31
  "transformers_version": "4.28.1",
32
  "type_vocab_size": 2,
33
  "use_cache": true,
34
+ "vocab_size": 30522,
35
+ "custom_pipelines": {
36
+ "spancnn-classification": {
37
+ "impl": "spancnn_pipeline.SpanClassificationPipeline",
38
+ "pt": [
39
+ "AutoModelForSequenceClassification"
40
+ ],
41
+ "tf": [
42
+ "TFAutoModelForSequenceClassification"
43
+ ]
44
+ }
45
+ }
46
  }
spancnn_pipeline.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, Pipeline, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
2
+ from transformers.pipelines import PIPELINE_REGISTRY
3
+ import torch
4
+
5
+ class SpanClassificationPipeline(Pipeline):
6
+ def __init__(self, model, tokenizer, device="cpu", **kwargs):
7
+ super().__init__(model=model, tokenizer=tokenizer, device=device, **kwargs)
8
+ self.model.to(self.device)
9
+ self.model.eval()
10
+
11
+ def _sanitize_parameters(self, **kwargs):
12
+ return {}, kwargs, {}
13
+
14
+ def preprocess(self, inputs):
15
+ return self.tokenizer(inputs, return_tensors="pt").to(self.device)
16
+
17
+ def _forward(self, model_inputs):
18
+ with torch.no_grad():
19
+ outputs = self.model(**model_inputs)
20
+ return outputs
21
+
22
+ def postprocess(self, model_outputs):
23
+ logits = model_outputs.logits
24
+ return int(torch.argmax(logits, dim=1).item())