llava-v1.5-7b / code /inference.py
liltom-eth's picture
update inference
d2cca8d
raw
history blame
2.4 kB
import requests
from PIL import Image
from io import BytesIO
import torch
from transformers import AutoTokenizer
from llava.model import LlavaLlamaForCausalLM
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
def model_fn(model_dir):
kwargs = {"device_map": "auto"}
kwargs["torch_dtype"] = torch.float16
model = LlavaLlamaForCausalLM.from_pretrained(
model_dir, low_cpu_mem_usage=True, **kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device="cuda", dtype=torch.float16)
image_processor = vision_tower.image_processor
return model, tokenizer, image_processor
def predict_fn(data, model_and_tokenizer):
# unpack model and tokenizer
model, tokenizer, image_processor = model_and_tokenizer
# get prompt & parameters
image_file = data.pop("image", data)
prompt = data.pop("question", data)
max_new_tokens = data.pop("max_new_tokens", 1024)
temperature = data.pop("temperature", 0.2)
stop_str = data.pop("stop_str", "###")
if image_file.startswith("http") or image_file.startswith("https"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
disable_torch_init()
image_tensor = (
image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
.half()
.cuda()
)
keywords = [stop_str]
input_ids = (
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
use_cache=True,
stopping_criteria=[stopping_criteria],
)
outputs = tokenizer.decode(
output_ids[0, input_ids.shape[1] :], skip_special_tokens=True
).strip()
return outputs