Ovis2-16B / app.py
openfree's picture
Update app.py
bc07dcb verified
import logging
from typing import List
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, TextIteratorStreamer
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
model_name = 'AIDC-AI/Ovis2-16B'
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to(device='cuda')
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
streamer = TextIteratorStreamer(
text_tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
IMAGE_PLACEHOLDER = "<image>"
def initialize_gen_kwargs():
return {
"max_new_tokens": 1536,
"do_sample": False,
"top_p": None,
"top_k": None,
"temperature": None,
"repetition_penalty": 1.05,
"eos_token_id": model.generation_config.eos_token_id,
"pad_token_id": text_tokenizer.pad_token_id,
"use_cache": True
}
def submit_chat(chatbot, text_input, image_input):
if text_input.strip() or image_input is not None:
chatbot.append((text_input, ""))
return chatbot, "", None
def ovis_chat(chatbot, text_input, image_input):
# ๊ธฐ์กด ๋Œ€ํ™”(์งˆ๋ฌธ/๋‹ต๋ณ€) ์žฌ๊ตฌ์„ฑ
conversations = []
for q, r in chatbot[:-1]:
conversations.append({"from": "human", "value": q})
conversations.append({"from": "gpt", "value": r})
# ๋งˆ์ง€๋ง‰ ์งˆ๋ฌธ
last_query = chatbot[-1][0]
if image_input is not None:
last_query = f"{IMAGE_PLACEHOLDER}\n{last_query}"
conversations.append({"from": "human", "value": last_query})
# === ์ˆ˜์ • ํฌ์ธํŠธ ===
# preprocess_inputs()์˜ ๋‘ ๋ฒˆ์งธ ์ธ์ž๋Š” 'images'์ด๋ฉฐ,
# image=image_input(ํ‚ค์›Œ๋“œ) X -> [image_input] (๋ฆฌ์ŠคํŠธ) ํ˜•ํƒœ๋กœ ์ „๋‹ฌ.
prompt, input_ids, pixel_values = model.preprocess_inputs(
conversations,
[image_input] if image_input is not None else None, # ๋‘ ๋ฒˆ์งธ ์ธ์ž๋กœ ์ด๋ฏธ์ง€
max_partition=16
)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
model_inputs = {
"inputs": input_ids.unsqueeze(0).to(device='cuda'),
"attention_mask": attention_mask.unsqueeze(0).to(device='cuda'),
"pixel_values": (
[pixel_values.to(dtype=visual_tokenizer.dtype, device='cuda')]
if pixel_values is not None else
[None]
),
}
gen_kwargs = initialize_gen_kwargs()
with torch.inference_mode():
model.generate(**model_inputs, **gen_kwargs, streamer=streamer)
response = ""
for new_text in streamer:
response += new_text
chatbot[-1][1] = response
yield chatbot
logger.info("[OVIS_CONV_START]")
for i, (req, ans) in enumerate(chatbot, 1):
logger.info(f"Q{i}: {req}\nA{i}: {ans}")
logger.info("[OVIS_CONV_END]")
def clear_chat():
return [], "", None
with gr.Blocks(title="Ovis Demo", theme=gr.themes.Ocean()) as demo:
chatbot = gr.Chatbot(label="Ovis", height=500, show_copy_button=True)
text_input = gr.Textbox(label="Prompt", placeholder="์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”...", lines=1)
image_input = gr.Image(label="Image (optional)", type="pil")
with gr.Row():
send_btn = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
send_btn.click(
fn=submit_chat,
inputs=[chatbot, text_input, image_input],
outputs=[chatbot, text_input, image_input]
).then(
fn=ovis_chat,
inputs=[chatbot, text_input, image_input],
outputs=chatbot
)
text_input.submit(
fn=submit_chat,
inputs=[chatbot, text_input, image_input],
outputs=[chatbot, text_input, image_input]
).then(
fn=ovis_chat,
inputs=[chatbot, text_input, image_input],
outputs=chatbot
)
clear_btn.click(fn=clear_chat, outputs=[chatbot, text_input, image_input])
demo.launch()