ashhadahsan commited on
Commit
da5430f
Β·
1 Parent(s): adf26de

Create pages/1_πŸ“ˆ_predict.py

Browse files
Files changed (1) hide show
  1. pages/1_πŸ“ˆ_predict.py +560 -0
pages/1_πŸ“ˆ_predict.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from transformers import pipeline
4
+ from stqdm import stqdm
5
+ from simplet5 import SimpleT5
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
+ from transformers import BertTokenizer, TFBertForSequenceClassification
8
+ from datetime import datetime
9
+ import logging
10
+ from transformers import TextClassificationPipeline
11
+ import gc
12
+ from datasets import load_dataset
13
+ from utils.openllmapi.api import ChatBot
14
+ from utils.openllmapi.exceptions import *
15
+ import time
16
+ from typing import List
17
+ from collections import OrderedDict
18
+
19
+ tokenizer_kwargs = dict(
20
+ max_length=128,
21
+ truncation=True,
22
+ padding=True,
23
+ )
24
+ SLEEP = 2
25
+
26
+
27
+ def cleanMemory(obj: TextClassificationPipeline):
28
+ del obj
29
+ gc.collect()
30
+
31
+
32
+ @st.cache_data
33
+ def getAllCats():
34
+ data = load_dataset("ashhadahsan/amazon_theme")
35
+ data = data["train"].to_pandas()
36
+ labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
37
+ del data
38
+ return labels
39
+
40
+
41
+ @st.cache_data
42
+ def getAllSubCats():
43
+ data = load_dataset("ashhadahsan/amazon_theme")
44
+ data = data["train"].to_pandas()
45
+ labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
46
+ del data
47
+ return labels
48
+
49
+
50
+ def assignHF(bot, what: str, to: str, old: List):
51
+ try:
52
+ old = ", ".join(old)
53
+ message_content = bot.chat(
54
+ f"""'Assign a one-line {what} to this summary of the text of a review
55
+ {to}
56
+ already assigned themes are , {old}
57
+ theme""",
58
+ )
59
+ try:
60
+ return message_content.split(":")[1].strip()
61
+ except:
62
+ return message_content.strip()
63
+ except ChatError:
64
+ return ""
65
+
66
+
67
+ @st.cache_resource
68
+ def loadZeroShotClassification():
69
+ classifierzero = pipeline(
70
+ "zero-shot-classification", model="facebook/bart-large-mnli"
71
+ )
72
+ return classifierzero
73
+
74
+
75
+ def assignZeroShot(zero, to: str, old: List):
76
+ assigned = zero(to, old)
77
+ assigneddict = dict(zip(assigned["labels"], assigned["scores"]))
78
+ od = OrderedDict(sorted(assigneddict.items(), key=lambda x: x[1], reverse=True))
79
+ return [od.keys()][0]
80
+
81
+
82
+ date = datetime.now().strftime(r"%Y-%m-%d")
83
+
84
+
85
+ @st.cache_resource
86
+ def load_t5() -> (AutoModelForSeq2SeqLM, AutoTokenizer):
87
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
88
+
89
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
90
+ return model, tokenizer
91
+
92
+
93
+ @st.cache_resource
94
+ def summarizationModel():
95
+ return pipeline("summarization", model="my_awesome_sum/")
96
+
97
+
98
+ @st.cache_resource
99
+ def convert_df(df: pd.DataFrame):
100
+ # IMPORTANT: Cache the conversion to prevent computation on every rerun
101
+ return df.to_csv(index=False).encode("utf-8")
102
+
103
+
104
+ # @st.cache(allow_output_mutation=True, suppress_st_warning=True)
105
+ # @st.cache_resource
106
+ def load_one_line_summarizer(model):
107
+ return model.load_model("t5", "snrspeaks/t5-one-line-summary")
108
+
109
+
110
+ @st.cache_resource
111
+ def classify_theme() -> TextClassificationPipeline:
112
+ tokenizer = BertTokenizer.from_pretrained(
113
+ "ashhadahsan/amazon-theme-bert-base-finetuned"
114
+ )
115
+ model = TFBertForSequenceClassification.from_pretrained(
116
+ "ashhadahsan/amazon-theme-bert-base-finetuned"
117
+ )
118
+ pipeline = TextClassificationPipeline(
119
+ model=model, tokenizer=tokenizer, top_k=1, **tokenizer_kwargs
120
+ )
121
+ return pipeline
122
+
123
+
124
+ @st.cache_resource
125
+ def classify_sub_theme() -> TextClassificationPipeline:
126
+ tokenizer = BertTokenizer.from_pretrained(
127
+ "ashhadahsan/amazon-subtheme-bert-base-finetuned"
128
+ )
129
+ model = TFBertForSequenceClassification.from_pretrained(
130
+ "ashhadahsan/amazon-subtheme-bert-base-finetuned"
131
+ )
132
+ pipeline = TextClassificationPipeline(
133
+ model=model, tokenizer=tokenizer, top_k=1, **tokenizer_kwargs
134
+ )
135
+ return pipeline
136
+
137
+
138
+ st.set_page_config(layout="wide", page_title="Amazon Review | Summarizer")
139
+ st.title("Amazon Review Summarizer")
140
+
141
+ uploaded_file = st.file_uploader("Choose a file", type=["xlsx", "xls", "csv"])
142
+
143
+ try:
144
+ bot = ChatBot(
145
+ cookies={
146
+ "hf-chat": st.secrets["hf-chat"],
147
+ "token": st.secrets["token"],
148
+ }
149
+ )
150
+ except ChatBotInitError as e:
151
+ print(e)
152
+
153
+ summarizer_option = st.selectbox(
154
+ "Select Summarizer",
155
+ ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
156
+ )
157
+ col1, col2, col3 = st.columns([1, 1, 1])
158
+
159
+ with col1:
160
+ summary_yes = st.checkbox("Summrization", value=False)
161
+
162
+ with col2:
163
+ classification = st.checkbox("Classify Category", value=True)
164
+
165
+ with col3:
166
+ sub_theme = st.checkbox("Sub theme classification", value=True)
167
+
168
+ treshold = st.slider(
169
+ label="Model Confidence value",
170
+ min_value=0.1,
171
+ max_value=0.8,
172
+ step=0.1,
173
+ value=0.6,
174
+ help="The confidence value of the model",
175
+ )
176
+
177
+ ps = st.empty()
178
+
179
+ if st.button("Process", type="primary"):
180
+ themes = getAllCats()
181
+ subthemes = getAllSubCats()
182
+ # st.write(themes)
183
+
184
+ oneline = SimpleT5()
185
+ load_one_line_summarizer(model=oneline)
186
+ zeroline = loadZeroShotClassification()
187
+
188
+ cancel_button = st.empty()
189
+ cancel_button2 = st.empty()
190
+ cancel_button3 = st.empty()
191
+ if uploaded_file is not None:
192
+ if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]:
193
+ df = pd.read_excel(uploaded_file, engine="openpyxl")
194
+ if uploaded_file.name.split(".")[-1] in [".csv"]:
195
+ df = pd.read_csv(uploaded_file)
196
+ columns = df.columns.values.tolist()
197
+ columns = [x.lower() for x in columns]
198
+ df.columns = columns
199
+ print(summarizer_option)
200
+ outputdf = pd.DataFrame()
201
+ try:
202
+ text = df["text"].values.tolist()
203
+ outputdf["text"] = text
204
+ if summarizer_option == "Custom trained on the dataset":
205
+ if summary_yes:
206
+ model = summarizationModel()
207
+
208
+ progress_text = "Summarization in progress. Please wait."
209
+ summary = []
210
+
211
+ for x in stqdm(range(len(text))):
212
+ if cancel_button.button("Cancel", key=x):
213
+ del model
214
+ break
215
+ try:
216
+ summary.append(
217
+ model(
218
+ f"summarize: {text[x]}",
219
+ max_length=50,
220
+ early_stopping=True,
221
+ )[0]["summary_text"]
222
+ )
223
+ except:
224
+ pass
225
+ outputdf["summary"] = summary
226
+ del model
227
+ if classification:
228
+ themePipe = classify_theme()
229
+ classes = []
230
+ classesUnlabel = []
231
+ classesUnlabelZero = []
232
+ for x in stqdm(
233
+ text,
234
+ desc="Assigning Themes ...",
235
+ total=len(text),
236
+ colour="#BF1A1A",
237
+ ):
238
+ output = themePipe(x)[0][0]["label"]
239
+ classes.append(output)
240
+ score = round(themePipe(x)[0][0]["score"], 2)
241
+ if score <= treshold:
242
+ onelineoutput=oneline.predict(x)[0]
243
+ time.sleep(SLEEP)
244
+ print("hit")
245
+ classesUnlabel.append(
246
+ assignHF(
247
+ bot=bot,
248
+ what="theme",
249
+ to=onelineoutput,
250
+ old=themes,
251
+ )
252
+ )
253
+ classesUnlabelZero.append(
254
+ assignZeroShot(
255
+ zero=zeroline, to=onelineoutput, old=themes
256
+ )
257
+ )
258
+
259
+ else:
260
+ classesUnlabel.append("")
261
+ classesUnlabelZero.append("")
262
+
263
+ outputdf["Review Theme"] = classes
264
+ outputdf["Review Theme-issue-new"] = classesUnlabel
265
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
266
+ cleanMemory(themePipe)
267
+ if sub_theme:
268
+ subThemePipe = classify_sub_theme()
269
+ classes = []
270
+ classesUnlabel = []
271
+ classesUnlabelZero = []
272
+ for x in stqdm(
273
+ text,
274
+ desc="Assigning Subthemes ...",
275
+ total=len(text),
276
+ colour="green",
277
+ ):
278
+ output = subThemePipe(x)[0][0]["label"]
279
+ classes.append(output)
280
+ score = round(subThemePipe(x)[0][0]["score"], 2)
281
+ if score <= treshold:
282
+ onelineoutput=oneline.predict(x)[0]
283
+
284
+ time.sleep(SLEEP)
285
+
286
+ print("hit")
287
+ classesUnlabel.append(
288
+ assignHF(
289
+ bot=bot,
290
+ what="subtheme",
291
+ to=onelineoutput,
292
+ old=subthemes,
293
+ )
294
+ )
295
+ classesUnlabelZero.append(
296
+ assignZeroShot(
297
+ zero=zeroline,
298
+ to=onelineoutput,
299
+ old=subthemes,
300
+ )
301
+ )
302
+
303
+ else:
304
+ classesUnlabel.append("")
305
+ classesUnlabelZero.append("")
306
+
307
+ outputdf["Review SubTheme"] = classes
308
+ outputdf["Review SubTheme-issue-new"] = classesUnlabel
309
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
310
+
311
+ cleanMemory(subThemePipe)
312
+
313
+ csv = convert_df(outputdf)
314
+ st.download_button(
315
+ label="Download output as CSV",
316
+ data=csv,
317
+ file_name=f"{summarizer_option}_{date}_df.csv",
318
+ mime="text/csv",
319
+ use_container_width=True,
320
+ )
321
+ if summarizer_option == "t5-base":
322
+ if summary_yes:
323
+ model, tokenizer = load_t5()
324
+ summary = []
325
+ for x in stqdm(range(len(text))):
326
+ if cancel_button2.button("Cancel", key=x):
327
+ del model, tokenizer
328
+ break
329
+ tokens_input = tokenizer.encode(
330
+ "summarize: " + text[x],
331
+ return_tensors="pt",
332
+ max_length=tokenizer.model_max_length,
333
+ truncation=True,
334
+ )
335
+ summary_ids = model.generate(
336
+ tokens_input,
337
+ min_length=80,
338
+ max_length=150,
339
+ length_penalty=20,
340
+ num_beams=2,
341
+ )
342
+ summary_gen = tokenizer.decode(
343
+ summary_ids[0], skip_special_tokens=True
344
+ )
345
+ summary.append(summary_gen)
346
+ del model, tokenizer
347
+ outputdf["summary"] = summary
348
+
349
+ if classification:
350
+ themePipe = classify_theme()
351
+ classes = []
352
+ classesUnlabel = []
353
+ classesUnlabelZero = []
354
+ for x in stqdm(
355
+ text, desc="Assigning Themes ...", total=len(text), colour="red"
356
+ ):
357
+ output = themePipe(x)[0][0]["label"]
358
+ classes.append(output)
359
+ score = round(themePipe(x)[0][0]["score"], 2)
360
+ if score <= treshold:
361
+ onelineoutput=oneline.predict(x)[0]
362
+
363
+ print("hit")
364
+ time.sleep(SLEEP)
365
+
366
+ classesUnlabel.append(
367
+ assignHF(
368
+ bot=bot,
369
+ what="theme",
370
+ to=onelineoutput
371
+ old=themes,
372
+ )
373
+ )
374
+ classesUnlabelZero.append(
375
+ assignZeroShot(
376
+ zero=zeroline, to=onelineoutput, old=themes
377
+ )
378
+ )
379
+
380
+ else:
381
+ classesUnlabel.append("")
382
+ classesUnlabelZero.append("")
383
+ outputdf["Review Theme"] = classes
384
+ outputdf["Review Theme-issue-new"] = classesUnlabel
385
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
386
+ cleanMemory(themePipe)
387
+
388
+ if sub_theme:
389
+ subThemePipe = classify_sub_theme()
390
+ classes = []
391
+ classesUnlabelZero = []
392
+
393
+ for x in stqdm(
394
+ text,
395
+ desc="Assigning Subthemes ...",
396
+ total=len(text),
397
+ colour="green",
398
+ ):
399
+ output = subThemePipe(x)[0][0]["label"]
400
+ classes.append(output)
401
+ score = round(subThemePipe(x)[0][0]["score"], 2)
402
+ if score <= treshold:
403
+ onelineoutput=oneline.predict(x)[0]
404
+
405
+ time.sleep(SLEEP)
406
+ print("hit")
407
+ classesUnlabel.append(
408
+ assignHF(
409
+ bot=bot,
410
+ what="subtheme",
411
+ to=onelineoutput,
412
+ old=subthemes,
413
+ )
414
+ )
415
+ classesUnlabelZero.append(
416
+ assignZeroShot(
417
+ zero=zeroline,
418
+ to=onelineoutput,
419
+ old=subthemes,
420
+ )
421
+ )
422
+
423
+ else:
424
+ classesUnlabel.append("")
425
+ classesUnlabelZero.append("")
426
+
427
+ outputdf["Review SubTheme"] = classes
428
+ outputdf["Review SubTheme-issue-new"] = classesUnlabel
429
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
430
+
431
+ cleanMemory(subThemePipe)
432
+
433
+ csv = convert_df(outputdf)
434
+ st.download_button(
435
+ label="Download output as CSV",
436
+ data=csv,
437
+ file_name=f"{summarizer_option}_{date}_df.csv",
438
+ mime="text/csv",
439
+ use_container_width=True,
440
+ )
441
+
442
+ if summarizer_option == "t5-one-line-summary":
443
+ if summary_yes:
444
+ model = SimpleT5()
445
+ load_one_line_summarizer(model=model)
446
+
447
+ summary = []
448
+ for x in stqdm(range(len(text))):
449
+ if cancel_button3.button("Cancel", key=x):
450
+ del model
451
+ break
452
+ try:
453
+ summary.append(model.predict(text[x])[0])
454
+ except:
455
+ pass
456
+ outputdf["summary"] = summary
457
+ del model
458
+
459
+ if classification:
460
+ themePipe = classify_theme()
461
+ classes = []
462
+ classesUnlabel = []
463
+ classesUnlabelZero = []
464
+ for x in stqdm(
465
+ text, desc="Assigning Themes ...", total=len(text), colour="red"
466
+ ):
467
+ output = themePipe(x)[0][0]["label"]
468
+ classes.append(output)
469
+ score = round(themePipe(x)[0][0]["score"], 2)
470
+ if score <= treshold:
471
+ onelineoutput=oneline.predict(x)[0]
472
+
473
+ time.sleep(SLEEP)
474
+
475
+ print("hit")
476
+ classesUnlabel.append(
477
+ assignHF(
478
+ bot=bot,
479
+ what="theme",
480
+ to=onelineoutput,
481
+ old=themes,
482
+ )
483
+ )
484
+ classesUnlabelZero.append(
485
+ assignZeroShot(
486
+ zero=zeroline, to=onelineoutput, old=themes
487
+ )
488
+ )
489
+
490
+ else:
491
+ classesUnlabel.append("")
492
+ classesUnlabelZero.append("")
493
+ outputdf["Review Theme"] = classes
494
+ outputdf["Review Theme-issue-new"] = classesUnlabel
495
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
496
+
497
+ if sub_theme:
498
+ subThemePipe = classify_sub_theme()
499
+ classes = []
500
+ classesUnlabelZero = []
501
+
502
+ for x in stqdm(
503
+ text,
504
+ desc="Assigning Subthemes ...",
505
+ total=len(text),
506
+ colour="green",
507
+ ):
508
+ output = subThemePipe(x)[0][0]["label"]
509
+ classes.append(output)
510
+ score = round(subThemePipe(x)[0][0]["score"], 2)
511
+ if score <= treshold:
512
+ print("hit")
513
+ onelineoutput=oneline.predict(x)[0]
514
+
515
+ time.sleep(SLEEP)
516
+ classesUnlabel.append(
517
+ assignHF(
518
+ bot=bot,
519
+ what="subtheme",
520
+ to=onelineoutput,
521
+ old=subthemes,
522
+ )
523
+ )
524
+ classesUnlabelZero.append(
525
+ assignZeroShot(
526
+ zero=zeroline,
527
+ to=onelineoutput,
528
+ old=subthemes,
529
+ )
530
+ )
531
+
532
+ else:
533
+ classesUnlabel.append("")
534
+ classesUnlabelZero.append("")
535
+
536
+ outputdf["Review SubTheme"] = classes
537
+ outputdf["Review SubTheme-issue-new"] = classesUnlabel
538
+ outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
539
+
540
+ cleanMemory(subThemePipe)
541
+
542
+ csv = convert_df(outputdf)
543
+ st.download_button(
544
+ label="Download output as CSV",
545
+ data=csv,
546
+ file_name=f"{summarizer_option}_{date}_df.csv",
547
+ mime="text/csv",
548
+ use_container_width=True,
549
+ )
550
+
551
+ except KeyError as e:
552
+ st.error(
553
+ "Please Make sure that your data must have a column named text",
554
+ icon="🚨",
555
+ )
556
+ st.info("Text column must have amazon reviews", icon="ℹ️")
557
+ # st.exception(e)
558
+
559
+ except BaseException as e:
560
+ logging.exception("An exception was occurred")