Spaces:
Runtime error
Runtime error
from transformers import T5ForConditionalGeneration,T5Tokenizer | |
from transformers import AutoModelWithLMHead, AutoTokenizer | |
from transformers import pipeline | |
import streamlit as st | |
import re | |
model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline") | |
tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline") | |
mrm_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-summarize-news") | |
mrm_model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-summarize-news") | |
jules_tokenizer = AutoTokenizer.from_pretrained("JulesBelveze/t5-small-headline-generator") | |
jules_model = T5ForConditionalGeneration.from_pretrained("JulesBelveze/t5-small-headline-generator") | |
# rouge = Rouge() | |
WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip())) | |
def generate_title(article): | |
text = "headline: " + article | |
encoding = tokenizer.encode_plus(text, return_tensors = "pt", max_length=2048, truncation=True) | |
input_ids = encoding["input_ids"] | |
attention_masks = encoding["attention_mask"] | |
beam_outputs = model.generate( | |
input_ids = input_ids, | |
attention_mask = attention_masks, | |
max_length = 50, | |
num_beams = 3, | |
do_sample = False, | |
# top_k=10, | |
early_stopping = False, | |
) | |
return tokenizer.decode(beam_outputs[0]) | |
def generate_title_2(article): | |
input_ids = tokenizer( | |
[WHITESPACE_HANDLER(article)], | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=384 | |
)["input_ids"] | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_length=84, | |
no_repeat_ngram_size=2, | |
num_beams=4 | |
)[0] | |
summary = tokenizer.decode( | |
output_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
) | |
return summary | |
def generate_summary(article): | |
article = article[:1024] | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
return summarizer(article, max_length=130, min_length=30, do_sample=False) | |
def main(): | |
st.title("Text Summarization") | |
text = st.text_area("Enter your text here:", "") | |
if st.button("Generate Summary"): | |
if text.strip() == "": | |
st.error("Please enter some text.") | |
else: | |
title = generate_title(text) | |
title_2 = generate_title_2(text) | |
summary = generate_summary(text) | |
# summary = summary[0]['summary_text'] | |
st.subheader("Generated Title:") | |
st.write(title.replace('<pad>', '').replace('</s>', '')) | |
st.subheader("Second Title:") | |
st.write(title_2) | |
st.subheader("Generated Description:") | |
# st.write(summary.replace('<pad>', '').replace('</s>', '')) | |
st.write(summary[0]['summary_text']) | |
if __name__ == "__main__": | |
main() | |