Spaces:
Runtime error
Runtime error
import re | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
from transformers import AutoModel, CLIPImageProcessor | |
from PIL import Image | |
import requests | |
import torch.nn.functional as F | |
from transformers import AutoProcessor, Pix2StructVisionModel, Pix2StructProcessor, Pix2StructForConditionalGeneration | |
cfg={ | |
"crop_size": 256, | |
"do_center_crop": True, | |
"do_normalize": True, | |
"do_resize": True, | |
"feature_extractor_type": "CLIPFeatureExtractor", | |
"image_mean": [ | |
0.48145466, | |
0.4578275, | |
0.40821073 | |
], | |
"image_std": [ | |
0.26862954, | |
0.26130258, | |
0.27577711 | |
], | |
"resample": 3, | |
"size": 256 | |
} | |
''' | |
Pixel2Struct-Large Model (pretrained version) | |
''' | |
class Pix2StructLargeVisionTower(nn.Module): | |
def __init__(self, vision_tower, args, delay_load=False): | |
super().__init__() | |
self.is_loaded = False | |
self.vision_tower_name = vision_tower | |
self.do_resize = args.do_resize | |
self.de_normalize = args.de_normalize # de-normalize the input image and perform preprocessing with pix2struct processor | |
self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect | |
self.input_image_size = args.input_image_size | |
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
self.freeze_vision = args.freeze_vision | |
self.args = args | |
if not self.is_loaded: | |
self.load_model() | |
def load_model(self): | |
if self.is_loaded: | |
return | |
whole_model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-large") | |
self.vision_tower = whole_model.encoder | |
self.pix2struct_processor = AutoProcessor.from_pretrained("google/pix2struct-large") | |
self.pix2struct_processor.image_processor.is_vqa = False | |
self.image_processor = CLIPImageProcessor(**cfg) | |
if self.input_image_size is not None: | |
self.image_processor.size=self.input_image_size | |
self.image_processor.crop_size={ | |
'height':self.input_image_size, | |
'width': self.input_image_size | |
} | |
if self.freeze_vision: | |
self.vision_tower.requires_grad_(False) | |
self.image_mean = torch.tensor(self.image_processor.image_mean).view(1, 3, 1, 1) | |
self.image_std = torch.tensor(self.image_processor.image_std).view(1, 3, 1, 1) | |
self.is_loaded = True | |
def feature_select(self, image_forward_outs): | |
image_features = image_forward_outs.hidden_states[self.select_layer] # [bs, n, c], cls at idx=0 | |
if self.select_feature == 'patch': | |
image_features = image_features[:, 1:] | |
elif self.select_feature == 'cls_patch': | |
image_features = image_features | |
else: | |
raise ValueError(f'Unexpected select feature: {self.select_feature}') | |
return image_features | |
# @torch.no_grad() | |
def forward(self, images): | |
if self.de_normalize: | |
mean = self.image_mean.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device) | |
std = self.image_std.clone().view(1, 3, 1, 1).to(dtype=images.dtype, device=images.device) | |
x = (images * std + mean) * 255.0 | |
x = self.pix2struct_processor(images=x.float(), return_tensors="pt") | |
image_features = self.vision_tower(**(x.to(device=self.device, dtype=self.dtype))).last_hidden_state | |
bs, n, c = image_features.shape | |
image_features = image_features[:, :2025, :] # HARD CODE | |
if self.do_resize: | |
image_features = image_features.transpose(1,2).reshape(bs, c, 45, 45) # HARD CODE | |
image_features = F.interpolate(image_features.float(), size=(32, 32), mode='bilinear', align_corners=True).to(dtype=image_features.dtype) # HARD CODE | |
return image_features | |
else: | |
return image_features | |
def dummy_feature(self): | |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
def dtype(self): | |
return next(self.vision_tower.parameters()).dtype | |
def device(self): | |
return next(self.vision_tower.parameters()).device | |
def config(self): | |
return self.vision_tower.config | |
def hidden_size(self): | |
#return self.config.hidden_size | |
hidden_dim = 1536 | |
return hidden_dim | |
def num_patches(self): | |
# return (self.config.image_size // self.config.patch_size) ** 2 | |
return self.config['num_patches'] | |
#main | |
if __name__ == "__main__": | |
''' | |
print('hello') | |
from PIL import Image | |
import requests | |
from transformers import AutoProcessor, Pix2StructVisionModel | |
model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base") | |
processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") | |
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/me.jpg") | |
for name, param in model.named_parameters(): | |
param.requires_grad = False | |
#inputs = processor(images=image, return_tensors="pt") | |
image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5') | |
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values | |
pixel_values = torch.cat([pixel_values, pixel_values], dim=0) | |
#inputs = pixel_values.to(torch.bfloat16) | |
print('pixel_values:', pixel_values.size()) | |
inputs = processor(images=pixel_values, max_patches=1024, return_tensors='pt')['flattened_patches'] | |
print(inputs.size()) | |
print(inputs.size()) | |
outputs = model(inputs) | |
print(outputs.last_hidden_state.size()) | |
''' | |
cfg={ | |
"crop_size": 1024, | |
"do_center_crop": True, | |
"do_normalize": True, | |
"do_resize": True, | |
"feature_extractor_type": "CLIPFeatureExtractor", | |
"image_mean": [ | |
0.48145466, | |
0.4578275, | |
0.40821073 | |
], | |
"image_std": [ | |
0.26862954, | |
0.26130258, | |
0.27577711 | |
], | |
"resample": 3, | |
"size": 1024 | |
} | |
from PIL import Image | |
import requests | |
from transformers import AutoProcessor, Pix2StructForConditionalGeneration | |
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | |
import torchvision.transforms as T | |
processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-large") | |
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-large") | |
#url = "https://www.ilankelman.org/stopsigns/australia.jpg" | |
#image = Image.open(requests.get(url, stream=True).raw) | |
image = Image.open("/lustre/fsw/portfolios/llmservice/users/fuxiaol/sample2.jpg") | |
image_processor= CLIPImageProcessor(**cfg) | |
pixel_values = image_processor(images=image, return_tensors='pt').pixel_values | |
print(pixel_values.size()) | |
mean = [0.48145466, 0.4578275, 0.40821073] | |
std = [0.26862954, 0.26130258, 0.27577711] | |
mean = torch.tensor(mean).view(1, 3, 1, 1) | |
std = torch.tensor(std).view(1, 3, 1, 1) | |
pixel_values = pixel_values * std + mean | |
print(pixel_values.size()) | |
#pixel_values.save('pix2image.jpg') | |
transform = T.ToPILImage() | |
img = transform(pixel_values.squeeze(0)) | |
img.save('pix2image.jpg') | |
inputs = processor(images=pixel_values, max_patches=1024,return_tensors="pt")['flattened_patches'] | |
# autoregressive generation | |
generated_ids = model.generate(inputs, max_new_tokens=50) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
print(generated_text) | |
#A stop sign is on a street corner. | |
#A stop sign is on a street corner. | |
''' | |
from PIL import Image | |
import requests | |
from transformers import AutoProcessor, CLIPModel | |
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig | |
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336") | |
model = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336') | |
url = "http://images.cocodataset.org/val2017/000000039769.jpg" | |
image = Image.open(requests.get(url, stream=True).raw) | |
print(image) | |
inputs = processor(images=image, return_tensors="pt") | |
#image_features = model.get_image_features(**inputs) | |
outputs = model(**inputs,output_hidden_states=True) | |
print(outputs.hidden_states[-1].size()) | |
print(outputs.hidden_states[-2].size()) | |
print(outputs.hidden_states[-3].size()) | |
''' | |
#sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
#sequence = processor.post_process_generation(sequence, fix_markdown=False) | |
# note: we're using repr here such for the sake of printing the \n characters, feel free to just print the sequence | |
#print(repr(sequence)) | |