File size: 3,498 Bytes
d3c86fd
2c55a80
38e9364
d3c86fd
38e9364
 
 
 
 
d3c86fd
 
 
38e9364
 
 
 
 
 
 
 
 
 
 
 
 
d3c86fd
38e9364
 
 
 
 
 
 
 
 
 
 
 
 
 
2c55a80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e76e540
 
 
 
 
 
 
 
 
 
 
 
 
 
d3c86fd
38e9364
 
d3c86fd
38e9364
 
 
 
 
 
 
 
 
 
d3c86fd
38e9364
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import nltk
import torch 
from summarizer import Summarizer
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.lsa import LsaSummarizer
from sumy.parsers.plaintext import PlaintextParser
from sumy.summarizers.lex_rank import LexRankSummarizer
from sumy.summarizers.sum_basic import SumBasicSummarizer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

nltk.download('punkt')

def extractive(method, file):
  sumarizer = method
  sentences_ = []
  doc_ = PlaintextParser(file, Tokenizer("en")).document
  for sentence in sumarizer(doc_, 5):
    sentences_.append(str(sentence))
    summm_ = " ".join(sentences_)
  return summm_

def summarize(file, model):
  
  with open(file.name) as f:
    doc = f.read()

  if model == "Pegasus":
    checkpoint = "google/pegasus-billsum"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
    inputs = tokenizer(doc, 
                    max_length=1024, 
                    truncation=True,
                    return_tensors="pt")
    
    summary_ids = model.generate(inputs["input_ids"])
    summary = tokenizer.batch_decode(summary_ids, 
                                  skip_special_tokens=True, 
                                  clean_up_tokenization_spaces=False)    
    summary =  summary[0]
  elif model == "LEDBill":
    tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill")
    model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True)
    
    input_ids = tokenizer(doc, return_tensors="pt").input_ids
    global_attention_mask = torch.zeros_like(input_ids)
    global_attention_mask[:, 0] = 1
    
    sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
    summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
    
    summary = summary[0]
    
  elif model == "ILC":
    tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc")
    model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True)
    
    input_ids = tokenizer(doc, return_tensors="pt").input_ids
    global_attention_mask = torch.zeros_like(input_ids)
    global_attention_mask[:, 0] = 1
    
    sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
    summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
    
    summary = summary[0]
  elif model == "Distill":
    checkpoint = "sshleifer/distill-pegasus-cnn-16-4"
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
    inputs = tokenizer(doc, 
                    max_length=1024, 
                    truncation=True,
                    return_tensors="pt")
    
    summary_ids = model.generate(inputs["input_ids"])
    summary = tokenizer.batch_decode(summary_ids, 
                                  skip_special_tokens=True, 
                                  clean_up_tokenization_spaces=False)    
    summary =  summary[0]    

  elif model == "TextRank":
    summary = extractive(LexRankSummarizer(), doc)

  elif model == "SumBasic":
    summary = extractive(SumBasicSummarizer(), doc)

  elif model == "Lsa":
    summary = extractive(LsaSummarizer(), doc)

  elif model == "BERT":
    modelbert = Summarizer('distilbert-base-uncased', hidden=[-1,-2], hidden_concat=True)
    result = modelbert(doc)
    summary = ''.join(result)
    
  return summary