zaoju-demo / app.py
cn91's picture
Update app.py
9a919c2
raw
history blame
4.46 kB
from transformers import pipeline, AutoTokenizer
import pandas as pd
import numpy as np
import torch
import streamlit as st
USE_GPU = True
if USE_GPU and torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device('cpu')
#MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-186M-Chinese-SentencePiece"
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-97M-CWS-Chinese"
WORD_PROBABILITY_THRESHOLD = 0.02
TOP_K_WORDS = 10
CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
@st.cache_resource
def get_model_chinese():
return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
def assess_chinese(word, sentence):
print("Assessing Chinese")
if sentence.lower().find(word.lower()) == -1:
print('Sentence does not contain the word!')
return
text = sentence.replace(word.lower(), "<mask>")
top_k_prediction = mask_filler_chinese(text, top_k=TOP_K_WORDS)
target_word_prediction = mask_filler_chinese(text, targets = word)
score = target_word_prediction[0]['score']
# append the original word if its not found in the results
top_k_prediction_filtered = [output for output in top_k_prediction if \
output['token_str'] == word]
if len(top_k_prediction_filtered) == 0:
top_k_prediction.extend(target_word_prediction)
return top_k_prediction, score
def assess_sentence(word, sentence):
return assess_chinese(word, sentence)
def get_chinese_word():
possible_words = CHINESE_WORDLIST
word = np.random.choice(possible_words)
return word
def get_word():
return get_chinese_word()
mask_filler_chinese = get_model_chinese()
#wordlist_chinese = get_wordlist_chinese()
def highlight_given_word(row):
color = '#ACE5EE' if row.Words == target_word else 'white'
return [f'background-color:{color}'] * len(row)
def get_top_5_results(top_k_prediction):
predictions_df = pd.DataFrame(top_k_prediction)
predictions_df = predictions_df.drop(columns=["token", "sequence"])
predictions_df = predictions_df.rename(columns={"score": "Probability", "token_str": "Words"})
if (predictions_df[:5].Words == target_word).sum() == 0:
print("target word not in top 5")
top_5_df = predictions_df[:5]
target_word_df = predictions_df[(predictions_df.Words == target_word)]
print(target_word_df)
top_5_df = pd.concat([top_5_df, target_word_df])
else:
top_5_df = predictions_df[:5]
top_5_df['Probability'] = top_5_df['Probability'].apply(lambda x: f"{x:.2%}")
return top_5_df
#### Streamlit Page
st.title("造句 Auto-marking Demo")
if 'target_word' not in st.session_state:
st.session_state['target_word'] = get_word()
target_word = st.session_state['target_word']
st.write("Target word: ", target_word)
if st.button("Get new word"):
st.session_state['target_word'] = get_word()
st.experimental_rerun()
st.subheader("Form your sentence and input below!")
sentence = st.text_input('Enter your sentence here', placeholder="Enter your sentence here!")
if st.button("Grade"):
top_k_prediction, score = assess_sentence(target_word, sentence)
with open('./result01.json', 'w') as outfile:
outfile.write(str(top_k_prediction))
st.write(f"Probability: {score:.2%}")
st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.2%}")
predictions_df = get_top_5_results(top_k_prediction)
df_style = predictions_df.style.apply(highlight_given_word, axis=1)
if (score >= WORD_PROBABILITY_THRESHOLD):
# st.balloons()
st.success("Yay good job! 🕺 Practice again with other words", icon="✅")
st.table(df_style)
else:
st.warning("Hmmm.. maybe try again?")
st.table(df_style)