File size: 2,117 Bytes
9afdb94
88e1cb1
9afdb94
 
 
3a86668
 
9afdb94
3a86668
9afdb94
3a86668
 
 
9afdb94
3a86668
9afdb94
3a86668
9afdb94
3a86668
 
 
 
9afdb94
3a86668
 
 
9afdb94
3a86668
 
 
9afdb94
3a86668
 
9afdb94
3a86668
 
 
9afdb94
3a86668
9afdb94
3a86668
 
 
 
 
 
 
 
 
 
 
9afdb94
 
 
3a86668
9afdb94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
---
base_model: google/paligemma2-3b-pt-448
library_name: peft
---

<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->

# PaliGemma2-3b-VQAv2

This model is a fine-tuned version of [google/paligemma2-3b-pt-448](https://huggingface.co/google/paligemma2-3b-pt-448) on half of the VQAv2 validation split, for task conditioning.
Fine-tuning script is [here](https://github.com/merveenoyan/smol-vision/blob/main/paligemma.py) which also comes in notebook form [here](https://github.com/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb).
Make sure you install transformers in main branch before using this or running fine-tuning.

## How to Use

Below is the code to use this model. Also see [inference notebook](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing).

```python
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests

model_id = "merve/paligemma2-3b-vqav2"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained("google/paligemma2-3b-pt-224")

prompt = "What is behind the cat?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat.png?download=true"
raw_image = Image.open(requests.get(image_file, stream=True).raw)

inputs = processor(prompt, raw_image.convert("RGB"), return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)

print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])
# gramophone
```

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 2e-05
- train_batch_size: 4
- eval_batch_size: 4
- seed: 42
- gradient_accumulation_steps: 4
- total_train_batch_size: 16
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_steps: 2
- num_epochs: 2

### Framework versions

- Transformers (main as of Dec 5)
- PEFT 0.13.2