Refreshing zero-shot classification with ModernBERT
Text classification is one of the fundamental tasks in ML with a long history of research and profound practical interests. Moreover, it’s one of the most essential components in many practical projects, starting from search engines, and ending with biomedical research. Text classification approaches are used for scientific article classification, user ticket classification, sentiment analysis of social media posts, company classification for financial research, etc. Let’s even generalize the tasks to sequence classification, the amount of applications and impact in this domain is even larger, starting from DNA sequences classification and ending with RAG pipelines which are the most used way to guarantee high-quality up-to-date outputs in chatbot systems.
Recent advancements in auto-regressive language modelling opened new frontiers for many zero-shot classification tasks including text classification. While these models offer impressive versatility, they often struggle with strict instruction adherence and can be computationally inefficient both for training and inference. Check our project designed to enable zero-shot classification with generative language models: Knowledgator Unlimited Classifier
Cross-encoders as Natural Language Inference (NLI) models are another popular approach frequently used as zero-shot classifiers and in Retrieval-Augmented Generation (RAG) pipelines. The method works by posing the sequence to be classified as the NLI premise and constructing a hypothesis from each candidate label. Overall, such an approach faces efficiency challenges when dealing with a large number of classes due to its pairwise processing approach. Moreover, it has limited abilities to comprehend cross-label information that can influence the quality of predictions, especially in complex scenarios.
Science Word2Vec work embedding-based approaches were recognised as one of the potential methods for text classification, particularly in the zero-shot settings. A better semantic understanding of sentences and text with sentence encoders became obvious and proposed the idea of using sentence embeddings for text classification. The emergence of Sentence Transformers improves the quality of embeddings even further and it became possible to use them for classification tasks even without fine-tuning. SetFit — a work based on top of sentence transformers made possible to achieve good performance even training on a small number of examples per label. Despite their efficiency and good performance on many semantic tasks embedding-based methods often fall short in complex scenarios involving logical and semantic constraints.
We propose a novel approach for text classification that builds upon the GLiNER architecture, adapting it specifically for sequence classification tasks. Our method aims to strike a balance between the accuracy of more complex models and the efficiency of embedding-based approaches while maintaining good zero-shot and few-shot capabilities.
Architecture of GLiClass
Our architecture introduces a novel approach to sequence classification that enables rich interactions between labels and input text while maintaining computational efficiency. The implementation consists of several key stages that work in concert to achieve superior classification performance.
Input Processing and Label Integration
The process begins with a label integration mechanism. We prepend each class label with a special token <
Contextual Representation Learning
Following tokenization, the combined input IDs are processed through a bi-directional transformer architecture (such as BERT or DeBERTa). This stage is crucial as it enables three distinct types of contextual understanding:
- Label-to-label interactions: Labels can share information between themselves, enabling the model to understand label relationships and hierarchies
- Text-to-label interactions: The input text can directly influence label representations
- Label-to-text interactions: Label information can guide the interpretation of the text
This multi-directional flow of information represents a significant advantage over traditional cross-encoder architectures, which typically limit interactions to text-label pairs and miss valuable inter-label relationships.
Representation Pooling
After obtaining contextualized representations, we employ separate pooling mechanisms for labels and text to distil the essential information from the transformer outputs. Our implementation supports multiple pooling strategies:
- First token pooling: Utilizing the initial token’s representation
- Mean pooling: Averaging across all tokens
- Attention-weighted pooling: Applying learned attention weights
- Custom pooling strategies: Tailored to specific classification requirements
The choice of pooling strategy can be optimized based on the specific requirements for the classification task and the nature of the input data.
Scoring Mechanism
The final stage involves computing compatibility scores between the pooled representations. We implement this through a flexible scoring framework that can accommodate various approaches:
** Simple dot product scoring: Efficient and effective for many applications ** Neural network scoring: More complex scoring functions for challenging scenarios ** Task-specific scoring modules: Customized for particular classification requirements
This modular scoring approach allows the architecture to adapt to different classification scenarios while maintaining computational efficiency.
How to use the models
We open-sourced our model at Hugging Face. Modern GLiClass Collection
To use them, first of all, install gliclass package:
pip install gliclass
Then you need to initialize a model and a pipeline:
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v2.0-init")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v2.0-init")
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
Here is how you can perform inference:
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
How to fine-tune
First of all, you need to prepare training data in the following format:
[
{"text": "Some text here!",
"all_labels": ["sport", "science", "business", …],
"true_labels": ["other"]}, …
]
Below, you can see the needed importing requirements:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from datasets import load_dataset, Dataset, DatasetDict
from sklearn.metrics import classification_report, f1_score, precision_recall_fscore_support, accuracy_score
import numpy as np
import random
from transformers import AutoTokenizer
import torch
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding
from gliclass.training import TrainingArguments, Trainer
Then, we initialize the model and tokenizer:
device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')
model_name = 'knowledgator/gliclass-base-v1.0'
model = GLiClassModel.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
After, that we specify training arguments:
max_length = 1024
problem_type = "multi_label_classification"
architecture_type = model.config.architecture_type
prompt_first = model.config.prompt_first
training_args = TrainingArguments(
output_dir='models/test',
learning_rate=1e-5,
weight_decay=0.01,
others_lr=1e-5,
others_weight_decay=0.01,
lr_scheduler_type='linear',
warmup_ratio=0.0,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=8,
evaluation_strategy="epoch",
save_steps = 1000,
save_total_limit=10,
dataloader_num_workers=8,
logging_steps=10,
use_cpu = False,
report_to="none",
fp16=False,
)
When you prepared your dataset in the right format we need to initialise the GLiClass dataset and data collectors:
train_dataset = GLiClassDataset(train_data, tokenizer, max_length, problem_type, architecture_type, prompt_first)
test_dataset = GLiClassDataset(train_data[:int(len(train_data)*0.1)], tokenizer, max_length, problem_type, architecture_type, prompt_first)
data_collator = DataCollatorWithPadding(device=device)
When everything is set we can start the training:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
See more examples in the repo: https://github.com/Knowledgator/GLiClass/blob/main/finetuning.ipynb
Key Applications
GLiClass demonstrates remarkable versatility across a wide range of natural language processing tasks, making it particularly valuable for both research and practical applications.
Multi-class Classification
The architecture efficiently handles large-scale classification tasks with up to 100 distinct classes in a single processing run. This capability is particularly valuable for applications such as document categorization, product classification, and content tagging systems where multiple detailed categories are required.
Topic Classification
GLiClass excels in identifying and categorizing the subject matter of texts, making it ideal for:
- Academic paper categorization
- News article classification
- Content recommendation systems
- Research document organization
Sentiment Analysis
The architecture effectively captures nuanced emotional and opinion-based content, supporting:
- Social media sentiment tracking
- Customer feedback analysis
- Product review classification
- Brand perception monitoring
Event Classification
GLiClass demonstrates strong capabilities in identifying and categorizing events within the text, supporting:
- News event categorization
- Social media event detection
- Historical event classification
- Timeline analysis and organization
Prompt-based Constrained Classification
The system offers flexible prompt-based classification with custom constraints, enabling:
- Guided classification tasks
- Context-aware categorization
- Custom classification rules
- Dynamic category adaptation
Natural Language Inference
GLiClass supports sophisticated reasoning about textual relationships, facilitating:
- Textual entailment detection
- Contradiction identification
- Semantic similarity assessment
- Logical relationship analysis
Retrieval-augmented generation (RAG)
Good generalization of architecture, together with support of natural language inference tasks make it a good choice for reranking in RAG pipelines. Moreover, the efficiency of GLiClass makes it more competitive especially in comparison with cross-encoders.
This comprehensive range of applications positions GLiClass as a versatile tool for modern natural language processing challenges, offering both flexibility and precision across diverse classification tasks.
Benchmarking results
We released a new GLiClass model based on ModernBERT, it opens new possibilities in comparison to older models like DeBERTa by providing longer context length support (up to 8k tokens) and much faster inference. We benchmarked our GLiClass models with our architectures and an older version of the architecture.
Below, you can see the F1 score on several text classification datasets. All tested models were not fine-tuned on those datasets and were tested in a zero-shot setting.
Model | IMDB | AG_NEWS | Emotions |
---|---|---|---|
gliclass-modern-large-v2.0-init (399 M) | 0.9137 | 0.7357 | 0.4140 |
gliclass-modern-base-v2.0-init (151 M) | 0.8264 | 0.6637 | 0.2985 |
gliclass-large-v1.0 (438 M) | 0.9404 | 0.7516 | 0.4874 |
gliclass-base-v1.0 (186 M) | 0.8650 | 0.6837 | 0.4749 |
gliclass-small-v1.0 (144 M) | 0.8650 | 0.6805 | 0.4664 |
Bart-large-mnli (407 M) | 0.89 | 0.6887 | 0.3765 |
Deberta-base-v3 (184 M) | 0.85 | 0.6455 | 0.5095 |
Comprehendo (184M) | 0.90 | 0.7982 | 0.5660 |
SetFit BAAI/bge-small-en-v1.5 (33.4M) | 0.86 | 0.5636 | 0.5754 |
Below you can find a more comprehensive comparison of ModernBERT GLiClass with other GLiClass models:
Dataset | gliclass-base-v1.0-init | gliclass-large-v1.0-init | gliclass-modern-base-v2.0-init | gliclass-modern-large-v2.0-init |
---|---|---|---|---|
CR | 0.8672 | 0.8024 | 0.9041 | 0.8980 |
sst2 | 0.8342 | 0.8734 | 0.9011 | 0.9434 |
sst5 | 0.2048 | 0.1638 | 0.1972 | 0.1123 |
20_news_groups | 0.2317 | 0.4151 | 0.2448 | 0.2792 |
spam | 0.5963 | 0.5407 | 0.5074 | 0.6364 |
financial_phrasebank | 0.3594 | 0.3705 | 0.2537 | 0.2562 |
imdb | 0.8772 | 0.8836 | 0.8255 | 0.9137 |
ag_news | 0.5614 | 0.7069 | 0.6050 | 0.6933 |
emotion | 0.2865 | 0.3840 | 0.2474 | 0.3746 |
cap_sotu | 0.3966 | 0.4353 | 0.2929 | 0.2919 |
rotten_tomatoes | 0.6626 | 0.7933 | 0.6630 | 0.5928 |
AVERAGE: | 0.5344 | 0.5790 | 0.5129 | 0.5447 |
We examined how performance grows if we fine-tune the model on a small portion of examples per label. Additionally, we tested a simple approach when we provided no real text but rather generic short descriptions of a given text topic, we called it weak-supervision. Surprisingly, for some datasets like “emotion” it improved performance substantially.
Model | Num Examples | sst5 | ag_news | emotion | AVERAGE: |
---|---|---|---|---|---|
gliclass-modern-large-v2.0-init | 0 | 0.1123 | 0.6933 | 0.3746 | 0.3934 |
gliclass-modern-large-v2.0-init | 8 | 0.5098 | 0.8339 | 0.5010 | 0.6149 |
gliclass-modern-large-v2.0-init | Weak Supervision | 0.0951 | 0.6478 | 0.4520 | 0.3983 |
gliclass-modern-base-v2.0-init | 0 | 0.1972 | 0.6050 | 0.2474 | 0.3499 |
gliclass-modern-base-v2.0-init | 8 | 0.3604 | 0.7481 | 0.4420 | 0.5168 |
gliclass-modern-base-v2.0-init | Weak Supervision | 0.1599 | 0.5713 | 0.3216 | 0.3509 |
gliclass-large-v1.0-init | 0 | 0.1639 | 0.7069 | 0.3840 | 0.4183 |
gliclass-large-v1.0-init | 8 | 0.4226 | 0.8415 | 0.4886 | 0.5842 |
gliclass-large-v1.0-init | Weak Supervision | 0.1689 | 0.7051 | 0.4586 | 0.4442 |
gliclass-base-v1.0-init | 0 | 0.2048 | 0.5614 | 0.2865 | 0.3509 |
gliclass-base-v1.0-init | 8 | 0.2007 | 0.8359 | 0.4856 | 0.5074 |
gliclass-base-v1.0-init | Weak Supervision | 0.0681 | 0.6627 | 0.3066 | 0.3458 |
Conclusion
GLiClass represents a significant advancement in the field of text classification, offering a robust and efficient solution that bridges the gap between the accuracy of complex transformer-based models and the simplicity of embedding-based approaches. By leveraging a novel architecture that facilitates rich interactions between input text and labels, GLiClass achieves superior performance in zero-shot and few-shot classification tasks while maintaining computational efficiency, even with large label sets. Its ability to capture cross-label dependencies, adapt to diverse classification scenarios, and integrate seamlessly with existing NLP pipelines positions it as a versatile tool for a wide range of applications, from sentiment analysis and topic classification to retrieval-augmented generation and natural language inference.