Upload 2 files
Browse files- main.py +249 -0
- requirements.txt +6 -0
main.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
# import inflect
|
3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
+
import torch
|
5 |
+
import string
|
6 |
+
import plotly.express as px
|
7 |
+
import pandas as pd
|
8 |
+
import nltk
|
9 |
+
from nltk.tokenize import sent_tokenize
|
10 |
+
nltk.download('punkt')
|
11 |
+
|
12 |
+
# Note - USE "VBA_venv" environment in the local github folder
|
13 |
+
|
14 |
+
punctuations = string.punctuation
|
15 |
+
|
16 |
+
def prep_text(text):
|
17 |
+
# function for preprocessing text
|
18 |
+
|
19 |
+
# remove trailing characters (\s\n) and convert to lowercase
|
20 |
+
clean_sents = [] # append clean con sentences
|
21 |
+
sent_tokens = sent_tokenize(str(text))
|
22 |
+
for sent_token in sent_tokens:
|
23 |
+
word_tokens = [str(word_token).strip().lower() for word_token in sent_token.split()]
|
24 |
+
word_tokens = [word_token for word_token in word_tokens if word_token not in punctuations]
|
25 |
+
clean_sents.append(' '.join((word_tokens)))
|
26 |
+
joined = ' '.join(clean_sents).strip(' ')
|
27 |
+
return joined
|
28 |
+
|
29 |
+
|
30 |
+
# model name or path to model
|
31 |
+
checkpoint_1 = "Highway/SubCat"
|
32 |
+
|
33 |
+
checkpoint_2 = "Highway/ExtraOver"
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache(allow_output_mutation=True)
|
37 |
+
def load_model_1():
|
38 |
+
return AutoModelForSequenceClassification.from_pretrained(checkpoint_1)
|
39 |
+
|
40 |
+
|
41 |
+
@st.cache(allow_output_mutation=True)
|
42 |
+
def load_tokenizer_1():
|
43 |
+
return AutoTokenizer.from_pretrained(checkpoint_1)
|
44 |
+
|
45 |
+
|
46 |
+
@st.cache(allow_output_mutation=True)
|
47 |
+
def load_model_2():
|
48 |
+
return AutoModelForSequenceClassification.from_pretrained(checkpoint_2)
|
49 |
+
|
50 |
+
|
51 |
+
@st.cache(allow_output_mutation=True)
|
52 |
+
def load_tokenizer_2():
|
53 |
+
return AutoTokenizer.from_pretrained(checkpoint_2)
|
54 |
+
|
55 |
+
|
56 |
+
st.set_page_config(
|
57 |
+
page_title="Cost Data Classifier", layout= "wide", initial_sidebar_state="auto", page_icon="💷"
|
58 |
+
)
|
59 |
+
|
60 |
+
st.title("🚦 AI Infrastructure Cost Data Classifier")
|
61 |
+
# st.header("")
|
62 |
+
|
63 |
+
with st.expander("About this app", expanded=False):
|
64 |
+
st.write(
|
65 |
+
"""
|
66 |
+
- Artificial Intelligence (AI) and Machine learning (ML) tool for automatic classification of infrastructure cost data for benchmarking
|
67 |
+
- Classifies cost descriptions from documents such as Bills of Quantities (BOQs) and Schedule of Rates
|
68 |
+
- Can be trained to classify granular and itemised cost descriptions into any predefined categories for benchmarking
|
69 |
+
- Contact research team to discuss your data structures and suitability for the app
|
70 |
+
- It is best to use this app on a laptop or desktop computer
|
71 |
+
"""
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
st.markdown("##### Description")
|
76 |
+
with st.form(key="my_form"):
|
77 |
+
Text_entry = st.text_area(
|
78 |
+
"Paste or type infrastructure cost description in the text box below (i.e., input)"
|
79 |
+
)
|
80 |
+
submitted = st.form_submit_button(label="👉 Get SubCat and ExtraOver!")
|
81 |
+
|
82 |
+
if submitted:
|
83 |
+
|
84 |
+
# First prediction
|
85 |
+
|
86 |
+
label_list_1 = [
|
87 |
+
'Arrow, Triangle, Circle, Letter, Numeral, Symbol and Sundries',
|
88 |
+
'Binder',
|
89 |
+
'Cable',
|
90 |
+
'Catman Other Adjustment',
|
91 |
+
'Cold Milling',
|
92 |
+
'Disposal of Acceptable/Unacceptable Material',
|
93 |
+
'Drain/Service Duct In Trench',
|
94 |
+
'Erection & Dismantling of Temporary Accommodation/Facilities (All Types)',
|
95 |
+
'Excavate And Replace Filter Material/Recycle Filter Material',
|
96 |
+
'Excavation',
|
97 |
+
'General TM Item',
|
98 |
+
'Information boards',
|
99 |
+
'Joint/Termination',
|
100 |
+
'Line, Ancillary Line, Solid Area',
|
101 |
+
'Loop Detector Installation',
|
102 |
+
'Minimum Lining Visit Charge',
|
103 |
+
'Node Marker',
|
104 |
+
'PCC Kerb',
|
105 |
+
'Provision of Mobile Welfare Facilities',
|
106 |
+
'Removal of Deformable Safety Fence',
|
107 |
+
'Removal of Line, Ancillary Line, Solid Area',
|
108 |
+
'Removal of Traffic Sign and post(s)',
|
109 |
+
'Road Stud',
|
110 |
+
'Safety Barrier Or Bifurcation (Non-Concrete)',
|
111 |
+
'Servicing of Temporary Accommodation/Facilities (All Types) (day)',
|
112 |
+
'Tack Coat',
|
113 |
+
'Temporary Road Markings',
|
114 |
+
'Thin Surface Course',
|
115 |
+
'Traffic Sign - Unknown specification',
|
116 |
+
'Vegetation Clearance/Weed Control (m2)',
|
117 |
+
'Others'
|
118 |
+
]
|
119 |
+
|
120 |
+
joined_clean_sents = prep_text(Text_entry)
|
121 |
+
|
122 |
+
# tokenize
|
123 |
+
tokenizer_1 = load_tokenizer_1()
|
124 |
+
tokenized_text_1 = tokenizer_1(joined_clean_sents, return_tensors="pt")
|
125 |
+
|
126 |
+
# predict
|
127 |
+
model_1 = load_model_1()
|
128 |
+
text_logits_1 = model_1(**tokenized_text_1).logits
|
129 |
+
predictions_1 = torch.softmax(text_logits_1, dim=1).tolist()[0]
|
130 |
+
predictions_1 = [round(a, 3) for a in predictions_1]
|
131 |
+
|
132 |
+
# dictionary with label as key and percentage as value
|
133 |
+
pred_dict_1 = (dict(zip(label_list_1, predictions_1)))
|
134 |
+
|
135 |
+
# sort 'pred_dict' by value and index the highest at [0]
|
136 |
+
sorted_preds_1 = sorted(pred_dict_1.items(), key=lambda x: x[1], reverse=True)
|
137 |
+
|
138 |
+
# Make dataframe for plotly bar chart
|
139 |
+
u_1, v_1 = zip(*sorted_preds_1)
|
140 |
+
x_1 = list(u_1)
|
141 |
+
y_1 = list(v_1)
|
142 |
+
df2 = pd.DataFrame()
|
143 |
+
df2['SubCatName'] = x_1
|
144 |
+
df2['Likelihood'] = y_1
|
145 |
+
|
146 |
+
c1, c2, c3 = st.columns([1.5, 0.5, 1])
|
147 |
+
|
148 |
+
with c1:
|
149 |
+
st.header("SubCatName")
|
150 |
+
# plot graph of predictions
|
151 |
+
fig = px.bar(df2, x="Likelihood", y="SubCatName", orientation="h")
|
152 |
+
|
153 |
+
fig.update_layout(
|
154 |
+
# barmode='stack',
|
155 |
+
template='ggplot2',
|
156 |
+
font=dict(
|
157 |
+
family="Arial",
|
158 |
+
size=14,
|
159 |
+
color="black"
|
160 |
+
),
|
161 |
+
autosize=False,
|
162 |
+
width=800,
|
163 |
+
height=500,
|
164 |
+
xaxis_title="Likelihood of SubCatName",
|
165 |
+
yaxis_title="SubCatNames",
|
166 |
+
# legend_title="Topics"
|
167 |
+
)
|
168 |
+
|
169 |
+
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
|
170 |
+
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
|
171 |
+
fig.update_annotations(font_size=14) # this changes y_axis, x_axis and subplot title font sizes
|
172 |
+
|
173 |
+
# Plot
|
174 |
+
st.plotly_chart(fig, use_container_width=False)
|
175 |
+
|
176 |
+
with c3:
|
177 |
+
st.header("")
|
178 |
+
predicted_1 = st.metric("Predicted SubCatName", sorted_preds_1[0][0])
|
179 |
+
Prediction_confidence_1 = st.metric("Prediction confidence", (str(round(sorted_preds_1[0][1]*100, 1))+"%"))
|
180 |
+
|
181 |
+
st.success("Great! SubCatName successfully predicted. ", icon="✅")
|
182 |
+
|
183 |
+
|
184 |
+
# Second prediction
|
185 |
+
|
186 |
+
label_list_2 = ["False", "True"]
|
187 |
+
|
188 |
+
joined_clean_sents = prep_text(Text_entry)
|
189 |
+
|
190 |
+
# tokenize
|
191 |
+
tokenizer_2 = load_tokenizer_2()
|
192 |
+
tokenized_text_2 = tokenizer_2(joined_clean_sents, return_tensors="pt")
|
193 |
+
|
194 |
+
# predict
|
195 |
+
model_2 = load_model_2()
|
196 |
+
text_logits_2 = model_2(**tokenized_text_2).logits
|
197 |
+
predictions_2 = torch.softmax(text_logits_2, dim=1).tolist()[0]
|
198 |
+
predictions_2 = [round(a_, 3) for a_ in predictions_2]
|
199 |
+
|
200 |
+
# dictionary with label as key and percentage as value
|
201 |
+
pred_dict_2 = (dict(zip(label_list_2, predictions_2)))
|
202 |
+
|
203 |
+
# sort 'pred_dict' by value and index the highest at [0]
|
204 |
+
sorted_preds_2 = sorted(pred_dict_2.items(), key=lambda x: x[1], reverse=True)
|
205 |
+
|
206 |
+
# Make dataframe for plotly bar chart
|
207 |
+
u_2, v_2 = zip(*sorted_preds_2)
|
208 |
+
x_2 = list(u_2)
|
209 |
+
y_2 = list(v_2)
|
210 |
+
df3 = pd.DataFrame()
|
211 |
+
df3['ExtraOver'] = x_2
|
212 |
+
df3['Likelihood'] = y_2
|
213 |
+
|
214 |
+
d1, d2, d3 = st.columns([1.5, 0.5, 1])
|
215 |
+
|
216 |
+
with d1:
|
217 |
+
st.header("ExtraOver")
|
218 |
+
# plot graph of predictions
|
219 |
+
fig = px.bar(df3, x="Likelihood", y="ExtraOver", orientation="h")
|
220 |
+
|
221 |
+
fig.update_layout(
|
222 |
+
# barmode='stack',
|
223 |
+
template='ggplot2',
|
224 |
+
font=dict(
|
225 |
+
family="Arial",
|
226 |
+
size=14,
|
227 |
+
color="black"
|
228 |
+
),
|
229 |
+
autosize=False,
|
230 |
+
width=800,
|
231 |
+
height=200,
|
232 |
+
xaxis_title="Likelihood of ExtraOver",
|
233 |
+
yaxis_title="ExtraOver",
|
234 |
+
# legend_title="Topics"
|
235 |
+
)
|
236 |
+
|
237 |
+
fig.update_xaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
|
238 |
+
fig.update_yaxes(tickangle=0, tickfont=dict(family='Arial', color='black', size=14))
|
239 |
+
fig.update_annotations(font_size=14) # this changes y_axis, x_axis and subplot title font sizes
|
240 |
+
|
241 |
+
# Plot
|
242 |
+
st.plotly_chart(fig, use_container_width=False)
|
243 |
+
|
244 |
+
with d3:
|
245 |
+
st.header("")
|
246 |
+
predicted_2 = st.metric("Predicted ExtraOver", sorted_preds_2[0][0])
|
247 |
+
Prediction_confidence_2 = st.metric("Prediction confidence", (str(round(sorted_preds_2[0][1]*100, 1))+"%"))
|
248 |
+
|
249 |
+
st.success("Great! ExtraOver successfully predicted. ", icon="✅")
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
plotly
|
4 |
+
pandas
|
5 |
+
nltk
|
6 |
+
streamlit
|