Refreshing zero-shot classification with ModernBERT

Community Article Published February 25, 2025

Text classification is one of the fundamental tasks in ML with a long history of research and profound practical interests. Moreover, it’s one of the most essential components in many practical projects, starting from search engines, and ending with biomedical research. Text classification approaches are used for scientific article classification, user ticket classification, sentiment analysis of social media posts, company classification for financial research, etc. Let’s even generalize the tasks to sequence classification, the amount of applications and impact in this domain is even larger, starting from DNA sequences classification and ending with RAG pipelines which are the most used way to guarantee high-quality up-to-date outputs in chatbot systems.


Recent advancements in auto-regressive language modelling opened new frontiers for many zero-shot classification tasks including text classification. While these models offer impressive versatility, they often struggle with strict instruction adherence and can be computationally inefficient both for training and inference. Check our project designed to enable zero-shot classification with generative language models: Knowledgator Unlimited Classifier

Cross-encoders as Natural Language Inference (NLI) models are another popular approach frequently used as zero-shot classifiers and in Retrieval-Augmented Generation (RAG) pipelines. The method works by posing the sequence to be classified as the NLI premise and constructing a hypothesis from each candidate label. Overall, such an approach faces efficiency challenges when dealing with a large number of classes due to its pairwise processing approach. Moreover, it has limited abilities to comprehend cross-label information that can influence the quality of predictions, especially in complex scenarios.

image/png

Science Word2Vec work embedding-based approaches were recognised as one of the potential methods for text classification, particularly in the zero-shot settings. A better semantic understanding of sentences and text with sentence encoders became obvious and proposed the idea of using sentence embeddings for text classification. The emergence of Sentence Transformers improves the quality of embeddings even further and it became possible to use them for classification tasks even without fine-tuning. SetFit — a work based on top of sentence transformers made possible to achieve good performance even training on a small number of examples per label. Despite their efficiency and good performance on many semantic tasks embedding-based methods often fall short in complex scenarios involving logical and semantic constraints.

We propose a novel approach for text classification that builds upon the GLiNER architecture, adapting it specifically for sequence classification tasks. Our method aims to strike a balance between the accuracy of more complex models and the efficiency of embedding-based approaches while maintaining good zero-shot and few-shot capabilities.

Architecture of GLiClass

Our architecture introduces a novel approach to sequence classification that enables rich interactions between labels and input text while maintaining computational efficiency. The implementation consists of several key stages that work in concert to achieve superior classification performance.

image/png

Input Processing and Label Integration

The process begins with a label integration mechanism. We prepend each class label with a special token <

Contextual Representation Learning

Following tokenization, the combined input IDs are processed through a bi-directional transformer architecture (such as BERT or DeBERTa). This stage is crucial as it enables three distinct types of contextual understanding:

  • Label-to-label interactions: Labels can share information between themselves, enabling the model to understand label relationships and hierarchies
  • Text-to-label interactions: The input text can directly influence label representations
  • Label-to-text interactions: Label information can guide the interpretation of the text

This multi-directional flow of information represents a significant advantage over traditional cross-encoder architectures, which typically limit interactions to text-label pairs and miss valuable inter-label relationships.

Representation Pooling

After obtaining contextualized representations, we employ separate pooling mechanisms for labels and text to distil the essential information from the transformer outputs. Our implementation supports multiple pooling strategies:

  • First token pooling: Utilizing the initial token’s representation
  • Mean pooling: Averaging across all tokens
  • Attention-weighted pooling: Applying learned attention weights
  • Custom pooling strategies: Tailored to specific classification requirements

The choice of pooling strategy can be optimized based on the specific requirements for the classification task and the nature of the input data.

Scoring Mechanism

The final stage involves computing compatibility scores between the pooled representations. We implement this through a flexible scoring framework that can accommodate various approaches:

** Simple dot product scoring: Efficient and effective for many applications ** Neural network scoring: More complex scoring functions for challenging scenarios ** Task-specific scoring modules: Customized for particular classification requirements

This modular scoring approach allows the architecture to adapt to different classification scenarios while maintaining computational efficiency.

How to use the models

We open-sourced our model at Hugging Face. Modern GLiClass Collection

To use them, first of all, install gliclass package:

pip install gliclass

