|
|
|
"""Untitled0.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1aMkctyYgdHD61sv7-bJHFN1B5taCv6c2 |
|
""" |
|
|
|
import gradio as gr |
|
from datasets import load_dataset |
|
import evaluate |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer |
|
import numpy as np |
|
import nltk |
|
|
|
nltk.download("punkt") |
|
raw_dataset = load_dataset("scientific_papers", "pubmed") |
|
metric = evaluate.load("rouge") |
|
model_checkpoint = "t5-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) |
|
|
|
if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]: |
|
prefix = "summarize: " |
|
else: |
|
prefix = "" |
|
|
|
|
|
max_input_length = 512 |
|
max_target_length = 128 |
|
def preprocess_function(examples): |
|
inputs = [prefix + doc for doc in examples["article"]] |
|
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True) |
|
|
|
|
|
|
|
labels = tokenizer(text_target=examples["abstract"], max_length=max_target_length, truncation=True) |
|
|
|
model_inputs["labels"] = labels["input_ids"] |
|
return model_inputs |
|
|
|
for split in ["train", "validation", "test"]: |
|
raw_dataset[split] = raw_dataset[split].select([n for n in np.random.randint(0, len(raw_dataset[split]) - 1, 1_000)]) |
|
tokenized_dataset = raw_dataset.map(preprocess_function, batched=True) |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint) |
|
|
|
batch_size = 8 |
|
|
|
args = Seq2SeqTrainingArguments( |
|
f"{model_checkpoint}-scientific_papers", |
|
evaluation_strategy="epoch", |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
weight_decay=0.01, |
|
save_total_limit=3, |
|
num_train_epochs=1, |
|
predict_with_generate=True, |
|
|
|
push_to_hub=False, |
|
gradient_accumulation_steps=2 |
|
) |
|
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model) |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
predictions, labels = eval_pred |
|
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) |
|
|
|
labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
|
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
|
decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds] |
|
decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels] |
|
result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) |
|
|
|
result = {key: value * 100 for key, value in result.items()} |
|
|
|
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions] |
|
result["gen_len"] = np.mean(prediction_lens) |
|
return {k: round(v, 4) for k, v in result.items()} |
|
|
|
trainer = Seq2SeqTrainer( |
|
model, |
|
args, |
|
train_dataset=tokenized_dataset["train"], |
|
eval_dataset=tokenized_dataset["validation"], |
|
data_collator=data_collator, |
|
tokenizer=tokenizer, |
|
compute_metrics=compute_metrics |
|
) |
|
trainer.train() |
|
|
|
|
|
import gradio as gr |
|
|
|
def summarizer(input_text): |
|
inputs = [prefix + input_text] |
|
model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="pt") |
|
summary_ids = model.generate( |
|
input_ids=model_inputs["input_ids"], |
|
attention_mask=model_inputs["attention_mask"], |
|
num_beams=4, |
|
length_penalty=2.0, |
|
max_length=max_target_length + 2, |
|
repetition_penalty=2.0, |
|
early_stopping=True, |
|
use_cache=True |
|
) |
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary |
|
|
|
|
|
iface = gr.Interface( |
|
fn=summarizer, |
|
inputs=gr.inputs.Textbox(label="Input Text"), |
|
outputs=gr.outputs.Textbox(label="Summary"), |
|
title="Scientific Paper Summarizer", |
|
description="Summarizes scientific papers using a fine-tuned T5 model", |
|
theme="gray" |
|
) |
|
iface.launch() |