|
import nltk |
|
import torch |
|
from summarizer import Summarizer |
|
from sumy.nlp.tokenizers import Tokenizer |
|
from sumy.summarizers.lsa import LsaSummarizer |
|
from sumy.parsers.plaintext import PlaintextParser |
|
from sumy.summarizers.lex_rank import LexRankSummarizer |
|
from sumy.summarizers.sum_basic import SumBasicSummarizer |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
nltk.download('punkt') |
|
|
|
def extractive(method, file): |
|
sumarizer = method |
|
sentences_ = [] |
|
doc_ = PlaintextParser(file, Tokenizer("en")).document |
|
for sentence in sumarizer(doc_, 5): |
|
sentences_.append(str(sentence)) |
|
summm_ = " ".join(sentences_) |
|
return summm_ |
|
|
|
def summarize(file, model): |
|
|
|
with open(file.name) as f: |
|
doc = f.read() |
|
|
|
if model == "Pegasus": |
|
checkpoint = "google/pegasus-billsum" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) |
|
inputs = tokenizer(doc, |
|
max_length=1024, |
|
truncation=True, |
|
return_tensors="pt") |
|
|
|
summary_ids = model.generate(inputs["input_ids"]) |
|
summary = tokenizer.batch_decode(summary_ids, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False) |
|
summary = summary[0] |
|
elif model == "LEDBill": |
|
tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True) |
|
|
|
input_ids = tokenizer(doc, return_tensors="pt").input_ids |
|
global_attention_mask = torch.zeros_like(input_ids) |
|
global_attention_mask[:, 0] = 1 |
|
|
|
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences |
|
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True) |
|
|
|
summary = summary[0] |
|
|
|
elif model == "ILC": |
|
tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True) |
|
|
|
input_ids = tokenizer(doc, return_tensors="pt").input_ids |
|
global_attention_mask = torch.zeros_like(input_ids) |
|
global_attention_mask[:, 0] = 1 |
|
|
|
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences |
|
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True) |
|
|
|
summary = summary[0] |
|
elif model == "Distill": |
|
checkpoint = "sshleifer/distill-pegasus-cnn-16-4" |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) |
|
inputs = tokenizer(doc, |
|
max_length=1024, |
|
truncation=True, |
|
return_tensors="pt") |
|
|
|
summary_ids = model.generate(inputs["input_ids"]) |
|
summary = tokenizer.batch_decode(summary_ids, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=False) |
|
summary = summary[0] |
|
|
|
elif model == "TextRank": |
|
summary = extractive(LexRankSummarizer(), doc) |
|
|
|
elif model == "SumBasic": |
|
summary = extractive(SumBasicSummarizer(), doc) |
|
|
|
elif model == "Lsa": |
|
summary = extractive(LsaSummarizer(), doc) |
|
|
|
elif model == "BERT": |
|
modelbert = Summarizer('distilbert-base-uncased', hidden=[-1,-2], hidden_concat=True) |
|
result = modelbert(doc) |
|
summary = ''.join(result) |
|
|
|
return summary |