Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from torch import Tensor, nn | |
import transformers | |
from transformers import SamProcessor | |
from transformers import SamModel, SamVisionConfig, SamVisionConfig | |
from transformers import SamImageProcessor | |
from PIL import Image | |
# Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam | |
class SamLayerNorm(nn.Module): | |
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. | |
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, | |
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). | |
""" | |
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.data_format = data_format | |
if self.data_format not in ["channels_last", "channels_first"]: | |
raise NotImplementedError(f"Unsupported data format: {self.data_format}") | |
self.normalized_shape = (normalized_shape,) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.data_format == "channels_last": | |
x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) | |
elif self.data_format == "channels_first": | |
input_dtype = x.dtype | |
x = x.float() | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = x.to(dtype=input_dtype) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
class ShortSamVisionNeck(nn.Module): | |
def __init__(self, config: SamVisionConfig): | |
super().__init__() | |
self.config = config | |
self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False) | |
self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first") | |
def forward(self, hidden_states): | |
hidden_states = hidden_states.permute(0, 3, 1, 2) | |
hidden_states = self.conv1(hidden_states) | |
hidden_states = self.layer_norm1(hidden_states) | |
hidden_states = hidden_states.permute(0,2,3,1) | |
return hidden_states | |
class SAMVisionTower(nn.Module): | |
def __init__(self, vision_tower, args): | |
super().__init__() | |
self.args = args | |
self.is_loaded = False | |
self.vision_tower_name = vision_tower | |
self.input_image_size = args.input_image_size | |
self.pixel_shuffle = getattr(args, 'add_pixel_shuffle', False) | |
self.freeze = args.freeze_vision | |
self.load_model() | |
def load_model(self): | |
if self.is_loaded: | |
return | |
self.image_processor= SamProcessor.from_pretrained("facebook/sam-vit-large") | |
sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder | |
sam_model.neck = ShortSamVisionNeck(sam_model.config) | |
self.image_processor.preprocess = self.image_processor.__call__ | |
self.image_processor.image_mean = [0.485,0.456,0.406] | |
self.vision_tower = sam_model | |
if self.freeze: | |
self.vision_tower.requires_grad_(False) | |
self.is_loaded = True | |
def forward(self, images): | |
if type(images) is list: | |
image_features = [] | |
for image in images: | |
image_feature = self.vision_tower(image.to(device=self.device).unsqueeze(0)) | |
image_features.append(image_feature) | |
else: | |
image_features = self.vision_tower(images.to(device=self.device)).last_hidden_state.flatten(start_dim=1, end_dim=2).to(device=self.device) | |
if self.pixel_shuffle: | |
b, n, c = image_features.shape | |
h = w = int(n ** 0.5) | |
image_features = image_features.transpose(1,2).reshape(b, c, h, w) | |
image_features = nn.functional.pixel_unshuffle(image_features, 2) | |
return image_features | |
def dummy_feature(self): | |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
def dtype(self): | |
return next(self.vision_tower.parameters()).dtype | |
def device(self): | |
return next(self.vision_tower.parameters()).device | |
def config(self): | |
# if self.is_loaded: | |
# return self.vision_tower.config | |
# else: | |
# return self.cfg_only | |
config_info = SamVisionConfig() | |
return SamVisionConfig() | |
def hidden_size(self): | |
#return self.config.hidden_size | |
if self.pixel_shuffle: | |
hidden_size = 256 * 4 | |
else: | |
hidden_size = 256 | |
return hidden_size | |
def num_patches(self): | |
# return (self.config.image_size // self.config.patch_size) ** 2 | |
return self.config.num_patches | |
#main | |
if __name__ == "__main__": | |
sam_model = SamModel.from_pretrained("facebook/sam-vit-large").vision_encoder | |
#sam_model = SamModel.from_pretrained("facebook/sam-vit-large") | |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-large") | |
for name, param in sam_model.named_parameters(): | |
param.requires_grad = False | |
#raw_image = torch.rand(1, 3, 224, 224).to('cuda') | |
raw_image = Image.open('/lustre/fsw/portfolios/llmservice/users/fuxiaol/image/me.jpg').convert("RGB") | |
inputs = sam_processor(raw_image, return_tensors="pt") | |
#print(inputs) | |
#print(inputs['pixel_values']) | |
out = sam_model(inputs['pixel_values']) | |
print(out[0].size()) | |
#vision_config = SamVisionConfig() | |
#print('=============') | |
#print(vision_config.hidden_size) | |
#print('=============') | |
#print(out) | |
#print(out) | |
#print(out) | |
#config_vision | |
#vision_config = SamVisionConfig() | |
#print(sam_model.layers) | |
#print(vision_config) | |