Spaces:
Runtime error
Runtime error
File size: 1,901 Bytes
f98f59d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
def load_summarizer(model_code):
name_dict = {
"bart": "facebook/bart-large-cnn",
"distill-bart": "sshleifer/distilbart-cnn-12-6",
"roberta": "google/roberta2roberta_L-24_cnn_daily_mail",
"pegasus": "google/pegasus-cnn_dailymail"
}
model_name = name_dict[model_code.lower()]
model, tokenizer = None, None
if "bart" in model_name:
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
if "pegasus" in model_name:
tokenizer = PegasusTokenizer.from_pretrained(model_name)
model = PegasusForConditionalGeneration.from_pretrained(model_name)
if "roberta" in model_name:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return model, tokenizer
def summarize_input(
input_article,
model,
tokenizer,
max_length=150,
min_length=50,
num_beams=3,
length_penalty=0.5,
no_repeat_ngram_size=3
):
text_input_ids = tokenizer.batch_encode_plus(
[input_article],
return_tensors='pt',
max_length=tokenizer.model_max_length
)['input_ids'].to("cpu")
summary_ids = model.generate(
text_input_ids,
num_beams=int(num_beams),
length_penalty=float(length_penalty),
max_length=int(max_length),
min_length=int(min_length),
no_repeat_ngram_size=int(no_repeat_ngram_size)
)
summary_txt = tokenizer.decode(summary_ids.squeeze(), skip_special_tokens=True)
return summary_txt.replace("<n>", "") |