import logging from typing import List import torch import gradio as gr from transformers import AutoModelForCausalLM, TextIteratorStreamer logging.getLogger("httpx").setLevel(logging.WARNING) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) model_name = 'AIDC-AI/Ovis2-16B' model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True ).to(device='cuda') text_tokenizer = model.get_text_tokenizer() visual_tokenizer = model.get_visual_tokenizer() streamer = TextIteratorStreamer( text_tokenizer, skip_prompt=True, skip_special_tokens=True ) IMAGE_PLACEHOLDER = "" def initialize_gen_kwargs(): return { "max_new_tokens": 1536, "do_sample": False, "top_p": None, "top_k": None, "temperature": None, "repetition_penalty": 1.05, "eos_token_id": model.generation_config.eos_token_id, "pad_token_id": text_tokenizer.pad_token_id, "use_cache": True } def submit_chat(chatbot, text_input, image_input): if text_input.strip() or image_input is not None: chatbot.append((text_input, "")) return chatbot, "", None def ovis_chat(chatbot, text_input, image_input): # 기존 대화(질문/답변) 재구성 conversations = [] for q, r in chatbot[:-1]: conversations.append({"from": "human", "value": q}) conversations.append({"from": "gpt", "value": r}) # 마지막 질문 last_query = chatbot[-1][0] if image_input is not None: last_query = f"{IMAGE_PLACEHOLDER}\n{last_query}" conversations.append({"from": "human", "value": last_query}) # === 수정 포인트 === # preprocess_inputs()의 두 번째 인자는 'images'이며, # image=image_input(키워드) X -> [image_input] (리스트) 형태로 전달. prompt, input_ids, pixel_values = model.preprocess_inputs( conversations, [image_input] if image_input is not None else None, # 두 번째 인자로 이미지 max_partition=16 ) attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) model_inputs = { "inputs": input_ids.unsqueeze(0).to(device='cuda'), "attention_mask": attention_mask.unsqueeze(0).to(device='cuda'), "pixel_values": ( [pixel_values.to(dtype=visual_tokenizer.dtype, device='cuda')] if pixel_values is not None else [None] ), } gen_kwargs = initialize_gen_kwargs() with torch.inference_mode(): model.generate(**model_inputs, **gen_kwargs, streamer=streamer) response = "" for new_text in streamer: response += new_text chatbot[-1][1] = response yield chatbot logger.info("[OVIS_CONV_START]") for i, (req, ans) in enumerate(chatbot, 1): logger.info(f"Q{i}: {req}\nA{i}: {ans}") logger.info("[OVIS_CONV_END]") def clear_chat(): return [], "", None with gr.Blocks(title="Ovis Demo", theme=gr.themes.Ocean()) as demo: chatbot = gr.Chatbot(label="Ovis", height=500, show_copy_button=True) text_input = gr.Textbox(label="Prompt", placeholder="질문을 입력하세요...", lines=1) image_input = gr.Image(label="Image (optional)", type="pil") with gr.Row(): send_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") send_btn.click( fn=submit_chat, inputs=[chatbot, text_input, image_input], outputs=[chatbot, text_input, image_input] ).then( fn=ovis_chat, inputs=[chatbot, text_input, image_input], outputs=chatbot ) text_input.submit( fn=submit_chat, inputs=[chatbot, text_input, image_input], outputs=[chatbot, text_input, image_input] ).then( fn=ovis_chat, inputs=[chatbot, text_input, image_input], outputs=chatbot ) clear_btn.click(fn=clear_chat, outputs=[chatbot, text_input, image_input]) demo.launch()