vlm-o / inference.py
veerpareek's picture
Upload 35 files
577d9ca verified
from PIL import Image
import torch
import fire
from processor import MultiModalProcessor
from model.utils.kv_cache import KVCache
from model.multimodal.multimodal_model import PaliGemmaForConditionalGeneration
from load_model import load_hf_model
def move_inputs_to_device(model_inputs: dict, device: str):
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
return model_inputs
def get_model_inputs(
processor: MultiModalProcessor, prompt: str, image_file_path: str, device: str
):
image = Image.open(image_file_path)
images = [image]
prompts = [prompt]
model_inputs = processor(text=prompts, images=images)
model_inputs = move_inputs_to_device(model_inputs, device)
return model_inputs
def test_inference(
model: PaliGemmaForConditionalGeneration,
processor: MultiModalProcessor,
device: str,
prompt: str,
image_file_path: str,
max_tokens_to_generate: int,
temperature: float,
top_p: float,
do_sample: bool,
):
model_inputs = get_model_inputs(processor, prompt, image_file_path, device)
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs["attention_mask"]
pixel_values = model_inputs["pixel_values"]
kv_cache = KVCache()
stop_token = processor.tokenizer.eos_token_id
generated_tokens = []
for _ in range(max_tokens_to_generate):
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
kv_cache=kv_cache,
)
kv_cache = outputs["kv_cache"]
next_token_logits = outputs["logits"][:, -1, :]
if do_sample:
next_token_logits = torch.softmax(next_token_logits / temperature, dim=-1)
next_token = _sample_top_p(next_token_logits, top_p)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
assert next_token.size() == (1, 1)
next_token = next_token.squeeze(0)
generated_tokens.append(next_token)
if next_token.item() == stop_token:
break
input_ids = next_token.unsqueeze(-1)
attention_mask = torch.cat(
[attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1
)
generated_tokens = torch.cat(generated_tokens, dim=-1)
decoded = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(decoded)
def _sample_top_p(probs: torch.Tensor, p: float):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
def main(
model_path: str = None,
prompt: str = None,
image_file_path: str = None,
max_tokens_to_generate: int = 100,
temperature: float = 0.8,
top_p: float = 0.9,
do_sample: bool = False,
only_cpu: bool = False,
):
device = "cpu"
if not only_cpu:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
print("Device in use: ", device)
print(f"Loading model")
model, tokenizer = load_hf_model(model_path, device)
model = model.to(device).eval()
num_image_tokens = model.config.vision_config.num_image_tokens
image_size = model.config.vision_config.image_size
max_length = 512
processor = MultiModalProcessor(tokenizer, num_image_tokens, image_size, max_length)
print("Running inference")
with torch.no_grad():
test_inference(
model,
processor,
device,
prompt,
image_file_path,
max_tokens_to_generate,
temperature,
top_p,
do_sample,
)
if __name__ == "__main__":
fire.Fire(main)