|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- sartajbhuvaji/gutenberg |
|
base_model: |
|
- google-bert/bert-base-uncased |
|
pipeline_tag: text-classification |
|
tags: |
|
- classification |
|
language: |
|
- en |
|
library_name: transformers |
|
--- |
|
|
|
```python |
|
from transformers import BertConfig, BertForSequenceClassification, BertTokenizer |
|
from datasets import load_dataset |
|
from transformers import pipeline |
|
import pandas as pd |
|
|
|
model = BertForSequenceClassification.from_pretrained("sartajbhuvaji/gutenberg-bert-base-uncased") |
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
# Create a text classification pipeline |
|
classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, device='cuda') |
|
|
|
# Test the pipeline |
|
result = classifier("This is a great book!") |
|
print(result) #[{'label': 'LABEL_8', 'score': 0.2576160430908203}] |
|
|
|
# Test the pipeline on a document |
|
dataset = load_dataset("sartajbhuvaji/gutenberg", split="100") |
|
df = dataset.to_pandas() |
|
|
|
doc_id = 1 |
|
doc_text = df.loc[df['DocID'] == doc_id, 'Text'].values[0] |
|
|
|
result = classifier(doc_text[:512]) # Truncate to 512 tokens |
|
print(result) # [{'label': 'LABEL_2', 'score': 0.28877997398376465}] |
|
``` |