paligemma_vqav2 / README.md
merve's picture
merve HF staff
Update `dataset` to reference to the actual dataset used (#4)
45de693 verified
metadata
license: gemma
base_model: google/paligemma-3b-pt-224
tags:
  - generated_from_trainer
datasets:
  - HuggingFaceM4/VQAv2
model-index:
  - name: paligemma_vqav2
    results: []

paligemma_vqav2

This model is a fine-tuned version of google/paligemma-3b-pt-224 on a small chunk of vq_av2 dataset. Fine-tuning code is here.

How to Use

Below is the code to use this model. Also see inference notebook.

from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests

model_id = "merve/paligemma_vqav2"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")

prompt = "What is behind the cat?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat.png?download=true"
raw_image = Image.open(requests.get(image_file, stream=True).raw)

inputs = processor(prompt, raw_image.convert("RGB"), return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)

print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
# gramophone

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2e-05
  • train_batch_size: 4
  • eval_batch_size: 8
  • seed: 42
  • gradient_accumulation_steps: 4
  • total_train_batch_size: 16
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 2
  • num_epochs: 2

Training results

Framework versions

  • Transformers 4.42.0.dev0
  • Pytorch 2.3.0+cu121
  • Datasets 2.19.1
  • Tokenizers 0.19.1