GLiClass repo

Then you need to initialize a model and a pipeline:

from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer

model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v2.0-init")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v2.0-init")

pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')

Here is how you can perform inference:

text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0] 

for result in results:
   print(result["label"], "=>", result["score"])

How to fine-tune

First of all, you need to prepare training data in the following format:

[
 {"text": "Some text here!", 
  "all_labels": ["sport", "science", "business",], 
  "true_labels": ["other"]},]

Below, you can see the needed importing requirements:

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"

from datasets import load_dataset, Dataset, DatasetDict

from sklearn.metrics import classification_report, f1_score, precision_recall_fscore_support, accuracy_score
import numpy as np
import random

from transformers import AutoTokenizer
import torch

from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from gliclass.data_processing import GLiClassDataset, DataCollatorWithPadding
from gliclass.training import TrainingArguments, Trainer

Then, we initialize the model and tokenizer:

device = torch.device('cuda:0') if torch.cuda.is_available else torch.device('cpu')

model_name = 'knowledgator/gliclass-base-v1.0'

model = GLiClassModel.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

After, that we specify training arguments:
max_length = 1024
problem_type = "multi_label_classification"
architecture_type = model.config.architecture_type
prompt_first = model.config.prompt_first

training_args = TrainingArguments(
    output_dir='models/test',
    learning_rate=1e-5,
    weight_decay=0.01,
    others_lr=1e-5,
    others_weight_decay=0.01,
    lr_scheduler_type='linear',
    warmup_ratio=0.0,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=8,
    evaluation_strategy="epoch",
    save_steps = 1000,
    save_total_limit=10,
    dataloader_num_workers=8,
    logging_steps=10,
    use_cpu = False,
    report_to="none",
    fp16=False,
    )

When you prepared your dataset in the right format we need to initialise the GLiClass dataset and data collectors:

train_dataset = GLiClassDataset(train_data, tokenizer, max_length, problem_type, architecture_type, prompt_first)
test_dataset = GLiClassDataset(train_data[:int(len(train_data)*0.1)], tokenizer, max_length, problem_type, architecture_type, prompt_first)

data_collator = DataCollatorWithPadding(device=device)

When everything is set we can start the training:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
trainer.train()

See more examples in the repo: https://github.com/Knowledgator/GLiClass/blob/main/finetuning.ipynb

Key Applications

GLiClass demonstrates remarkable versatility across a wide range of natural language processing tasks, making it particularly valuable for both research and practical applications.

Multi-class Classification

The architecture efficiently handles large-scale classification tasks with up to 100 distinct classes in a single processing run. This capability is particularly valuable for applications such as document categorization, product classification, and content tagging systems where multiple detailed categories are required.

Topic Classification

GLiClass excels in identifying and categorizing the subject matter of texts, making it ideal for:

  • Academic paper categorization
  • News article classification
  • Content recommendation systems
  • Research document organization

Sentiment Analysis

The architecture effectively captures nuanced emotional and opinion-based content, supporting:

  • Social media sentiment tracking
  • Customer feedback analysis
  • Product review classification
  • Brand perception monitoring

Event Classification

GLiClass demonstrates strong capabilities in identifying and categorizing events within the text, supporting:

  • News event categorization
  • Social media event detection
  • Historical event classification
  • Timeline analysis and organization

Prompt-based Constrained Classification

The system offers flexible prompt-based classification with custom constraints, enabling:

  • Guided classification tasks
  • Context-aware categorization
  • Custom classification rules
  • Dynamic category adaptation

Natural Language Inference

GLiClass supports sophisticated reasoning about textual relationships, facilitating:

  • Textual entailment detection
  • Contradiction identification
  • Semantic similarity assessment
  • Logical relationship analysis

Retrieval-augmented generation (RAG)

Good generalization of architecture, together with support of natural language inference tasks make it a good choice for reranking in RAG pipelines. Moreover, the efficiency of GLiClass makes it more competitive especially in comparison with cross-encoders.

This comprehensive range of applications positions GLiClass as a versatile tool for modern natural language processing challenges, offering both flexibility and precision across diverse classification tasks.

Benchmarking results

We released a new GLiClass model based on ModernBERT, it opens new possibilities in comparison to older models like DeBERTa by providing longer context length support (up to 8k tokens) and much faster inference. We benchmarked our GLiClass models with our architectures and an older version of the architecture.

