Spaces:
Runtime error
Runtime error
from PIL import Image | |
import torch | |
import fire | |
from processor import MultiModalProcessor | |
from model.utils.kv_cache import KVCache | |
from model.multimodal.multimodal_model import PaliGemmaForConditionalGeneration | |
from load_model import load_hf_model | |
def move_inputs_to_device(model_inputs: dict, device: str): | |
model_inputs = {k: v.to(device) for k, v in model_inputs.items()} | |
return model_inputs | |
def get_model_inputs( | |
processor: MultiModalProcessor, prompt: str, image_file_path: str, device: str | |
): | |
image = Image.open(image_file_path) | |
images = [image] | |
prompts = [prompt] | |
model_inputs = processor(text=prompts, images=images) | |
model_inputs = move_inputs_to_device(model_inputs, device) | |
return model_inputs | |
def test_inference( | |
model: PaliGemmaForConditionalGeneration, | |
processor: MultiModalProcessor, | |
device: str, | |
prompt: str, | |
image_file_path: str, | |
max_tokens_to_generate: int, | |
temperature: float, | |
top_p: float, | |
do_sample: bool, | |
): | |
model_inputs = get_model_inputs(processor, prompt, image_file_path, device) | |
input_ids = model_inputs["input_ids"] | |
attention_mask = model_inputs["attention_mask"] | |
pixel_values = model_inputs["pixel_values"] | |
kv_cache = KVCache() | |
stop_token = processor.tokenizer.eos_token_id | |
generated_tokens = [] | |
for _ in range(max_tokens_to_generate): | |
outputs = model( | |
input_ids=input_ids, | |
pixel_values=pixel_values, | |
attention_mask=attention_mask, | |
kv_cache=kv_cache, | |
) | |
kv_cache = outputs["kv_cache"] | |
next_token_logits = outputs["logits"][:, -1, :] | |
if do_sample: | |
next_token_logits = torch.softmax(next_token_logits / temperature, dim=-1) | |
next_token = _sample_top_p(next_token_logits, top_p) | |
else: | |
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) | |
assert next_token.size() == (1, 1) | |
next_token = next_token.squeeze(0) | |
generated_tokens.append(next_token) | |
if next_token.item() == stop_token: | |
break | |
input_ids = next_token.unsqueeze(-1) | |
attention_mask = torch.cat( | |
[attention_mask, torch.ones((1, 1), device=input_ids.device)], dim=-1 | |
) | |
generated_tokens = torch.cat(generated_tokens, dim=-1) | |
decoded = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
print(decoded) | |
def _sample_top_p(probs: torch.Tensor, p: float): | |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | |
probs_sum = torch.cumsum(probs_sort, dim=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | |
next_token = torch.multinomial(probs_sort, num_samples=1) | |
next_token = torch.gather(probs_idx, -1, next_token) | |
return next_token | |
def main( | |
model_path: str = None, | |
prompt: str = None, | |
image_file_path: str = None, | |
max_tokens_to_generate: int = 100, | |
temperature: float = 0.8, | |
top_p: float = 0.9, | |
do_sample: bool = False, | |
only_cpu: bool = False, | |
): | |
device = "cpu" | |
if not only_cpu: | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
print("Device in use: ", device) | |
print(f"Loading model") | |
model, tokenizer = load_hf_model(model_path, device) | |
model = model.to(device).eval() | |
num_image_tokens = model.config.vision_config.num_image_tokens | |
image_size = model.config.vision_config.image_size | |
max_length = 512 | |
processor = MultiModalProcessor(tokenizer, num_image_tokens, image_size, max_length) | |
print("Running inference") | |
with torch.no_grad(): | |
test_inference( | |
model, | |
processor, | |
device, | |
prompt, | |
image_file_path, | |
max_tokens_to_generate, | |
temperature, | |
top_p, | |
do_sample, | |
) | |
if __name__ == "__main__": | |
fire.Fire(main) |