Commit
·
8a00d0d
0
Parent(s):
initial commit
Browse files- .gitignore +1 -0
- __pycache__/text_encoder.cpython-311.pyc +0 -0
- __pycache__/train.cpython-311.pyc +0 -0
- __pycache__/vision_encoder.cpython-311.pyc +0 -0
- _dataset/__pycache__/preprocess_images.cpython-311.pyc +0 -0
- _dataset/preprocess_captions.ipynb +188 -0
- _dataset/preprocess_images.py +79 -0
- demo.ipynb +240 -0
- text_encoder.py +27 -0
- train.py +204 -0
- vision_encoder.py +56 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
checkpoints
|
__pycache__/text_encoder.cpython-311.pyc
ADDED
Binary file (1.82 kB). View file
|
|
__pycache__/train.cpython-311.pyc
ADDED
Binary file (11.5 kB). View file
|
|
__pycache__/vision_encoder.cpython-311.pyc
ADDED
Binary file (3.05 kB). View file
|
|
_dataset/__pycache__/preprocess_images.cpython-311.pyc
ADDED
Binary file (5.8 kB). View file
|
|
_dataset/preprocess_captions.ipynb
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from collections import defaultdict\n",
|
10 |
+
"from transformers import AutoTokenizer\n",
|
11 |
+
"from tqdm import tqdm\n",
|
12 |
+
"import json\n",
|
13 |
+
"\n",
|
14 |
+
"def load_and_process_token_file(input_path, tokenizer_name=\"answerdotai/ModernBERT-base\"):\n",
|
15 |
+
" captions_dict = defaultdict(list)\n",
|
16 |
+
" tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n",
|
17 |
+
" max_length = 0 # Initialize max length counter\n",
|
18 |
+
"\n",
|
19 |
+
" # Read and process the token file with tokenization\n",
|
20 |
+
" with open(input_path, 'r') as file:\n",
|
21 |
+
" for line in tqdm(file, desc=\"Processing Captions\"):\n",
|
22 |
+
" image_id, caption = line.strip().split('\\t')\n",
|
23 |
+
" jpg_number = image_id.split('.')[0]\n",
|
24 |
+
" \n",
|
25 |
+
" # Tokenize without padding and truncation to calculate the true length\n",
|
26 |
+
" tokens = tokenizer(caption, return_tensors=\"pt\", padding=False, truncation=False)\n",
|
27 |
+
" token_ids = tokens['input_ids'].squeeze(0).tolist()\n",
|
28 |
+
" \n",
|
29 |
+
" # Update max_length based on this tokenized sequence length\n",
|
30 |
+
" max_length = max(max_length, len(token_ids))\n",
|
31 |
+
" \n",
|
32 |
+
" # Tokenize with padding and attention mask (padded to 93 tokens)\n",
|
33 |
+
" tokens_padded = tokenizer(caption, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=2**7) # 93 < 2**7\n",
|
34 |
+
" token_ids_padded = tokens_padded['input_ids'].squeeze(0).tolist()\n",
|
35 |
+
" attention_mask = tokens_padded['attention_mask'].squeeze(0).tolist()\n",
|
36 |
+
"\n",
|
37 |
+
" # Save both raw caption, tokenized version, and attention mask\n",
|
38 |
+
" captions_dict[jpg_number].append({\n",
|
39 |
+
" \"text\": caption,\n",
|
40 |
+
" \"tokenized\": token_ids_padded,\n",
|
41 |
+
" \"attention_mask\": attention_mask\n",
|
42 |
+
" })\n",
|
43 |
+
"\n",
|
44 |
+
" print(f\"Maximum sequence length (before padding): {max_length}\")\n",
|
45 |
+
" return captions_dict, max_length\n",
|
46 |
+
"\n",
|
47 |
+
"# Define the input path and process the file\n",
|
48 |
+
"input_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/results_20130124.token'\n",
|
49 |
+
"captions_dict, max_length = load_and_process_token_file(input_path)\n",
|
50 |
+
"\n",
|
51 |
+
"# Save the modified dictionary with tokenized captions and attention masks to a JSON file\n",
|
52 |
+
"output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n",
|
53 |
+
"with open(output_path, 'w') as json_file:\n",
|
54 |
+
" json.dump(captions_dict, json_file)\n",
|
55 |
+
"\n",
|
56 |
+
"# Display the maximum token length\n",
|
57 |
+
"print(f\"Final maximum token length across dataset: {max_length}\")\n",
|
58 |
+
"\n",
|
59 |
+
"# Display the first few entries to verify the content\n",
|
60 |
+
"for jpg, captions in list(captions_dict.items())[:5]:\n",
|
61 |
+
" print(f\"{jpg}: {captions}\")"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": null,
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"\n",
|
71 |
+
"# Save the dictionary to a JSON file\n",
|
72 |
+
"output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_dict.json'\n",
|
73 |
+
"with open(output_path, 'w') as json_file:\n",
|
74 |
+
" json.dump(captions_dict, json_file)\n",
|
75 |
+
"\n",
|
76 |
+
"print(f\"Captions dictionary saved to {output_path}\")"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": null,
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [],
|
84 |
+
"source": []
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 2,
|
89 |
+
"metadata": {},
|
90 |
+
"outputs": [],
|
91 |
+
"source": [
|
92 |
+
"import torch\n",
|
93 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
94 |
+
"import os\n",
|
95 |
+
"import json\n",
|
96 |
+
"import numpy as np\n",
|
97 |
+
"import random\n",
|
98 |
+
"\n",
|
99 |
+
"\n",
|
100 |
+
"# Vision Caption Dataset\n",
|
101 |
+
"class VisionCaptionDataset(torch.utils.data.Dataset):\n",
|
102 |
+
" def __init__(self, captions_path, embeddings_dir, normalize=True):\n",
|
103 |
+
" with open(captions_path, 'r') as f:\n",
|
104 |
+
" self.captions_dict = json.load(f)\n",
|
105 |
+
"\n",
|
106 |
+
" self.embeddings_dir = embeddings_dir\n",
|
107 |
+
" self.image_ids = list(self.captions_dict.keys())\n",
|
108 |
+
" self.normalize = normalize\n",
|
109 |
+
"\n",
|
110 |
+
" def __len__(self):\n",
|
111 |
+
" return len(self.image_ids)\n",
|
112 |
+
"\n",
|
113 |
+
" def __getitem__(self, idx):\n",
|
114 |
+
" image_id = self.image_ids[idx]\n",
|
115 |
+
" \n",
|
116 |
+
" # Randomly select a caption and load the tokenized version\n",
|
117 |
+
" caption_entry = random.choice(self.captions_dict[image_id])\n",
|
118 |
+
" tokenized_caption = caption_entry[\"tokenized\"]\n",
|
119 |
+
" attention_mask = caption_entry[\"attention_mask\"]\n",
|
120 |
+
"\n",
|
121 |
+
" # Load vision embedding\n",
|
122 |
+
" embedding_path = os.path.join(self.embeddings_dir, f\"{image_id}.npy\")\n",
|
123 |
+
" embedding = np.load(embedding_path)\n",
|
124 |
+
"\n",
|
125 |
+
" # Convert vision embedding and tokenized caption to tensors\n",
|
126 |
+
" embedding = torch.tensor(embedding, dtype=torch.float32)\n",
|
127 |
+
" tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)\n",
|
128 |
+
" attention_mask = torch.tensor(attention_mask, dtype=torch.long)\n",
|
129 |
+
"\n",
|
130 |
+
" return embedding, tokenized_caption, attention_mask\n",
|
131 |
+
"\n",
|
132 |
+
"# Example usage\n",
|
133 |
+
"# Paths for dataset\n",
|
134 |
+
"captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n",
|
135 |
+
"embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'\n",
|
136 |
+
"\n",
|
137 |
+
"# Initialize the dataset and split it into train/validation sets\n",
|
138 |
+
"full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)\n",
|
139 |
+
"\n",
|
140 |
+
"# Initialize the DataLoaders with `num_workers` and `pin_memory`\n",
|
141 |
+
"train_dataloader = DataLoader(full_dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)\n"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": null,
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [],
|
149 |
+
"source": [
|
150 |
+
"# Verify a batch\n",
|
151 |
+
"for batch in train_dataloader:\n",
|
152 |
+
" embeddings, captions, attn_mask = batch\n",
|
153 |
+
" print(embeddings.shape, len(captions))\n",
|
154 |
+
" \n",
|
155 |
+
"\n",
|
156 |
+
" break"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": null,
|
162 |
+
"metadata": {},
|
163 |
+
"outputs": [],
|
164 |
+
"source": []
|
165 |
+
}
|
166 |
+
],
|
167 |
+
"metadata": {
|
168 |
+
"kernelspec": {
|
169 |
+
"display_name": "hf-env",
|
170 |
+
"language": "python",
|
171 |
+
"name": "python3"
|
172 |
+
},
|
173 |
+
"language_info": {
|
174 |
+
"codemirror_mode": {
|
175 |
+
"name": "ipython",
|
176 |
+
"version": 3
|
177 |
+
},
|
178 |
+
"file_extension": ".py",
|
179 |
+
"mimetype": "text/x-python",
|
180 |
+
"name": "python",
|
181 |
+
"nbconvert_exporter": "python",
|
182 |
+
"pygments_lexer": "ipython3",
|
183 |
+
"version": "3.11.11"
|
184 |
+
}
|
185 |
+
},
|
186 |
+
"nbformat": 4,
|
187 |
+
"nbformat_minor": 2
|
188 |
+
}
|
_dataset/preprocess_images.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
from PIL import Image
|
8 |
+
from transformers.image_utils import load_image
|
9 |
+
import sys
|
10 |
+
sys.path.append('..')
|
11 |
+
from vision_encoder import ideficsV3
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
class VisionPreprocessor:
|
15 |
+
def __init__(self, device=None, param_dtype=torch.float32):
|
16 |
+
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
+
self.param_dtype = param_dtype
|
18 |
+
|
19 |
+
# Initialize and freeze the vision encoder
|
20 |
+
self.vision_encoder = ideficsV3("HuggingFaceTB/SmolVLM-Instruct").eval().to(self.device)
|
21 |
+
for param in self.vision_encoder.parameters():
|
22 |
+
param.requires_grad = False
|
23 |
+
|
24 |
+
def load_image(self, image_path):
|
25 |
+
"""Load an image using PIL without preprocessing."""
|
26 |
+
image = load_image(image_path)
|
27 |
+
# Convert to tensor without resizing or additional transformations
|
28 |
+
inputs = self.vision_encoder.image_processor(images=[image], return_tensors="pt")
|
29 |
+
pixel_values = inputs.pixel_values.to(self.param_dtype).to(self.device)
|
30 |
+
return pixel_values
|
31 |
+
|
32 |
+
def extract_embedding(self, image_tensor):
|
33 |
+
"""Extract raw vision embedding."""
|
34 |
+
with torch.no_grad():
|
35 |
+
vision_output = self.vision_encoder(image_tensor)
|
36 |
+
|
37 |
+
vision_output = vision_output.mean(axis=0)
|
38 |
+
|
39 |
+
return vision_output
|
40 |
+
|
41 |
+
def save_embedding(self, vision_output, file_path):
|
42 |
+
"""Save the vision output to a numpy file."""
|
43 |
+
np.save(file_path, vision_output.cpu().numpy())
|
44 |
+
|
45 |
+
def process_directory(self, image_paths, output_dir):
|
46 |
+
"""Process all images in a directory with a progress bar and save the embeddings."""
|
47 |
+
if os.path.exists(output_dir):
|
48 |
+
shutil.rmtree(output_dir)
|
49 |
+
print(f"Existing directory cleared: {output_dir}")
|
50 |
+
os.makedirs(output_dir, exist_ok=True)
|
51 |
+
|
52 |
+
# Adding tqdm for progress bar
|
53 |
+
for image_path in tqdm(image_paths, desc="Processing Images", unit="image"):
|
54 |
+
|
55 |
+
# Load and extract features without preprocessing
|
56 |
+
image_tensor = self.load_image(image_path)
|
57 |
+
vision_output = self.extract_embedding(image_tensor)
|
58 |
+
|
59 |
+
# Save the output with the same filename but as a .npy
|
60 |
+
output_file_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}.npy")
|
61 |
+
self.save_embedding(vision_output, output_file_path)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
torch.manual_seed(42)
|
66 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
67 |
+
param_dtype = torch.float32
|
68 |
+
|
69 |
+
# Instantiate the pipeline
|
70 |
+
pipeline = VisionPreprocessor(device, param_dtype)
|
71 |
+
|
72 |
+
# Specify input and output directories
|
73 |
+
input_directory = "/mnt/nvme/shared_A/datasets/flickr30k/data/flickr30k-images"
|
74 |
+
output_directory = "/mnt/nvme/shared_A/datasets/flickr30k/data/vision_embeddings_reduced2"
|
75 |
+
|
76 |
+
image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
77 |
+
# Process all images in the input directory
|
78 |
+
pipeline.process_directory(image_paths, output_directory)
|
79 |
+
print("Processing complete!")
|
demo.ipynb
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Image search with modernBERT"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 18,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [],
|
15 |
+
"source": [
|
16 |
+
"from _dataset.preprocess_images import *\n",
|
17 |
+
"import random"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": null,
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"\n",
|
27 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
28 |
+
"pipeline = VisionPreprocessor(device, param_dtype=torch.float32)\n",
|
29 |
+
"\n",
|
30 |
+
"num_images = 25\n",
|
31 |
+
"input_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/val2017\"\n",
|
32 |
+
"image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
|
33 |
+
"\n",
|
34 |
+
"# Shuffle and take the first 25 images\n",
|
35 |
+
"# random.shuffle(image_paths)\n",
|
36 |
+
"image_paths = image_paths[:num_images]\n",
|
37 |
+
"\n",
|
38 |
+
"# Print the selected image paths\n",
|
39 |
+
"print(\"Selected Image Paths:\")\n",
|
40 |
+
"for path in image_paths:\n",
|
41 |
+
" print(path)\n"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": null,
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"import os\n",
|
51 |
+
"import shutil\n",
|
52 |
+
"\n",
|
53 |
+
"# Specify the output directory\n",
|
54 |
+
"output_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
|
55 |
+
"\n",
|
56 |
+
"# Clear the vision embeddings directory if it exists, otherwise create it\n",
|
57 |
+
"if os.path.exists(output_directory):\n",
|
58 |
+
" shutil.rmtree(output_directory)\n",
|
59 |
+
" print(f\"Existing directory cleared: {output_directory}\")\n",
|
60 |
+
"os.makedirs(output_directory, exist_ok=True)\n",
|
61 |
+
"\n",
|
62 |
+
"# Process all images in the input directory\n",
|
63 |
+
"pipeline.process_directory(image_paths, output_directory)\n",
|
64 |
+
"print(\"Image embeddings saved!\")"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"from train import JointNetwork\n",
|
74 |
+
"\n",
|
75 |
+
"def load_checkpoint_and_prepare_model(checkpoint_path, device=\"cuda\"):\n",
|
76 |
+
" \"\"\"Load trained JointNetwork() from checkpoint\"\"\"\n",
|
77 |
+
" device = torch.device(device)\n",
|
78 |
+
" model = JointNetwork()\n",
|
79 |
+
" checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n",
|
80 |
+
" model.load_state_dict(checkpoint['model_state_dict'])\n",
|
81 |
+
" model.to(device)\n",
|
82 |
+
" model.eval()\n",
|
83 |
+
" model.device = device\n",
|
84 |
+
" print(f\"Model loaded successfully from {checkpoint_path}.\")\n",
|
85 |
+
" return model\n",
|
86 |
+
"\n",
|
87 |
+
"def get_text_embedding(model, text_prompt):\n",
|
88 |
+
" \"\"\"Encode a text prompt to get its embedding using the modernBERT encoder.\"\"\"\n",
|
89 |
+
" tokenized_text = model.text_encoder.tokenizer(text_prompt, return_tensors=\"pt\").to(model.device)\n",
|
90 |
+
" with torch.no_grad():\n",
|
91 |
+
" text_features = model.text_encoder(tokenized_text)\n",
|
92 |
+
" text_features = model.text_projector(text_features.mean(dim=1))\n",
|
93 |
+
" text_features = F.normalize(text_features, dim=1)\n",
|
94 |
+
" return text_features\n",
|
95 |
+
"\n",
|
96 |
+
"def load_image_embeddings(model, embeddings_dir):\n",
|
97 |
+
" \"\"\"Load all precomputed image embeddings from the specified directory.\"\"\"\n",
|
98 |
+
" vision_embeddings = []\n",
|
99 |
+
" for file in sorted(os.listdir(embeddings_dir)):\n",
|
100 |
+
" if file.endswith(\".npy\"):\n",
|
101 |
+
" image_encoding = torch.tensor(np.load(os.path.join(embeddings_dir, file)), dtype=torch.float32).to(model.device)\n",
|
102 |
+
" vision_pooled = image_encoding.mean(dim=0).unsqueeze(0)\n",
|
103 |
+
" vision_embedded = model.vision_projector(vision_pooled)\n",
|
104 |
+
" vision_embedded = F.normalize(vision_embedded, dim=1)\n",
|
105 |
+
" vision_embeddings.append(vision_embedded)\n",
|
106 |
+
" \n",
|
107 |
+
" if len(vision_embeddings) == 0:\n",
|
108 |
+
" raise ValueError(\"No vision embeddings found in the specified directory.\")\n",
|
109 |
+
" print(f\"Vision embeddings loaded successfully from {embeddings_dir}.\")\n",
|
110 |
+
" return torch.stack(vision_embeddings).squeeze(1)\n",
|
111 |
+
"\n",
|
112 |
+
"def compare_text_to_images(text_embedding, vision_embeddings):\n",
|
113 |
+
" \"\"\"Compare a text embedding against a batch of image embeddings using cosine similarity.\"\"\"\n",
|
114 |
+
" cosine_similarities = torch.matmul(text_embedding, vision_embeddings.T).squeeze(0)\n",
|
115 |
+
" similarity_scores = cosine_similarities.cpu().detach().numpy()\n",
|
116 |
+
" ranked_indices = similarity_scores.argsort()[::-1] # Sort in descending order\n",
|
117 |
+
" return ranked_indices, similarity_scores\n",
|
118 |
+
"\n",
|
119 |
+
"\n",
|
120 |
+
"\n",
|
121 |
+
"# Paths and settings\n",
|
122 |
+
"checkpoint_path = \"/home/nolan4/projects/hf-contest/checkpoints/model_checkpoint_20250109_102039.pth\"\n",
|
123 |
+
"embeddings_dir = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
|
124 |
+
"\n",
|
125 |
+
"# Load the model and precomputed vision embeddings\n",
|
126 |
+
"model = load_checkpoint_and_prepare_model(checkpoint_path)\n",
|
127 |
+
"vision_embeddings = load_image_embeddings(model, embeddings_dir)"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": null,
|
133 |
+
"metadata": {},
|
134 |
+
"outputs": [],
|
135 |
+
"source": [
|
136 |
+
"import matplotlib.pyplot as plt\n",
|
137 |
+
"import os\n",
|
138 |
+
"from PIL import Image\n",
|
139 |
+
"\n",
|
140 |
+
"def display_images_from_paths(image_paths, num_images=5):\n",
|
141 |
+
"\n",
|
142 |
+
" num_images = min(num_images, len(image_paths))\n",
|
143 |
+
" if num_images == 0:\n",
|
144 |
+
" print(\"No images found in the directory.\")\n",
|
145 |
+
" return\n",
|
146 |
+
"\n",
|
147 |
+
" plt.figure(figsize=(12, 8))\n",
|
148 |
+
" for i, image_path in enumerate(image_paths[:num_images]):\n",
|
149 |
+
" img = Image.open(image_path)\n",
|
150 |
+
" plt.subplot(1, num_images, i + 1)\n",
|
151 |
+
" plt.imshow(img)\n",
|
152 |
+
" plt.axis('off') \n",
|
153 |
+
" plt.title(f\"{os.path.basename(image_path).split('.')[0]}\")\n",
|
154 |
+
"\n",
|
155 |
+
" plt.tight_layout()\n",
|
156 |
+
" plt.show()\n",
|
157 |
+
"\n",
|
158 |
+
"# Example usage\n",
|
159 |
+
"# random.shuffle(image_paths)\n",
|
160 |
+
"display_images_from_paths(image_paths, num_images=10)"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": null,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"# Paths and settings\n",
|
170 |
+
"text_prompt = \"cars driving down the road\"\n",
|
171 |
+
"# text_prompt = \"stuffed brown teddy bear\"\n",
|
172 |
+
"\n",
|
173 |
+
"\n",
|
174 |
+
"# Load the model and embeddings\n",
|
175 |
+
"text_embedding = get_text_embedding(model, text_prompt)\n",
|
176 |
+
"\n",
|
177 |
+
"# Perform comparison and display results\n",
|
178 |
+
"ranked_indices, similarity_scores = compare_text_to_images(text_embedding, vision_embeddings)\n",
|
179 |
+
"print(f\"\\nTop 5 Most Similar Images:\")\n",
|
180 |
+
"for idx in ranked_indices[:5]:\n",
|
181 |
+
" print(f\"Image Index: {idx}, Similarity Score: {similarity_scores[idx]:.4f}\")"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": null,
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [],
|
189 |
+
"source": [
|
190 |
+
"# Ensure ranked_indices is converted to a Python list\n",
|
191 |
+
"selected_image_paths = [image_paths[idx] for idx in ranked_indices[:10]]\n",
|
192 |
+
"\n",
|
193 |
+
"# Display the top N ranked images\n",
|
194 |
+
"display_images_from_paths(selected_image_paths, num_images=4)"
|
195 |
+
]
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"cell_type": "code",
|
199 |
+
"execution_count": null,
|
200 |
+
"metadata": {},
|
201 |
+
"outputs": [],
|
202 |
+
"source": []
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"cell_type": "code",
|
206 |
+
"execution_count": null,
|
207 |
+
"metadata": {},
|
208 |
+
"outputs": [],
|
209 |
+
"source": []
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": null,
|
214 |
+
"metadata": {},
|
215 |
+
"outputs": [],
|
216 |
+
"source": []
|
217 |
+
}
|
218 |
+
],
|
219 |
+
"metadata": {
|
220 |
+
"kernelspec": {
|
221 |
+
"display_name": "hf-env",
|
222 |
+
"language": "python",
|
223 |
+
"name": "python3"
|
224 |
+
},
|
225 |
+
"language_info": {
|
226 |
+
"codemirror_mode": {
|
227 |
+
"name": "ipython",
|
228 |
+
"version": 3
|
229 |
+
},
|
230 |
+
"file_extension": ".py",
|
231 |
+
"mimetype": "text/x-python",
|
232 |
+
"name": "python",
|
233 |
+
"nbconvert_exporter": "python",
|
234 |
+
"pygments_lexer": "ipython3",
|
235 |
+
"version": "3.11.11"
|
236 |
+
}
|
237 |
+
},
|
238 |
+
"nbformat": 4,
|
239 |
+
"nbformat_minor": 2
|
240 |
+
}
|
text_encoder.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, ModernBertModel
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import pdb
|
6 |
+
|
7 |
+
class modernBERT(nn.Module):
|
8 |
+
def __init__(self, model_name="answerdotai/ModernBERT-base"):
|
9 |
+
super(modernBERT, self).__init__()
|
10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
+
self.bert = ModernBertModel.from_pretrained(model_name)
|
12 |
+
|
13 |
+
def forward(self, inputs):
|
14 |
+
# inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
15 |
+
outputs = self.bert(**inputs)
|
16 |
+
|
17 |
+
return outputs.last_hidden_state # logits
|
18 |
+
|
19 |
+
# Example training loop
|
20 |
+
if __name__ == "__main__":
|
21 |
+
model = modernBERT("answerdotai/ModernBERT-base")
|
22 |
+
|
23 |
+
texts = ["Potato's no name for a dog"]
|
24 |
+
text_inputs = {"input_ids": model.tokenizer(texts)}
|
25 |
+
output = model(text_inputs)
|
26 |
+
|
27 |
+
print(output[0].shape)
|
train.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.data import DataLoader, random_split
|
5 |
+
from text_encoder import *
|
6 |
+
from vision_encoder import *
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
from tqdm import tqdm
|
12 |
+
import datetime
|
13 |
+
|
14 |
+
# Vision Caption Dataset
|
15 |
+
class VisionCaptionDataset(torch.utils.data.Dataset):
|
16 |
+
def __init__(self, captions_path, embeddings_dir, normalize=True):
|
17 |
+
with open(captions_path, 'r') as f:
|
18 |
+
self.captions_dict = json.load(f)
|
19 |
+
|
20 |
+
self.embeddings_dir = embeddings_dir
|
21 |
+
self.image_ids = list(self.captions_dict.keys())
|
22 |
+
self.normalize = normalize
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.image_ids)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
image_id = self.image_ids[idx]
|
29 |
+
|
30 |
+
caption_entry = random.choice(self.captions_dict[image_id])
|
31 |
+
tokenized_caption = caption_entry["tokenized"]
|
32 |
+
attention_mask = caption_entry["attention_mask"]
|
33 |
+
|
34 |
+
embedding_path = os.path.join(self.embeddings_dir, f"{image_id}.npy")
|
35 |
+
embedding = np.load(embedding_path)
|
36 |
+
|
37 |
+
embedding = torch.tensor(embedding, dtype=torch.float32)
|
38 |
+
tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)
|
39 |
+
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
|
40 |
+
|
41 |
+
return embedding, tokenized_caption, attention_mask
|
42 |
+
|
43 |
+
|
44 |
+
class JointNetwork(nn.Module):
|
45 |
+
def __init__(self):
|
46 |
+
super(JointNetwork, self).__init__()
|
47 |
+
|
48 |
+
self.text_encoder = modernBERT("answerdotai/ModernBERT-base")
|
49 |
+
|
50 |
+
for param in self.text_encoder.parameters():
|
51 |
+
param.requires_grad = True
|
52 |
+
|
53 |
+
self.vision_projector = nn.Linear(1152, 512)
|
54 |
+
self.text_projector = nn.Linear(768, 512)
|
55 |
+
|
56 |
+
def forward(self, tokenized_text, image_encoding):
|
57 |
+
vision_patch_pooled = image_encoding.mean(dim=1)
|
58 |
+
text_output = self.text_encoder(tokenized_text)
|
59 |
+
text_pooled = text_output.mean(dim=1)
|
60 |
+
|
61 |
+
vision_embedded = self.vision_projector(vision_patch_pooled)
|
62 |
+
text_embedded = self.text_projector(text_pooled)
|
63 |
+
|
64 |
+
vision_embedded = F.normalize(vision_embedded, dim=1)
|
65 |
+
text_embedded = F.normalize(text_embedded, dim=1)
|
66 |
+
|
67 |
+
return text_embedded, vision_embedded
|
68 |
+
|
69 |
+
|
70 |
+
def infoNCE_loss(text_features, vision_features, temperature=0.07):
|
71 |
+
text_features = F.normalize(text_features, p=2, dim=-1)
|
72 |
+
vision_features = F.normalize(vision_features, p=2, dim=-1)
|
73 |
+
|
74 |
+
similarity_matrix = torch.matmul(text_features, vision_features.T) / temperature
|
75 |
+
batch_size = vision_features.size(0)
|
76 |
+
labels = torch.arange(batch_size, device=vision_features.device)
|
77 |
+
|
78 |
+
loss_text_to_image = F.cross_entropy(similarity_matrix, labels)
|
79 |
+
loss_image_to_text = F.cross_entropy(similarity_matrix.T, labels)
|
80 |
+
|
81 |
+
return (loss_text_to_image + loss_image_to_text) / 2
|
82 |
+
|
83 |
+
|
84 |
+
def train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=5, freeze_text_encoder=True, checkpoint_path=None):
|
85 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
86 |
+
best_val_loss = float('inf') # Initialize with a very high value
|
87 |
+
|
88 |
+
# Freeze text encoder if specified
|
89 |
+
if freeze_text_encoder:
|
90 |
+
for param in model.text_encoder.parameters():
|
91 |
+
param.requires_grad = False
|
92 |
+
|
93 |
+
# Ensure new layers are trainable
|
94 |
+
for param in model.vision_projector.parameters():
|
95 |
+
param.requires_grad = True
|
96 |
+
for param in model.text_projector.parameters():
|
97 |
+
param.requires_grad = True
|
98 |
+
|
99 |
+
model.to(device)
|
100 |
+
|
101 |
+
for epoch in range(num_epochs):
|
102 |
+
|
103 |
+
# Train loop
|
104 |
+
model.train()
|
105 |
+
total_loss = 0.0
|
106 |
+
|
107 |
+
print(f"\nEpoch {epoch + 1}/{num_epochs} - Training:")
|
108 |
+
train_progress = tqdm(train_loader, desc="Training", leave=True)
|
109 |
+
|
110 |
+
for image_embeddings, tokenized_captions, attention_masks in train_progress:
|
111 |
+
text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)}
|
112 |
+
image_embeddings = image_embeddings.to(device)
|
113 |
+
|
114 |
+
optimizer.zero_grad()
|
115 |
+
text_features, vision_features = model(text_inputs, image_embeddings)
|
116 |
+
loss = infoNCE_loss(text_features, vision_features)
|
117 |
+
loss.backward()
|
118 |
+
optimizer.step()
|
119 |
+
total_loss += loss.item()
|
120 |
+
train_progress.set_postfix(loss=loss.item())
|
121 |
+
|
122 |
+
scheduler.step()
|
123 |
+
|
124 |
+
# Validation Loop
|
125 |
+
model.eval()
|
126 |
+
val_loss = 0.0
|
127 |
+
|
128 |
+
print(f"\nEpoch {epoch + 1}/{num_epochs} - Validation:")
|
129 |
+
val_progress = tqdm(val_loader, desc="Validation", leave=True)
|
130 |
+
|
131 |
+
with torch.no_grad():
|
132 |
+
for image_embeddings, tokenized_captions, attention_masks in val_progress:
|
133 |
+
text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)}
|
134 |
+
image_embeddings = image_embeddings.to(device)
|
135 |
+
|
136 |
+
text_features, vision_features = model(text_inputs, image_embeddings)
|
137 |
+
loss = infoNCE_loss(text_features, vision_features)
|
138 |
+
val_loss += loss.item()
|
139 |
+
val_progress.set_postfix(loss=loss.item())
|
140 |
+
|
141 |
+
avg_train_loss = total_loss / len(train_loader)
|
142 |
+
avg_val_loss = val_loss / len(val_loader)
|
143 |
+
print(f"\nEpoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
|
144 |
+
|
145 |
+
# Save best model
|
146 |
+
if checkpoint_path is not None:
|
147 |
+
if avg_val_loss < best_val_loss:
|
148 |
+
best_val_loss = avg_val_loss
|
149 |
+
torch.save({
|
150 |
+
'epoch': epoch + 1,
|
151 |
+
'model_state_dict': model.state_dict(),
|
152 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
153 |
+
'val_loss': best_val_loss
|
154 |
+
}, checkpoint_path)
|
155 |
+
print(f"New Best Model Saved at: {checkpoint_path} (Val Loss: {best_val_loss:.4f})")
|
156 |
+
|
157 |
+
print("Training completed!")
|
158 |
+
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
# Set random seed for reproducibility
|
163 |
+
# torch.manual_seed(42)
|
164 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
165 |
+
|
166 |
+
# Paths for dataset
|
167 |
+
captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'
|
168 |
+
# embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'
|
169 |
+
embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/vision_embeddings_reduced2'
|
170 |
+
|
171 |
+
# Initialize datasets and loaders
|
172 |
+
full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)
|
173 |
+
train_size = int(0.85 * len(full_dataset))
|
174 |
+
val_size = len(full_dataset) - train_size
|
175 |
+
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
176 |
+
|
177 |
+
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8, pin_memory=True)
|
178 |
+
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
|
179 |
+
|
180 |
+
# Initialize model, optimizer, and scheduler
|
181 |
+
model = JointNetwork().to(device)
|
182 |
+
|
183 |
+
checkpoint_path = f"./checkpoints/model_checkpoint_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
|
184 |
+
|
185 |
+
# **Phase 1 Configuration: Training new layers only**
|
186 |
+
initial_lr = 1e-4
|
187 |
+
min_lr = 1e-6
|
188 |
+
num_epochs = 16
|
189 |
+
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
|
190 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
|
191 |
+
|
192 |
+
# **Phase 1: Train new layers only, freeze text encoder**
|
193 |
+
print("\n### Phase 1: Training new layers only (Text Encoder Frozen) ###")
|
194 |
+
train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs, freeze_text_encoder=True, checkpoint_path=checkpoint_path)
|
195 |
+
|
196 |
+
# # **Phase 2 Configuration: Fine-tuning with adjusted learning rate**
|
197 |
+
# initial_lr = 1e-4
|
198 |
+
# min_lr = 1e-6
|
199 |
+
# num_epochs = 3
|
200 |
+
# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
|
201 |
+
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
|
202 |
+
|
203 |
+
# print("\n### Phase 2: Fine-tuning text encoder and new layers ###")
|
204 |
+
# train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs, freeze_text_encoder=False, checkpoint_path=checkpoint_path)
|
vision_encoder.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq
|
4 |
+
from transformers.image_utils import load_image
|
5 |
+
|
6 |
+
|
7 |
+
class ideficsV3(nn.Module):
|
8 |
+
def __init__(self, model_name="HuggingFaceTB/SmolVLM-Instruct"):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
# load smolVLM model from huggingface
|
12 |
+
self.image_processor = AutoProcessor.from_pretrained(model_name).image_processor
|
13 |
+
smolVLM = AutoModelForVision2Seq.from_pretrained(model_name, torch_dtype=torch.float32)
|
14 |
+
|
15 |
+
# Extract the necessary modules
|
16 |
+
self.vision_model = smolVLM.model.vision_model
|
17 |
+
|
18 |
+
def forward(self, pixel_values):
|
19 |
+
|
20 |
+
#################################################################
|
21 |
+
|
22 |
+
# The error ValueError: too many values to unpack (expected 4) occurs because the pixel_values tensor you passed into the model has a shape of [1, 13, 3, 384, 384], while the vision transformer (ViT) expects an input shape of [batch_size, channels, height, width], i.e., a 4D tensor.
|
23 |
+
# Your pixel_values tensor is 5D because it contains multiple patches, while the ViT expects a single image or batch of images.
|
24 |
+
# You need to flatten the patch dimension (the second dimension, 13) into the batch dimension (1) before passing it to the vision transformer.
|
25 |
+
|
26 |
+
# Flatten the patch dimension into the batch dimension
|
27 |
+
batch_size, num_patches, channels, height, width = pixel_values.shape
|
28 |
+
pixel_values = pixel_values.view(batch_size * num_patches, channels, height, width)
|
29 |
+
|
30 |
+
#################################################################
|
31 |
+
|
32 |
+
# Run images through the vision transformer
|
33 |
+
vision_outputs = self.vision_model(pixel_values)
|
34 |
+
x = vision_outputs.last_hidden_state # shape := [batch_size * num_patches, 729, 1152]
|
35 |
+
|
36 |
+
return x
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
|
40 |
+
# Instantiate truncated model
|
41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
42 |
+
truncated_model = ideficsV3().to(device).eval()
|
43 |
+
truncated_model.eval()
|
44 |
+
|
45 |
+
image1 = load_image("https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg")
|
46 |
+
image2 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
|
47 |
+
|
48 |
+
inputs1 = truncated_model.image_processor(images=[image1, image2], return_tensors="pt")
|
49 |
+
pixel_values = inputs1.pixel_values.to(model_dtype).to(device)
|
50 |
+
|
51 |
+
# Pass pixel_values through your truncated model
|
52 |
+
with torch.no_grad():
|
53 |
+
outputs = truncated_model(pixel_values)
|
54 |
+
|
55 |
+
print(outputs.shape) # Should be [batch_size, 2048] given the projection layer output.
|
56 |
+
|