Spaces:
Sleeping
Sleeping
Update to 710M Character Model, add RTD
Browse files
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
from transformers import pipeline, AutoTokenizer
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import streamlit as st
|
|
|
6 |
|
7 |
USE_GPU = True
|
8 |
|
@@ -11,64 +12,127 @@ if USE_GPU and torch.cuda.is_available():
|
|
11 |
else:
|
12 |
device = torch.device('cpu')
|
13 |
|
14 |
-
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-
|
15 |
-
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
TOP_K_WORDS = 200
|
20 |
-
|
21 |
-
CHINESE_WORDLIST = ['一定','一样','不得了','主观','从此','便于','俗话','倒霉','候选','充沛','分别','反倒','只好','同情','吹捧','咳嗽','围绕','如意','实行','将近','就职','应该','归还','当面','忘记','急忙','恢复','悲哀','感冒','成长','截至','打架','把握','报告','抱怨','担保','拒绝','拜访','拥护','拳头','拼搏','损坏','接待','握手','揭发','攀登','显示','普遍','未免','欣赏','正式','比如','流浪','涂抹','深刻','演绎','留念','瞻仰','确保','稍微','立刻','精心','结算','罕见','访问','请示','责怪','起初','转达','辅导','过瘾','运动','连忙','适合','遭受','重叠','镇静']
|
22 |
|
23 |
@st.cache_resource
|
24 |
def get_model_chinese():
|
25 |
return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
|
26 |
|
27 |
@st.cache_resource
|
28 |
-
def
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def assess_chinese(word, sentence):
|
33 |
print("Assessing Chinese")
|
34 |
-
|
35 |
-
|
|
|
|
|
36 |
if sentence.lower().find(word.lower()) == -1:
|
37 |
print('Sentence does not contain the word!')
|
38 |
return
|
39 |
|
40 |
-
text = sentence.replace(word.lower(), "
|
41 |
|
42 |
-
top_k_prediction =
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
norm_factor = 0
|
46 |
for output in top_k_prediction:
|
47 |
-
if output['
|
48 |
norm_factor += output['score']
|
49 |
|
50 |
top_k_prediction_new = []
|
51 |
for output in top_k_prediction:
|
52 |
-
if output['
|
53 |
output['score'] = output['score']/(1-min(0.5,norm_factor))
|
54 |
top_k_prediction_new.append(output)
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
# append the original word if its not found in the results
|
60 |
top_k_prediction_filtered = [output for output in top_k_prediction_new if \
|
61 |
output['token_str'] == word]
|
62 |
if len(top_k_prediction_filtered) == 0:
|
63 |
-
top_k_prediction_new.extend(target_word_prediction)
|
64 |
|
65 |
return top_k_prediction_new, score
|
66 |
|
67 |
def assess_sentence(word, sentence):
|
68 |
return assess_chinese(word, sentence)
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def get_chinese_word():
|
71 |
-
possible_words =
|
72 |
word = np.random.choice(possible_words)
|
73 |
return word
|
74 |
|
@@ -77,6 +141,8 @@ def get_word():
|
|
77 |
|
78 |
mask_filler_chinese = get_model_chinese()
|
79 |
#wordlist_chinese = get_wordlist_chinese()
|
|
|
|
|
80 |
|
81 |
def highlight_given_word(row):
|
82 |
color = '#ACE5EE' if row.Words == target_word else 'white'
|
@@ -101,14 +167,17 @@ def get_top_5_results(top_k_prediction):
|
|
101 |
return top_5_df
|
102 |
|
103 |
#### Streamlit Page
|
104 |
-
st.title("造句
|
105 |
|
106 |
if 'target_word' not in st.session_state:
|
107 |
st.session_state['target_word'] = get_word()
|
108 |
target_word = st.session_state['target_word']
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
st.
|
111 |
-
if st.button("Get new word"):
|
112 |
st.session_state['target_word'] = get_word()
|
113 |
st.experimental_rerun()
|
114 |
|
@@ -122,16 +191,23 @@ if st.button("Grade"):
|
|
122 |
with open('./result01.json', 'w') as outfile:
|
123 |
outfile.write(str(top_k_prediction))
|
124 |
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
|
|
|
|
127 |
predictions_df = get_top_5_results(top_k_prediction)
|
128 |
df_style = predictions_df.style.apply(highlight_given_word, axis=1)
|
129 |
|
130 |
if (score >= WORD_PROBABILITY_THRESHOLD):
|
131 |
# st.balloons()
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
else:
|
135 |
-
st.warning("
|
136 |
-
|
137 |
|
|
|
1 |
+
from transformers import pipeline, AutoTokenizer, ElectraForPreTraining
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
import streamlit as st
|
6 |
+
from annotated_text import annotated_text
|
7 |
|
8 |
USE_GPU = True
|
9 |
|
|
|
12 |
else:
|
13 |
device = torch.device('cpu')
|
14 |
|
15 |
+
MODEL_NAME_CHINESE = "IDEA-CCNL/Erlangshen-DeBERTa-v2-710M-Chinese"
|
16 |
+
RTD_MODEL_NAME_CHINESE = "hfl/chinese-electra-180g-large-discriminator"
|
17 |
|
18 |
+
WORD_PROBABILITY_THRESHOLD = 0.05
|
19 |
+
TOP_K_WORDS = 10
|
|
|
|
|
|
|
20 |
|
21 |
@st.cache_resource
|
22 |
def get_model_chinese():
|
23 |
return pipeline("fill-mask", MODEL_NAME_CHINESE, device = device)
|
24 |
|
25 |
@st.cache_resource
|
26 |
+
def get_rtd_tokenizer_chinese():
|
27 |
+
return AutoTokenizer.from_pretrained(RTD_MODEL_NAME_CHINESE)
|
28 |
+
|
29 |
+
@st.cache_resource
|
30 |
+
def get_rtd_model_chinese():
|
31 |
+
return ElectraForPreTraining.from_pretrained(RTD_MODEL_NAME_CHINESE)
|
32 |
+
|
33 |
+
@st.cache_resource
|
34 |
+
def get_wordlist_chinese():
|
35 |
+
df = pd.read_csv('wordlist_chinese_v2.csv')
|
36 |
+
wordlist = df[df.assess == True]
|
37 |
+
return wordlist['Chinese'].tolist()
|
38 |
+
|
39 |
+
@st.cache_resource
|
40 |
+
def get_allowed_words():
|
41 |
+
df = pd.read_csv('allowed_words.csv')
|
42 |
+
return set(list(df['word']))
|
43 |
|
44 |
def assess_chinese(word, sentence):
|
45 |
print("Assessing Chinese")
|
46 |
+
number_of_chars = len(word)
|
47 |
+
assert number_of_chars == 2
|
48 |
+
|
49 |
+
allowed_words = get_allowed_words()
|
50 |
if sentence.lower().find(word.lower()) == -1:
|
51 |
print('Sentence does not contain the word!')
|
52 |
return
|
53 |
|
54 |
+
text = sentence.replace(word.lower(), "[MASK]"*number_of_chars)
|
55 |
|
56 |
+
top_k_prediction = []
|
57 |
+
candidates = mask_filler_chinese(text, top_k=TOP_K_WORDS)[0]
|
58 |
+
for candidate in candidates:
|
59 |
+
temp_text = text.replace("[MASK]", candidate['token_str'], 1)
|
60 |
+
second_predictions = mask_filler_chinese(temp_text, top_k=5)
|
61 |
+
for prediction in second_predictions:
|
62 |
+
prediction['token_str'] = candidate['token_str'] + prediction['token_str']
|
63 |
+
prediction['score'] = candidate['score'] * prediction['score']
|
64 |
+
|
65 |
+
top_k_prediction.extend(second_predictions)
|
66 |
+
top_k_prediction = sorted(top_k_prediction, key = lambda x: x['score'], reverse = True)[:(TOP_K_WORDS*5)]
|
67 |
|
68 |
norm_factor = 0
|
69 |
for output in top_k_prediction:
|
70 |
+
if output['token_str'] not in allowed_words:
|
71 |
norm_factor += output['score']
|
72 |
|
73 |
top_k_prediction_new = []
|
74 |
for output in top_k_prediction:
|
75 |
+
if output['token_str'] in allowed_words:
|
76 |
output['score'] = output['score']/(1-min(0.5,norm_factor))
|
77 |
top_k_prediction_new.append(output)
|
78 |
+
print (f"NORM_FACTOR: {norm_factor}")
|
79 |
+
|
80 |
+
# Get target word prediction
|
81 |
+
temp_text = text
|
82 |
+
output1 = mask_filler_chinese(text, targets=word[0])[0][0]
|
83 |
+
temp_text = text.replace("[MASK]", word[0], 1)
|
84 |
+
output2 = mask_filler_chinese(temp_text, targets = word[1])[0]
|
85 |
+
output2['token_str'] = output1['token_str'] + output2['token_str']
|
86 |
+
output2['score'] = output1['score'] * output2['score']
|
87 |
+
target_word_prediction = output2
|
88 |
+
|
89 |
+
target_word_prediction['score'] = target_word_prediction['score'] / (1-min(0.5,norm_factor))
|
90 |
+
score = target_word_prediction['score']
|
91 |
|
92 |
# append the original word if its not found in the results
|
93 |
top_k_prediction_filtered = [output for output in top_k_prediction_new if \
|
94 |
output['token_str'] == word]
|
95 |
if len(top_k_prediction_filtered) == 0:
|
96 |
+
top_k_prediction_new.extend([target_word_prediction])
|
97 |
|
98 |
return top_k_prediction_new, score
|
99 |
|
100 |
def assess_sentence(word, sentence):
|
101 |
return assess_chinese(word, sentence)
|
102 |
|
103 |
+
def get_annotated_sentence(sentence, errors):
|
104 |
+
if len(errors) == 0:
|
105 |
+
return sentence
|
106 |
+
|
107 |
+
output = ["Input sentence: "]
|
108 |
+
|
109 |
+
wrong_char_indices = [e[0].item() for e in errors]
|
110 |
+
curr_ind = 0
|
111 |
+
for i in range(len(wrong_char_indices)):
|
112 |
+
output.append(sentence[curr_ind:wrong_char_indices[i]])
|
113 |
+
output.append((sentence[wrong_char_indices[i]], "", "#F8C8DC"))
|
114 |
+
# output.append((sentence[wrong_char_indices[i]], " ", "#ff4b4b"))
|
115 |
+
curr_ind = wrong_char_indices[i] + 1
|
116 |
+
output.append(sentence[curr_ind:])
|
117 |
+
print(output)
|
118 |
+
|
119 |
+
return output
|
120 |
+
|
121 |
+
def get_word_errors(word, sentence):
|
122 |
+
tokens = rtd_tokenizer_chinese(sentence, return_tensors = 'pt', return_offsets_mapping = True)
|
123 |
+
scores = rtd_model_chinese(**rtd_tokenizer_chinese(sentence, return_tensors = 'pt'))[0][0]
|
124 |
+
|
125 |
+
errors = []
|
126 |
+
for i in range(len(scores)):
|
127 |
+
if scores[i] > 0:
|
128 |
+
errors.append(tokens['offset_mapping'][0][i])
|
129 |
+
|
130 |
+
print(errors)
|
131 |
+
return errors
|
132 |
+
|
133 |
+
|
134 |
def get_chinese_word():
|
135 |
+
possible_words = get_wordlist_chinese()
|
136 |
word = np.random.choice(possible_words)
|
137 |
return word
|
138 |
|
|
|
141 |
|
142 |
mask_filler_chinese = get_model_chinese()
|
143 |
#wordlist_chinese = get_wordlist_chinese()
|
144 |
+
rtd_tokenizer_chinese = get_rtd_tokenizer_chinese()
|
145 |
+
rtd_model_chinese = get_rtd_model_chinese()
|
146 |
|
147 |
def highlight_given_word(row):
|
148 |
color = '#ACE5EE' if row.Words == target_word else 'white'
|
|
|
167 |
return top_5_df
|
168 |
|
169 |
#### Streamlit Page
|
170 |
+
st.title("造句 Self-marking Demo")
|
171 |
|
172 |
if 'target_word' not in st.session_state:
|
173 |
st.session_state['target_word'] = get_word()
|
174 |
target_word = st.session_state['target_word']
|
175 |
+
target_word_ind = get_wordlist_chinese().index(target_word)
|
176 |
+
|
177 |
+
#st.write("Target word: ", target_word)
|
178 |
+
target_word = st.selectbox("Choose a word:", get_wordlist_chinese(), index = target_word_ind)
|
179 |
|
180 |
+
if st.button("Get random word"):
|
|
|
181 |
st.session_state['target_word'] = get_word()
|
182 |
st.experimental_rerun()
|
183 |
|
|
|
191 |
with open('./result01.json', 'w') as outfile:
|
192 |
outfile.write(str(top_k_prediction))
|
193 |
|
194 |
+
errors = get_word_errors(target_word, sentence)
|
195 |
+
annotated_sentence = get_annotated_sentence(sentence, errors)
|
196 |
+
|
197 |
+
annotated_text(annotated_sentence)
|
198 |
+
|
199 |
+
st.write(f"Probability score: {score:.1%}. (Target: {WORD_PROBABILITY_THRESHOLD:.1%})")
|
200 |
+
# st.write(f"Target probability: {WORD_PROBABILITY_THRESHOLD:.1%}")
|
201 |
predictions_df = get_top_5_results(top_k_prediction)
|
202 |
df_style = predictions_df.style.apply(highlight_given_word, axis=1)
|
203 |
|
204 |
if (score >= WORD_PROBABILITY_THRESHOLD):
|
205 |
# st.balloons()
|
206 |
+
if (len(errors) == 0):
|
207 |
+
st.success("Yay good job! 🕺 Practice again with other words", icon="✅")
|
208 |
+
else:
|
209 |
+
st.warning("Potential word errors detected. Try again?")
|
210 |
else:
|
211 |
+
st.warning("Probability score too low. Maybe try again?")
|
212 |
+
st.table(df_style)
|
213 |
|