shramay-palta commited on
Commit
b5dc66e
·
verified ·
1 Parent(s): 793dccd

Upload DemoT5QAPipeline

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. demo_t5_qa_pipe.py +2 -2
config.json CHANGED
@@ -6,7 +6,7 @@
6
  "classifier_dropout": 0.0,
7
  "custom_pipelines": {
8
  "demo-t5-small-qa": {
9
- "impl": "__main__.DemoT5QAPipeline",
10
  "pt": [
11
  "AutoModelForSeq2SeqLM"
12
  ],
 
6
  "classifier_dropout": 0.0,
7
  "custom_pipelines": {
8
  "demo-t5-small-qa": {
9
+ "impl": "demo_t5_qa_pipe.DemoT5QAPipeline",
10
  "pt": [
11
  "AutoModelForSeq2SeqLM"
12
  ],
demo_t5_qa_pipe.py CHANGED
@@ -16,7 +16,7 @@ class DemoT5QAPipeline(Text2TextGenerationPipeline):
16
  generate_kwargs.get("max_length", self.model.config.max_length),
17
  )
18
  outputs = self.model.generate(**model_inputs, **generate_kwargs, return_dict_in_generate=True, output_scores=True, max_new_tokens=75)
19
-
20
  # Code from the parent class
21
  output_ids = outputs.sequences
22
  out_b = output_ids.shape[0]
@@ -28,7 +28,7 @@ class DemoT5QAPipeline(Text2TextGenerationPipeline):
28
  output_sequences = outputs.sequences
29
  output_scores = outputs.scores
30
  return {"output_ids": output_ids, "output_sequences": output_sequences, "output_scores": output_scores}
31
-
32
  def postprocess(self, model_outputs):
33
  guess_text = super().postprocess(model_outputs)[0]['generated_text']
34
 
 
16
  generate_kwargs.get("max_length", self.model.config.max_length),
17
  )
18
  outputs = self.model.generate(**model_inputs, **generate_kwargs, return_dict_in_generate=True, output_scores=True, max_new_tokens=75)
19
+
20
  # Code from the parent class
21
  output_ids = outputs.sequences
22
  out_b = output_ids.shape[0]
 
28
  output_sequences = outputs.sequences
29
  output_scores = outputs.scores
30
  return {"output_ids": output_ids, "output_sequences": output_sequences, "output_scores": output_scores}
31
+
32
  def postprocess(self, model_outputs):
33
  guess_text = super().postprocess(model_outputs)[0]['generated_text']
34