Delete gen.py
Browse files
gen.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
|
3 |
-
from transformers import Pipeline
|
4 |
-
|
5 |
-
|
6 |
-
def softmax(outputs):
|
7 |
-
maxes = np.max(outputs, axis=-1, keepdims=True)
|
8 |
-
shifted_exp = np.exp(outputs - maxes)
|
9 |
-
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
10 |
-
|
11 |
-
|
12 |
-
class PairClassificationPipeline(Pipeline):
|
13 |
-
def _sanitize_parameters(self, **kwargs):
|
14 |
-
preprocess_kwargs = {}
|
15 |
-
if "second_text" in kwargs:
|
16 |
-
preprocess_kwargs["second_text"] = kwargs["second_text"]
|
17 |
-
return preprocess_kwargs, {}, {}
|
18 |
-
|
19 |
-
def preprocess(self, text, second_text=None):
|
20 |
-
return self.tokenizer(text, text_pair=second_text, return_tensors=self.framework)
|
21 |
-
|
22 |
-
def _forward(self, model_inputs):
|
23 |
-
return self.model(**model_inputs)
|
24 |
-
|
25 |
-
def postprocess(self, model_outputs):
|
26 |
-
logits = model_outputs.logits[0].numpy()
|
27 |
-
probabilities = softmax(logits)
|
28 |
-
|
29 |
-
best_class = np.argmax(probabilities)
|
30 |
-
label = self.model.config.id2label[best_class]
|
31 |
-
score = probabilities[best_class].item()
|
32 |
-
logits = logits.tolist()
|
33 |
-
return {"label": label, "score": score, "logits": logits}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|