mboushaba's picture
Update app.py
d1d5efe verified
raw
history blame
2.69 kB
import time
import gradio as gr
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoProcessor
model_id = "microsoft/Phi-3.5-vision-instruct"
# Note: set _attn_implementation='eager' if you don't have flash_attn installed
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map = "auto",
trust_remote_code = True,
torch_dtype = torch.bfloat16,
_attn_implementation = 'eager'
)
device = torch.device("cpu")
model.to(device)
# for best performance, use num_crops=4 for multi-frame, num_crops=16 for single-frame.
processor = AutoProcessor.from_pretrained(model_id,
trust_remote_code = True,
num_crops = 4
)
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"
def call_model(raw_image = None, text_input = None):
prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
image = raw_image.convert("RGB")
inputs = processor(prompt, image, return_tensors = "pt").to("cpu:0")
generate_ids = model.generate(**inputs,
max_new_tokens = 1000,
eos_token_id = processor.tokenizer.eos_token_id,
)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids,
skip_special_tokens = True,
clean_up_tokenization_spaces = False)[0]
return response
def get_model_memory_footprint(model_):
footprint = model_.get_memory_footprint()
return f"Footprint of the model in MBs: {footprint / 1e+6}Mb"
def process(raw_image, prompt):
print("start...")
start_time = time.time()
memory_usage = get_model_memory_footprint(model)
model_response = call_model(raw_image = raw_image, text_input = prompt)
end_time = time.time()
execution_time = end_time - start_time
execution_time_min = round((execution_time / 60), 2)
print(f"Execution time: {execution_time:.4f} seconds")
print(f"Execution time: {execution_time_min:.2f} min")
return memory_usage, model_response, execution_time_min
iface = gr.Interface(process,
inputs = [gr.Image(type = 'pil'), gr.Textbox(label = "What do you want to ask?")],
outputs = [gr.Textbox(label = "Memory usage"), gr.Textbox(label = "Model response"),
gr.Textbox(label = "Execution time (min)")])
if __name__ == '__main__':
iface.launch()