Spaces:
Running
Running
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer | |
from PIL import Image | |
import torch | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import fitz # PyMuPDF for PDF processing | |
from io import BytesIO | |
# Load model and processor | |
ckpt = "Qwen/Qwen2.5-VL-7B-Instruct" | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
ckpt, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True | |
).to("cuda") | |
processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True) | |
def process_pdf(file_path): | |
"""Convert first page of PDF to PIL Image""" | |
pdf_doc = fitz.open(file_path) | |
page = pdf_doc.load_page(0) | |
pix = page.get_pixmap() | |
img_bytes = pix.tobytes("ppm") | |
image = Image.open(BytesIO(img_bytes)).convert("RGB") | |
pdf_doc.close() | |
return image | |
def bot_streaming(message, history, max_new_tokens=2048): | |
txt = message["text"] | |
images = [] | |
messages = [] | |
# Process history | |
for i, (user_msg, bot_msg) in enumerate(history): | |
if isinstance(user_msg, list): # Contains files | |
hist_images = [] | |
content = [{"type": "text", "text": user_msg[0]["text"]}] | |
for file_info in user_msg[1:]: | |
file_path = file_info["path"] if isinstance(file_info, dict) else file_info | |
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): | |
img = Image.open(file_path).convert("RGB") | |
hist_images.append(img) | |
content.append({"type": "image"}) | |
elif file_path.lower().endswith('.pdf'): | |
try: | |
img = process_pdf(file_path) | |
hist_images.append(img) | |
content.append({"type": "image"}) | |
except Exception as e: | |
print(f"Error processing PDF: {e}") | |
images.extend(hist_images) | |
messages.append({"role": "user", "content": content}) | |
messages.append({"role": "assistant", "content": bot_msg}) | |
else: | |
messages.extend([ | |
{"role": "user", "content": [{"type": "text", "text": user_msg}]}, | |
{"role": "assistant", "content": [{"type": "text", "text": bot_msg}]} | |
]) | |
# Process current message | |
current_images = [] | |
content = [{"type": "text", "text": txt}] | |
if message["files"]: | |
for file_info in message["files"]: | |
file_path = file_info["path"] if isinstance(file_info, dict) else file_info | |
try: | |
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp')): | |
img = Image.open(file_path).convert("RGB") | |
current_images.append(img) | |
content.append({"type": "image"}) | |
elif file_path.lower().endswith('.pdf'): | |
img = process_pdf(file_path) | |
current_images.append(img) | |
content.append({"type": "image"}) | |
except Exception as e: | |
print(f"File processing error: {e}") | |
images.extend(current_images) | |
messages.append({"role": "user", "content": content}) | |
else: | |
messages.append({"role": "user", "content": [{"type": "text", "text": txt}]}) | |
# Generate response | |
inputs = processor( | |
text=processor.apply_chat_template(messages, add_generation_prompt=True), | |
images=images if images else None, | |
return_tensors="pt" | |
).to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_special_tokens=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
yield buffer | |
# Configure Gradio interface | |
textbox = gr.MultimodalTextbox( | |
file_upload_kwargs={ | |
"file_count": "multiple", | |
"file_types": ["image", ".pdf"] | |
}, | |
placeholder="Input message or upload files...", | |
show_label=False | |
) | |
demo = gr.ChatInterface( | |
fn=bot_streaming, | |
title="MultiFile AI Assistant", | |
examples=[], | |
textbox=textbox, | |
additional_inputs=[ | |
gr.Slider(10, 4096, value=512, label="Max New Tokens") | |
], | |
css=".gradio-container {background: #fafafa}", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |