Spaces:
Sleeping
Sleeping
wanderer2k1
commited on
Commit
•
a7b7647
1
Parent(s):
d3a77a4
major fix
Browse files- SessionState.py +0 -117
- app.py +49 -41
SessionState.py
DELETED
@@ -1,117 +0,0 @@
|
|
1 |
-
"""Hack to add per-session state to Streamlit.
|
2 |
-
|
3 |
-
Usage
|
4 |
-
-----
|
5 |
-
|
6 |
-
>>> import SessionState
|
7 |
-
>>>
|
8 |
-
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
9 |
-
>>> session_state.user_name
|
10 |
-
''
|
11 |
-
>>> session_state.user_name = 'Mary'
|
12 |
-
>>> session_state.favorite_color
|
13 |
-
'black'
|
14 |
-
|
15 |
-
Since you set user_name above, next time your script runs this will be the
|
16 |
-
result:
|
17 |
-
>>> session_state = get(user_name='', favorite_color='black')
|
18 |
-
>>> session_state.user_name
|
19 |
-
'Mary'
|
20 |
-
|
21 |
-
"""
|
22 |
-
try:
|
23 |
-
import streamlit.ReportThread as ReportThread
|
24 |
-
from streamlit.server.Server import Server
|
25 |
-
except Exception:
|
26 |
-
# Streamlit >= 0.65.0
|
27 |
-
import streamlit.report_thread as ReportThread
|
28 |
-
from streamlit.server.server import Server
|
29 |
-
|
30 |
-
|
31 |
-
class SessionState(object):
|
32 |
-
def __init__(self, **kwargs):
|
33 |
-
"""A new SessionState object.
|
34 |
-
|
35 |
-
Parameters
|
36 |
-
----------
|
37 |
-
**kwargs : any
|
38 |
-
Default values for the session state.
|
39 |
-
|
40 |
-
Example
|
41 |
-
-------
|
42 |
-
>>> session_state = SessionState(user_name='', favorite_color='black')
|
43 |
-
>>> session_state.user_name = 'Mary'
|
44 |
-
''
|
45 |
-
>>> session_state.favorite_color
|
46 |
-
'black'
|
47 |
-
|
48 |
-
"""
|
49 |
-
for key, val in kwargs.items():
|
50 |
-
setattr(self, key, val)
|
51 |
-
|
52 |
-
|
53 |
-
def get(**kwargs):
|
54 |
-
"""Gets a SessionState object for the current session.
|
55 |
-
|
56 |
-
Creates a new object if necessary.
|
57 |
-
|
58 |
-
Parameters
|
59 |
-
----------
|
60 |
-
**kwargs : any
|
61 |
-
Default values you want to add to the session state, if we're creating a
|
62 |
-
new one.
|
63 |
-
|
64 |
-
Example
|
65 |
-
-------
|
66 |
-
>>> session_state = get(user_name='', favorite_color='black')
|
67 |
-
>>> session_state.user_name
|
68 |
-
''
|
69 |
-
>>> session_state.user_name = 'Mary'
|
70 |
-
>>> session_state.favorite_color
|
71 |
-
'black'
|
72 |
-
|
73 |
-
Since you set user_name above, next time your script runs this will be the
|
74 |
-
result:
|
75 |
-
>>> session_state = get(user_name='', favorite_color='black')
|
76 |
-
>>> session_state.user_name
|
77 |
-
'Mary'
|
78 |
-
|
79 |
-
"""
|
80 |
-
# Hack to get the session object from Streamlit.
|
81 |
-
|
82 |
-
ctx = ReportThread.get_report_ctx()
|
83 |
-
|
84 |
-
this_session = None
|
85 |
-
|
86 |
-
current_server = Server.get_current()
|
87 |
-
if hasattr(current_server, '_session_infos'):
|
88 |
-
# Streamlit < 0.56
|
89 |
-
session_infos = Server.get_current()._session_infos.values()
|
90 |
-
else:
|
91 |
-
session_infos = Server.get_current()._session_info_by_id.values()
|
92 |
-
|
93 |
-
for session_info in session_infos:
|
94 |
-
s = session_info.session
|
95 |
-
if (
|
96 |
-
# Streamlit < 0.54.0
|
97 |
-
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
|
98 |
-
or
|
99 |
-
# Streamlit >= 0.54.0
|
100 |
-
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
|
101 |
-
or
|
102 |
-
# Streamlit >= 0.65.2
|
103 |
-
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
104 |
-
):
|
105 |
-
this_session = s
|
106 |
-
|
107 |
-
if this_session is None:
|
108 |
-
raise RuntimeError(
|
109 |
-
"Oh noes. Couldn't get your Streamlit Session object. "
|
110 |
-
'Are you doing something fancy with threads?')
|
111 |
-
|
112 |
-
# Got the session object! Now let's attach some state into it.
|
113 |
-
|
114 |
-
if not hasattr(this_session, '_custom_session_state'):
|
115 |
-
this_session._custom_session_state = SessionState(**kwargs)
|
116 |
-
|
117 |
-
return this_session._custom_session_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
#basics
|
|
|
2 |
import time
|
3 |
import pandas as pd
|
4 |
import numpy as np
|
@@ -13,7 +14,8 @@ from sentence_transformers.util import cos_sim
|
|
13 |
|
14 |
#streamlit
|
15 |
import streamlit as st
|
16 |
-
import
|
|
|
17 |
from load_css import local_css
|
18 |
local_css("./style.css")
|
19 |
|
@@ -28,9 +30,8 @@ import os.path as path, sys
|
|
28 |
from pathlib import Path
|
29 |
current_dir = path.dirname(path.abspath(getsourcefile(lambda:0)))
|
30 |
sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)])
|
31 |
-
import src.clean_dataset as clean
|
32 |
|
33 |
-
@st.cache(allow_output_mutation=True)
|
34 |
|
35 |
def preprocess(sentence):
|
36 |
sentence=str(sentence)
|
@@ -49,14 +50,15 @@ def selectbox_with_default(text, values, default=DEFAULT, sidebar=False):
|
|
49 |
func = st.sidebar.selectbox if sidebar else st.selectbox
|
50 |
return func(text, np.insert(np.array(values, object), 0, default))
|
51 |
|
|
|
52 |
def neuralqa():
|
53 |
-
|
54 |
model = T5ForConditionalGeneration.from_pretrained("wanderer2k1/T5-LawsQA")
|
55 |
tokenizer = T5TokenizerFast.from_pretrained("wanderer2k1/T5-LawsQA")
|
56 |
|
57 |
bi_encoder = SentenceTransformer('wanderer2k1/BertCondenser_LawsQA')
|
58 |
return tokenizer, model, bi_encoder
|
59 |
|
|
|
60 |
def hf_run_model(tokenizer, model, input_string, **generator_args):
|
61 |
generator_args = {
|
62 |
"max_length": 256,
|
@@ -73,55 +75,52 @@ def hf_run_model(tokenizer, model, input_string, **generator_args):
|
|
73 |
output = [item.split("<sep>") for item in output]
|
74 |
return output
|
75 |
|
76 |
-
|
77 |
#%%
|
78 |
sys.path.pop(0)
|
79 |
|
80 |
#1. load in complete transformed and processed dataset
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
df = pd.read_csv('./data/corpus.pkl', sep = '\t')
|
83 |
-
passages = df['text'].values.tolist()
|
84 |
-
passage_id = df['title'].values.tolist()
|
85 |
|
86 |
#2 load corpus embeddings for neural QA:
|
87 |
-
|
88 |
-
|
89 |
-
embedded_passages =
|
|
|
90 |
|
91 |
#3 load BM25:
|
92 |
-
|
93 |
-
|
|
|
94 |
|
95 |
-
|
96 |
-
|
|
|
97 |
|
98 |
#%%
|
99 |
-
#title start page
|
100 |
-
st.title('Closed Domain (Vietnamese Laws) QA System')
|
101 |
|
102 |
-
sdg = Image.open('./logo.jpg')
|
103 |
-
st.sidebar.image(sdg, width=300)
|
104 |
-
st.sidebar.title('Settings')
|
105 |
|
|
|
106 |
|
107 |
-
st.caption("by HoangNV - on custom laws QA data set")
|
108 |
-
returns = st.sidebar.slider('Maximal number of answer suggestions:', 1, 3, 2)
|
109 |
|
110 |
def deploy(question):
|
111 |
-
tokenizer, model, bi_encoder = neuralqa()
|
112 |
top_k = returns # Number of passages we want to retrieve with the bi-encoder
|
113 |
|
114 |
tokenized_query = preprocess(question).split()
|
115 |
query = ' '.join(tokenized_query)
|
116 |
-
emb_query = bi_encoder.encode(query)
|
117 |
|
118 |
-
scores = bm25.get_scores(tokenized_query)
|
119 |
top_score_ids = np.argpartition(scores, -50)[-50:]
|
120 |
|
121 |
emb_candidates = torch.Tensor()
|
122 |
|
123 |
for i in top_score_ids:
|
124 |
-
emb_candidates = torch.cat([emb_candidates,embedded_passages[i:i+1]], axis = 0)
|
125 |
|
126 |
|
127 |
cosine_sim = cos_sim(emb_query, emb_candidates)
|
@@ -135,14 +134,14 @@ def deploy(question):
|
|
135 |
answers = []
|
136 |
|
137 |
for doc_ind in top_score_ids:
|
138 |
-
doc = passages[doc_ind].replace('_',' ')
|
139 |
|
140 |
matches.append(doc)#' '.join(doc).replace('_',' '))
|
141 |
-
ids.append(passage_id[doc_ind].replace('_',' '))#' '.join(doc[:30].split()[:3]))
|
142 |
# i=0
|
143 |
for context in matches:
|
144 |
q = "Trả lời câu hỏi: "+query + " Trong ngữ cảnh: "+context#tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(context))
|
145 |
-
a = hf_run_model(tokenizer, model, q)[0][0]
|
146 |
answers.append(a)
|
147 |
|
148 |
# generate result df
|
@@ -157,25 +156,34 @@ def deploy(question):
|
|
157 |
st.header("Results:")
|
158 |
st.table(df_results)
|
159 |
|
160 |
-
del tokenizer, model, bi_encoder
|
|
|
|
|
|
|
161 |
|
162 |
#%%
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
if len(question) != 0:
|
166 |
t0 = time.time()
|
167 |
with st.spinner('Finding best answers...'):
|
168 |
deploy(question)
|
169 |
-
st.write(str(time.time()-t0))
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
st.write(' ')
|
174 |
-
st.write(' ')
|
175 |
-
st.write(' ')
|
176 |
-
st.write(' ')
|
177 |
-
if st.button("Run again!"):
|
178 |
-
session.run_id += 1
|
179 |
|
180 |
#%%
|
181 |
p = Path('.')
|
|
|
1 |
#basics
|
2 |
+
from http import server
|
3 |
import time
|
4 |
import pandas as pd
|
5 |
import numpy as np
|
|
|
14 |
|
15 |
#streamlit
|
16 |
import streamlit as st
|
17 |
+
# from streamlit_server_state import server_state, server_state_lock
|
18 |
+
# import SessionState
|
19 |
from load_css import local_css
|
20 |
local_css("./style.css")
|
21 |
|
|
|
30 |
from pathlib import Path
|
31 |
current_dir = path.dirname(path.abspath(getsourcefile(lambda:0)))
|
32 |
sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)])
|
33 |
+
# import src.clean_dataset as clean
|
34 |
|
|
|
35 |
|
36 |
def preprocess(sentence):
|
37 |
sentence=str(sentence)
|
|
|
50 |
func = st.sidebar.selectbox if sidebar else st.selectbox
|
51 |
return func(text, np.insert(np.array(values, object), 0, default))
|
52 |
|
53 |
+
@st.cache(allow_output_mutation=True)
|
54 |
def neuralqa():
|
|
|
55 |
model = T5ForConditionalGeneration.from_pretrained("wanderer2k1/T5-LawsQA")
|
56 |
tokenizer = T5TokenizerFast.from_pretrained("wanderer2k1/T5-LawsQA")
|
57 |
|
58 |
bi_encoder = SentenceTransformer('wanderer2k1/BertCondenser_LawsQA')
|
59 |
return tokenizer, model, bi_encoder
|
60 |
|
61 |
+
|
62 |
def hf_run_model(tokenizer, model, input_string, **generator_args):
|
63 |
generator_args = {
|
64 |
"max_length": 256,
|
|
|
75 |
output = [item.split("<sep>") for item in output]
|
76 |
return output
|
77 |
|
|
|
78 |
#%%
|
79 |
sys.path.pop(0)
|
80 |
|
81 |
#1. load in complete transformed and processed dataset
|
82 |
+
if 'df' not in st.session_state:
|
83 |
+
st.session_state['df'] = pd.read_csv('./data/corpus.pkl', sep = '\t')
|
84 |
+
st.session_state['passages'] = st.session_state['df']['text'].values.tolist()
|
85 |
+
st.session_state['passage_id'] = st.session_state['df']['title'].values.tolist()
|
86 |
|
|
|
|
|
|
|
87 |
|
88 |
#2 load corpus embeddings for neural QA:
|
89 |
+
if 'embedded_passages' not in st.session_state:
|
90 |
+
with open("./data/embedded_corpus_BertCondenser_tuples.pkl", 'rb') as inp:
|
91 |
+
embedded_passages = pickle.load(inp)
|
92 |
+
st.session_state['embedded_passages'] = torch.Tensor(embedded_passages)
|
93 |
|
94 |
#3 load BM25:
|
95 |
+
if 'bm25' not in st.session_state:
|
96 |
+
with open("models/BM25_pyvi_segmented_splitted.pkl", 'rb') as inp:
|
97 |
+
st.session_state['bm25'] = pickle.load(inp)
|
98 |
|
99 |
+
#4: model
|
100 |
+
if 'model' not in st.session_state:
|
101 |
+
st.session_state['tokenizer'], st.session_state['model'], st.session_state['bi_encoder'] = neuralqa()
|
102 |
|
103 |
#%%
|
|
|
|
|
104 |
|
|
|
|
|
|
|
105 |
|
106 |
+
#%%
|
107 |
|
|
|
|
|
108 |
|
109 |
def deploy(question):
|
110 |
+
# tokenizer, model, bi_encoder = neuralqa()
|
111 |
top_k = returns # Number of passages we want to retrieve with the bi-encoder
|
112 |
|
113 |
tokenized_query = preprocess(question).split()
|
114 |
query = ' '.join(tokenized_query)
|
115 |
+
emb_query = st.session_state['bi_encoder'].encode(query)
|
116 |
|
117 |
+
scores = st.session_state['bm25'].get_scores(tokenized_query)
|
118 |
top_score_ids = np.argpartition(scores, -50)[-50:]
|
119 |
|
120 |
emb_candidates = torch.Tensor()
|
121 |
|
122 |
for i in top_score_ids:
|
123 |
+
emb_candidates = torch.cat([emb_candidates,st.session_state['embedded_passages'][i:i+1]], axis = 0)
|
124 |
|
125 |
|
126 |
cosine_sim = cos_sim(emb_query, emb_candidates)
|
|
|
134 |
answers = []
|
135 |
|
136 |
for doc_ind in top_score_ids:
|
137 |
+
doc = st.session_state['passages'][doc_ind].replace('_',' ')
|
138 |
|
139 |
matches.append(doc)#' '.join(doc).replace('_',' '))
|
140 |
+
ids.append(st.session_state['passage_id'][doc_ind].replace('_',' '))#' '.join(doc[:30].split()[:3]))
|
141 |
# i=0
|
142 |
for context in matches:
|
143 |
q = "Trả lời câu hỏi: "+query + " Trong ngữ cảnh: "+context#tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(context))
|
144 |
+
a = hf_run_model(st.session_state['tokenizer'], st.session_state['model'], q)[0][0]
|
145 |
answers.append(a)
|
146 |
|
147 |
# generate result df
|
|
|
156 |
st.header("Results:")
|
157 |
st.table(df_results)
|
158 |
|
159 |
+
# del tokenizer, model, bi_encoder, emb_candidates
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
|
164 |
#%%
|
165 |
+
#title start page
|
166 |
+
st.title('Closed Domain (Vietnamese Laws) QA System')
|
167 |
+
|
168 |
+
sdg = Image.open('./logo.jpg')
|
169 |
+
st.sidebar.image(sdg, width=300)
|
170 |
+
st.sidebar.title('Settings')
|
171 |
+
|
172 |
+
|
173 |
+
st.caption("by HoangNV - on custom laws QA data set")
|
174 |
+
returns = st.sidebar.slider('Number of answer suggestions:', 1, 3, 2)
|
175 |
+
|
176 |
+
|
177 |
+
question = st.text_input('Type in your legal question:')
|
178 |
|
179 |
if len(question) != 0:
|
180 |
t0 = time.time()
|
181 |
with st.spinner('Finding best answers...'):
|
182 |
deploy(question)
|
183 |
+
st.write("Runtime: "+str(time.time()-t0))
|
184 |
+
|
185 |
+
|
186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
#%%
|
189 |
p = Path('.')
|