File size: 7,357 Bytes
a55cbd2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import math
from typing import ClassVar, List, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import BatchFeature
from transformers.models.qwen2_vl import Qwen2VLProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
class ColQwenStellaProcessor(BaseVisualRetrieverProcessor, Qwen2VLProcessor):
"""
Processor for ColQwen2.
"""
visual_prompt_prefix: ClassVar[str] = (
"<|im_start|><|image_pad|><|im_end|><|endoftext|>"
)
query_prefix: ClassVar[str] = "Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: "
query_augmentation_token: ClassVar[str] = "<|endoftext|>"
image_token: ClassVar[str] = "<|image_pad|>"
@property
def image_token_id(self) -> int:
return self.tokenizer.convert_tokens_to_ids(self.image_token)
def __init__(self, *args, **kwargs):
num_image_tokens = kwargs.pop("num_image_tokens", 768)
super().__init__(*args, **kwargs)
self.tokenizer.padding_side = "left"
self.min_pixels = 4 * 28 * 28
self.max_pixels = num_image_tokens * 28 * 28
self.factor = 28
self.max_ratio = 200
@staticmethod
def smart_resize_helper(
width: int,
height: int,
factor: int,
max_ratio: int,
min_pixels: int,
max_pixels: int,
) -> Tuple[int, int]:
"""
Returns the image size so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > max_ratio:
raise ValueError(
f"absolute aspect ratio must be smaller than {max_ratio}, "
f"got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def smart_resize(self, image: Image.Image) -> Image.Image:
"""
Resize and convert the image to the required format.
"""
image_size = image.size
resized_height, resized_width = self.smart_resize_helper(
width=image_size[0],
height=image_size[1],
factor=self.factor,
max_ratio=self.max_ratio,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
return image.convert("RGB").resize((resized_width, resized_height))
def process_images(
self,
images: List[Image.Image],
) -> BatchFeature:
"""
Process images for ColQwen2.
"""
texts_doc = [self.visual_prompt_prefix] * len(images)
resized_images: List[Image.Image] = [self.smart_resize(image) for image in images]
# # batch_doc["input_ids"][0][batch_doc["input_ids"][0]==151655] = 151646
batch_doc = self(
text=texts_doc,
images=resized_images,
padding="longest",
return_tensors="pt",
)
for i in range(batch_doc["input_ids"].shape[0]):
batch_doc["input_ids"][i][batch_doc["input_ids"][i]==151655] = 151646
# NOTE: The following code is a hack to make sure the scatter in DDP is done correctly when training
# on multiple GPUs.
offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]
# separate pixel_values for each image
pixel_values = torch.split(batch_doc["pixel_values"], offsets.tolist())
# pad pixel_values to the same length to be able to make it into a tensor
max_length = max([len(pv) for pv in pixel_values])
pixel_values = [
torch.cat([pv, torch.zeros((max_length - len(pv), pv.shape[1]), dtype=pv.dtype, device=pv.device)])
for pv in pixel_values
]
batch_doc["pixel_values"] = torch.stack(pixel_values)
return batch_doc
def process_queries(
self,
queries: List[str],
max_length: int = 50,
suffix: Optional[str] = None,
) -> BatchFeature:
"""
Process queries for ColQwen2.
"""
if suffix is None:
suffix = self.query_augmentation_token * 10
texts_query: List[str] = []
for query in queries:
query = self.query_prefix + query + suffix
texts_query.append(query)
batch_query = self(
text=texts_query,
return_tensors="pt",
padding="longest",
)
return batch_query
def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
"""
return self.score_multi_vector(qs, ps, device=device, **kwargs)
def get_n_patches(
self,
image_size: Tuple[int, int],
patch_size: int,
spatial_merge_size: int,
) -> Tuple[int, int]:
"""
Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
size (height, width) with the given patch size.
The `spatial_merge_size` is the number of patches that will be merged spatially. It is stored in
as a `Qwen2VLForConditionalGeneration` attribute under `model.spatial_merge_size`.
"""
height_new, width_new = self.smart_resize_helper(
width=image_size[0],
height=image_size[1],
factor=self.factor,
max_ratio=self.max_ratio,
min_pixels=self.min_pixels,
max_pixels=self.max_pixels,
)
n_patches_x = width_new // patch_size // spatial_merge_size
n_patches_y = height_new // patch_size // spatial_merge_size
return n_patches_x, n_patches_y
def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
return batch_images.input_ids == self.image_token_id
|