Below, you can see the F1 score on several text classification datasets. All tested models were not fine-tuned on those datasets and were tested in a zero-shot setting.

Model IMDB AG_NEWS Emotions
gliclass-modern-large-v2.0-init (399 M) 0.9137 0.7357 0.4140
gliclass-modern-base-v2.0-init (151 M) 0.8264 0.6637 0.2985
gliclass-large-v1.0 (438 M) 0.9404 0.7516 0.4874
gliclass-base-v1.0 (186 M) 0.8650 0.6837 0.4749
gliclass-small-v1.0 (144 M) 0.8650 0.6805 0.4664
Bart-large-mnli (407 M) 0.89 0.6887 0.3765
Deberta-base-v3 (184 M) 0.85 0.6455 0.5095
Comprehendo (184M) 0.90 0.7982 0.5660
SetFit BAAI/bge-small-en-v1.5 (33.4M) 0.86 0.5636 0.5754

Below you can find a more comprehensive comparison of ModernBERT GLiClass with other GLiClass models:

Dataset gliclass-base-v1.0-init gliclass-large-v1.0-init gliclass-modern-base-v2.0-init gliclass-modern-large-v2.0-init
CR 0.8672 0.8024 0.9041 0.8980
sst2 0.8342 0.8734 0.9011 0.9434
sst5 0.2048 0.1638 0.1972 0.1123
20_news_groups 0.2317 0.4151 0.2448 0.2792
spam 0.5963 0.5407 0.5074 0.6364
financial_phrasebank 0.3594 0.3705 0.2537 0.2562
imdb 0.8772 0.8836 0.8255 0.9137
ag_news 0.5614 0.7069 0.6050 0.6933
emotion 0.2865 0.3840 0.2474 0.3746
cap_sotu 0.3966 0.4353 0.2929 0.2919
rotten_tomatoes 0.6626 0.7933 0.6630 0.5928
AVERAGE: 0.5344 0.5790 0.5129 0.5447

We examined how performance grows if we fine-tune the model on a small portion of examples per label. Additionally, we tested a simple approach when we provided no real text but rather generic short descriptions of a given text topic, we called it weak-supervision. Surprisingly, for some datasets like “emotion” it improved performance substantially.

Model Num Examples sst5 ag_news emotion AVERAGE:
gliclass-modern-large-v2.0-init 0 0.1123 0.6933 0.3746 0.3934
gliclass-modern-large-v2.0-init 8 0.5098 0.8339 0.5010 0.6149
gliclass-modern-large-v2.0-init Weak Supervision 0.0951 0.6478 0.4520 0.3983
gliclass-modern-base-v2.0-init 0 0.1972 0.6050 0.2474 0.3499
gliclass-modern-base-v2.0-init 8 0.3604 0.7481 0.4420 0.5168
gliclass-modern-base-v2.0-init Weak Supervision 0.1599 0.5713 0.3216 0.3509
gliclass-large-v1.0-init 0 0.1639 0.7069 0.3840 0.4183
gliclass-large-v1.0-init 8 0.4226 0.8415 0.4886 0.5842
gliclass-large-v1.0-init Weak Supervision 0.1689 0.7051 0.4586 0.4442
gliclass-base-v1.0-init 0 0.2048 0.5614 0.2865 0.3509
gliclass-base-v1.0-init 8 0.2007 0.8359 0.4856 0.5074
gliclass-base-v1.0-init Weak Supervision 0.0681 0.6627 0.3066 0.3458

Conclusion

GLiClass represents a significant advancement in the field of text classification, offering a robust and efficient solution that bridges the gap between the accuracy of complex transformer-based models and the simplicity of embedding-based approaches. By leveraging a novel architecture that facilitates rich interactions between input text and labels, GLiClass achieves superior performance in zero-shot and few-shot classification tasks while maintaining computational efficiency, even with large label sets. Its ability to capture cross-label dependencies, adapt to diverse classification scenarios, and integrate seamlessly with existing NLP pipelines positions it as a versatile tool for a wide range of applications, from sentiment analysis and topic classification to retrieval-augmented generation and natural language inference.

Community

Sign up or log in to comment