Paula Leonova
commited on
Commit
·
a6b5529
1
Parent(s):
0a49db3
Clean up notes
Browse files
app.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# Reference: https://huggingface.co/spaces/team-zero-shot-nli/zero-shot-nli/blob/main/app.py
|
2 |
|
3 |
from os import write
|
4 |
import pandas as pd
|
@@ -8,7 +7,6 @@ import streamlit as st
|
|
8 |
|
9 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
10 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
11 |
-
# from utils import plot_result, examples_load, example_long_text_load, to_excel
|
12 |
import json
|
13 |
|
14 |
|
@@ -31,8 +29,6 @@ if __name__ == '__main__':
|
|
31 |
if text_input == display_text:
|
32 |
text_input = example_text
|
33 |
|
34 |
-
# minimum_tokens = 30
|
35 |
-
# maximum_tokens = 100
|
36 |
labels = st.text_input('Possible labels (comma-separated):',ex_labels, max_chars=1000)
|
37 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
38 |
submit_button = st.form_submit_button(label='Submit')
|
@@ -41,8 +37,6 @@ if __name__ == '__main__':
|
|
41 |
if len(labels) == 0:
|
42 |
st.write('Enter some text and at least one possible topic to see predictions.')
|
43 |
|
44 |
-
|
45 |
-
|
46 |
# For each body of text, create text chunks of a certain token size required for the transformer
|
47 |
nested_sentences = create_nest_sentences(document = text_input, token_max_length = 1024)
|
48 |
|
@@ -69,21 +63,17 @@ if __name__ == '__main__':
|
|
69 |
st.markdown(final_summary)
|
70 |
|
71 |
topics, scores = classifier_zero(classifier, sequence=final_summary, labels=labels, multi_class=True)
|
72 |
-
|
73 |
# st.markdown("### Top Label Predictions: Combined Summary")
|
74 |
# plot_result(topics[::-1][:], scores[::-1][:])
|
75 |
-
|
76 |
# st.markdown("### Download Data")
|
77 |
data = pd.DataFrame({'label': topics, 'scores_from_summary': scores})
|
78 |
# st.dataframe(data)
|
79 |
-
|
80 |
# coded_data = base64.b64encode(data.to_csv(index = False). encode ()).decode()
|
81 |
# st.markdown(
|
82 |
# f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Download Data</a>',
|
83 |
# unsafe_allow_html = True
|
84 |
# )
|
85 |
|
86 |
-
|
87 |
st.markdown("### Top Label Predictions: Summary & Full Text")
|
88 |
topics_ex_text, scores_ex_text = classifier_zero(classifier, sequence=example_text, labels=labels, multi_class=True)
|
89 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
|
|
|
|
1 |
|
2 |
from os import write
|
3 |
import pandas as pd
|
|
|
7 |
|
8 |
from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
|
9 |
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
|
|
|
10 |
import json
|
11 |
|
12 |
|
|
|
29 |
if text_input == display_text:
|
30 |
text_input = example_text
|
31 |
|
|
|
|
|
32 |
labels = st.text_input('Possible labels (comma-separated):',ex_labels, max_chars=1000)
|
33 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
34 |
submit_button = st.form_submit_button(label='Submit')
|
|
|
37 |
if len(labels) == 0:
|
38 |
st.write('Enter some text and at least one possible topic to see predictions.')
|
39 |
|
|
|
|
|
40 |
# For each body of text, create text chunks of a certain token size required for the transformer
|
41 |
nested_sentences = create_nest_sentences(document = text_input, token_max_length = 1024)
|
42 |
|
|
|
63 |
st.markdown(final_summary)
|
64 |
|
65 |
topics, scores = classifier_zero(classifier, sequence=final_summary, labels=labels, multi_class=True)
|
|
|
66 |
# st.markdown("### Top Label Predictions: Combined Summary")
|
67 |
# plot_result(topics[::-1][:], scores[::-1][:])
|
|
|
68 |
# st.markdown("### Download Data")
|
69 |
data = pd.DataFrame({'label': topics, 'scores_from_summary': scores})
|
70 |
# st.dataframe(data)
|
|
|
71 |
# coded_data = base64.b64encode(data.to_csv(index = False). encode ()).decode()
|
72 |
# st.markdown(
|
73 |
# f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Download Data</a>',
|
74 |
# unsafe_allow_html = True
|
75 |
# )
|
76 |
|
|
|
77 |
st.markdown("### Top Label Predictions: Summary & Full Text")
|
78 |
topics_ex_text, scores_ex_text = classifier_zero(classifier, sequence=example_text, labels=labels, multi_class=True)
|
79 |
plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
|
models.py
CHANGED
@@ -33,7 +33,6 @@ def load_summary_model():
|
|
33 |
summarizer = pipeline(task='summarization', model=model_name)
|
34 |
return summarizer
|
35 |
|
36 |
-
|
37 |
# def load_summary_model():
|
38 |
# model_name = "facebook/bart-large-mnli"
|
39 |
# tokenizer = BartTokenizer.from_pretrained(model_name)
|
@@ -41,7 +40,6 @@ def load_summary_model():
|
|
41 |
# summarizer = pipeline(task='summarization', model=model, tokenizer=tokenizer, framework='pt')
|
42 |
# return summarizer
|
43 |
|
44 |
-
|
45 |
def summarizer_gen(summarizer, sequence:str, maximum_tokens:int, minimum_tokens:int):
|
46 |
output = summarizer(sequence, num_beams=4, max_length=maximum_tokens, min_length=minimum_tokens, do_sample=False)
|
47 |
return output[0].get('summary_text')
|
|
|
33 |
summarizer = pipeline(task='summarization', model=model_name)
|
34 |
return summarizer
|
35 |
|
|
|
36 |
# def load_summary_model():
|
37 |
# model_name = "facebook/bart-large-mnli"
|
38 |
# tokenizer = BartTokenizer.from_pretrained(model_name)
|
|
|
40 |
# summarizer = pipeline(task='summarization', model=model, tokenizer=tokenizer, framework='pt')
|
41 |
# return summarizer
|
42 |
|
|
|
43 |
def summarizer_gen(summarizer, sequence:str, maximum_tokens:int, minimum_tokens:int):
|
44 |
output = summarizer(sequence, num_beams=4, max_length=maximum_tokens, min_length=minimum_tokens, do_sample=False)
|
45 |
return output[0].get('summary_text')
|