Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
|
2 |
-
|
3 |
import gradio as gr
|
4 |
from torch.nn import functional as F
|
5 |
import seaborn
|
@@ -60,7 +60,9 @@ def generate(model_name, text):
|
|
60 |
input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
|
61 |
outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
|
62 |
output = tokenizer.decode(outputs[0])
|
63 |
-
return ".".join(output.split(".")[:-1]) + "."
|
|
|
|
|
64 |
|
65 |
|
66 |
output_text = gr.outputs.Textbox()
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
|
2 |
+
import re
|
3 |
import gradio as gr
|
4 |
from torch.nn import functional as F
|
5 |
import seaborn
|
|
|
60 |
input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
|
61 |
outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
|
62 |
output = tokenizer.decode(outputs[0])
|
63 |
+
#return ".".join(output.split(".")[:-1]) + "."
|
64 |
+
sent = ".".join(output.split(".")[:-1]) + "."
|
65 |
+
return re.match(r'<pad> ([^<>]*)', sent).group(1)
|
66 |
|
67 |
|
68 |
output_text = gr.outputs.Textbox()
|