sadickam commited on
Commit
3129827
·
verified ·
1 Parent(s): 2b0abca

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -0
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import regex as re
3
+ import torch
4
+ import nltk
5
+ import pandas as pd
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ from nltk.tokenize import sent_tokenize
8
+ import plotly.express as px
9
+ import time
10
+ import tqdm
11
+
12
+ nltk.download('punkt')
13
+
14
+ # Define the device (GPU or CPU)
15
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Define the model and tokenizer
18
+ checkpoint = "ieq/IEQ-BERT"
19
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint).to(device)
20
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)
21
+
22
+
23
+ # Define the function for preprocessing text
24
+ def prep_text(text):
25
+ clean_sents = [] # append clean con sentences
26
+ sent_tokens = str(text).split('.')
27
+ for sent_token in sent_tokens:
28
+ word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
29
+ word_tokens = [word_token for word_token in word_tokens if word_token not in punctuations]
30
+ clean_sents.append(' '.join((word_tokens)))
31
+ joined_clean_sents = '. '.join(clean_sents).strip(' ')
32
+ return joined_clean_sents
33
+
34
+
35
+ # APP INFO
36
+ def app_info():
37
+ check = """
38
+ Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text.
39
+ """
40
+
41
+ return check
42
+
43
+
44
+ # Create Gradio interface for app info
45
+ iface1 = gr.Interface(
46
+ fn=app_info, inputs=None, outputs=['text'], title="General-Infomation",
47
+ description='''
48
+ This app, powered by the IEQ-BERT model (sadickam/sdg-classification-bert), is for automating the classification of text concerning
49
+ with respect to indoor environmetal quality (IEQ). IEQ refers to the quality of the indoor air, lighting,
50
+ temperature, and acoustics within a building, as well as the overall comfort and well-being of its occupants. It encompasses various
51
+ factors that can impact the health, productivity, and satisfaction of people who spend time indoors, such as office workers, students,
52
+ patients, and residents. This app assigns five labels to any given text and a text may be assigned one or more labels. The five labels include
53
+ the following:
54
+ - Acoustic
55
+ - Indoor air quality (IAQ)
56
+ - No IEQ (label assigned when no IEQ is defected)
57
+ - Thermal
58
+ - Visual
59
+
60
+ Because IEQ-BERT is capable of assigning one or more labels to a text, it is possible that the returned prediction like
61
+ (Acoustic_No IEQ) or (NO IEQ_Thermal). These multiple predictions that include "No IEQ" may suggest lack of contextual
62
+ clarity in the text and need manual review to affirm label.
63
+
64
+ This app has two analysis modules summarised below:
65
+ - Single-Text-Prediction - Analyses text pasted in a text box and return IEQ prediction.
66
+ - Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with IEQ prediction for each row of text.
67
+
68
+ This app runs on a free server and may therefore not be suitable for analysing large CSV files.
69
+ If you need assistance with analysing large CSV, do get in touch using the contact information in the Contact section.
70
+
71
+ <h3>Contact</h3>
72
+ <p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model
73
+ powering this app, we are happy to hear from you.</p>
74
+
75
+ Dr Abdul-Manan Sadick - [email protected]
76
+ Dr Giorgia Chinazzo - [email protected]
77
+ ''')
78
+
79
+
80
+ # SINGLE TEXT
81
+ # Define the prediction function
82
+ def predict_single_text(text):
83
+ """
84
+ Predicts the IEQ labels for a single text.
85
+
86
+ Args:
87
+ text (str): The text to be analyzed.
88
+
89
+ Returns:
90
+ top_prediction (dict): A dictionary containing the top predicted IEQ labels and their corresponding probabilities.
91
+ fig (plotly.graph_objs.Figure): A bar chart showing the likelihood of each IEQ label.
92
+ """
93
+ # Preprocess the input text
94
+ cleaned_text = prep_text(text)
95
+
96
+ # Check if the text is empty after preprocessing
97
+ if cleaned_text == "":
98
+ raise gr.Error('This model needs some text input to return a prediction')
99
+
100
+ # Tokenize the preprocessed text
101
+ tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(
102
+ device)
103
+
104
+ # Make predictions
105
+ with torch.no_grad():
106
+ outputs = model(**tokenized_text)
107
+ logits = outputs.logits
108
+
109
+ # Calculate the probabilities
110
+ probabilities = torch.sigmoid(logits).squeeze()
111
+
112
+ # Define the threshold for prediction
113
+ threshold = 0.3
114
+
115
+ # Get the predicted labels
116
+ predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist()
117
+
118
+ # Define the list of IEQ labels
119
+ label_list = [
120
+ 'Acoustic',
121
+ 'Indoor air quality',
122
+ 'No IEQ',
123
+ 'Thermal',
124
+ 'Visual'
125
+ ]
126
+
127
+ # Map the predicted labels to their corresponding names
128
+ predicted_labels = [label_list[i] for i in range(len(label_list)) if predicted_labels_[i] == 1]
129
+
130
+ # Get the probabilities of the predicted labels
131
+ predicted_prob = [round(a_, 3) for a_ in probabilities.cpu().numpy().tolist() if a_ > threshold]
132
+
133
+ # Create a dictionary containing the top predicted IEQ labels and their corresponding probabilities
134
+ top_prediction = (dict(zip(predicted_labels, predicted_prob)))
135
+
136
+ # Create a bar chart showing the likelihood of each IEQ label
137
+ # Make dataframe for plotly bar chart
138
+ u, v = zip(*dict(zip(label_list, probabilities.cpu().numpy().tolist())).items())
139
+ m = list(u)
140
+ n = list(v)
141
+ df2 = pd.DataFrame()
142
+ df2['IEQ'] = m
143
+ df2['Likelihood'] = n
144
+
145
+ # plot graph of predictions
146
+ fig = px.bar(df2, x="Likelihood", y="IEQ", orientation="h")
147
+
148
+ fig.update_layout(
149
+ # barmode='stack',
150
+ template='seaborn', font=dict(family="Arial", size=12, color="black"),
151
+ autosize=True,
152
+ # width=800,
153
+ # height=500,
154
+ xaxis_title="Likelihood of IEQ",
155
+ yaxis_title="Indoor environmental quality (IEQ)",
156
+ # legend_title="Topics"
157
+ )
158
+
159
+ fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
160
+ fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
161
+ fig.update_annotations(font_size=12)
162
+
163
+ return top_prediction, fig
164
+
165
+ # Create Gradio interface for single text
166
+
167
+
168
+ iface2 = gr.Interface(fn=predict_single_text,
169
+ inputs=gr.Textbox(lines=7, label="Paste or type text here"),
170
+ outputs=[gr.Label(label="Top Prediction", show_label=True),
171
+ gr.Plot(label="Likelihood of all labels", show_label=True)],
172
+ title="Single Text Prediction",
173
+ article="**Note:** The quality of model predictions may depend on the quality of information provided."
174
+ )
175
+
176
+
177
+ # UPLOAD CSV
178
+ # Define the prediction function
179
+ def predict_from_csv(file, column_name, progress=gr.Progress()):
180
+ """
181
+ Predicts the IEQ labels for a list of texts in a CSV file.
182
+
183
+ Args:
184
+ file (str): The path to the CSV file.
185
+ column_name (str): The name of the column containing the text to be analyzed.
186
+ progress (gr.Progress): A progress bar to display the analysis progress.
187
+
188
+ Returns:
189
+ fig (plotly.graph_objs.Figure): A histogram showing the frequency of each IEQ label.
190
+ output_csv (gr.File): A downloadable CSV file containing the predictions.
191
+ """
192
+ # Read the CSV file
193
+ df_docs = pd.read_csv(file)
194
+
195
+ # Check if the specified column exists
196
+ if column_name not in df_docs.columns:
197
+ raise gr.Error(f"The column '{column_name}' does not exist in the uploaded CSV file.")
198
+
199
+ # Extract the text list from the specified column
200
+ text_list = df_docs[column_name].tolist()
201
+
202
+ # Define the list of IEQ labels
203
+ label_list = [
204
+ 'Acoustic',
205
+ 'Indoor air quality',
206
+ 'No IEQ',
207
+ 'Thermal',
208
+ 'Visual'
209
+ ]
210
+
211
+ # Initialize lists to store the predictions
212
+ labels_predicted = []
213
+ prediction_scores = []
214
+
215
+ # Preprocess text and make predictions
216
+ for text_input in progress.tqdm(text_list, desc="Analysing data"):
217
+ # Sleep to avoid rate limiting
218
+ time.sleep(0.02)
219
+
220
+ # Preprocess the text
221
+ cleaned_text = prep_text(text_input)
222
+
223
+ # Tokenize the text
224
+ tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to(
225
+ device)
226
+
227
+ # Make predictions
228
+ with torch.no_grad():
229
+ outputs = model(**tokenized_text)
230
+ logits = outputs.logits
231
+
232
+ # Calculate the probabilities
233
+ predictions = torch.sigmoid(logits).squeeze()
234
+
235
+ # Define the threshold for prediction
236
+ threshold = 0.3
237
+
238
+ # Get the predicted labels
239
+ predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist()
240
+
241
+ # Map the predicted labels to their corresponding names
242
+ predicted_labels = [label_list[i] for i in range(len(label_list)) if predicted_labels_[i] == 1]
243
+
244
+ # Get the probabilities of the predicted labels
245
+ prediction_score = [round(a_, 3) for a_ in predictions.cpu().numpy().tolist() if a_ > threshold]
246
+
247
+ # Append the predictions to the lists
248
+ labels_predicted.append(predicted_labels)
249
+ prediction_scores.append(prediction_score)
250
+
251
+ # Append the predictions to the DataFrame
252
+ df_docs['IEQ_predicted'] = labels_predicted
253
+ df_docs['prediction_scores'] = prediction_scores
254
+
255
+ # Save the predictions to a CSV file
256
+ df_docs.to_csv('IEQ_predictions.csv')
257
+
258
+ # Create a downloadable CSV file
259
+ output_csv = gr.File(value='IEQ_predictions.csv', visible=True)
260
+
261
+ # Create a histogram showing the frequency of each IEQ label
262
+ fig = px.histogram(df_docs, y="IEQ_predicted")
263
+ fig.update_layout(
264
+ template='seaborn',
265
+ font=dict(family="Arial", size=12, color="black"),
266
+ autosize=True,
267
+ # width=800,
268
+ # height=500,
269
+ xaxis_title="IEQ counts",
270
+ yaxis_title="Indoor environmetal quality (IEQ)",
271
+ )
272
+ fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
273
+ fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12))
274
+ fig.update_annotations(font_size=12)
275
+
276
+ return fig, output_csv
277
+
278
+
279
+ # Define the input component
280
+ file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"])
281
+ column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True)
282
+
283
+ # Create the Gradio interface
284
+ iface3 = gr.Interface(fn=predict_from_csv,
285
+ inputs=[file_input, column_name_input],
286
+ outputs=[gr.Plot(label='Frequency of IEQs', show_label=True),
287
+ gr.File(label='Download output CSV', show_label=True)],
288
+ title="Multi-text Prediction (CVS)",
289
+ description='**NOTE:** Please enter the column name containing the text to be analyzed.')
290
+
291
+ # Create a tabbed interface
292
+ demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3],
293
+ tab_names=["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"],
294
+ title="Indoor Environmetal Quality (IEQ) Text Classifier App",
295
+ theme='soft'
296
+ )
297
+
298
+ # Launch the interface
299
+ demo.queue().launch()