wanderer2k1 commited on
Commit
9833a80
1 Parent(s): baa3568
README.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Setup
2
+
3
+ 1. Cài Python 3.9.13: https://www.python.org/ftp/python/3.9.13/python-3.9.13-amd64.exe
4
+ Lưu ý: khi install lưu ý tick chọn "Add Python 3.9 to PATH".
5
+
6
+ 2. Mở command line dẫn đến thư mục này, nhập:
7
+ python -m venv venv
8
+ venv/Scripts/activate
9
+ python -m pip install -r requirements.txt
10
+
11
+ 3. Tải file dữ liệu về từ link: https://drive.google.com/file/d/1s2-Yi1R8pEgGOPNwbJEsSY-Ltum1UNLH/view?usp=sharing
12
+ Giải nén file ở thư mục này, tên thư mục sau giải nén là "data". Lưu ý: các file dữ liệu ở ngay trong thư mục data, tránh sau khi giải nén thêm folder data bên trong folder data.
13
+
14
+ 4. Tải file models về từ link: https://drive.google.com/file/d/1aHBXKINBuLEDLPYF-GMUTwQBDDF-FNSj/view?usp=sharing
15
+ Giải nén file ở thư mục này, tên thư mục sau giải nén là "models".
16
+
17
+ 5. để chạy chương trình, mở command line dẫn đến thư mục này, nhập:
18
+ venv/Scripts/activate
19
+ streamlit run streamlit/main.py
20
+
21
+ * Lưu ý: lần đầu query đầu, hệ thống sẽ tải các models về từ repo cá nhân, dung lượng khoảng 3GB nên mất nhiều thời gian.
22
+
23
+ # Cấu trúc thư mục:
24
+ .
25
+ |
26
+ |_Notebooks: Các .ipynb notebooks đã xử lý dữ liệu, huấn luyện mô hình và đánh giá mô hình. Đã chạy trên Google colab.
27
+ ||_Prepare_data: Các .ipynb notebook xử lý dữ liệu cho huấn luyện và đánh giá mô hình.
28
+ ||_Training: Các .ipynb notebook huấn luyện mô hình.
29
+ ||_Evaluation: Các .ipynb notebook đánh giá mô hình.
30
+ |
31
+ |_src: Script python chứa hàm chạy chương trình.
32
+ |
33
+ |_streamlit: Chứa script chạy webapp, file css style webapp, các lớp liên quan đến web app.
34
+ |
35
+ |_requirements.txt: file chứa các thư viện python cần cài đặt.
36
+ |
37
+ |_README.txt
38
+
SessionState.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hack to add per-session state to Streamlit.
2
+
3
+ Usage
4
+ -----
5
+
6
+ >>> import SessionState
7
+ >>>
8
+ >>> session_state = SessionState.get(user_name='', favorite_color='black')
9
+ >>> session_state.user_name
10
+ ''
11
+ >>> session_state.user_name = 'Mary'
12
+ >>> session_state.favorite_color
13
+ 'black'
14
+
15
+ Since you set user_name above, next time your script runs this will be the
16
+ result:
17
+ >>> session_state = get(user_name='', favorite_color='black')
18
+ >>> session_state.user_name
19
+ 'Mary'
20
+
21
+ """
22
+ try:
23
+ import streamlit.ReportThread as ReportThread
24
+ from streamlit.server.Server import Server
25
+ except Exception:
26
+ # Streamlit >= 0.65.0
27
+ import streamlit.report_thread as ReportThread
28
+ from streamlit.server.server import Server
29
+
30
+
31
+ class SessionState(object):
32
+ def __init__(self, **kwargs):
33
+ """A new SessionState object.
34
+
35
+ Parameters
36
+ ----------
37
+ **kwargs : any
38
+ Default values for the session state.
39
+
40
+ Example
41
+ -------
42
+ >>> session_state = SessionState(user_name='', favorite_color='black')
43
+ >>> session_state.user_name = 'Mary'
44
+ ''
45
+ >>> session_state.favorite_color
46
+ 'black'
47
+
48
+ """
49
+ for key, val in kwargs.items():
50
+ setattr(self, key, val)
51
+
52
+
53
+ def get(**kwargs):
54
+ """Gets a SessionState object for the current session.
55
+
56
+ Creates a new object if necessary.
57
+
58
+ Parameters
59
+ ----------
60
+ **kwargs : any
61
+ Default values you want to add to the session state, if we're creating a
62
+ new one.
63
+
64
+ Example
65
+ -------
66
+ >>> session_state = get(user_name='', favorite_color='black')
67
+ >>> session_state.user_name
68
+ ''
69
+ >>> session_state.user_name = 'Mary'
70
+ >>> session_state.favorite_color
71
+ 'black'
72
+
73
+ Since you set user_name above, next time your script runs this will be the
74
+ result:
75
+ >>> session_state = get(user_name='', favorite_color='black')
76
+ >>> session_state.user_name
77
+ 'Mary'
78
+
79
+ """
80
+ # Hack to get the session object from Streamlit.
81
+
82
+ ctx = ReportThread.get_report_ctx()
83
+
84
+ this_session = None
85
+
86
+ current_server = Server.get_current()
87
+ if hasattr(current_server, '_session_infos'):
88
+ # Streamlit < 0.56
89
+ session_infos = Server.get_current()._session_infos.values()
90
+ else:
91
+ session_infos = Server.get_current()._session_info_by_id.values()
92
+
93
+ for session_info in session_infos:
94
+ s = session_info.session
95
+ if (
96
+ # Streamlit < 0.54.0
97
+ (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
98
+ or
99
+ # Streamlit >= 0.54.0
100
+ (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
101
+ or
102
+ # Streamlit >= 0.65.2
103
+ (not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
104
+ ):
105
+ this_session = s
106
+
107
+ if this_session is None:
108
+ raise RuntimeError(
109
+ "Oh noes. Couldn't get your Streamlit Session object. "
110
+ 'Are you doing something fancy with threads?')
111
+
112
+ # Got the session object! Now let's attach some state into it.
113
+
114
+ if not hasattr(this_session, '_custom_session_state'):
115
+ this_session._custom_session_state = SessionState(**kwargs)
116
+
117
+ return this_session._custom_session_state
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #basics
2
+ import time
3
+ import pandas as pd
4
+ import numpy as np
5
+ import pickle
6
+ from PIL import Image
7
+
8
+ #DL
9
+ import torch
10
+ from transformers import T5ForConditionalGeneration, T5TokenizerFast
11
+ from sentence_transformers import SentenceTransformer
12
+ from sentence_transformers.util import cos_sim
13
+
14
+ #streamlit
15
+ import streamlit as st
16
+ import SessionState
17
+ from load_css import local_css
18
+ local_css("./style.css")
19
+
20
+ #text preprocess
21
+ import re
22
+ from pyvi import ViTokenizer
23
+ from rank_bm25 import BM25Okapi
24
+
25
+ #helper functions
26
+ from inspect import getsourcefile
27
+ import os.path as path, sys
28
+ from pathlib import Path
29
+ current_dir = path.dirname(path.abspath(getsourcefile(lambda:0)))
30
+ sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)])
31
+ import src.clean_dataset as clean
32
+
33
+ @st.cache(allow_output_mutation=True)
34
+
35
+ def preprocess(sentence):
36
+ sentence=str(sentence)
37
+ sentence = sentence.lower()
38
+ sentence=sentence.replace('{html}',"")
39
+ cleanr = re.compile('<.*?>')
40
+ cleantext = re.sub(cleanr, '', sentence)
41
+ rem_url=re.sub(r'http\S+', '',cleantext)
42
+ word_list = rem_url.split()
43
+ preped = ViTokenizer.tokenize(" ".join(word_list))
44
+ return preped
45
+
46
+ DEFAULT = '< PICK A VALUE >'
47
+
48
+ def selectbox_with_default(text, values, default=DEFAULT, sidebar=False):
49
+ func = st.sidebar.selectbox if sidebar else st.selectbox
50
+ return func(text, np.insert(np.array(values, object), 0, default))
51
+
52
+ def neuralqa():
53
+
54
+ model = T5ForConditionalGeneration.from_pretrained("wanderer2k1/T5-LawsQA")
55
+ tokenizer = T5TokenizerFast.from_pretrained("wanderer2k1/T5-LawsQA")
56
+
57
+ bi_encoder = SentenceTransformer('wanderer2k1/BertCondenser_LawsQA')
58
+ return tokenizer, model, bi_encoder
59
+
60
+ def hf_run_model(tokenizer, model, input_string, **generator_args):
61
+ generator_args = {
62
+ "max_length": 256,
63
+ "temperature":0.0,
64
+ "num_beams": 4,
65
+ "length_penalty": 0.1,
66
+ "no_repeat_ngram_size": 8,
67
+ "early_stopping": True,
68
+ }
69
+ input_string = "generate questions: " + input_string + " </s>"
70
+ input_ids = tokenizer.encode(input_string, return_tensors="pt")
71
+ res = model.generate(input_ids, **generator_args)
72
+ output = tokenizer.batch_decode(res, skip_special_tokens=True)
73
+ output = [item.split("<sep>") for item in output]
74
+ return output
75
+
76
+
77
+ #%%
78
+ sys.path.pop(0)
79
+
80
+ #1. load in complete transformed and processed dataset
81
+
82
+ df = pd.read_csv('./data/corpus.pkl', sep = '\t')
83
+ passages = df['text'].values.tolist()
84
+ passage_id = df['title'].values.tolist()
85
+
86
+ #2 load corpus embeddings for neural QA:
87
+ with open("./data/embedded_corpus_BertCondenser_tuples.pkl", 'rb') as inp:
88
+ embedded_passages = pickle.load(inp)
89
+ embedded_passages = torch.Tensor(embedded_passages)
90
+
91
+ #3 load BM25:
92
+ with open("models/BM25_pyvi_segmented_splitted.pkl", 'rb') as inp:
93
+ bm25 = pickle.load(inp)
94
+
95
+ #%%
96
+ session = SessionState.get(run_id=0)
97
+
98
+ #%%
99
+ #title start page
100
+ st.title('Closed Domain (Vietnamese Laws) QA System')
101
+
102
+ sdg = Image.open('./logo.jpg')
103
+ st.sidebar.image(sdg, width=300)
104
+ st.sidebar.title('Settings')
105
+
106
+
107
+ st.caption("by HoangNV - on custom laws QA data set")
108
+ returns = st.sidebar.slider('Maximal number of answer suggestions:', 1, 3, 2)
109
+
110
+ def deploy(question):
111
+ tokenizer, model, bi_encoder = neuralqa()
112
+ top_k = returns # Number of passages we want to retrieve with the bi-encoder
113
+
114
+ tokenized_query = preprocess(question).split()
115
+ query = ' '.join(tokenized_query)
116
+ emb_query = bi_encoder.encode(query)
117
+
118
+ scores = bm25.get_scores(tokenized_query)
119
+ top_score_ids = np.argpartition(scores, -50)[-50:]
120
+
121
+ emb_candidates = torch.Tensor()
122
+
123
+ for i in top_score_ids:
124
+ emb_candidates = torch.cat([emb_candidates,embedded_passages[i:i+1]], axis = 0)
125
+
126
+
127
+ cosine_sim = cos_sim(emb_query, emb_candidates)
128
+
129
+ doc_inds = np.argpartition(cosine_sim.numpy()[0], -top_k)[-top_k:]
130
+
131
+ top_score_ids = top_score_ids.take(doc_inds)
132
+
133
+ matches = []
134
+ ids = []
135
+ answers = []
136
+
137
+ for doc_ind in top_score_ids:
138
+ doc = passages[doc_ind].replace('_',' ')
139
+
140
+ matches.append(doc)#' '.join(doc).replace('_',' '))
141
+ ids.append(passage_id[doc_ind].replace('_',' '))#' '.join(doc[:30].split()[:3]))
142
+ # i=0
143
+ for context in matches:
144
+ q = "Trả lời câu hỏi: "+query + " Trong ngữ cảnh: "+context#tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(context))
145
+ a = hf_run_model(tokenizer, model, q)[0][0]
146
+ answers.append(a)
147
+
148
+ # generate result df
149
+ df_results = pd.DataFrame(
150
+ {'Title': ids,
151
+ 'Answer': answers,
152
+ 'Retrieved': matches,
153
+ })
154
+
155
+ # st.header("Retrieved Answers:")
156
+ # df_results.set_index('title', inplace=True)
157
+ st.header("Results:")
158
+ st.table(df_results)
159
+
160
+ del tokenizer, model, bi_encoder#, question_embedding
161
+
162
+ #%%
163
+ question = st.text_input('Type in your legal question (be as specific as possible):')
164
+
165
+ if len(question) != 0:
166
+ t0 = time.time()
167
+ with st.spinner('Finding best answers...'):
168
+ deploy(question)
169
+ st.write(str(time.time()-t0))
170
+
171
+ st.write(' ')
172
+ st.write(' ')
173
+ st.write(' ')
174
+ st.write(' ')
175
+ st.write(' ')
176
+ st.write(' ')
177
+ if st.button("Run again!"):
178
+ session.run_id += 1
179
+
180
+ #%%
181
+ p = Path('.')
data/corpus.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e987b70b6d1d60dbee9db5999dc05faf72de806f77c66ad28002fc22b115c664
3
+ size 136262181
data/embedded_corpus_BertCondenser_tuples.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57fd74e8ba28cf5ed8a9b9785eb06b0dd1dd1b2147bf180a9c9987aedd1d5a67
3
+ size 308422821
load_css.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def local_css(file_name):
4
+ with open(file_name) as f:
5
+ st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
logo.jpg ADDED
models/BM25_pyvi_segmented_splitted.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f782ece67739ffc40cd64cd443918b857e2ab26cf13003e8bfee6c620da0f66d
3
+ size 93633728
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.3.1
2
+ sentence_transformers==2.2.2
3
+ numpy==1.20.1
4
+ transformers==4.28.0
5
+ pandas==1.2.3
6
+ textwrap3==0.9.2
7
+ torch==1.8.0
8
+ joblib==1.0.1
9
+ Pillow==8.1.2
10
+ protobuf==3.20.*
11
+ altair<5
12
+ rank-bm25
13
+ nltk
14
+ pyvi
src/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Mon Oct 5 2020
5
+
6
+ @author: jonas
7
+ """
8
+
src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (202 Bytes). View file
 
