luminoria commited on
Commit
bc70267
1 Parent(s): ec34458

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -15
app.py CHANGED
@@ -2,6 +2,7 @@ from transformers import T5ForConditionalGeneration,T5Tokenizer
2
  from transformers import AutoModelWithLMHead, AutoTokenizer
3
  from transformers import pipeline
4
  import streamlit as st
 
5
 
6
  model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
7
  tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
@@ -9,37 +10,55 @@ tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
9
  mrm_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
10
  mrm_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
11
 
 
 
 
 
12
 
13
  def generate_title(article):
14
  text = "headline: " + article
15
  encoding = tokenizer.encode_plus(text, return_tensors = "pt", max_length=2048, truncation=True)
16
  input_ids = encoding["input_ids"]
17
  attention_masks = encoding["attention_mask"]
18
-
19
  beam_outputs = model.generate(
20
- input_ids = input_ids,
21
- attention_mask = attention_masks,
22
- max_length = 50,
23
- num_beams = 3,
24
- do_sample = True,
25
- top_k=10,
26
- early_stopping = False,
27
- )
28
 
29
  return tokenizer.decode(beam_outputs[0])
30
 
31
- # def generate_summary(article):
32
- # input_ids = mrm_tokenizer.encode(article, return_tensors="pt", add_special_tokens=True)
33
-
34
- # generated_ids = mrm_model.generate(input_ids=input_ids, num_beams=3, max_length=200, repetition_penalty=2.5, length_penalty=1.0, early_stopping=False, truncation=True)
 
 
 
 
35
 
36
- # preds = [mrm_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # return preds[0]
39
  def generate_summary(article):
40
  article = article[:1024]
41
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
42
  return summarizer(article, max_length=130, min_length=30, do_sample=False)
 
43
  def main():
44
  st.title("Text Summarization")
45
  text = st.text_area("Enter your text here:", "")
@@ -49,11 +68,15 @@ def main():
49
  st.error("Please enter some text.")
50
  else:
51
  title = generate_title(text)
 
52
  summary = generate_summary(text)
53
  # summary = summary[0]['summary_text']
54
 
55
  st.subheader("Generated Title:")
56
  st.write(title.replace('<pad>', '').replace('</s>', ''))
 
 
 
57
 
58
  st.subheader("Generated Description:")
59
 
 
2
  from transformers import AutoModelWithLMHead, AutoTokenizer
3
  from transformers import pipeline
4
  import streamlit as st
5
+ import re
6
 
7
  model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
8
  tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
 
10
  mrm_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
11
  mrm_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news")
12
 
13
+ jules_tokenizer = AutoTokenizer.from_pretrained("JulesBelveze/t5-small-headline-generator")
14
+ jules_model = T5ForConditionalGeneration.from_pretrained("JulesBelveze/t5-small-headline-generator")
15
+ # rouge = Rouge()
16
+ WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
17
 
18
  def generate_title(article):
19
  text = "headline: " + article
20
  encoding = tokenizer.encode_plus(text, return_tensors = "pt", max_length=2048, truncation=True)
21
  input_ids = encoding["input_ids"]
22
  attention_masks = encoding["attention_mask"]
 
23
  beam_outputs = model.generate(
24
+ input_ids = input_ids,
25
+ attention_mask = attention_masks,
26
+ max_length = 50,
27
+ num_beams = 3,
28
+ do_sample = False,
29
+ # top_k=10,
30
+ early_stopping = False,
31
+ )
32
 
33
  return tokenizer.decode(beam_outputs[0])
34
 
35
+ def generate_title_2(article):
36
+ input_ids = tokenizer(
37
+ [WHITESPACE_HANDLER(article)],
38
+ return_tensors="pt",
39
+ padding="max_length",
40
+ truncation=True,
41
+ max_length=384
42
+ )["input_ids"]
43
 
44
+ output_ids = model.generate(
45
+ input_ids=input_ids,
46
+ max_length=84,
47
+ no_repeat_ngram_size=2,
48
+ num_beams=4
49
+ )[0]
50
+ summary = tokenizer.decode(
51
+ output_ids,
52
+ skip_special_tokens=True,
53
+ clean_up_tokenization_spaces=False
54
+ )
55
+ return summary
56
 
 
57
  def generate_summary(article):
58
  article = article[:1024]
59
  summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
60
  return summarizer(article, max_length=130, min_length=30, do_sample=False)
61
+
62
  def main():
63
  st.title("Text Summarization")
64
  text = st.text_area("Enter your text here:", "")
 
68
  st.error("Please enter some text.")
69
  else:
70
  title = generate_title(text)
71
+ title_2 = generate_title_2(text)
72
  summary = generate_summary(text)
73
  # summary = summary[0]['summary_text']
74
 
75
  st.subheader("Generated Title:")
76
  st.write(title.replace('<pad>', '').replace('</s>', ''))
77
+
78
+ st.subheader("Second Title:")
79
+ st.write(title_2)
80
 
81
  st.subheader("Generated Description:")
82