File size: 3,498 Bytes
d3c86fd 2c55a80 38e9364 d3c86fd 38e9364 d3c86fd 38e9364 d3c86fd 38e9364 2c55a80 e76e540 d3c86fd 38e9364 d3c86fd 38e9364 d3c86fd 38e9364 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import nltk
import torch
from summarizer import Summarizer
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.lsa import LsaSummarizer
from sumy.parsers.plaintext import PlaintextParser
from sumy.summarizers.lex_rank import LexRankSummarizer
from sumy.summarizers.sum_basic import SumBasicSummarizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
nltk.download('punkt')
def extractive(method, file):
sumarizer = method
sentences_ = []
doc_ = PlaintextParser(file, Tokenizer("en")).document
for sentence in sumarizer(doc_, 5):
sentences_.append(str(sentence))
summm_ = " ".join(sentences_)
return summm_
def summarize(file, model):
with open(file.name) as f:
doc = f.read()
if model == "Pegasus":
checkpoint = "google/pegasus-billsum"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
inputs = tokenizer(doc,
max_length=1024,
truncation=True,
return_tensors="pt")
summary_ids = model.generate(inputs["input_ids"])
summary = tokenizer.batch_decode(summary_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
summary = summary[0]
elif model == "LEDBill":
tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill")
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True)
input_ids = tokenizer(doc, return_tensors="pt").input_ids
global_attention_mask = torch.zeros_like(input_ids)
global_attention_mask[:, 0] = 1
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
summary = summary[0]
elif model == "ILC":
tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc")
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True)
input_ids = tokenizer(doc, return_tensors="pt").input_ids
global_attention_mask = torch.zeros_like(input_ids)
global_attention_mask[:, 0] = 1
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
summary = summary[0]
elif model == "Distill":
checkpoint = "sshleifer/distill-pegasus-cnn-16-4"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
inputs = tokenizer(doc,
max_length=1024,
truncation=True,
return_tensors="pt")
summary_ids = model.generate(inputs["input_ids"])
summary = tokenizer.batch_decode(summary_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
summary = summary[0]
elif model == "TextRank":
summary = extractive(LexRankSummarizer(), doc)
elif model == "SumBasic":
summary = extractive(SumBasicSummarizer(), doc)
elif model == "Lsa":
summary = extractive(LsaSummarizer(), doc)
elif model == "BERT":
modelbert = Summarizer('distilbert-base-uncased', hidden=[-1,-2], hidden_concat=True)
result = modelbert(doc)
summary = ''.join(result)
return summary |