Spaces:
Running
Running
import gradio as gr | |
import regex as re | |
import torch | |
import nltk | |
import pandas as pd | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from nltk.tokenize import sent_tokenize | |
import plotly.express as px | |
import time | |
import tqdm | |
nltk.download('punkt') | |
# Define the device (GPU or CPU) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Define the model and tokenizer | |
checkpoint = "ieq/IEQ-BERT" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint).to(device) | |
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device) | |
# Define the function for preprocessing text | |
def prep_text(text): | |
clean_sents = [] # append clean con sentences | |
sent_tokens = str(text).split('.') | |
for sent_token in sent_tokens: | |
word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()] | |
word_tokens = [word_token for word_token in word_tokens if word_token not in punctuations] | |
clean_sents.append(' '.join((word_tokens))) | |
joined_clean_sents = '. '.join(clean_sents).strip(' ') | |
return joined_clean_sents | |
# APP INFO | |
def app_info(): | |
check = """ | |
Please go to either the "Single-Text-Prediction" or "Multi-Text-Prediction" tab to analyse your text. | |
""" | |
return check | |
# Create Gradio interface for app info | |
iface1 = gr.Interface( | |
fn=app_info, inputs=None, outputs=['text'], title="General-Infomation", | |
description=''' | |
This app, powered by the IEQ-BERT model (sadickam/sdg-classification-bert), is for automating the classification of text concerning | |
with respect to indoor environmetal quality (IEQ). IEQ refers to the quality of the indoor air, lighting, | |
temperature, and acoustics within a building, as well as the overall comfort and well-being of its occupants. It encompasses various | |
factors that can impact the health, productivity, and satisfaction of people who spend time indoors, such as office workers, students, | |
patients, and residents. This app assigns five labels to any given text and a text may be assigned one or more labels. The five labels include | |
the following: | |
- Acoustic | |
- Indoor air quality (IAQ) | |
- No IEQ (label assigned when no IEQ is defected) | |
- Thermal | |
- Visual | |
Because IEQ-BERT is capable of assigning one or more labels to a text, it is possible that the returned prediction like | |
(Acoustic_No IEQ) or (NO IEQ_Thermal). These multiple predictions that include "No IEQ" may suggest lack of contextual | |
clarity in the text and need manual review to affirm label. | |
This app has two analysis modules summarised below: | |
- Single-Text-Prediction - Analyses text pasted in a text box and return IEQ prediction. | |
- Multi-Text-Prediction - Analyses multiple rows of texts in an uploaded CSV file and returns a downloadable CSV file with IEQ prediction for each row of text. | |
This app runs on a free server and may therefore not be suitable for analysing large CSV files. | |
If you need assistance with analysing large CSV, do get in touch using the contact information in the Contact section. | |
<h3>Contact</h3> | |
<p>We would be happy to receive your feedback regarding this app. If you would also like to collaborate with us to explore some use cases for the model | |
powering this app, we are happy to hear from you.</p> | |
Dr Abdul-Manan Sadick - [email protected] | |
Dr Giorgia Chinazzo - [email protected] | |
''') | |
# SINGLE TEXT | |
# Define the prediction function | |
def predict_single_text(text): | |
""" | |
Predicts the IEQ labels for a single text. | |
Args: | |
text (str): The text to be analyzed. | |
Returns: | |
top_prediction (dict): A dictionary containing the top predicted IEQ labels and their corresponding probabilities. | |
fig (plotly.graph_objs.Figure): A bar chart showing the likelihood of each IEQ label. | |
""" | |
# Preprocess the input text | |
cleaned_text = prep_text(text) | |
# Check if the text is empty after preprocessing | |
if cleaned_text == "": | |
raise gr.Error('This model needs some text input to return a prediction') | |
# Tokenize the preprocessed text | |
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to( | |
device) | |
# Make predictions | |
with torch.no_grad(): | |
outputs = model(**tokenized_text) | |
logits = outputs.logits | |
# Calculate the probabilities | |
probabilities = torch.sigmoid(logits).squeeze() | |
# Define the threshold for prediction | |
threshold = 0.3 | |
# Get the predicted labels | |
predicted_labels_ = (probabilities.cpu().numpy() > threshold).tolist() | |
# Define the list of IEQ labels | |
label_list = [ | |
'Acoustic', | |
'Indoor air quality', | |
'No IEQ', | |
'Thermal', | |
'Visual' | |
] | |
# Map the predicted labels to their corresponding names | |
predicted_labels = [label_list[i] for i in range(len(label_list)) if predicted_labels_[i] == 1] | |
# Get the probabilities of the predicted labels | |
predicted_prob = [round(a_, 3) for a_ in probabilities.cpu().numpy().tolist() if a_ > threshold] | |
# Create a dictionary containing the top predicted IEQ labels and their corresponding probabilities | |
top_prediction = (dict(zip(predicted_labels, predicted_prob))) | |
# Create a bar chart showing the likelihood of each IEQ label | |
# Make dataframe for plotly bar chart | |
u, v = zip(*dict(zip(label_list, probabilities.cpu().numpy().tolist())).items()) | |
m = list(u) | |
n = list(v) | |
df2 = pd.DataFrame() | |
df2['IEQ'] = m | |
df2['Likelihood'] = n | |
# plot graph of predictions | |
fig = px.bar(df2, x="Likelihood", y="IEQ", orientation="h") | |
fig.update_layout( | |
# barmode='stack', | |
template='seaborn', font=dict(family="Arial", size=12, color="black"), | |
autosize=True, | |
# width=800, | |
# height=500, | |
xaxis_title="Likelihood of IEQ", | |
yaxis_title="Indoor environmental quality (IEQ)", | |
# legend_title="Topics" | |
) | |
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_annotations(font_size=12) | |
return top_prediction, fig | |
# Create Gradio interface for single text | |
iface2 = gr.Interface(fn=predict_single_text, | |
inputs=gr.Textbox(lines=7, label="Paste or type text here"), | |
outputs=[gr.Label(label="Top Prediction", show_label=True), | |
gr.Plot(label="Likelihood of all labels", show_label=True)], | |
title="Single Text Prediction", | |
article="**Note:** The quality of model predictions may depend on the quality of information provided." | |
) | |
# UPLOAD CSV | |
# Define the prediction function | |
def predict_from_csv(file, column_name, progress=gr.Progress()): | |
""" | |
Predicts the IEQ labels for a list of texts in a CSV file. | |
Args: | |
file (str): The path to the CSV file. | |
column_name (str): The name of the column containing the text to be analyzed. | |
progress (gr.Progress): A progress bar to display the analysis progress. | |
Returns: | |
fig (plotly.graph_objs.Figure): A histogram showing the frequency of each IEQ label. | |
output_csv (gr.File): A downloadable CSV file containing the predictions. | |
""" | |
# Read the CSV file | |
df_docs = pd.read_csv(file) | |
# Check if the specified column exists | |
if column_name not in df_docs.columns: | |
raise gr.Error(f"The column '{column_name}' does not exist in the uploaded CSV file.") | |
# Extract the text list from the specified column | |
text_list = df_docs[column_name].tolist() | |
# Define the list of IEQ labels | |
label_list = [ | |
'Acoustic', | |
'Indoor air quality', | |
'No IEQ', | |
'Thermal', | |
'Visual' | |
] | |
# Initialize lists to store the predictions | |
labels_predicted = [] | |
prediction_scores = [] | |
# Preprocess text and make predictions | |
for text_input in progress.tqdm(text_list, desc="Analysing data"): | |
# Sleep to avoid rate limiting | |
time.sleep(0.02) | |
# Preprocess the text | |
cleaned_text = prep_text(text_input) | |
# Tokenize the text | |
tokenized_text = tokenizer(cleaned_text, return_tensors="pt", truncation=True, max_length=512, padding=True).to( | |
device) | |
# Make predictions | |
with torch.no_grad(): | |
outputs = model(**tokenized_text) | |
logits = outputs.logits | |
# Calculate the probabilities | |
predictions = torch.sigmoid(logits).squeeze() | |
# Define the threshold for prediction | |
threshold = 0.3 | |
# Get the predicted labels | |
predicted_labels_ = (predictions.cpu().numpy() > threshold).tolist() | |
# Map the predicted labels to their corresponding names | |
predicted_labels = [label_list[i] for i in range(len(label_list)) if predicted_labels_[i] == 1] | |
# Get the probabilities of the predicted labels | |
prediction_score = [round(a_, 3) for a_ in predictions.cpu().numpy().tolist() if a_ > threshold] | |
# Append the predictions to the lists | |
labels_predicted.append(predicted_labels) | |
prediction_scores.append(prediction_score) | |
# Append the predictions to the DataFrame | |
df_docs['IEQ_predicted'] = labels_predicted | |
df_docs['prediction_scores'] = prediction_scores | |
# Save the predictions to a CSV file | |
df_docs.to_csv('IEQ_predictions.csv') | |
# Create a downloadable CSV file | |
output_csv = gr.File(value='IEQ_predictions.csv', visible=True) | |
# Create a histogram showing the frequency of each IEQ label | |
fig = px.histogram(df_docs, y="IEQ_predicted") | |
fig.update_layout( | |
template='seaborn', | |
font=dict(family="Arial", size=12, color="black"), | |
autosize=True, | |
# width=800, | |
# height=500, | |
xaxis_title="IEQ counts", | |
yaxis_title="Indoor environmetal quality (IEQ)", | |
) | |
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=12)) | |
fig.update_annotations(font_size=12) | |
return fig, output_csv | |
# Define the input component | |
file_input = gr.File(label="Upload CSV file here", show_label=True, file_types=[".csv"]) | |
column_name_input = gr.Textbox(label="Enter the column name containing the text to be analyzed", show_label=True) | |
# Create the Gradio interface | |
iface3 = gr.Interface(fn=predict_from_csv, | |
inputs=[file_input, column_name_input], | |
outputs=[gr.Plot(label='Frequency of IEQs', show_label=True), | |
gr.File(label='Download output CSV', show_label=True)], | |
title="Multi-text Prediction (CVS)", | |
description='**NOTE:** Please enter the column name containing the text to be analyzed.') | |
# Create a tabbed interface | |
demo = gr.TabbedInterface(interface_list=[iface1, iface2, iface3], | |
tab_names=["General-App-Info", "Single-Text-Prediction", "Multi-Text-Prediction (CSV)"], | |
title="Indoor Environmetal Quality (IEQ) Text Classifier App", | |
theme='soft' | |
) | |
# Launch the interface | |
demo.queue().launch() | |