Bashir Rastegarpanah commited on
Commit
ec0497f
1 Parent(s): da18d19
Files changed (1) hide show
  1. handler.py +39 -0
handler.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+
4
+
5
+ class EndpointHandler:
6
+ def __init__(self,
7
+ path=""
8
+ ):
9
+ # Preload all the elements you are going to need at inference.
10
+ # pseudo:
11
+ self.model = AutoModelForSequenceClassification.from_pretrained(path)
12
+ self.tokenizer = AutoTokenizer.from_pretrained("roberta-large", padding_side='right')
13
+
14
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
15
+
16
+ input_dict = data.pop("inputs", data)
17
+
18
+ self.model.eval()
19
+ input = self.tokenizer(input_dict['answer'],
20
+ input_dict['source'],
21
+ truncation=True,
22
+ max_length=None,
23
+ return_tensors="pt"
24
+ )
25
+ # input.to(device)
26
+ # with torch.no_grad():
27
+ # output = model(**input)
28
+ output = model(**input)
29
+
30
+ prediction = output.logits.argmax(dim=-1)
31
+
32
+ #smax = nn.Softmax(dim=1)
33
+ #score = smax(output.logits)
34
+
35
+ return [{
36
+ "label": prediction.item(),
37
+ #"score": score[0][0].item()
38
+ }
39
+ ]