mins
initial commit
b443c25
raw
history blame
3.65 kB
import torch
import torch.nn as nn
from transformers import CLIPImageProcessor
from .vision_models.convnext import convnext_xxlarge
from torch.utils.checkpoint import checkpoint
cfg={
"crop_size": 256,
"do_center_crop": True,
"do_normalize": True,
"do_resize": True,
"feature_extractor_type": "CLIPFeatureExtractor",
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"resample": 3,
"size": 256
}
class ConvNextVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.freeze_vision=args.freeze_vision
self.input_image_size=args.input_image_size
self.vision_tower_name = vision_tower
self.select_layer = -1 # hardcode
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.load_model()
def load_model(self):
self.image_processor = CLIPImageProcessor(**cfg)
if 'xxlarge' in self.vision_tower_name:
self.vision_tower = convnext_xxlarge(self.vision_tower_name)
setattr(self.vision_tower, 'hidden_size', 3072)
else:
raise NotImplementedError
if self.freeze_vision:
self.vision_tower.requires_grad_(False)
# Hardcode
for s in self.vision_tower.stages:
s.grad_checkpointing = True
if self.input_image_size is not None:
self.image_processor.size=self.input_image_size
self.image_processor.crop_size={
'height':self.input_image_size,
'width': self.input_image_size
}
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs[self.select_layer]
return image_features
def forward_features(self, x):
x = self.vision_tower.stem(x)
image_forward_out=[]
for blk in self.vision_tower.stages:
x = blk(x)
b,c,h,w=x.shape
image_forward_out.append(x.view(b,c,-1).transpose(1,2))
return image_forward_out
def forward(self, images):
if self.freeze_vision:
with torch.no_grad():
image_features = self._forward_images(images)
else:
image_features = self._forward_images(images)
return image_features
def _forward_images(self, images):
image_forward_outs = self.forward_features(images.to(device=self.device, dtype=self.dtype))
image_features = self.feature_select(image_forward_outs)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def num_attention_heads(self):
# as constant
return 16
@property
def num_layers(self):
# as constant
return 4
@property
def hidden_size(self):
return self.vision_tower.hidden_size
@property
def num_patches(self):
return (cfg['image_size'] // self.patch_embed.patch_size[0]) ** 2