update
Browse files- image_tower_magma.py +379 -0
image_tower_magma.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Image processor class for Magma."""
|
17 |
+
|
18 |
+
from typing import List, Optional, Union
|
19 |
+
import logging
|
20 |
+
|
21 |
+
# Configure root logger
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import torchvision
|
26 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
27 |
+
from transformers.image_transforms import (
|
28 |
+
convert_to_rgb,
|
29 |
+
)
|
30 |
+
from transformers.image_utils import (
|
31 |
+
OPENAI_CLIP_MEAN,
|
32 |
+
OPENAI_CLIP_STD,
|
33 |
+
ImageInput,
|
34 |
+
make_list_of_images,
|
35 |
+
valid_images,
|
36 |
+
)
|
37 |
+
|
38 |
+
from transformers.utils import TensorType, is_vision_available, logging
|
39 |
+
logging.set_verbosity_info()
|
40 |
+
logger = logging.get_logger(__name__)
|
41 |
+
|
42 |
+
|
43 |
+
if is_vision_available():
|
44 |
+
from PIL import Image
|
45 |
+
|
46 |
+
import torchvision
|
47 |
+
|
48 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
49 |
+
# All rights reserved.
|
50 |
+
|
51 |
+
# This source code is licensed under the license found in the
|
52 |
+
# LICENSE file in the root directory of this source tree.
|
53 |
+
import json
|
54 |
+
import torch
|
55 |
+
import torch.nn as nn
|
56 |
+
import torch.nn.functional as F
|
57 |
+
|
58 |
+
import open_clip
|
59 |
+
from open_clip.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
|
60 |
+
from open_clip.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
61 |
+
list_pretrained_tags_by_model, download_pretrained_from_hf
|
62 |
+
from open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
63 |
+
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
|
64 |
+
from pathlib import Path
|
65 |
+
from typing import Optional, Tuple, Type
|
66 |
+
from functools import partial
|
67 |
+
import torch.utils.checkpoint as checkpoint
|
68 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
69 |
+
from dataclasses import asdict
|
70 |
+
HF_HUB_PREFIX = 'hf-hub:'
|
71 |
+
|
72 |
+
def _get_hf_config(model_id, cache_dir=None):
|
73 |
+
config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
|
74 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
75 |
+
config = json.load(f)
|
76 |
+
return config
|
77 |
+
|
78 |
+
def create_model(
|
79 |
+
model_name: str,
|
80 |
+
pretrained: Optional[str] = None,
|
81 |
+
precision: str = 'fp32',
|
82 |
+
device: Union[str, torch.device] = 'cpu',
|
83 |
+
jit: bool = False,
|
84 |
+
force_quick_gelu: bool = False,
|
85 |
+
force_custom_text: bool = False,
|
86 |
+
force_patch_dropout: Optional[float] = None,
|
87 |
+
force_path_dropout: Optional[float] = None,
|
88 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
89 |
+
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
90 |
+
pretrained_image: bool = False,
|
91 |
+
pretrained_hf: bool = True,
|
92 |
+
cache_dir: Optional[str] = None,
|
93 |
+
output_dict: Optional[bool] = None,
|
94 |
+
require_pretrained: bool = False,
|
95 |
+
**model_kwargs,
|
96 |
+
):
|
97 |
+
force_preprocess_cfg = force_preprocess_cfg or {}
|
98 |
+
preprocess_cfg = asdict(PreprocessCfg())
|
99 |
+
has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
|
100 |
+
if has_hf_hub_prefix:
|
101 |
+
model_id = model_name[len(HF_HUB_PREFIX):]
|
102 |
+
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
103 |
+
config = _get_hf_config(model_id, cache_dir)
|
104 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
|
105 |
+
model_cfg = config['model_cfg']
|
106 |
+
pretrained_hf = False # override, no need to load original HF text weights
|
107 |
+
else:
|
108 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
109 |
+
checkpoint_path = None
|
110 |
+
model_cfg = None
|
111 |
+
|
112 |
+
if device == "auto":
|
113 |
+
device = {'': device}
|
114 |
+
else:
|
115 |
+
device = torch.device(device)
|
116 |
+
|
117 |
+
if pretrained and pretrained.lower() == 'openai':
|
118 |
+
logger.info(f'Loading pretrained {model_name} from OpenAI.')
|
119 |
+
model = load_openai_model(
|
120 |
+
model_name,
|
121 |
+
precision=precision,
|
122 |
+
device=device,
|
123 |
+
cache_dir=cache_dir,
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
model_cfg = model_cfg or get_model_config(model_name)
|
127 |
+
if model_cfg is not None:
|
128 |
+
logger.info(f'Loaded {model_name} model config.')
|
129 |
+
else:
|
130 |
+
logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
131 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
132 |
+
|
133 |
+
if force_quick_gelu:
|
134 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
135 |
+
model_cfg["quick_gelu"] = True
|
136 |
+
|
137 |
+
if force_patch_dropout is not None:
|
138 |
+
# override the default patch dropout value
|
139 |
+
model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
|
140 |
+
|
141 |
+
if force_path_dropout is not None:
|
142 |
+
# override the default patch dropout value
|
143 |
+
model_cfg["vision_cfg"]["timm_drop_path"] = force_path_dropout
|
144 |
+
|
145 |
+
if force_image_size is not None:
|
146 |
+
# override model config's image size
|
147 |
+
model_cfg["vision_cfg"]["image_size"] = force_image_size
|
148 |
+
|
149 |
+
is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
|
150 |
+
if pretrained_image:
|
151 |
+
if is_timm_model:
|
152 |
+
# pretrained weight loading for timm models set via vision_cfg
|
153 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
154 |
+
else:
|
155 |
+
assert False, 'pretrained image towers currently only supported for timm models'
|
156 |
+
|
157 |
+
# cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
|
158 |
+
cast_dtype = get_cast_dtype(precision)
|
159 |
+
is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
160 |
+
if is_hf_model:
|
161 |
+
# load pretrained weights for HF text model IFF no CLIP weights being loaded
|
162 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
|
163 |
+
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
|
164 |
+
|
165 |
+
# model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
|
166 |
+
if custom_text:
|
167 |
+
if "multimodal_cfg" in model_cfg:
|
168 |
+
model = CoCa(**model_cfg, cast_dtype=cast_dtype)
|
169 |
+
else:
|
170 |
+
model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
|
171 |
+
else:
|
172 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
173 |
+
|
174 |
+
if precision in ("fp16", "bf16"):
|
175 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
176 |
+
# manual mixed precision that matches original OpenAI behaviour
|
177 |
+
if is_timm_model:
|
178 |
+
# FIXME this is a bit janky, create timm based model in low-precision and
|
179 |
+
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
180 |
+
# Why? The convert_weights_to_lp fn only works with native models.
|
181 |
+
if device != {'':'auto'}:
|
182 |
+
model.to(device=device, dtype=dtype)
|
183 |
+
else:
|
184 |
+
model.to(dtype=dtype)
|
185 |
+
from .transformer import LayerNormFp32
|
186 |
+
|
187 |
+
def _convert_ln(m):
|
188 |
+
if isinstance(m, LayerNormFp32):
|
189 |
+
m.weight.data = m.weight.data.to(torch.float32)
|
190 |
+
m.bias.data = m.bias.data.to(torch.float32)
|
191 |
+
model.apply(_convert_ln)
|
192 |
+
else:
|
193 |
+
model.to(device=device)
|
194 |
+
convert_weights_to_lp(model, dtype=dtype)
|
195 |
+
elif precision in ("pure_fp16", "pure_bf16"):
|
196 |
+
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
197 |
+
model.to(device=device, dtype=dtype)
|
198 |
+
# else:
|
199 |
+
# model.to(device=device)
|
200 |
+
|
201 |
+
pretrained_loaded = False
|
202 |
+
if pretrained:
|
203 |
+
checkpoint_path = ''
|
204 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
205 |
+
if pretrained_cfg:
|
206 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
207 |
+
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
|
208 |
+
elif os.path.exists(pretrained):
|
209 |
+
checkpoint_path = pretrained
|
210 |
+
|
211 |
+
# if checkpoint_path:
|
212 |
+
# logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
213 |
+
# open_clip.load_checkpoint(model, checkpoint_path)
|
214 |
+
# else:
|
215 |
+
# error_str = (
|
216 |
+
# f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
217 |
+
# f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
218 |
+
# logger.warning(error_str)
|
219 |
+
# raise RuntimeError(error_str)
|
220 |
+
# pretrained_loaded = True
|
221 |
+
elif has_hf_hub_prefix and require_pretrained:
|
222 |
+
logger.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
|
223 |
+
print(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
|
224 |
+
open_clip.load_checkpoint(model, checkpoint_path)
|
225 |
+
pretrained_loaded = True
|
226 |
+
|
227 |
+
if require_pretrained and not pretrained_loaded:
|
228 |
+
# callers of create_model_from_pretrained always expect pretrained weights
|
229 |
+
raise RuntimeError(
|
230 |
+
f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
|
231 |
+
|
232 |
+
if output_dict and hasattr(model, "output_dict"):
|
233 |
+
model.output_dict = True
|
234 |
+
|
235 |
+
if jit:
|
236 |
+
model = torch.jit.script(model)
|
237 |
+
|
238 |
+
# set image preprocessing configuration in model attributes for convenience
|
239 |
+
if getattr(model.visual, 'image_size', None) is not None:
|
240 |
+
# use image_size set on model creation (via config or force_image_size arg)
|
241 |
+
force_preprocess_cfg['size'] = model.visual.image_size
|
242 |
+
set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
|
243 |
+
|
244 |
+
return model
|
245 |
+
|
246 |
+
def create_model_and_transforms(
|
247 |
+
model_name: str,
|
248 |
+
pretrained: Optional[str] = None,
|
249 |
+
precision: str = 'fp32',
|
250 |
+
device: Union[str, torch.device] = 'cpu',
|
251 |
+
jit: bool = False,
|
252 |
+
force_quick_gelu: bool = False,
|
253 |
+
force_custom_text: bool = False,
|
254 |
+
force_patch_dropout: Optional[float] = None,
|
255 |
+
force_path_dropout: Optional[float] = None,
|
256 |
+
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
257 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
258 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
259 |
+
image_interpolation: Optional[str] = None,
|
260 |
+
image_resize_mode: Optional[str] = None, # only effective for inference
|
261 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
262 |
+
pretrained_image: bool = False,
|
263 |
+
pretrained_hf: bool = True,
|
264 |
+
cache_dir: Optional[str] = None,
|
265 |
+
output_dict: Optional[bool] = None,
|
266 |
+
**model_kwargs,
|
267 |
+
):
|
268 |
+
force_preprocess_cfg = merge_preprocess_kwargs(
|
269 |
+
{}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
|
270 |
+
|
271 |
+
return create_model(
|
272 |
+
model_name,
|
273 |
+
pretrained,
|
274 |
+
precision=precision,
|
275 |
+
device=device,
|
276 |
+
jit=jit,
|
277 |
+
force_quick_gelu=force_quick_gelu,
|
278 |
+
force_custom_text=force_custom_text,
|
279 |
+
force_patch_dropout=force_patch_dropout,
|
280 |
+
force_path_dropout=force_path_dropout,
|
281 |
+
force_image_size=force_image_size,
|
282 |
+
force_preprocess_cfg=force_preprocess_cfg,
|
283 |
+
pretrained_image=pretrained_image,
|
284 |
+
pretrained_hf=pretrained_hf,
|
285 |
+
cache_dir=cache_dir,
|
286 |
+
output_dict=output_dict,
|
287 |
+
**model_kwargs,
|
288 |
+
)
|
289 |
+
|
290 |
+
class D2CLIP_HF(nn.Module):
|
291 |
+
def __init__(self, config, **kwargs):
|
292 |
+
super().__init__()
|
293 |
+
self.model_name = config['vision_backbone']
|
294 |
+
|
295 |
+
require_pretrained = kwargs.get('require_pretrained', False)
|
296 |
+
if self.model_name == "convnextxxlarge":
|
297 |
+
clip_model = create_model_and_transforms('hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg', require_pretrained=require_pretrained)
|
298 |
+
elif self.model_name == "convnextlarge":
|
299 |
+
clip_model = create_model_and_transforms('hf-hub:laion/CLIP-convnext_large-laion2B-s34B-b82K-augreg', require_pretrained=require_pretrained)
|
300 |
+
|
301 |
+
self.clip_vision_model = clip_model.visual
|
302 |
+
|
303 |
+
model_name = self.model_name.lower()
|
304 |
+
assert 'convnext' in model_name, f"Only convnext backbone is supported for Magma model, but got {model_name}"
|
305 |
+
self.model_type = 'convnext'
|
306 |
+
if 'xxlarge' in model_name:
|
307 |
+
self.output_channels = [384, 384, 768, 1536, 3072]
|
308 |
+
elif 'large' in model_name:
|
309 |
+
self.output_channels = [192, 192, 384, 768, 1536]
|
310 |
+
elif 'base' in model_name:
|
311 |
+
self.output_channels = [128, 128, 256, 512, 1024]
|
312 |
+
|
313 |
+
self._out_feature_strides = {
|
314 |
+
"res2": 4,
|
315 |
+
"res3": 8,
|
316 |
+
"res4": 16,
|
317 |
+
"res5": 32,
|
318 |
+
}
|
319 |
+
self._out_feature_channels = {
|
320 |
+
"res2": self.output_channels[1],
|
321 |
+
"res3": self.output_channels[2],
|
322 |
+
"res4": self.output_channels[3],
|
323 |
+
"res5": self.output_channels[4],
|
324 |
+
}
|
325 |
+
|
326 |
+
def extract_features_convnext(self, x, gradient_checkpointing=True):
|
327 |
+
out = {}
|
328 |
+
x = self.clip_vision_model.trunk.stem(x)
|
329 |
+
if gradient_checkpointing:
|
330 |
+
x = checkpoint.checkpoint(self.clip_vision_model.trunk.stages, x)
|
331 |
+
else:
|
332 |
+
x = self.clip_vision_model.trunk.stages(x)
|
333 |
+
out['clip_vis_dense'] = x
|
334 |
+
return out
|
335 |
+
|
336 |
+
|
337 |
+
def forward(self, x, gradient_checkpointing=True):
|
338 |
+
"""
|
339 |
+
Args:
|
340 |
+
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
|
341 |
+
Returns:
|
342 |
+
dict[str->Tensor]: names and the corresponding features
|
343 |
+
"""
|
344 |
+
return self.extract_features_convnext(x, gradient_checkpointing=gradient_checkpointing)
|
345 |
+
|
346 |
+
@property
|
347 |
+
def size_divisibility(self):
|
348 |
+
return 32
|
349 |
+
|
350 |
+
class MagmaImageTower(D2CLIP_HF):
|
351 |
+
r"""
|
352 |
+
Constructs a Magma image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
|
353 |
+
for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512)
|
354 |
+
|
355 |
+
Args:
|
356 |
+
config (dict): Configuration dictionary containing the keys for the image processor.
|
357 |
+
"""
|
358 |
+
|
359 |
+
def __init__(
|
360 |
+
self,
|
361 |
+
config,
|
362 |
+
**kwargs
|
363 |
+
) -> None:
|
364 |
+
super().__init__(config, **kwargs)
|
365 |
+
|
366 |
+
@property
|
367 |
+
def hidden_size(self):
|
368 |
+
return self.output_channels[-1]
|
369 |
+
|
370 |
+
|
371 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
372 |
+
r"""
|
373 |
+
Args:
|
374 |
+
x (torch.Tensor): A tensor of shape (N, C, H, W) representing an image.
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
torch.Tensor: A tensor of shape (N, C, H, W) representing the processed image.
|
378 |
+
"""
|
379 |
+
return super().forward(x)
|