Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from transformers import CLIPImageProcessor | |
from .vision_models.convnext import convnext_xxlarge | |
from torch.utils.checkpoint import checkpoint | |
cfg={ | |
"crop_size": 256, | |
"do_center_crop": True, | |
"do_normalize": True, | |
"do_resize": True, | |
"feature_extractor_type": "CLIPFeatureExtractor", | |
"image_mean": [ | |
0.48145466, | |
0.4578275, | |
0.40821073 | |
], | |
"image_std": [ | |
0.26862954, | |
0.26130258, | |
0.27577711 | |
], | |
"resample": 3, | |
"size": 256 | |
} | |
class ConvNextVisionTower(nn.Module): | |
def __init__(self, vision_tower, args, delay_load=False): | |
super().__init__() | |
self.is_loaded = False | |
self.freeze_vision=args.freeze_vision | |
self.input_image_size=args.input_image_size | |
self.vision_tower_name = vision_tower | |
self.select_layer = -1 # hardcode | |
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
self.load_model() | |
def load_model(self): | |
self.image_processor = CLIPImageProcessor(**cfg) | |
if 'xxlarge' in self.vision_tower_name: | |
self.vision_tower = convnext_xxlarge(self.vision_tower_name) | |
setattr(self.vision_tower, 'hidden_size', 3072) | |
else: | |
raise NotImplementedError | |
if self.freeze_vision: | |
self.vision_tower.requires_grad_(False) | |
# Hardcode | |
for s in self.vision_tower.stages: | |
s.grad_checkpointing = True | |
if self.input_image_size is not None: | |
self.image_processor.size=self.input_image_size | |
self.image_processor.crop_size={ | |
'height':self.input_image_size, | |
'width': self.input_image_size | |
} | |
self.is_loaded = True | |
def feature_select(self, image_forward_outs): | |
image_features = image_forward_outs[self.select_layer] | |
return image_features | |
def forward_features(self, x): | |
x = self.vision_tower.stem(x) | |
image_forward_out=[] | |
for blk in self.vision_tower.stages: | |
x = blk(x) | |
b,c,h,w=x.shape | |
image_forward_out.append(x.view(b,c,-1).transpose(1,2)) | |
return image_forward_out | |
def forward(self, images): | |
if self.freeze_vision: | |
with torch.no_grad(): | |
image_features = self._forward_images(images) | |
else: | |
image_features = self._forward_images(images) | |
return image_features | |
def _forward_images(self, images): | |
image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype)) | |
image_features = self.feature_select(image_forward_outs) | |
return image_features | |
def dummy_feature(self): | |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
def dtype(self): | |
return next(self.vision_tower.parameters()).dtype | |
def device(self): | |
return next(self.vision_tower.parameters()).device | |
def config(self): | |
assert NotImplementedError | |
pass | |
def num_attention_heads(self): | |
# as constant | |
return 16 | |
def num_layers(self): | |
# as constant | |
return 4 | |
def hidden_size(self): | |
return self.vision_tower.hidden_size | |
def num_patches(self): | |
return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2 | |