src/__pycache__/clean_dataset.cpython-39.pyc ADDED
Binary file (1.91 kB). View file
 
src/clean_dataset.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on
5
+
6
+ @author:
7
+
8
+ @title: clean_dataset
9
+
10
+ @descriptions: set of functions that enable splitting and cleaning.
11
+ """
12
+
13
+ #%%
14
+ import pandas as pd
15
+ import numpy as np
16
+ import string
17
+ from itertools import chain
18
+ from textwrap3 import wrap
19
+ import re
20
+
21
+ def split_at_length(dataframe, column, length, title = True):
22
+ wrapped = []
23
+ for i in dataframe[column]:
24
+ wrapped.append(wrap(str(i), length))
25
+
26
+ dataframe = dataframe.assign(wrapped=wrapped)
27
+ dataframe['wrapped'] = dataframe['wrapped'].apply(lambda x: '; '.join(map(str, x)))
28
+
29
+ if title == True:
30
+ splitted = pd.concat([pd.Series(row['title'], row['wrapped'].split("; "), )
31
+ for _, row in dataframe.iterrows()]).reset_index()
32
+ splitted = splitted.rename(columns={"index": "text", 0: "title"})
33
+
34
+ else:
35
+ splitted = []
36
+
37
+
38
+
39
+ return dataframe, splitted
40
+
41
+ def basic(s):
42
+ """
43
+ :param s: string to be processed
44
+ :return: processed string: see comments in the source code for more info
45
+ """
46
+ # Text Lowercase
47
+ s = s.lower()
48
+ # Remove punctuation
49
+ translator = str.maketrans(' ', ' ', string.punctuation)
50
+ s = s.translate(translator)
51
+ # Remove URLs
52
+ s = re.sub(r'^https?:\/\/.*[\r\n]*', ' ', s, flags=re.MULTILINE)
53
+ s = re.sub(r"http\S+", " ", s)
54
+ # Remove new line characters
55
+ s = re.sub('\n', ' ', s)
56
+
57
+ # Remove distracting single quotes
58
+ s = re.sub("\'", " ", s)
59
+ # Remove all remaining numbers and non alphanumeric characters
60
+ s = re.sub(r'\d+', ' ', s)
61
+ s = re.sub(r'\W+', ' ', s)
62
+
63
+ # define custom words to replace:
64
+ #s = re.sub(r'strengthenedstakeholder', 'strengthened stakeholder', s)
65
+
66
+ return s.strip()
67
+
68
+ def remove_linebreaks(s):
69
+ """
70
+ :param s: string to be processed
71
+ :return: processed string: see comments in the source code for more info
72
+ """
73
+ # Remove new line characters
74
+ s = re.sub('\n', ' ', s)
75
+
76
+ return s.strip()
style.css ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .highlight {
2
+ border-radius: 0.4rem;
3
+ color: white;
4
+ padding: 0.5rem;
5
+ margin-bottom: 1rem;
6
+ }
7
+ .bold {
8
+ padding-left: 1rem;
9
+ font-weight: 700;
10
+ }
11
+ .blue {
12
+ background-color: lightcoral;
13
+ }
14
+ .green {
15
+ background-color: green;
16
+ }
17
+ .red {
18
+ background-color: red;
19
+ }
20
+ .IndianRed {
21
+ background-color: IndianRed;
22
+ }
23
+ .lightgreen {
24
+ background-color: lightgreen;
25
+ }
26
+ .turquoise {
27
+ background-color: turquoise;
28
+ }
29
+
30
+