nolan4 commited on
Commit
8a00d0d
·
0 Parent(s):

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ checkpoints
__pycache__/text_encoder.cpython-311.pyc ADDED
Binary file (1.82 kB). View file
 
__pycache__/train.cpython-311.pyc ADDED
Binary file (11.5 kB). View file
 
__pycache__/vision_encoder.cpython-311.pyc ADDED
Binary file (3.05 kB). View file
 
_dataset/__pycache__/preprocess_images.cpython-311.pyc ADDED
Binary file (5.8 kB). View file
 
_dataset/preprocess_captions.ipynb ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from collections import defaultdict\n",
10
+ "from transformers import AutoTokenizer\n",
11
+ "from tqdm import tqdm\n",
12
+ "import json\n",
13
+ "\n",
14
+ "def load_and_process_token_file(input_path, tokenizer_name=\"answerdotai/ModernBERT-base\"):\n",
15
+ " captions_dict = defaultdict(list)\n",
16
+ " tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n",
17
+ " max_length = 0 # Initialize max length counter\n",
18
+ "\n",
19
+ " # Read and process the token file with tokenization\n",
20
+ " with open(input_path, 'r') as file:\n",
21
+ " for line in tqdm(file, desc=\"Processing Captions\"):\n",
22
+ " image_id, caption = line.strip().split('\\t')\n",
23
+ " jpg_number = image_id.split('.')[0]\n",
24
+ " \n",
25
+ " # Tokenize without padding and truncation to calculate the true length\n",
26
+ " tokens = tokenizer(caption, return_tensors=\"pt\", padding=False, truncation=False)\n",
27
+ " token_ids = tokens['input_ids'].squeeze(0).tolist()\n",
28
+ " \n",
29
+ " # Update max_length based on this tokenized sequence length\n",
30
+ " max_length = max(max_length, len(token_ids))\n",
31
+ " \n",
32
+ " # Tokenize with padding and attention mask (padded to 93 tokens)\n",
33
+ " tokens_padded = tokenizer(caption, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=2**7) # 93 < 2**7\n",
34
+ " token_ids_padded = tokens_padded['input_ids'].squeeze(0).tolist()\n",
35
+ " attention_mask = tokens_padded['attention_mask'].squeeze(0).tolist()\n",
36
+ "\n",
37
+ " # Save both raw caption, tokenized version, and attention mask\n",
38
+ " captions_dict[jpg_number].append({\n",
39
+ " \"text\": caption,\n",
40
+ " \"tokenized\": token_ids_padded,\n",
41
+ " \"attention_mask\": attention_mask\n",
42
+ " })\n",
43
+ "\n",
44
+ " print(f\"Maximum sequence length (before padding): {max_length}\")\n",
45
+ " return captions_dict, max_length\n",
46
+ "\n",
47
+ "# Define the input path and process the file\n",
48
+ "input_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/results_20130124.token'\n",
49
+ "captions_dict, max_length = load_and_process_token_file(input_path)\n",
50
+ "\n",
51
+ "# Save the modified dictionary with tokenized captions and attention masks to a JSON file\n",
52
+ "output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n",
53
+ "with open(output_path, 'w') as json_file:\n",
54
+ " json.dump(captions_dict, json_file)\n",
55
+ "\n",
56
+ "# Display the maximum token length\n",
57
+ "print(f\"Final maximum token length across dataset: {max_length}\")\n",
58
+ "\n",
59
+ "# Display the first few entries to verify the content\n",
60
+ "for jpg, captions in list(captions_dict.items())[:5]:\n",
61
+ " print(f\"{jpg}: {captions}\")"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "\n",
71
+ "# Save the dictionary to a JSON file\n",
72
+ "output_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_dict.json'\n",
73
+ "with open(output_path, 'w') as json_file:\n",
74
+ " json.dump(captions_dict, json_file)\n",
75
+ "\n",
76
+ "print(f\"Captions dictionary saved to {output_path}\")"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": []
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": 2,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "import torch\n",
93
+ "from torch.utils.data import Dataset, DataLoader\n",
94
+ "import os\n",
95
+ "import json\n",
96
+ "import numpy as np\n",
97
+ "import random\n",
98
+ "\n",
99
+ "\n",
100
+ "# Vision Caption Dataset\n",
101
+ "class VisionCaptionDataset(torch.utils.data.Dataset):\n",
102
+ " def __init__(self, captions_path, embeddings_dir, normalize=True):\n",
103
+ " with open(captions_path, 'r') as f:\n",
104
+ " self.captions_dict = json.load(f)\n",
105
+ "\n",
106
+ " self.embeddings_dir = embeddings_dir\n",
107
+ " self.image_ids = list(self.captions_dict.keys())\n",
108
+ " self.normalize = normalize\n",
109
+ "\n",
110
+ " def __len__(self):\n",
111
+ " return len(self.image_ids)\n",
112
+ "\n",
113
+ " def __getitem__(self, idx):\n",
114
+ " image_id = self.image_ids[idx]\n",
115
+ " \n",
116
+ " # Randomly select a caption and load the tokenized version\n",
117
+ " caption_entry = random.choice(self.captions_dict[image_id])\n",
118
+ " tokenized_caption = caption_entry[\"tokenized\"]\n",
119
+ " attention_mask = caption_entry[\"attention_mask\"]\n",
120
+ "\n",
121
+ " # Load vision embedding\n",
122
+ " embedding_path = os.path.join(self.embeddings_dir, f\"{image_id}.npy\")\n",
123
+ " embedding = np.load(embedding_path)\n",
124
+ "\n",
125
+ " # Convert vision embedding and tokenized caption to tensors\n",
126
+ " embedding = torch.tensor(embedding, dtype=torch.float32)\n",
127
+ " tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)\n",
128
+ " attention_mask = torch.tensor(attention_mask, dtype=torch.long)\n",
129
+ "\n",
130
+ " return embedding, tokenized_caption, attention_mask\n",
131
+ "\n",
132
+ "# Example usage\n",
133
+ "# Paths for dataset\n",
134
+ "captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'\n",
135
+ "embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'\n",
136
+ "\n",
137
+ "# Initialize the dataset and split it into train/validation sets\n",
138
+ "full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)\n",
139
+ "\n",
140
+ "# Initialize the DataLoaders with `num_workers` and `pin_memory`\n",
141
+ "train_dataloader = DataLoader(full_dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)\n"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "# Verify a batch\n",
151
+ "for batch in train_dataloader:\n",
152
+ " embeddings, captions, attn_mask = batch\n",
153
+ " print(embeddings.shape, len(captions))\n",
154
+ " \n",
155
+ "\n",
156
+ " break"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": []
165
+ }
166
+ ],
167
+ "metadata": {
168
+ "kernelspec": {
169
+ "display_name": "hf-env",
170
+ "language": "python",
171
+ "name": "python3"
172
+ },
173
+ "language_info": {
174
+ "codemirror_mode": {
175
+ "name": "ipython",
176
+ "version": 3
177
+ },
178
+ "file_extension": ".py",
179
+ "mimetype": "text/x-python",
180
+ "name": "python",
181
+ "nbconvert_exporter": "python",
182
+ "pygments_lexer": "ipython3",
183
+ "version": "3.11.11"
184
+ }
185
+ },
186
+ "nbformat": 4,
187
+ "nbformat_minor": 2
188
+ }
_dataset/preprocess_images.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import os
6
+ import shutil
7
+ from PIL import Image
8
+ from transformers.image_utils import load_image
9
+ import sys
10
+ sys.path.append('..')
11
+ from vision_encoder import ideficsV3
12
+ from tqdm import tqdm
13
+
14
+ class VisionPreprocessor:
15
+ def __init__(self, device=None, param_dtype=torch.float32):
16
+ self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.param_dtype = param_dtype
18
+
19
+ # Initialize and freeze the vision encoder
20
+ self.vision_encoder = ideficsV3("HuggingFaceTB/SmolVLM-Instruct").eval().to(self.device)
21
+ for param in self.vision_encoder.parameters():
22
+ param.requires_grad = False
23
+
24
+ def load_image(self, image_path):
25
+ """Load an image using PIL without preprocessing."""
26
+ image = load_image(image_path)
27
+ # Convert to tensor without resizing or additional transformations
28
+ inputs = self.vision_encoder.image_processor(images=[image], return_tensors="pt")
29
+ pixel_values = inputs.pixel_values.to(self.param_dtype).to(self.device)
30
+ return pixel_values
31
+
32
+ def extract_embedding(self, image_tensor):
33
+ """Extract raw vision embedding."""
34
+ with torch.no_grad():
35
+ vision_output = self.vision_encoder(image_tensor)
36
+
37
+ vision_output = vision_output.mean(axis=0)
38
+
39
+ return vision_output
40
+
41
+ def save_embedding(self, vision_output, file_path):
42
+ """Save the vision output to a numpy file."""
43
+ np.save(file_path, vision_output.cpu().numpy())
44
+
45
+ def process_directory(self, image_paths, output_dir):
46
+ """Process all images in a directory with a progress bar and save the embeddings."""
47
+ if os.path.exists(output_dir):
48
+ shutil.rmtree(output_dir)
49
+ print(f"Existing directory cleared: {output_dir}")
50
+ os.makedirs(output_dir, exist_ok=True)
51
+
52
+ # Adding tqdm for progress bar
53
+ for image_path in tqdm(image_paths, desc="Processing Images", unit="image"):
54
+
55
+ # Load and extract features without preprocessing
56
+ image_tensor = self.load_image(image_path)
57
+ vision_output = self.extract_embedding(image_tensor)
58
+
59
+ # Save the output with the same filename but as a .npy
60
+ output_file_path = os.path.join(output_dir, f"{os.path.splitext(os.path.basename(image_path))[0]}.npy")
61
+ self.save_embedding(vision_output, output_file_path)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ torch.manual_seed(42)
66
+ device = "cuda" if torch.cuda.is_available() else "cpu"
67
+ param_dtype = torch.float32
68
+
69
+ # Instantiate the pipeline
70
+ pipeline = VisionPreprocessor(device, param_dtype)
71
+
72
+ # Specify input and output directories
73
+ input_directory = "/mnt/nvme/shared_A/datasets/flickr30k/data/flickr30k-images"
74
+ output_directory = "/mnt/nvme/shared_A/datasets/flickr30k/data/vision_embeddings_reduced2"
75
+
76
+ image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
77
+ # Process all images in the input directory
78
+ pipeline.process_directory(image_paths, output_directory)
79
+ print("Processing complete!")
demo.ipynb ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Image search with modernBERT"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 18,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "from _dataset.preprocess_images import *\n",
17
+ "import random"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "\n",
27
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
28
+ "pipeline = VisionPreprocessor(device, param_dtype=torch.float32)\n",
29
+ "\n",
30
+ "num_images = 25\n",
31
+ "input_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/val2017\"\n",
32
+ "image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
33
+ "\n",
34
+ "# Shuffle and take the first 25 images\n",
35
+ "# random.shuffle(image_paths)\n",
36
+ "image_paths = image_paths[:num_images]\n",
37
+ "\n",
38
+ "# Print the selected image paths\n",
39
+ "print(\"Selected Image Paths:\")\n",
40
+ "for path in image_paths:\n",
41
+ " print(path)\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "import os\n",
51
+ "import shutil\n",
52
+ "\n",
53
+ "# Specify the output directory\n",
54
+ "output_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
55
+ "\n",
56
+ "# Clear the vision embeddings directory if it exists, otherwise create it\n",
57
+ "if os.path.exists(output_directory):\n",
58
+ " shutil.rmtree(output_directory)\n",
59
+ " print(f\"Existing directory cleared: {output_directory}\")\n",
60
+ "os.makedirs(output_directory, exist_ok=True)\n",
61
+ "\n",
62
+ "# Process all images in the input directory\n",
63
+ "pipeline.process_directory(image_paths, output_directory)\n",
64
+ "print(\"Image embeddings saved!\")"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "from train import JointNetwork\n",
74
+ "\n",
75
+ "def load_checkpoint_and_prepare_model(checkpoint_path, device=\"cuda\"):\n",
76
+ " \"\"\"Load trained JointNetwork() from checkpoint\"\"\"\n",
77
+ " device = torch.device(device)\n",
78
+ " model = JointNetwork()\n",
79
+ " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n",
80
+ " model.load_state_dict(checkpoint['model_state_dict'])\n",
81
+ " model.to(device)\n",
82
+ " model.eval()\n",
83
+ " model.device = device\n",
84
+ " print(f\"Model loaded successfully from {checkpoint_path}.\")\n",
85
+ " return model\n",
86
+ "\n",
87
+ "def get_text_embedding(model, text_prompt):\n",
88
+ " \"\"\"Encode a text prompt to get its embedding using the modernBERT encoder.\"\"\"\n",
89
+ " tokenized_text = model.text_encoder.tokenizer(text_prompt, return_tensors=\"pt\").to(model.device)\n",
90
+ " with torch.no_grad():\n",
91
+ " text_features = model.text_encoder(tokenized_text)\n",
92
+ " text_features = model.text_projector(text_features.mean(dim=1))\n",
93
+ " text_features = F.normalize(text_features, dim=1)\n",
94
+ " return text_features\n",
95
+ "\n",
96
+ "def load_image_embeddings(model, embeddings_dir):\n",
97
+ " \"\"\"Load all precomputed image embeddings from the specified directory.\"\"\"\n",
98
+ " vision_embeddings = []\n",
99
+ " for file in sorted(os.listdir(embeddings_dir)):\n",
100
+ " if file.endswith(\".npy\"):\n",
101
+ " image_encoding = torch.tensor(np.load(os.path.join(embeddings_dir, file)), dtype=torch.float32).to(model.device)\n",
102
+ " vision_pooled = image_encoding.mean(dim=0).unsqueeze(0)\n",
103
+ " vision_embedded = model.vision_projector(vision_pooled)\n",
104
+ " vision_embedded = F.normalize(vision_embedded, dim=1)\n",
105
+ " vision_embeddings.append(vision_embedded)\n",
106
+ " \n",
107
+ " if len(vision_embeddings) == 0:\n",
108
+ " raise ValueError(\"No vision embeddings found in the specified directory.\")\n",
109
+ " print(f\"Vision embeddings loaded successfully from {embeddings_dir}.\")\n",
110
+ " return torch.stack(vision_embeddings).squeeze(1)\n",
111
+ "\n",
112
+ "def compare_text_to_images(text_embedding, vision_embeddings):\n",
113
+ " \"\"\"Compare a text embedding against a batch of image embeddings using cosine similarity.\"\"\"\n",
114
+ " cosine_similarities = torch.matmul(text_embedding, vision_embeddings.T).squeeze(0)\n",
115
+ " similarity_scores = cosine_similarities.cpu().detach().numpy()\n",
116
+ " ranked_indices = similarity_scores.argsort()[::-1] # Sort in descending order\n",
117
+ " return ranked_indices, similarity_scores\n",
118
+ "\n",
119
+ "\n",
120
+ "\n",
121
+ "# Paths and settings\n",
122
+ "checkpoint_path = \"/home/nolan4/projects/hf-contest/checkpoints/model_checkpoint_20250109_102039.pth\"\n",
123
+ "embeddings_dir = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
124
+ "\n",
125
+ "# Load the model and precomputed vision embeddings\n",
126
+ "model = load_checkpoint_and_prepare_model(checkpoint_path)\n",
127
+ "vision_embeddings = load_image_embeddings(model, embeddings_dir)"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": null,
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "import matplotlib.pyplot as plt\n",
137
+ "import os\n",
138
+ "from PIL import Image\n",
139
+ "\n",
140
+ "def display_images_from_paths(image_paths, num_images=5):\n",
141
+ "\n",
142
+ " num_images = min(num_images, len(image_paths))\n",
143
+ " if num_images == 0:\n",
144
+ " print(\"No images found in the directory.\")\n",
145
+ " return\n",
146
+ "\n",
147
+ " plt.figure(figsize=(12, 8))\n",
148
+ " for i, image_path in enumerate(image_paths[:num_images]):\n",
149
+ " img = Image.open(image_path)\n",
150
+ " plt.subplot(1, num_images, i + 1)\n",
151
+ " plt.imshow(img)\n",
152
+ " plt.axis('off') \n",
153
+ " plt.title(f\"{os.path.basename(image_path).split('.')[0]}\")\n",
154
+ "\n",
155
+ " plt.tight_layout()\n",
156
+ " plt.show()\n",
157
+ "\n",
158
+ "# Example usage\n",
159
+ "# random.shuffle(image_paths)\n",
160
+ "display_images_from_paths(image_paths, num_images=10)"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "# Paths and settings\n",
170
+ "text_prompt = \"cars driving down the road\"\n",
171
+ "# text_prompt = \"stuffed brown teddy bear\"\n",
172
+ "\n",
173
+ "\n",
174
+ "# Load the model and embeddings\n",
175
+ "text_embedding = get_text_embedding(model, text_prompt)\n",
176
+ "\n",
177
+ "# Perform comparison and display results\n",
178
+ "ranked_indices, similarity_scores = compare_text_to_images(text_embedding, vision_embeddings)\n",
179
+ "print(f\"\\nTop 5 Most Similar Images:\")\n",
180
+ "for idx in ranked_indices[:5]:\n",
181
+ " print(f\"Image Index: {idx}, Similarity Score: {similarity_scores[idx]:.4f}\")"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "# Ensure ranked_indices is converted to a Python list\n",
191
+ "selected_image_paths = [image_paths[idx] for idx in ranked_indices[:10]]\n",
192
+ "\n",
193
+ "# Display the top N ranked images\n",
194
+ "display_images_from_paths(selected_image_paths, num_images=4)"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": []
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": []
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": []
217
+ }
218
+ ],
219
+ "metadata": {
220
+ "kernelspec": {
221
+ "display_name": "hf-env",
222
+ "language": "python",
223
+ "name": "python3"
224
+ },
225
+ "language_info": {
226
+ "codemirror_mode": {
227
+ "name": "ipython",
228
+ "version": 3
229
+ },
230
+ "file_extension": ".py",
231
+ "mimetype": "text/x-python",
232
+ "name": "python",
233
+ "nbconvert_exporter": "python",
234
+ "pygments_lexer": "ipython3",
235
+ "version": "3.11.11"
236
+ }
237
+ },
238
+ "nbformat": 4,
239
+ "nbformat_minor": 2
240
+ }
text_encoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, ModernBertModel
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ import pdb
6
+
7
+ class modernBERT(nn.Module):
8
+ def __init__(self, model_name="answerdotai/ModernBERT-base"):
9
+ super(modernBERT, self).__init__()
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ self.bert = ModernBertModel.from_pretrained(model_name)
12
+
13
+ def forward(self, inputs):
14
+ # inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True)
15
+ outputs = self.bert(**inputs)
16
+
17
+ return outputs.last_hidden_state # logits
18
+
19
+ # Example training loop
20
+ if __name__ == "__main__":
21
+ model = modernBERT("answerdotai/ModernBERT-base")
22
+
23
+ texts = ["Potato's no name for a dog"]
24
+ text_inputs = {"input_ids": model.tokenizer(texts)}
25
+ output = model(text_inputs)
26
+
27
+ print(output[0].shape)
train.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader, random_split
5
+ from text_encoder import *
6
+ from vision_encoder import *
7
+ import os
8
+ import json
9
+ import numpy as np
10
+ import random
11
+ from tqdm import tqdm
12
+ import datetime
13
+
14
+ # Vision Caption Dataset
15
+ class VisionCaptionDataset(torch.utils.data.Dataset):
16
+ def __init__(self, captions_path, embeddings_dir, normalize=True):
17
+ with open(captions_path, 'r') as f:
18
+ self.captions_dict = json.load(f)
19
+
20
+ self.embeddings_dir = embeddings_dir
21
+ self.image_ids = list(self.captions_dict.keys())
22
+ self.normalize = normalize
23
+
24
+ def __len__(self):
25
+ return len(self.image_ids)
26
+
27
+ def __getitem__(self, idx):
28
+ image_id = self.image_ids[idx]
29
+
30
+ caption_entry = random.choice(self.captions_dict[image_id])
31
+ tokenized_caption = caption_entry["tokenized"]
32
+ attention_mask = caption_entry["attention_mask"]
33
+
34
+ embedding_path = os.path.join(self.embeddings_dir, f"{image_id}.npy")
35
+ embedding = np.load(embedding_path)
36
+
37
+ embedding = torch.tensor(embedding, dtype=torch.float32)
38
+ tokenized_caption = torch.tensor(tokenized_caption, dtype=torch.long)
39
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long)
40
+
41
+ return embedding, tokenized_caption, attention_mask
42
+
43
+
44
+ class JointNetwork(nn.Module):
45
+ def __init__(self):
46
+ super(JointNetwork, self).__init__()
47
+
48
+ self.text_encoder = modernBERT("answerdotai/ModernBERT-base")
49
+
50
+ for param in self.text_encoder.parameters():
51
+ param.requires_grad = True
52
+
53
+ self.vision_projector = nn.Linear(1152, 512)
54
+ self.text_projector = nn.Linear(768, 512)
55
+
56
+ def forward(self, tokenized_text, image_encoding):
57
+ vision_patch_pooled = image_encoding.mean(dim=1)
58
+ text_output = self.text_encoder(tokenized_text)
59
+ text_pooled = text_output.mean(dim=1)
60
+
61
+ vision_embedded = self.vision_projector(vision_patch_pooled)
62
+ text_embedded = self.text_projector(text_pooled)
63
+
64
+ vision_embedded = F.normalize(vision_embedded, dim=1)
65
+ text_embedded = F.normalize(text_embedded, dim=1)
66
+
67
+ return text_embedded, vision_embedded
68
+
69
+
70
+ def infoNCE_loss(text_features, vision_features, temperature=0.07):
71
+ text_features = F.normalize(text_features, p=2, dim=-1)
72
+ vision_features = F.normalize(vision_features, p=2, dim=-1)
73
+
74
+ similarity_matrix = torch.matmul(text_features, vision_features.T) / temperature
75
+ batch_size = vision_features.size(0)
76
+ labels = torch.arange(batch_size, device=vision_features.device)
77
+
78
+ loss_text_to_image = F.cross_entropy(similarity_matrix, labels)
79
+ loss_image_to_text = F.cross_entropy(similarity_matrix.T, labels)
80
+
81
+ return (loss_text_to_image + loss_image_to_text) / 2
82
+
83
+
84
+ def train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=5, freeze_text_encoder=True, checkpoint_path=None):
85
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+ best_val_loss = float('inf') # Initialize with a very high value
87
+
88
+ # Freeze text encoder if specified
89
+ if freeze_text_encoder:
90
+ for param in model.text_encoder.parameters():
91
+ param.requires_grad = False
92
+
93
+ # Ensure new layers are trainable
94
+ for param in model.vision_projector.parameters():
95
+ param.requires_grad = True
96
+ for param in model.text_projector.parameters():
97
+ param.requires_grad = True
98
+
99
+ model.to(device)
100
+
101
+ for epoch in range(num_epochs):
102
+
103
+ # Train loop
104
+ model.train()
105
+ total_loss = 0.0
106
+
107
+ print(f"\nEpoch {epoch + 1}/{num_epochs} - Training:")
108
+ train_progress = tqdm(train_loader, desc="Training", leave=True)
109
+
110
+ for image_embeddings, tokenized_captions, attention_masks in train_progress:
111
+ text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)}
112
+ image_embeddings = image_embeddings.to(device)
113
+
114
+ optimizer.zero_grad()
115
+ text_features, vision_features = model(text_inputs, image_embeddings)
116
+ loss = infoNCE_loss(text_features, vision_features)
117
+ loss.backward()
118
+ optimizer.step()
119
+ total_loss += loss.item()
120
+ train_progress.set_postfix(loss=loss.item())
121
+
122
+ scheduler.step()
123
+
124
+ # Validation Loop
125
+ model.eval()
126
+ val_loss = 0.0
127
+
128
+ print(f"\nEpoch {epoch + 1}/{num_epochs} - Validation:")
129
+ val_progress = tqdm(val_loader, desc="Validation", leave=True)
130
+
131
+ with torch.no_grad():
132
+ for image_embeddings, tokenized_captions, attention_masks in val_progress:
133
+ text_inputs = {"input_ids": tokenized_captions.to(device), "attention_mask": attention_masks.to(device)}
134
+ image_embeddings = image_embeddings.to(device)
135
+
136
+ text_features, vision_features = model(text_inputs, image_embeddings)
137
+ loss = infoNCE_loss(text_features, vision_features)
138
+ val_loss += loss.item()
139
+ val_progress.set_postfix(loss=loss.item())
140
+
141
+ avg_train_loss = total_loss / len(train_loader)
142
+ avg_val_loss = val_loss / len(val_loader)
143
+ print(f"\nEpoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
144
+
145
+ # Save best model
146
+ if checkpoint_path is not None:
147
+ if avg_val_loss < best_val_loss:
148
+ best_val_loss = avg_val_loss
149
+ torch.save({
150
+ 'epoch': epoch + 1,
151
+ 'model_state_dict': model.state_dict(),
152
+ 'optimizer_state_dict': optimizer.state_dict(),
153
+ 'val_loss': best_val_loss
154
+ }, checkpoint_path)
155
+ print(f"New Best Model Saved at: {checkpoint_path} (Val Loss: {best_val_loss:.4f})")
156
+
157
+ print("Training completed!")
158
+
159
+
160
+
161
+ if __name__ == "__main__":
162
+ # Set random seed for reproducibility
163
+ # torch.manual_seed(42)
164
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
165
+
166
+ # Paths for dataset
167
+ captions_path = '/mnt/nvme/shared_A/datasets/flickr30k/data/captions_tokenized.json'
168
+ # embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/reduced_vision_embeddings'
169
+ embeddings_dir = '/mnt/nvme/shared_A/datasets/flickr30k/data/vision_embeddings_reduced2'
170
+
171
+ # Initialize datasets and loaders
172
+ full_dataset = VisionCaptionDataset(captions_path, embeddings_dir)
173
+ train_size = int(0.85 * len(full_dataset))
174
+ val_size = len(full_dataset) - train_size
175
+ train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
176
+
177
+ train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8, pin_memory=True)
178
+ val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)
179
+
180
+ # Initialize model, optimizer, and scheduler
181
+ model = JointNetwork().to(device)
182
+
183
+ checkpoint_path = f"./checkpoints/model_checkpoint_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
184
+
185
+ # **Phase 1 Configuration: Training new layers only**
186
+ initial_lr = 1e-4
187
+ min_lr = 1e-6
188
+ num_epochs = 16
189
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
190
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
191
+
192
+ # **Phase 1: Train new layers only, freeze text encoder**
193
+ print("\n### Phase 1: Training new layers only (Text Encoder Frozen) ###")
194
+ train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs, freeze_text_encoder=True, checkpoint_path=checkpoint_path)
195
+
196
+ # # **Phase 2 Configuration: Fine-tuning with adjusted learning rate**
197
+ # initial_lr = 1e-4
198
+ # min_lr = 1e-6
199
+ # num_epochs = 3
200
+ # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=initial_lr)
201
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=min_lr)
202
+
203
+ # print("\n### Phase 2: Fine-tuning text encoder and new layers ###")
204
+ # train_model(model, train_loader, val_loader, optimizer, scheduler, num_epochs=num_epochs, freeze_text_encoder=False, checkpoint_path=checkpoint_path)
vision_encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoProcessor, AutoModelForVision2Seq
4
+ from transformers.image_utils import load_image
5
+
6
+
7
+ class ideficsV3(nn.Module):
8
+ def __init__(self, model_name="HuggingFaceTB/SmolVLM-Instruct"):
9
+ super().__init__()
10
+
11
+ # load smolVLM model from huggingface
12
+ self.image_processor = AutoProcessor.from_pretrained(model_name).image_processor
13
+ smolVLM = AutoModelForVision2Seq.from_pretrained(model_name, torch_dtype=torch.float32)
14
+
15
+ # Extract the necessary modules
16
+ self.vision_model = smolVLM.model.vision_model
17
+
18
+ def forward(self, pixel_values):
19
+
20
+ #################################################################
21
+
22
+ # The error ValueError: too many values to unpack (expected 4) occurs because the pixel_values tensor you passed into the model has a shape of [1, 13, 3, 384, 384], while the vision transformer (ViT) expects an input shape of [batch_size, channels, height, width], i.e., a 4D tensor.
23
+ # Your pixel_values tensor is 5D because it contains multiple patches, while the ViT expects a single image or batch of images.
24
+ # You need to flatten the patch dimension (the second dimension, 13) into the batch dimension (1) before passing it to the vision transformer.
25
+
26
+ # Flatten the patch dimension into the batch dimension
27
+ batch_size, num_patches, channels, height, width = pixel_values.shape
28
+ pixel_values = pixel_values.view(batch_size * num_patches, channels, height, width)
29
+
30
+ #################################################################
31
+
32
+ # Run images through the vision transformer
33
+ vision_outputs = self.vision_model(pixel_values)
34
+ x = vision_outputs.last_hidden_state # shape := [batch_size * num_patches, 729, 1152]
35
+
36
+ return x
37
+
38
+ if __name__ == "__main__":
39
+
40
+ # Instantiate truncated model
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ truncated_model = ideficsV3().to(device).eval()
43
+ truncated_model.eval()
44
+
45
+ image1 = load_image("https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg")
46
+ image2 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
47
+
48
+ inputs1 = truncated_model.image_processor(images=[image1, image2], return_tensors="pt")
49
+ pixel_values = inputs1.pixel_values.to(model_dtype).to(device)
50
+
51
+ # Pass pixel_values through your truncated model
52
+ with torch.no_grad():
53
+ outputs = truncated_model(pixel_values)
54
+
55
+ print(outputs.shape) # Should be [batch_size, 2048] given the projection layer output.
56
+