PhantHive commited on
Commit
c3fe735
1 Parent(s): 03994f7

Delete gen.py

Browse files
Files changed (1) hide show
  1. gen.py +0 -33
gen.py DELETED
@@ -1,33 +0,0 @@
1
- import numpy as np
2
-
3
- from transformers import Pipeline
4
-
5
-
6
- def softmax(outputs):
7
- maxes = np.max(outputs, axis=-1, keepdims=True)
8
- shifted_exp = np.exp(outputs - maxes)
9
- return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
10
-
11
-
12
- class PairClassificationPipeline(Pipeline):
13
- def _sanitize_parameters(self, **kwargs):
14
- preprocess_kwargs = {}
15
- if "second_text" in kwargs:
16
- preprocess_kwargs["second_text"] = kwargs["second_text"]
17
- return preprocess_kwargs, {}, {}
18
-
19
- def preprocess(self, text, second_text=None):
20
- return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework)
21
-
22
- def _forward(self, model_inputs):
23
- return self.model(**model_inputs)
24
-
25
- def postprocess(self, model_outputs):
26
- logits = model_outputs.logits[0].numpy()
27
- probabilities = softmax(logits)
28
-
29
- best_class = np.argmax(probabilities)
30
- label = self.model.config.id2label[best_class]
31
- score = probabilities[best_class].item()
32
- logits = logits.tolist()
33
- return {"label": label, "score": score, "logits": logits}