Spaces:
Runtime error
Runtime error
Commit
·
cca4ece
1
Parent(s):
9111b95
optimizing app
Browse files
app.py
CHANGED
@@ -14,6 +14,25 @@ _plotly_config={'displayModeBar': False}
|
|
14 |
from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
|
15 |
from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
st.set_page_config( # Alternate names: setup_page, page, layout
|
19 |
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
|
@@ -97,24 +116,27 @@ if select_task=='README':
|
|
97 |
|
98 |
############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
#create model/token dir for sentiment classification
|
103 |
-
create_model_dir(chkpt=sent_chkpt, model_dir=sent_model_dir)
|
104 |
|
105 |
|
106 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
107 |
-
def sentiment_task_selected(task,
|
|
|
|
|
|
|
|
|
|
|
108 |
#model & tokenizer initialization for normal sentiment classification
|
109 |
-
model_sentiment=AutoModelForSequenceClassification.from_pretrained(
|
110 |
-
tokenizer_sentiment=AutoTokenizer.from_pretrained(
|
111 |
|
112 |
# create onnx model for sentiment classification
|
113 |
create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
|
114 |
|
115 |
#create inference session
|
116 |
-
sentiment_session = ort.InferenceSession("
|
117 |
-
sentiment_session_quant = ort.InferenceSession("
|
118 |
|
119 |
return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
|
120 |
|
@@ -123,26 +145,31 @@ def sentiment_task_selected(task,sent_model_dir=sent_model_dir):
|
|
123 |
|
124 |
############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
# create model/token dir for zeroshot clf
|
129 |
-
create_model_dir(chkpt=zs_chkpt, model_dir=zs_model_dir)
|
130 |
|
131 |
|
132 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
133 |
-
def zs_task_selected(task,
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
create_onnx_model_zs()
|
140 |
|
141 |
#create inference session from onnx model
|
142 |
-
zs_session = ort.InferenceSession(f"{
|
143 |
-
zs_session_quant = ort.InferenceSession(f"{
|
144 |
|
145 |
-
return
|
146 |
|
147 |
############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
|
148 |
|
@@ -256,7 +283,7 @@ if select_task == 'Detect Sentiment':
|
|
256 |
if select_task=='Zero Shot Classification':
|
257 |
|
258 |
t1=time.time()
|
259 |
-
|
260 |
t2 = time.time()
|
261 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
262 |
|
@@ -267,29 +294,16 @@ if select_task=='Zero Shot Classification':
|
|
267 |
c1,c2,c3,c4=st.columns(4)
|
268 |
|
269 |
with c1:
|
270 |
-
response1=st.button("
|
271 |
with c2:
|
272 |
-
response2=st.button("ONNX runtime")
|
273 |
-
with c3:
|
274 |
-
|
275 |
-
with c4:
|
276 |
-
|
277 |
|
278 |
-
if any([response1,response2
|
279 |
if response1:
|
280 |
-
start=time.time()
|
281 |
-
df_output = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
|
282 |
-
end=time.time()
|
283 |
-
st.write("")
|
284 |
-
st.write(f"Time taken for computation {(end-start)*1000:.1f} ms")
|
285 |
-
fig = px.bar(x='Probability',
|
286 |
-
y='labels',
|
287 |
-
text='Probability',
|
288 |
-
data_frame=df_output,
|
289 |
-
title='Zero Shot Normalized Probabilities')
|
290 |
-
|
291 |
-
st.plotly_chart(fig, config=_plotly_config)
|
292 |
-
elif response2:
|
293 |
start = time.time()
|
294 |
df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
|
295 |
end=time.time()
|
@@ -303,7 +317,7 @@ if select_task=='Zero Shot Classification':
|
|
303 |
title='Zero Shot Normalized Probabilities')
|
304 |
|
305 |
st.plotly_chart(fig,config=_plotly_config)
|
306 |
-
elif
|
307 |
start = time.time()
|
308 |
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
|
309 |
_tokenizer=tokenizer_zs)
|
@@ -317,53 +331,6 @@ if select_task=='Zero Shot Classification':
|
|
317 |
title='Zero Shot Normalized Probabilities')
|
318 |
|
319 |
st.plotly_chart(fig, config=_plotly_config)
|
320 |
-
elif response4:
|
321 |
-
normal_runtime = []
|
322 |
-
for i in range(100):
|
323 |
-
start = time.time()
|
324 |
-
_ = zero_shot_classification(input_texts, input_lables,model=model_zs,tokenizer=tokenizer_zs)
|
325 |
-
end = time.time()
|
326 |
-
t = (end - start) * 1000
|
327 |
-
normal_runtime.append(t)
|
328 |
-
normal_runtime = np.clip(normal_runtime, 50, 400)
|
329 |
-
|
330 |
-
onnx_runtime = []
|
331 |
-
for i in range(100):
|
332 |
-
start = time.time()
|
333 |
-
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session,
|
334 |
-
_tokenizer=tokenizer_zs)
|
335 |
-
end = time.time()
|
336 |
-
t = (end - start) * 1000
|
337 |
-
onnx_runtime.append(t)
|
338 |
-
onnx_runtime = np.clip(onnx_runtime, 50, 200)
|
339 |
-
|
340 |
-
onnx_runtime_quant = []
|
341 |
-
for i in range(100):
|
342 |
-
start = time.time()
|
343 |
-
_ = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
|
344 |
-
_tokenizer=tokenizer_zs)
|
345 |
-
end = time.time()
|
346 |
-
|
347 |
-
t = (end - start) * 1000
|
348 |
-
onnx_runtime_quant.append(t)
|
349 |
-
onnx_runtime_quant = np.clip(onnx_runtime_quant, 50, 200)
|
350 |
-
|
351 |
-
temp_df = pd.DataFrame({'Normal Runtime (ms)': normal_runtime,
|
352 |
-
'ONNX Runtime (ms)': onnx_runtime,
|
353 |
-
'ONNX Quant Runtime (ms)': onnx_runtime_quant})
|
354 |
-
|
355 |
-
from plotly.subplots import make_subplots
|
356 |
-
|
357 |
-
fig = make_subplots(rows=1, cols=3, start_cell="bottom-left",
|
358 |
-
subplot_titles=['Normal Runtime', 'ONNX Runtime', 'ONNX Runtime with Quantization'])
|
359 |
-
|
360 |
-
fig.add_trace(go.Histogram(x=temp_df['Normal Runtime (ms)']), row=1, col=1)
|
361 |
-
fig.add_trace(go.Histogram(x=temp_df['ONNX Runtime (ms)']), row=1, col=2)
|
362 |
-
fig.add_trace(go.Histogram(x=temp_df['ONNX Quant Runtime (ms)']), row=1, col=3)
|
363 |
-
fig.update_layout(height=400, width=1000,
|
364 |
-
title_text="10 Simulations of different Runtimes",
|
365 |
-
showlegend=False)
|
366 |
-
st.plotly_chart(fig, config=_plotly_config)
|
367 |
else:
|
368 |
pass
|
369 |
|
|
|
14 |
from sentiment_clf_helper import classify_sentiment,create_onnx_model_sentiment,classify_sentiment_onnx
|
15 |
from zeroshot_clf_helper import zero_shot_classification,create_onnx_model_zs,zero_shot_classification_onnx
|
16 |
|
17 |
+
import yaml
|
18 |
+
def read_yaml(file_path):
|
19 |
+
with open(file_path, "r") as f:
|
20 |
+
return yaml.safe_load(f)
|
21 |
+
|
22 |
+
config = read_yaml('config.yaml')
|
23 |
+
|
24 |
+
sent_chkpt=config['SENTIMENT_CLF']['sent_chkpt']
|
25 |
+
sent_mdl_dir=config['SENTIMENT_CLF']['sent_mdl_dir']
|
26 |
+
sent_onnx_mdl_dir=config['SENTIMENT_CLF']['sent_onnx_mdl_dir']
|
27 |
+
sent_onnx_mdl_name=config['SENTIMENT_CLF']['sent_onnx_mdl_name']
|
28 |
+
sent_onnx_quant_mdl_name=config['SENTIMENT_CLF']['sent_onnx_quant_mdl_name']
|
29 |
+
|
30 |
+
zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt']
|
31 |
+
zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir']
|
32 |
+
zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
|
33 |
+
zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
|
34 |
+
zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
|
35 |
+
|
36 |
|
37 |
st.set_page_config( # Alternate names: setup_page, page, layout
|
38 |
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
|
|
|
116 |
|
117 |
############### Pre-Download & instantiate objects for sentiment analysis *********************** START **********************
|
118 |
|
119 |
+
# #create model/token dir for sentiment classification for faster inference
|
120 |
+
# create_model_dir(chkpt=sent_chkpt, model_dir=sent_mdl_dir)
|
|
|
|
|
121 |
|
122 |
|
123 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
124 |
+
def sentiment_task_selected(task,
|
125 |
+
sent_chkpt=sent_chkpt,
|
126 |
+
sent_mdl_dir=sent_mdl_dir,
|
127 |
+
sent_onnx_mdl_dir=sent_onnx_mdl_dir,
|
128 |
+
sent_onnx_mdl_name=sent_onnx_mdl_name,
|
129 |
+
sent_onnx_quant_mdl_name=sent_onnx_quant_mdl_name):
|
130 |
#model & tokenizer initialization for normal sentiment classification
|
131 |
+
model_sentiment=AutoModelForSequenceClassification.from_pretrained(sent_chkpt)
|
132 |
+
tokenizer_sentiment=AutoTokenizer.from_pretrained(sent_chkpt)
|
133 |
|
134 |
# create onnx model for sentiment classification
|
135 |
create_onnx_model_sentiment(_model=model_sentiment, _tokenizer=tokenizer_sentiment)
|
136 |
|
137 |
#create inference session
|
138 |
+
sentiment_session = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_mdl_name}")
|
139 |
+
sentiment_session_quant = ort.InferenceSession(f"{sent_onnx_mdl_dir}/{sent_onnx_quant_mdl_name}")
|
140 |
|
141 |
return model_sentiment,tokenizer_sentiment,sentiment_session,sentiment_session_quant
|
142 |
|
|
|
145 |
|
146 |
############### Pre-Download & instantiate objects for Zero shot clf *********************** START **********************
|
147 |
|
148 |
+
# # create model/token dir for zeroshot clf
|
149 |
+
# create_model_dir(chkpt=zs_chkpt, model_dir=zs_mdl_dir)
|
|
|
|
|
150 |
|
151 |
|
152 |
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=None, ttl=None)
|
153 |
+
def zs_task_selected(task,
|
154 |
+
zs_chkpt=zs_chkpt ,
|
155 |
+
zs_mdl_dir=zs_mdl_dir,
|
156 |
+
zs_onnx_mdl_dir=zs_onnx_mdl_dir,
|
157 |
+
zs_onnx_mdl_name=zs_onnx_mdl_name,
|
158 |
+
zs_onnx_quant_mdl_name=zs_onnx_quant_mdl_name):
|
159 |
+
|
160 |
+
##model & tokenizer initialization for normal ZS classification
|
161 |
+
# model_zs=AutoModelForSequenceClassification.from_pretrained(zs_chkpt)
|
162 |
+
# we just need tokenizer for inference and not model since onnx model is already saved
|
163 |
+
tokenizer_zs=AutoTokenizer.from_pretrained(zs_chkpt)
|
164 |
+
|
165 |
+
# create onnx model for zeroshot
|
166 |
create_onnx_model_zs()
|
167 |
|
168 |
#create inference session from onnx model
|
169 |
+
zs_session = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}")
|
170 |
+
zs_session_quant = ort.InferenceSession(f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}")
|
171 |
|
172 |
+
return tokenizer_zs,zs_session,zs_session_quant
|
173 |
|
174 |
############## Pre-Download & instantiate objects for Zero shot analysis ********************* END **********************************
|
175 |
|
|
|
283 |
if select_task=='Zero Shot Classification':
|
284 |
|
285 |
t1=time.time()
|
286 |
+
tokenizer_zs,zs_session,zs_session_quant = zs_task_selected(task=select_task)
|
287 |
t2 = time.time()
|
288 |
st.write(f"Total time to load Model is {(t2-t1)*1000:.1f} ms")
|
289 |
|
|
|
294 |
c1,c2,c3,c4=st.columns(4)
|
295 |
|
296 |
with c1:
|
297 |
+
response1=st.button("ONNX runtime")
|
298 |
with c2:
|
299 |
+
response2=st.button("ONNX runtime Quantized")
|
300 |
+
# with c3:
|
301 |
+
# response3=st.button("ONNX runtime with Quantization")
|
302 |
+
# with c4:
|
303 |
+
# response4 = st.button("Simulate 10 runs each runtime")
|
304 |
|
305 |
+
if any([response1,response2]):
|
306 |
if response1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
start = time.time()
|
308 |
df_output=zero_shot_classification_onnx(premise=input_texts,labels=input_lables,_session=zs_session,_tokenizer=tokenizer_zs)
|
309 |
end=time.time()
|
|
|
317 |
title='Zero Shot Normalized Probabilities')
|
318 |
|
319 |
st.plotly_chart(fig,config=_plotly_config)
|
320 |
+
elif response2:
|
321 |
start = time.time()
|
322 |
df_output = zero_shot_classification_onnx(premise=input_texts, labels=input_lables, _session=zs_session_quant,
|
323 |
_tokenizer=tokenizer_zs)
|
|
|
331 |
title='Zero Shot Normalized Probabilities')
|
332 |
|
333 |
st.plotly_chart(fig, config=_plotly_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
else:
|
335 |
pass
|
336 |
|