|
import logging |
|
from typing import List |
|
import torch |
|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, TextIteratorStreamer |
|
|
|
logging.getLogger("httpx").setLevel(logging.WARNING) |
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
model_name = 'AIDC-AI/Ovis2-16B' |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True |
|
).to(device='cuda') |
|
|
|
text_tokenizer = model.get_text_tokenizer() |
|
visual_tokenizer = model.get_visual_tokenizer() |
|
streamer = TextIteratorStreamer( |
|
text_tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True |
|
) |
|
|
|
IMAGE_PLACEHOLDER = "<image>" |
|
|
|
def initialize_gen_kwargs(): |
|
return { |
|
"max_new_tokens": 1536, |
|
"do_sample": False, |
|
"top_p": None, |
|
"top_k": None, |
|
"temperature": None, |
|
"repetition_penalty": 1.05, |
|
"eos_token_id": model.generation_config.eos_token_id, |
|
"pad_token_id": text_tokenizer.pad_token_id, |
|
"use_cache": True |
|
} |
|
|
|
def submit_chat(chatbot, text_input, image_input): |
|
if text_input.strip() or image_input is not None: |
|
chatbot.append((text_input, "")) |
|
return chatbot, "", None |
|
|
|
def ovis_chat(chatbot, text_input, image_input): |
|
|
|
conversations = [] |
|
for q, r in chatbot[:-1]: |
|
conversations.append({"from": "human", "value": q}) |
|
conversations.append({"from": "gpt", "value": r}) |
|
|
|
|
|
last_query = chatbot[-1][0] |
|
if image_input is not None: |
|
last_query = f"{IMAGE_PLACEHOLDER}\n{last_query}" |
|
|
|
conversations.append({"from": "human", "value": last_query}) |
|
|
|
|
|
|
|
|
|
prompt, input_ids, pixel_values = model.preprocess_inputs( |
|
conversations, |
|
[image_input] if image_input is not None else None, |
|
max_partition=16 |
|
) |
|
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) |
|
|
|
model_inputs = { |
|
"inputs": input_ids.unsqueeze(0).to(device='cuda'), |
|
"attention_mask": attention_mask.unsqueeze(0).to(device='cuda'), |
|
"pixel_values": ( |
|
[pixel_values.to(dtype=visual_tokenizer.dtype, device='cuda')] |
|
if pixel_values is not None else |
|
[None] |
|
), |
|
} |
|
|
|
gen_kwargs = initialize_gen_kwargs() |
|
with torch.inference_mode(): |
|
model.generate(**model_inputs, **gen_kwargs, streamer=streamer) |
|
response = "" |
|
for new_text in streamer: |
|
response += new_text |
|
chatbot[-1][1] = response |
|
yield chatbot |
|
|
|
logger.info("[OVIS_CONV_START]") |
|
for i, (req, ans) in enumerate(chatbot, 1): |
|
logger.info(f"Q{i}: {req}\nA{i}: {ans}") |
|
logger.info("[OVIS_CONV_END]") |
|
|
|
def clear_chat(): |
|
return [], "", None |
|
|
|
with gr.Blocks(title="Ovis Demo", theme=gr.themes.Ocean()) as demo: |
|
chatbot = gr.Chatbot(label="Ovis", height=500, show_copy_button=True) |
|
text_input = gr.Textbox(label="Prompt", placeholder="์ง๋ฌธ์ ์
๋ ฅํ์ธ์...", lines=1) |
|
image_input = gr.Image(label="Image (optional)", type="pil") |
|
|
|
with gr.Row(): |
|
send_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
|
|
send_btn.click( |
|
fn=submit_chat, |
|
inputs=[chatbot, text_input, image_input], |
|
outputs=[chatbot, text_input, image_input] |
|
).then( |
|
fn=ovis_chat, |
|
inputs=[chatbot, text_input, image_input], |
|
outputs=chatbot |
|
) |
|
|
|
text_input.submit( |
|
fn=submit_chat, |
|
inputs=[chatbot, text_input, image_input], |
|
outputs=[chatbot, text_input, image_input] |
|
).then( |
|
fn=ovis_chat, |
|
inputs=[chatbot, text_input, image_input], |
|
outputs=chatbot |
|
) |
|
|
|
clear_btn.click(fn=clear_chat, outputs=[chatbot, text_input, image_input]) |
|
|
|
demo.launch() |
|
|