File size: 2,808 Bytes
cec4eb1 e098c18 0b71553 a7d217c 283d9fc 27d9c81 0b71553 27d9c81 0b71553 4c507d7 27d9c81 0b71553 5016fd8 5963b0e 2971a87 e098c18 0b71553 2971a87 0b71553 a31490f 0b71553 8165996 0b71553 eb434de 2f7fcf5 eb434de c9f2232 eb434de 27d9c81 eb434de 27d9c81 eb434de 0b71553 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig
import re
import gradio as gr
from torch.nn import functional as F
import seaborn
import matplotlib
import platform
from transformers.file_utils import ModelOutput
if platform.system() == "Darwin":
print("MacOS")
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import io
from PIL import Image
import matplotlib.font_manager as fm
# global var
MODEL_NAME = 'yseop/FNP_T5_D2T_complete'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
config = AutoConfig.from_pretrained(MODEL_NAME)
MODEL_BUF = {
"name": MODEL_NAME,
"tokenizer": tokenizer,
"model": model,
"config": config
}
font_dir = ['./']
for font in fm.findSystemFonts(font_dir):
print(font)
fm.fontManager.addfont(font)
plt.rcParams["font.family"] = 'NanumGothicCoding'
def change_model_name(name):
MODEL_BUF["name"] = name
MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name)
MODEL_BUF["model"] = AutoModelForSeq2SeqLM.from_pretrained(name)
MODEL_BUF["config"] = AutoConfig.from_pretrained(name)
def generate(model_name, text):
if model_name != MODEL_NAME:
change_model_name(model_name)
tokenizer = MODEL_BUF["tokenizer"]
model = MODEL_BUF["model"]
config = MODEL_BUF["config"]
model.eval()
input_ids = tokenizer.encode("AFA:{}".format(text), return_tensors="pt")
outputs = model.generate(input_ids, max_length=200, num_beams=2, repetition_penalty=2.5, top_k=50, top_p=0.98, length_penalty=1.0, early_stopping=True)
output = tokenizer.decode(outputs[0])
#return ".".join(output.split(".")[:-1]) + "."
sent = ".".join(output.split(".")[:-1]) + "."
return re.match(r'<pad> ([^<>]*)', sent).group(1)
output_text = gr.outputs.Textbox()
if __name__ == '__main__':
text = ['Group profit | valIs | € 115.7 million && € 115.7 million | dTime | in 2019',
'Net income | valIs | $48 million && $48 million | diGeo | in France && Net income | jPose | the interest rate && the interest rate | valIs | 0.6%',
'The retirement age | incBy | 7 years && 7 years | cTime | 2018 && The retirement age | jpose | life expectancy && life expectancy | incBy | 10 years',
'sales | incBy | € 115.7 million && € 115.7 million | dTime | in 2019 && € 115.7 million | diGeo | Europe']
model_name_list = [
'yseop/FNP_T5_D2T_complete',
'yseop/FNP_T5_D2T_simple'
]
app = gr.Interface(
fn=generate,
inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], outputs=output_text,
examples = [[MODEL_BUF["name"], text]],
title="FTG",
description="Financial Text Generation"
)
app.launch(inline=False)
|