|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
import torch |
|
from transformers import AutoTokenizer, BitsAndBytesConfig |
|
|
|
from llava.model import LlavaLlamaForCausalLM |
|
from llava.utils import disable_torch_init |
|
from llava.constants import IMAGE_TOKEN_INDEX |
|
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria |
|
|
|
|
|
def model_fn(model_dir): |
|
kwargs = {"device_map": "auto"} |
|
kwargs["torch_dtype"] = torch.float16 |
|
model = LlavaLlamaForCausalLM.from_pretrained( |
|
model_dir, low_cpu_mem_usage=True, **kwargs |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False) |
|
|
|
vision_tower = model.get_vision_tower() |
|
if not vision_tower.is_loaded: |
|
vision_tower.load_model() |
|
vision_tower.to(device="cuda", dtype=torch.float16) |
|
image_processor = vision_tower.image_processor |
|
return model, tokenizer, image_processor |
|
|
|
|
|
def predict_fn(data, model_and_tokenizer): |
|
|
|
model, tokenizer, image_processor = model_and_tokenizer |
|
|
|
|
|
image_file = data.pop("image", data) |
|
prompt = data.pop("question", data) |
|
|
|
max_new_tokens = data.pop("max_new_tokens", 1024) |
|
temperature = data.pop("temperature", 0.2) |
|
stop_str = data.pop("stop_str", "###") |
|
|
|
if image_file.startswith("http") or image_file.startswith("https"): |
|
response = requests.get(image_file) |
|
image = Image.open(BytesIO(response.content)).convert("RGB") |
|
else: |
|
image = Image.open(image_file).convert("RGB") |
|
|
|
disable_torch_init() |
|
image_tensor = ( |
|
image_processor.preprocess(image, return_tensors="pt")["pixel_values"] |
|
.half() |
|
.cuda() |
|
) |
|
|
|
keywords = [stop_str] |
|
input_ids = ( |
|
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") |
|
.unsqueeze(0) |
|
.cuda() |
|
) |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
|
with torch.inference_mode(): |
|
output_ids = model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=True, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria], |
|
) |
|
outputs = tokenizer.decode( |
|
output_ids[0, input_ids.shape[1] :], skip_special_tokens=True |
|
).strip() |
|
return outputs |
|
|