shramay-palta commited on
Commit
9ca2cbf
·
verified ·
1 Parent(s): 7e287fa

Delete demo_t5_qa_pipe.py

Browse files
Files changed (1) hide show
  1. demo_t5_qa_pipe.py +0 -39
demo_t5_qa_pipe.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
- import tensorflow as tf
3
- import numpy as np
4
- from transformers import Text2TextGenerationPipeline
5
-
6
- class DemoT5QAPipeline(Text2TextGenerationPipeline):
7
- def _forward(self, model_inputs, **generate_kwargs):
8
- if self.framework == "pt":
9
- in_b, input_length = model_inputs["input_ids"].shape
10
- elif self.framework == "tf":
11
- in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
12
-
13
- self.check_inputs(
14
- input_length,
15
- generate_kwargs.get("min_length", self.model.config.min_length),
16
- generate_kwargs.get("max_length", self.model.config.max_length),
17
- )
18
- outputs = self.model.generate(**model_inputs, **generate_kwargs, return_dict_in_generate=True, output_scores=True, max_new_tokens=75)
19
-
20
- # Code from the parent class
21
- output_ids = outputs.sequences
22
- out_b = output_ids.shape[0]
23
- if self.framework == "pt":
24
- output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
25
- elif self.framework == "tf":
26
- output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
27
-
28
- output_sequences = outputs.sequences
29
- output_scores = outputs.scores
30
- return {"output_ids": output_ids, "output_sequences": output_sequences, "output_scores": output_scores}
31
-
32
- def postprocess(self, model_outputs):
33
- guess_text = super().postprocess(model_outputs)[0]['generated_text']
34
-
35
- transition_scores = self.model.compute_transition_scores(model_outputs['output_sequences'], model_outputs['output_scores'], normalize_logits=True)
36
- log_probs = np.round(np.exp(transition_scores.cpu().numpy()), 3)[0]
37
- guess_prob = np.product(log_probs)
38
-
39
- return {'guess': guess_text, 'confidence': guess_prob}