Fine-tuned Longformer for Summarization of Machine Learning Articles
Model Details
- GitHub: https://github.com/Bakhitovd/led-base-7168-ml
- Model name: bakhitovd/led-base-7168-ml
- Model type: Longformer (alenai/led-base-16384)
- Model description: This Longformer model has been fine-tuned on a focused subset of the arXiv part of the scientific papers dataset, specifically targeting articles about Machine Learning. It aims to generate accurate and consistent summaries of machine learning research papers.
Intended Use
This model is intended to be used for text summarization tasks, specifically for summarizing machine learning research papers.
How to Use
import torch
from transformers import LEDTokenizer, LEDForConditionalGeneration
tokenizer = LEDTokenizer.from_pretrained("bakhitovd/led-base-7168-ml")
model = LEDForConditionalGeneration.from_pretrained("bakhitovd/led-base-7168-ml")
Use the model for summarization
article = "... long document ..."
inputs_dict = tokenizer.encode(article, padding="max_length", max_length=16384, return_tensors="pt", truncation=True)
input_ids = inputs_dict.input_ids.to("cuda")
attention_mask = inputs_dict.attention_mask.to("cuda")
global_attention_mask = torch.zeros_like(attention_mask)
global_attention_mask[:, 0] = 1
predicted_abstract_ids = model.generate(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask, max_length=512)
summary = tokenizer.decode(predicted_abstract_ids, skip_special_tokens=True)
print(summary)
Training Data
Dataset name: bakhitovd/data_science_arxiv
This dataset is a subset of the 'Scientific papers' dataset, which contains articles semantically, structurally, and meaningfully closest to articles describing machine learning. This subset was obtained using K-means clustering on the embeddings generated by SciBERT.
Evaluation Results
The model's performance was evaluated using ROUGE metrics and it showed improved performance over the baseline models.
- Downloads last month
- 124
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.