cmunhozc's picture
Update README.md
4c94613
metadata
license: mit
base_model: bert-base-cased
tags:
  - CENIA
  - News
metrics:
  - accuracy
model-index:
  - name: bert-base-cased-finetuned
    results: []
datasets:
  - cmunhozc/usa_news_en
language:
  - en
pipeline_tag: text-classification
widget:
  - text: >-
      Poll: Which COVID-related closure in San Francisco has you the most shook
      up? || President Trump has pardoned Edward DeBartolo Jr., the former San
      Francisco 49ers owner convicted in a gambling fraud scandal.
    output:
      - label: RELATED
        score: 0
      - label: UNRELATED
        score: 1
  - text: >-
      The first batch of 2020 census data surprised many. A look at what's next
      || There were some genuine surprises in the first batch of data from the
      nation’s 2020 head count released this week by the U.S. Census Bureau.
    output:
      - label: RELATED
        score: 1
      - label: UNRELATED
        score: 0

bert-base-cased-finetuned

This model is a fine-tuned version of bert-base-cased on the usa_news_en dataset. It achieves the following results on the evaluation set:

  • Loss: 0.0900
  • Accuracy: 0.9800

Model description

The fine-tuned model corresponds to a binary classification model that determines whether two English news headlines are related or not related. In the following paper {News Gathering: Leveraging Transformers to Rank News} it can find more details. To utilize the fine-tuned model, you can follow the steps outlined below:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import Trainer

### 1. Load the model:
model_name = "cmunhozc/news-ranking-ft-bert"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

### 2. Dataset:
def preprocess_fctn(examples):
  return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True)
...
encoded_dataset = dataset.map(preprocess_fctn, batched=True, load_from_cache_file=False)
...

### 3. Evaluation:
def compute_metrics(eval_pred):
  predictions, labels = eval_pred
  predictions = np.argmax(predictions, axis=1)

trainer_hf = Trainer(model,
                     eval_dataset    = encoded_dataset['validation'],
                     tokenizer       = tokenizer,
                     compute_metrics = compute_metrics)

trainer_hf.evaluate()

predictions = trainer_hf.predict(encoded_dataset["validation"])
acc_val  = metric.compute(predictions=np.argmax(predictions.predictions,axis=1).tolist(), references=predictions.label_ids)['accuracy']

Finally, with the classification above model, you can follow the steps below to generate the news ranking.

  • For each news article in the google_news_en dataset dataset positioned as the first element in a pair, retrieve all corresponding pairs from the dataset.
  • Employing pair encoders, rank the news articles that occupy the second position in each pair, determining their relevance to the first article.
  • Organize each list generated by the encoders based on the probabilities obtained for the relevance class.

Intended uses & limitations

More information needed

Training, evaluation and test data

The training data is sourced from the train split in usa_news_en dataset, and a similar procedure is applied for the validation set. In the case of testing, the initial segment for the text classification model is derived from the test_1 and test_2 splits. As for the ranking model, the test dataset from google_news_en dataset is utilized

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 1e-05
  • train_batch_size: 32
  • eval_batch_size: 32
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 3

Training results

Training Loss Epoch Step Validation Loss Accuracy
0.0967 1.0 3526 0.0651 0.9771
0.0439 2.0 7052 0.0820 0.9776
0.0231 3.0 10578 0.0900 0.9800

Framework versions

  • Transformers 4.35.2
  • Pytorch 2.1.0+cu121
  • Datasets 2.16.1
  • Tokenizers 0.15.0