jw2yang commited on
Commit
3fe01ed
·
1 Parent(s): 18e9ab4
Files changed (1) hide show
  1. image_tower_magma.py +379 -0
image_tower_magma.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Image processor class for Magma."""
17
+
18
+ from typing import List, Optional, Union
19
+ import logging
20
+
21
+ # Configure root logger
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+ import numpy as np
25
+ import torchvision
26
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
27
+ from transformers.image_transforms import (
28
+ convert_to_rgb,
29
+ )
30
+ from transformers.image_utils import (
31
+ OPENAI_CLIP_MEAN,
32
+ OPENAI_CLIP_STD,
33
+ ImageInput,
34
+ make_list_of_images,
35
+ valid_images,
36
+ )
37
+
38
+ from transformers.utils import TensorType, is_vision_available, logging
39
+ logging.set_verbosity_info()
40
+ logger = logging.get_logger(__name__)
41
+
42
+
43
+ if is_vision_available():
44
+ from PIL import Image
45
+
46
+ import torchvision
47
+
48
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
49
+ # All rights reserved.
50
+
51
+ # This source code is licensed under the license found in the
52
+ # LICENSE file in the root directory of this source tree.
53
+ import json
54
+ import torch
55
+ import torch.nn as nn
56
+ import torch.nn.functional as F
57
+
58
+ import open_clip
59
+ from open_clip.transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
60
+ from open_clip.pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
61
+ list_pretrained_tags_by_model, download_pretrained_from_hf
62
+ from open_clip.model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
63
+ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
64
+ from pathlib import Path
65
+ from typing import Optional, Tuple, Type
66
+ from functools import partial
67
+ import torch.utils.checkpoint as checkpoint
68
+ from typing import Any, Dict, Optional, Tuple, Union
69
+ from dataclasses import asdict
70
+ HF_HUB_PREFIX = 'hf-hub:'
71
+
72
+ def _get_hf_config(model_id, cache_dir=None):
73
+ config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir)
74
+ with open(config_path, 'r', encoding='utf-8') as f:
75
+ config = json.load(f)
76
+ return config
77
+
78
+ def create_model(
79
+ model_name: str,
80
+ pretrained: Optional[str] = None,
81
+ precision: str = 'fp32',
82
+ device: Union[str, torch.device] = 'cpu',
83
+ jit: bool = False,
84
+ force_quick_gelu: bool = False,
85
+ force_custom_text: bool = False,
86
+ force_patch_dropout: Optional[float] = None,
87
+ force_path_dropout: Optional[float] = None,
88
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
89
+ force_preprocess_cfg: Optional[Dict[str, Any]] = None,
90
+ pretrained_image: bool = False,
91
+ pretrained_hf: bool = True,
92
+ cache_dir: Optional[str] = None,
93
+ output_dict: Optional[bool] = None,
94
+ require_pretrained: bool = False,
95
+ **model_kwargs,
96
+ ):
97
+ force_preprocess_cfg = force_preprocess_cfg or {}
98
+ preprocess_cfg = asdict(PreprocessCfg())
99
+ has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)
100
+ if has_hf_hub_prefix:
101
+ model_id = model_name[len(HF_HUB_PREFIX):]
102
+ checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
103
+ config = _get_hf_config(model_id, cache_dir)
104
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])
105
+ model_cfg = config['model_cfg']
106
+ pretrained_hf = False # override, no need to load original HF text weights
107
+ else:
108
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
109
+ checkpoint_path = None
110
+ model_cfg = None
111
+
112
+ if device == "auto":
113
+ device = {'': device}
114
+ else:
115
+ device = torch.device(device)
116
+
117
+ if pretrained and pretrained.lower() == 'openai':
118
+ logger.info(f'Loading pretrained {model_name} from OpenAI.')
119
+ model = load_openai_model(
120
+ model_name,
121
+ precision=precision,
122
+ device=device,
123
+ cache_dir=cache_dir,
124
+ )
125
+ else:
126
+ model_cfg = model_cfg or get_model_config(model_name)
127
+ if model_cfg is not None:
128
+ logger.info(f'Loaded {model_name} model config.')
129
+ else:
130
+ logger.error(f'Model config for {model_name} not found; available models {list_models()}.')
131
+ raise RuntimeError(f'Model config for {model_name} not found.')
132
+
133
+ if force_quick_gelu:
134
+ # override for use of QuickGELU on non-OpenAI transformer models
135
+ model_cfg["quick_gelu"] = True
136
+
137
+ if force_patch_dropout is not None:
138
+ # override the default patch dropout value
139
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
140
+
141
+ if force_path_dropout is not None:
142
+ # override the default patch dropout value
143
+ model_cfg["vision_cfg"]["timm_drop_path"] = force_path_dropout
144
+
145
+ if force_image_size is not None:
146
+ # override model config's image size
147
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
148
+
149
+ is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})
150
+ if pretrained_image:
151
+ if is_timm_model:
152
+ # pretrained weight loading for timm models set via vision_cfg
153
+ model_cfg['vision_cfg']['timm_model_pretrained'] = True
154
+ else:
155
+ assert False, 'pretrained image towers currently only supported for timm models'
156
+
157
+ # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes
158
+ cast_dtype = get_cast_dtype(precision)
159
+ is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
160
+ if is_hf_model:
161
+ # load pretrained weights for HF text model IFF no CLIP weights being loaded
162
+ model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained
163
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model
164
+
165
+ # model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)
166
+ if custom_text:
167
+ if "multimodal_cfg" in model_cfg:
168
+ model = CoCa(**model_cfg, cast_dtype=cast_dtype)
169
+ else:
170
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
171
+ else:
172
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
173
+
174
+ if precision in ("fp16", "bf16"):
175
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
176
+ # manual mixed precision that matches original OpenAI behaviour
177
+ if is_timm_model:
178
+ # FIXME this is a bit janky, create timm based model in low-precision and
179
+ # then cast only LayerNormFp32 instances back to float32 so they don't break.
180
+ # Why? The convert_weights_to_lp fn only works with native models.
181
+ if device != {'':'auto'}:
182
+ model.to(device=device, dtype=dtype)
183
+ else:
184
+ model.to(dtype=dtype)
185
+ from .transformer import LayerNormFp32
186
+
187
+ def _convert_ln(m):
188
+ if isinstance(m, LayerNormFp32):
189
+ m.weight.data = m.weight.data.to(torch.float32)
190
+ m.bias.data = m.bias.data.to(torch.float32)
191
+ model.apply(_convert_ln)
192
+ else:
193
+ model.to(device=device)
194
+ convert_weights_to_lp(model, dtype=dtype)
195
+ elif precision in ("pure_fp16", "pure_bf16"):
196
+ dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
197
+ model.to(device=device, dtype=dtype)
198
+ # else:
199
+ # model.to(device=device)
200
+
201
+ pretrained_loaded = False
202
+ if pretrained:
203
+ checkpoint_path = ''
204
+ pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
205
+ if pretrained_cfg:
206
+ checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
207
+ preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)
208
+ elif os.path.exists(pretrained):
209
+ checkpoint_path = pretrained
210
+
211
+ # if checkpoint_path:
212
+ # logger.info(f'Loading pretrained {model_name} weights ({pretrained}).')
213
+ # open_clip.load_checkpoint(model, checkpoint_path)
214
+ # else:
215
+ # error_str = (
216
+ # f'Pretrained weights ({pretrained}) not found for model {model_name}.'
217
+ # f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
218
+ # logger.warning(error_str)
219
+ # raise RuntimeError(error_str)
220
+ # pretrained_loaded = True
221
+ elif has_hf_hub_prefix and require_pretrained:
222
+ logger.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
223
+ print(f'Loading pretrained {model_name} weights ({checkpoint_path}).')
224
+ open_clip.load_checkpoint(model, checkpoint_path)
225
+ pretrained_loaded = True
226
+
227
+ if require_pretrained and not pretrained_loaded:
228
+ # callers of create_model_from_pretrained always expect pretrained weights
229
+ raise RuntimeError(
230
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
231
+
232
+ if output_dict and hasattr(model, "output_dict"):
233
+ model.output_dict = True
234
+
235
+ if jit:
236
+ model = torch.jit.script(model)
237
+
238
+ # set image preprocessing configuration in model attributes for convenience
239
+ if getattr(model.visual, 'image_size', None) is not None:
240
+ # use image_size set on model creation (via config or force_image_size arg)
241
+ force_preprocess_cfg['size'] = model.visual.image_size
242
+ set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))
243
+
244
+ return model
245
+
246
+ def create_model_and_transforms(
247
+ model_name: str,
248
+ pretrained: Optional[str] = None,
249
+ precision: str = 'fp32',
250
+ device: Union[str, torch.device] = 'cpu',
251
+ jit: bool = False,
252
+ force_quick_gelu: bool = False,
253
+ force_custom_text: bool = False,
254
+ force_patch_dropout: Optional[float] = None,
255
+ force_path_dropout: Optional[float] = None,
256
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
257
+ image_mean: Optional[Tuple[float, ...]] = None,
258
+ image_std: Optional[Tuple[float, ...]] = None,
259
+ image_interpolation: Optional[str] = None,
260
+ image_resize_mode: Optional[str] = None, # only effective for inference
261
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
262
+ pretrained_image: bool = False,
263
+ pretrained_hf: bool = True,
264
+ cache_dir: Optional[str] = None,
265
+ output_dict: Optional[bool] = None,
266
+ **model_kwargs,
267
+ ):
268
+ force_preprocess_cfg = merge_preprocess_kwargs(
269
+ {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode)
270
+
271
+ return create_model(
272
+ model_name,
273
+ pretrained,
274
+ precision=precision,
275
+ device=device,
276
+ jit=jit,
277
+ force_quick_gelu=force_quick_gelu,
278
+ force_custom_text=force_custom_text,
279
+ force_patch_dropout=force_patch_dropout,
280
+ force_path_dropout=force_path_dropout,
281
+ force_image_size=force_image_size,
282
+ force_preprocess_cfg=force_preprocess_cfg,
283
+ pretrained_image=pretrained_image,
284
+ pretrained_hf=pretrained_hf,
285
+ cache_dir=cache_dir,
286
+ output_dict=output_dict,
287
+ **model_kwargs,
288
+ )
289
+
290
+ class D2CLIP_HF(nn.Module):
291
+ def __init__(self, config, **kwargs):
292
+ super().__init__()
293
+ self.model_name = config['vision_backbone']
294
+
295
+ require_pretrained = kwargs.get('require_pretrained', False)
296
+ if self.model_name == "convnextxxlarge":
297
+ clip_model = create_model_and_transforms('hf-hub:laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg', require_pretrained=require_pretrained)
298
+ elif self.model_name == "convnextlarge":
299
+ clip_model = create_model_and_transforms('hf-hub:laion/CLIP-convnext_large-laion2B-s34B-b82K-augreg', require_pretrained=require_pretrained)
300
+
301
+ self.clip_vision_model = clip_model.visual
302
+
303
+ model_name = self.model_name.lower()
304
+ assert 'convnext' in model_name, f"Only convnext backbone is supported for Magma model, but got {model_name}"
305
+ self.model_type = 'convnext'
306
+ if 'xxlarge' in model_name:
307
+ self.output_channels = [384, 384, 768, 1536, 3072]
308
+ elif 'large' in model_name:
309
+ self.output_channels = [192, 192, 384, 768, 1536]
310
+ elif 'base' in model_name:
311
+ self.output_channels = [128, 128, 256, 512, 1024]
312
+
313
+ self._out_feature_strides = {
314
+ "res2": 4,
315
+ "res3": 8,
316
+ "res4": 16,
317
+ "res5": 32,
318
+ }
319
+ self._out_feature_channels = {
320
+ "res2": self.output_channels[1],
321
+ "res3": self.output_channels[2],
322
+ "res4": self.output_channels[3],
323
+ "res5": self.output_channels[4],
324
+ }
325
+
326
+ def extract_features_convnext(self, x, gradient_checkpointing=True):
327
+ out = {}
328
+ x = self.clip_vision_model.trunk.stem(x)
329
+ if gradient_checkpointing:
330
+ x = checkpoint.checkpoint(self.clip_vision_model.trunk.stages, x)
331
+ else:
332
+ x = self.clip_vision_model.trunk.stages(x)
333
+ out['clip_vis_dense'] = x
334
+ return out
335
+
336
+
337
+ def forward(self, x, gradient_checkpointing=True):
338
+ """
339
+ Args:
340
+ x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
341
+ Returns:
342
+ dict[str->Tensor]: names and the corresponding features
343
+ """
344
+ return self.extract_features_convnext(x, gradient_checkpointing=gradient_checkpointing)
345
+
346
+ @property
347
+ def size_divisibility(self):
348
+ return 32
349
+
350
+ class MagmaImageTower(D2CLIP_HF):
351
+ r"""
352
+ Constructs a Magma image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques
353
+ for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512)
354
+
355
+ Args:
356
+ config (dict): Configuration dictionary containing the keys for the image processor.
357
+ """
358
+
359
+ def __init__(
360
+ self,
361
+ config,
362
+ **kwargs
363
+ ) -> None:
364
+ super().__init__(config, **kwargs)
365
+
366
+ @property
367
+ def hidden_size(self):
368
+ return self.output_channels[-1]
369
+
370
+
371
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
372
+ r"""
373
+ Args:
374
+ x (torch.Tensor): A tensor of shape (N, C, H, W) representing an image.
375
+
376
+ Returns:
377
+ torch.Tensor: A tensor of shape (N, C, H, W) representing the processed image.
378
+ """
379
+ return super().forward(x)