Spaces:
Configuration error
Configuration error
import argparse | |
import sys | |
import torch | |
import torch.nn as nn | |
from PIL import Image | |
from transformers import ( | |
AutoTokenizer, | |
BitsAndBytesConfig, | |
LlamaForCausalLM, | |
SiglipImageProcessor, | |
SiglipVisionModel, | |
) | |
from transformers import TextStreamer | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def tokenizer_image_token(prompt, tokenizer, image_token_index=-200): | |
prompt_chunks = prompt.split("<image>") | |
tokenized_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks] | |
input_ids = tokenized_chunks[0] | |
for chunk in tokenized_chunks[1:]: | |
input_ids.append(image_token_index) | |
input_ids.extend(chunk[1:]) # Exclude BOS token on nonzero index | |
return torch.tensor(input_ids, dtype=torch.long) | |
def process_tensors(input_ids, image_features, embedding_layer): | |
# Find the index of -200 in input_ids | |
split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0] | |
# Split the input_ids at the index found, excluding -200 | |
input_ids_1 = input_ids[:, :split_index] | |
input_ids_2 = input_ids[:, split_index + 1 :] | |
# Convert input_ids to embeddings | |
embeddings_1 = embedding_layer(input_ids_1) | |
embeddings_2 = embedding_layer(input_ids_2) | |
device = image_features.device | |
token_embeddings_part1 = embeddings_1.to(device) | |
token_embeddings_part2 = embeddings_2.to(device) | |
# Concatenate the token embeddings and image features | |
concatenated_embeddings = torch.cat( | |
[token_embeddings_part1, image_features, token_embeddings_part2], dim=1 | |
) | |
# Create the corrected attention mask | |
attention_mask = torch.ones( | |
concatenated_embeddings.shape[:2], dtype=torch.long, device=device | |
) | |
return concatenated_embeddings, attention_mask | |
def initialize_models(): | |
# bnb_config = BitsAndBytesConfig( | |
# load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16 | |
# ) | |
tokenizer = AutoTokenizer.from_pretrained( | |
#"unsloth/llama-3-8b-Instruct", | |
"E:\Workspace\BAP\LlamaVision\llama-3-8b-Instruct", | |
use_fast=True | |
) | |
model = LlamaForCausalLM.from_pretrained( | |
#"unsloth/llama-3-8b-Instruct", | |
"E:\Workspace\BAP\LlamaVision\llama-3-8b-Instruct", | |
torch_dtype=torch.float16, | |
device_map="auto" | |
#quantization_config=bnb_config, | |
) | |
for param in model.base_model.parameters(): | |
param.requires_grad = False | |
#model_name = "google/siglip-so400m-patch14-384" | |
model_name = "E:\Workspace\BAP\LlamaVision\siglip-so400m-patch14-384" | |
vision_model = SiglipVisionModel.from_pretrained( | |
model_name, torch_dtype=torch.float16 | |
) | |
processor = SiglipImageProcessor.from_pretrained(model_name) | |
vision_model = vision_model.to("cpu") | |
return tokenizer, model, vision_model, processor | |
class ProjectionModule(nn.Module): | |
def __init__(self, mm_hidden_size, hidden_size): | |
super(ProjectionModule, self).__init__() | |
# Directly set up the sequential model | |
self.model = nn.Sequential( | |
nn.Linear(mm_hidden_size, hidden_size), | |
nn.GELU(), | |
nn.Linear(hidden_size, hidden_size), | |
) | |
def forward(self, x): | |
return self.model(x) | |
def load_projection_module(mm_hidden_size=1152, hidden_size=4096, device="cpu"): | |
projection_module = ProjectionModule(mm_hidden_size, hidden_size) | |
checkpoint = torch.load("./mm_projector.bin") | |
checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()} | |
projection_module.load_state_dict(checkpoint) | |
projection_module = projection_module.to(device).half() | |
return projection_module | |
def answer_question( | |
image_path, tokenizer, model, vision_model, processor, projection_module | |
): | |
image = Image.open(image_path).convert("RGB") | |
tokenizer.eos_token = "<|eot_id|>" | |
try: | |
q = input("\nuser: ") | |
except EOFError: | |
q = "" | |
if not q: | |
print("no input detected. exiting.") | |
sys.exit() | |
question = "<image>" + q | |
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | |
input_ids = ( | |
tokenizer_image_token(prompt, tokenizer) | |
.unsqueeze(0) | |
.to(model.device) | |
) | |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
with torch.inference_mode(): | |
image_inputs = processor( | |
images=[image], | |
return_tensors="pt", | |
do_resize=True, | |
size={"height": 384, "width": 384}, | |
).to("cpu") | |
image_inputs = image_inputs["pixel_values"].squeeze(0) | |
image_forward_outs = vision_model( | |
image_inputs.to(device="cpu", dtype=torch.float16).unsqueeze(0), | |
output_hidden_states=True, | |
) | |
image_features = image_forward_outs.hidden_states[-2] | |
projected_embeddings = projection_module(image_features).to("cpu") | |
embedding_layer = model.get_input_embeddings() | |
# text_embeddings = embedding_layer(input_ids) | |
new_embeds, attn_mask = process_tensors( | |
input_ids, projected_embeddings, embedding_layer | |
) | |
device = model.device | |
attn_mask = attn_mask.to(device) | |
new_embeds = new_embeds.to(device) | |
model_kwargs = { | |
"do_sample": True, | |
"temperature": 0.2, | |
"max_new_tokens": 2000, | |
"use_cache": True, | |
"streamer": streamer, | |
"pad_token_id": tokenizer.eos_token_id | |
} | |
while True: | |
print('assistant: ') | |
generated_ids = model.generate( | |
inputs_embeds=new_embeds, attention_mask=attn_mask, **model_kwargs | |
)[0] | |
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False) | |
try: | |
q = input("\nuser: ") | |
except EOFError: | |
q = "" | |
if not q: | |
print("no input detected. exiting.") | |
new_text = ( | |
generated_text | |
+ "<|start_header_id|>user<|end_header_id|>\n\n" | |
+ q | |
+ "<|start_header_id|>assistant<|end_header_id|>\n\n" | |
) | |
new_input_ids = tokenizer(new_text, return_tensors="pt").input_ids.to( | |
device | |
) | |
new_embeddings = embedding_layer(new_input_ids) | |
new_embeds = torch.cat([new_embeds, new_embeddings], dim=1) | |
attn_mask = torch.ones(new_embeds.shape[:2], device=device) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Answer questions based on an image") | |
parser.add_argument("-i", "--image", required=True, help="Path to the image file") | |
args = parser.parse_args() | |
tokenizer, model, vision_model, processor = initialize_models() | |
projection_module = load_projection_module() | |
answer_question( | |
args.image, | |
tokenizer, | |
model, | |
vision_model, | |
processor, | |
projection_module, | |
) | |