cnstvariable commited on
Commit
504015e
1 Parent(s): c34fe27

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Untitled0.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/13kE5uGoL2gfzSwTJli-WZolqCNBZXxNV
8
+ """
9
+
10
+ import tensorflow as tf
11
+ import numpy as np
12
+ import pandas as pd
13
+ import streamlit as st
14
+ import re
15
+ import os
16
+ import csv
17
+ from tqdm import tqdm
18
+ import faiss
19
+ from nltk.translate.bleu_score import sentence_bleu
20
+ from datetime import datetime
21
+
22
+ def decontractions(phrase):
23
+ """decontracted takes text and convert contractions into natural form.
24
+ ref: https://stackoverflow.com/questions/19790188/expanding-english-language-contractions-in-python/47091490#47091490"""
25
+ # specific
26
+ phrase = re.sub(r"won\'t", "will not", phrase)
27
+ phrase = re.sub(r"can\'t", "can not", phrase)
28
+ phrase = re.sub(r"won\’t", "will not", phrase)
29
+ phrase = re.sub(r"can\’t", "can not", phrase)
30
+
31
+ # general
32
+ phrase = re.sub(r"n\'t", " not", phrase)
33
+ phrase = re.sub(r"\'re", " are", phrase)
34
+ phrase = re.sub(r"\'s", " is", phrase)
35
+ phrase = re.sub(r"\'d", " would", phrase)
36
+ phrase = re.sub(r"\'ll", " will", phrase)
37
+ phrase = re.sub(r"\'t", " not", phrase)
38
+ phrase = re.sub(r"\'ve", " have", phrase)
39
+ phrase = re.sub(r"\'m", " am", phrase)
40
+
41
+ phrase = re.sub(r"n\’t", " not", phrase)
42
+ phrase = re.sub(r"\’re", " are", phrase)
43
+ phrase = re.sub(r"\’s", " is", phrase)
44
+ phrase = re.sub(r"\’d", " would", phrase)
45
+ phrase = re.sub(r"\’ll", " will", phrase)
46
+ phrase = re.sub(r"\’t", " not", phrase)
47
+ phrase = re.sub(r"\’ve", " have", phrase)
48
+ phrase = re.sub(r"\’m", " am", phrase)
49
+
50
+ return phrase
51
+
52
+
53
+ def preprocess(text):
54
+ # convert all the text into lower letters
55
+ # remove the words betweent brakets ()
56
+ # remove these characters: {'$', ')', '?', '"', '’', '.', '°', '!', ';', '/', "'", '€', '%', ':', ',', '('}
57
+ # replace these spl characters with space: '\u200b', '\xa0', '-', '/'
58
+
59
+ text = text.lower()
60
+ text = decontractions(text)
61
+ text = re.sub('[$)\?"’.°!;\'€%:,(/]', '', text)
62
+ text = re.sub('\u200b', ' ', text)
63
+ text = re.sub('\xa0', ' ', text)
64
+ text = re.sub('-', ' ', text)
65
+ return text
66
+
67
+
68
+ #importing bert tokenizer and loading the trained question embedding extractor model
69
+
70
+ from transformers import AutoTokenizer, TFGPT2Model
71
+ @st.cache(allow_output_mutation=True)
72
+ def return_biobert_tokenizer_model():
73
+ '''returns pretrained biobert tokenizer and question extractor model'''
74
+ biobert_tokenizer = AutoTokenizer.from_pretrained("cambridgeltl/BioRedditBERT-uncased")
75
+ question_extractor_model1=tf.keras.models.load_model('question_extractor_model_2_11')
76
+ return biobert_tokenizer,question_extractor_model1
77
+
78
+
79
+ #importing gpt2 tokenizer and loading the trained gpt2 model
80
+ from transformers import GPT2Tokenizer,TFGPT2LMHeadModel
81
+ @st.cache(allow_output_mutation=True)
82
+ def return_gpt2_tokenizer_model():
83
+ '''returns pretrained gpt2 tokenizer and gpt2 model'''
84
+ gpt2_tokenizer=GPT2Tokenizer.from_pretrained("gpt2")
85
+ tf_gpt2_model=TFGPT2LMHeadModel.from_pretrained("tf_gpt2_model_2_118_50000")
86
+ return gpt2_tokenizer,tf_gpt2_model
87
+
88
+ #preparing the faiss search
89
+ qa=pd.read_pickle('train_gpt_data.pkl')
90
+ question_bert = qa["Q_FFNN_embeds"].tolist()
91
+ answer_bert = qa["A_FFNN_embeds"].tolist()
92
+ question_bert = np.array(question_bert)
93
+ answer_bert = np.array(answer_bert)
94
+
95
+ question_bert = question_bert.astype('float32')
96
+ answer_bert = answer_bert.astype('float32')
97
+
98
+ answer_index = faiss.IndexFlatIP(answer_bert.shape[-1])
99
+
100
+ question_index = faiss.IndexFlatIP(question_bert.shape[-1])
101
+ answer_index.add(answer_bert)
102
+ question_index.add(question_bert)
103
+
104
+
105
+ print('finished initializing')
106
+
107
+ #defining function to prepare the data for gpt inference
108
+ #https://github.com/ash3n/DocProduct
109
+
110
+ def preparing_gpt_inference_data(gpt2_tokenizer,question,question_embedding):
111
+ topk=20
112
+ scores,indices=answer_index.search(
113
+ question_embedding.astype('float32'), topk)
114
+ q_sub=qa.iloc[indices.reshape(20)]
115
+
116
+ line = '`QUESTION: %s `ANSWER: ' % (
117
+ question)
118
+ encoded_len=len(gpt2_tokenizer.encode(line))
119
+ for i in q_sub.iterrows():
120
+ line='`QUESTION: %s `ANSWER: %s ' % (i[1]['question'],i[1]['answer']) + line
121
+ line=line.replace('\n','')
122
+ encoded_len=len(gpt2_tokenizer.encode(line))
123
+ if encoded_len>=1024:
124
+ break
125
+ return gpt2_tokenizer.encode(line)[-1024:]
126
+
127
+
128
+
129
+ #function to generate answer given a question and the required answer length
130
+
131
+ def give_answer(question,answer_len):
132
+ preprocessed_question=preprocess(question)
133
+ question_len=len(preprocessed_question.split(' '))
134
+ truncated_question=preprocessed_question
135
+ if question_len>500:
136
+ truncated_question=' '.join(preprocessed_question.split(' ')[:500])
137
+ biobert_tokenizer,question_extractor_model1= return_biobert_tokenizer_model()
138
+ gpt2_tokenizer,tf_gpt2_model= return_gpt2_tokenizer_model()
139
+ encoded_question= biobert_tokenizer.encode(truncated_question)
140
+ max_length=512
141
+ padded_question=tf.keras.preprocessing.sequence.pad_sequences(
142
+ [encoded_question], maxlen=max_length, padding='post')
143
+ question_mask=[[1 if token!=0 else 0 for token in question] for question in padded_question]
144
+ embeddings=question_extractor_model1({'question':np.array(padded_question),'question_mask':np.array(question_mask)})
145
+ gpt_input=preparing_gpt_inference_data(gpt2_tokenizer,truncated_question,embeddings.numpy())
146
+ mask_start = len(gpt_input) - list(gpt_input[::-1]).index(4600) + 1
147
+ input=gpt_input[:mask_start+1]
148
+ if len(input)>(1024-answer_len):
149
+ input=input[-(1024-answer_len):]
150
+ gpt2_output=gpt2_tokenizer.decode(tf_gpt2_model.generate(input_ids=tf.constant([np.array(input)]),max_length=1024,temperature=0.7)[0])
151
+ answer=gpt2_output.rindex('`ANSWER: ')
152
+ return gpt2_output[answer+len('`ANSWER: '):]
153
+
154
+
155
+
156
+ #defining the final function to generate answer assuming default answer length to be 20
157
+ def final_func_1(question):
158
+ answer_len=25
159
+ return give_answer(question,answer_len)
160
+
161
+
162
+ def main():
163
+ st.title('Medical Chatbot')
164
+ question=st.text_input('Question',"Type Here")
165
+ result=""
166
+ if st.button('ask'):
167
+ #with st.spinner("You Know! an apple a day keeps doctor away!"):
168
+ start=datetime.now()
169
+ result=final_func_1(question)
170
+ end_time =datetime.now()
171
+ st.success("Here is the answer")
172
+ st.text(result)
173
+ st.text("result recieved within "+str((end_time-start).total_seconds()))
174
+
175
+
176
+
177
+
178
+
179
+ if __name__=='__main__':
180
+ main()