mins
initial commit
b443c25
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor, nn
import transformers
from transformers import SamProcessor
from transformers import SamModel, SamVisionConfig, SamVisionConfig
from transformers import SamImageProcessor
from PIL import Image
# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
class SamLayerNorm(nn.Module):
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(f"Unsupported data format: {self.data_format}")
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ShortSamVisionNeck(nn.Module):
def __init__(self, config: SamVisionConfig):
super().__init__()
self.config = config
self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
def forward(self, hidden_states):
hidden_states = hidden_states.permute(0, 3, 1, 2)
hidden_states = self.conv1(hidden_states)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = hidden_states.permute(0,2,3,1)
return hidden_states
class SAMVisionTower(nn.Module):
def __init__(self, vision_tower, args):
super().__init__()
self.args = args
self.is_loaded = False
self.vision_tower_name = vision_tower
self.input_image_size = args.input_image_size
self.pixel_shuffle = getattr(args, 'add_pixel_shuffle', False)
self.freeze = args.freeze_vision
self.load_model()
def load_model(self):
if self.is_loaded:
return
self.image_processor= SamProcessor.from_pretrained("facebook/sam-vit-large")
sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder
sam_model.neck = ShortSamVisionNeck(sam_model.config)
self.image_processor.preprocess = self.image_processor.__call__
self.image_processor.image_mean = [0.485,0.456,0.406]
self.vision_tower = sam_model
if self.freeze:
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.vision_tower(image.to(device=self.device).unsqueeze(0))
image_features.append(image_feature)
else:
image_features = self.vision_tower(images.to(device=self.device)).last_hidden_state.flatten(start_dim=1, end_dim=2).to(device=self.device)
if self.pixel_shuffle:
b, n, c = image_features.shape
h = w = int(n ** 0.5)
image_features = image_features.transpose(1,2).reshape(b, c, h, w)
image_features = nn.functional.pixel_unshuffle(image_features, 2)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.vision_tower.parameters()).dtype
@property
def device(self):
return next(self.vision_tower.parameters()).device
@property
def config(self):
# if self.is_loaded:
# return self.vision_tower.config
# else:
# return self.cfg_only
config_info = SamVisionConfig()
return SamVisionConfig()
@property
def hidden_size(self):
#return self.config.hidden_size
if self.pixel_shuffle:
hidden_size = 256 * 4
else:
hidden_size = 256
return hidden_size
@property
def num_patches(self):
# return (self.config.image_size // self.config.patch_size) ** 2
return self.config.num_patches
#main
if __name__ == "__main__":
sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder
#sam_model = SamModel.from_pretrained("facebook/sam-vit-large")
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
for name, param in sam_model.named_parameters():
param.requires_grad = False
#raw_image = torch.rand(1, 3, 224, 224).to('cuda')
raw_image = Image.open('/lustre/fsw/portfolios/llmservice/users/fuxiaol/image/me.jpg').convert("RGB")
inputs = sam_processor(raw_image, return_tensors="pt")
#print(inputs)
#print(inputs['pixel_values'])
out = sam_model(inputs['pixel_values'])
print(out[0].size())
#vision_config = SamVisionConfig()
#print('=============')
#print(vision_config.hidden_size)
#print('=============')
#print(out)
#print(out)
#print(out)
#config_vision
#vision_config = SamVisionConfig()
#print(sam_model.layers)
#print(vision_config)