dmariko commited on
Commit
e098c18
1 Parent(s): c9f2232

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
2
-
3
  import gradio as gr
4
  from torch.nn import functional as F
5
  import seaborn
@@ -60,7 +60,9 @@ def generate(model_name, text):
60
  input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
61
  outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
62
  output = tokenizer.decode(outputs[0])
63
- return ".".join(output.split(".")[:-1]) + "."
 
 
64
 
65
 
66
  output_text = gr.outputs.Textbox()
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
2
+ import re
3
  import gradio as gr
4
  from torch.nn import functional as F
5
  import seaborn
 
60
  input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
61
  outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
62
  output = tokenizer.decode(outputs[0])
63
+ #return ".".join(output.split(".")[:-1]) + "."
64
+ sent = ".".join(output.split(".")[:-1]) + "."
65
+ return re.match(r'<pad> ([^<>]*)', sent).group(1)
66
 
67
 
68
  output_text = gr.outputs.Textbox()