|
import subprocess |
|
subprocess.run('pip install flash-attn==2.7.0.post2 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
import spaces |
|
import os |
|
import re |
|
import time |
|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM |
|
from transformers import TextIteratorStreamer |
|
from threading import Thread |
|
|
|
model_name = 'AIDC-AI/Ovis2-16B' |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, |
|
torch_dtype=torch.bfloat16, |
|
multimodal_max_length=8192, |
|
trust_remote_code=True).to(device='cuda') |
|
text_tokenizer = model.get_text_tokenizer() |
|
visual_tokenizer = model.get_visual_tokenizer() |
|
streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
image_placeholder = '<image>' |
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
def submit_chat(chatbot, text_input): |
|
response = '' |
|
chatbot.append((text_input, response)) |
|
return chatbot ,'' |
|
|
|
@spaces.GPU |
|
def ovis_chat(chatbot, image_input): |
|
|
|
conversations = [{ |
|
"from": "system", |
|
"value": "You are a helpful assistant, and your task is to provide reliable and structured responses to users." |
|
}] |
|
response = "" |
|
text_input = chatbot[-1][0] |
|
for query, response in chatbot[:-1]: |
|
conversations.append({ |
|
"from": "human", |
|
"value": query |
|
}) |
|
conversations.append({ |
|
"from": "gpt", |
|
"value": response |
|
}) |
|
text_input = text_input.replace(image_placeholder, '') |
|
conversations.append({ |
|
"from": "human", |
|
"value": text_input |
|
}) |
|
if image_input is not None: |
|
conversations[1]["value"] = image_placeholder + '\n' + conversations[1]["value"] |
|
prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input]) |
|
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id) |
|
input_ids = input_ids.unsqueeze(0).to(device=model.device) |
|
attention_mask = attention_mask.unsqueeze(0).to(device=model.device) |
|
if image_input is None: |
|
pixel_values = [None] |
|
else: |
|
pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)] |
|
|
|
with torch.inference_mode(): |
|
gen_kwargs = dict( |
|
max_new_tokens=1536, |
|
do_sample=False, |
|
top_p=None, |
|
top_k=None, |
|
temperature=None, |
|
repetition_penalty=None, |
|
eos_token_id=model.generation_config.eos_token_id, |
|
pad_token_id=text_tokenizer.pad_token_id, |
|
use_cache=True |
|
) |
|
response = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
model.generate( |
|
input_ids, |
|
pixel_values=pixel_values, |
|
attention_mask=attention_mask, |
|
streamer=streamer, |
|
**gen_kwargs |
|
) |
|
|
|
for new_text in streamer: |
|
response += new_text |
|
chatbot[-1][1] = response |
|
yield chatbot |
|
|
|
|
|
print('*'*60) |
|
print('*'*60) |
|
print('OVIS_CONV_START') |
|
for i, (request, answer) in enumerate(chatbot[:-1], 1): |
|
print(f'Q{i}:\n {request}') |
|
print(f'A{i}:\n {answer}') |
|
print('New_Q:\n', text_input) |
|
print('New_A:\n', response) |
|
print('OVIS_CONV_END') |
|
|
|
def clear_chat(): |
|
return [], None, "" |
|
|
|
with open(f"{cur_dir}/resource/logo.svg", "r", encoding="utf-8") as svg_file: |
|
svg_content = svg_file.read() |
|
font_size = "2.5em" |
|
svg_content = re.sub(r'(<svg[^>]*)(>)', rf'\1 height="{font_size}" style="vertical-align: middle; display: inline-block;"\2', svg_content) |
|
html = f""" |
|
<p align="center" style="font-size: {font_size}; line-height: 1;"> |
|
<span style="display: inline-block; vertical-align: middle;">{svg_content}</span> |
|
<span style="display: inline-block; vertical-align: middle;">{model_name.split('/')[-1]}</span> |
|
</p> |
|
<center><font size=3><b>Ovis</b> has been open-sourced on <a href='https://huggingface.co/{model_name}'>😊 Huggingface</a> and <a href='https://github.com/AIDC-AI/Ovis'>🌟 GitHub</a>. If you find Ovis useful, a like❤️ or a star🌟 would be appreciated.</font></center> |
|
""" |
|
|
|
latex_delimiters_set = [{ |
|
"left": "\\(", |
|
"right": "\\)", |
|
"display": True |
|
}, { |
|
"left": "\\begin{equation}", |
|
"right": "\\end{equation}", |
|
"display": True |
|
}, { |
|
"left": "\\begin{align}", |
|
"right": "\\end{align}", |
|
"display": True |
|
}, { |
|
"left": "\\begin{alignat}", |
|
"right": "\\end{alignat}", |
|
"display": True |
|
}, { |
|
"left": "\\begin{gather}", |
|
"right": "\\end{gather}", |
|
"display": True |
|
}, { |
|
"left": "\\begin{CD}", |
|
"right": "\\end{CD}", |
|
"display": True |
|
}, { |
|
"left": "\\[", |
|
"right": "\\]", |
|
"display": True |
|
}] |
|
|
|
text_input = gr.Textbox(label="prompt", placeholder="Enter your text here...", lines=1, container=False) |
|
with gr.Blocks(title=model_name.split('/')[-1], theme=gr.themes.Ocean()) as demo: |
|
gr.HTML(html) |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
image_input = gr.Image(label="image", height=350, type="pil") |
|
gr.Examples( |
|
examples=[ |
|
[f"{cur_dir}/examples/case0.png", "Find the area of the shaded region."], |
|
[f"{cur_dir}/examples/case1.png", "explain this model to me."], |
|
[f"{cur_dir}/examples/case2.png", "What is net profit margin as a percentage of total revenue?"], |
|
], |
|
inputs=[image_input, text_input] |
|
) |
|
with gr.Column(scale=7): |
|
chatbot = gr.Chatbot(label="Ovis", layout="panel", height=600, show_copy_button=True, latex_delimiters=latex_delimiters_set) |
|
text_input.render() |
|
with gr.Row(): |
|
send_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
|
|
send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input],chatbot) |
|
submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input],chatbot) |
|
clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input]) |
|
|
|
demo.launch() |
|
|