Imagechat / app.py
Daemontatox's picture
Update app.py
b62635c verified
raw
history blame
8.04 kB
import spaces
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
from PIL import Image
import torch
from threading import Thread
import gradio as gr
import fitz # PyMuPDF
import io
import logging
import time
import numpy as np
from threading import Thread
import gradio as gr
from gradio import FileData
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load model and processor
ckpt = "Qwen/Qwen2.5-VL-7B-Instruct"
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda")
processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
class DocumentState:
def __init__(self):
self.current_doc_images = []
self.current_doc_text = ""
self.doc_type = None
def clear(self):
self.current_doc_images = []
self.current_doc_text = ""
self.doc_type = None
doc_state = DocumentState()
def process_pdf_file(file_path):
"""Convert PDF to images and extract text using PyMuPDF."""
try:
doc = fitz.open(file_path)
images = []
text = ""
for page_num in range(doc.page_count):
try:
page = doc[page_num]
page_text = page.get_text("text")
if page_text.strip():
text += f"Page {page_num + 1}:\n{page_text}\n\n"
zoom = 3
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat, alpha=False)
img_data = pix.tobytes("png")
img = Image.open(io.BytesIO(img_data))
img = img.convert("RGB")
max_size = 1600
if max(img.size) > max_size:
ratio = max_size / max(img.size)
new_size = tuple(int(dim * ratio) for dim in img.size)
img = img.resize(new_size, Image.Resampling.LANCZOS)
images.append(img)
except Exception as e:
logger.error(f"Error processing page {page_num}: {str(e)}")
continue
doc.close()
if not images:
raise ValueError("No valid images could be extracted from the PDF")
return images, text
except Exception as e:
logger.error(f"Error processing PDF file: {str(e)}")
raise
def process_uploaded_file(file):
"""Process uploaded file and update document state."""
try:
doc_state.clear()
if file is None:
return "No file uploaded. Please upload a file."
if isinstance(file, dict):
file_path = file["name"]
else:
file_path = file.name
file_ext = file_path.lower().split('.')[-1]
image_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'}
if file_ext == 'pdf':
doc_state.doc_type = 'pdf'
try:
doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path)
return f"PDF processed successfully. Total pages: {len(doc_state.current_doc_images)}. You can now ask questions about the content."
except Exception as e:
return f"Error processing PDF: {str(e)}. Please try a different PDF file."
elif file_ext in image_extensions:
doc_state.doc_type = 'image'
try:
img = Image.open(file_path).convert("RGB")
max_size = 1600
if max(img.size) > max_size:
ratio = max_size / max(img.size)
new_size = tuple(int(dim * ratio) for dim in img.size)
img = img.resize(new_size, Image.Resampling.LANCZOS)
doc_state.current_doc_images = [img]
return "Image loaded successfully. You can now ask questions about the content."
except Exception as e:
return f"Error processing image: {str(e)}. Please try a different image file."
else:
return f"Unsupported file type: {file_ext}. Please upload a PDF or image file."
except Exception as e:
logger.error(f"Error in process_file: {str(e)}")
return "An error occurred while processing the file. Please try again."
@spaces.GPU()
def bot_streaming(user_prompt, max_new_tokens=4096):
try:
if not user_prompt.strip():
yield "Please enter a valid prompt/question."
return
messages = []
# Include document context
if doc_state.current_doc_images:
context = f"\nDocument context:\n{doc_state.current_doc_text}" if doc_state.current_doc_text else ""
current_msg = f"{user_prompt}{context}"
messages.append({"role": "user", "content": [{"type": "text", "text": current_msg}, {"type": "image"}]})
else:
messages.append({"role": "user", "content": [{"type": "text", "text": user_prompt}]})
# Process inputs
texts = processor.apply_chat_template(messages, add_generation_prompt=True)
try:
if doc_state.current_doc_images:
inputs = processor(
text=texts,
images=doc_state.current_doc_images[0:1],
return_tensors="pt"
).to("cuda")
else:
inputs = processor(text=texts, return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
except Exception as e:
logger.error(f"Error in model processing: {str(e)}")
yield "An error occurred while processing your request. Please try again."
except Exception as e:
logger.error(f"Error in bot_streaming: {str(e)}")
yield "An error occurred. Please try again."
def clear_context():
"""Clear the current document context."""
doc_state.clear()
return "Document context cleared. You can upload a new document."
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Document Analyzer with Custom Prompts")
gr.Markdown("Upload a document and enter your custom prompt/question about its contents.")
with gr.Row():
file_upload = gr.File(
label="Upload Document (PDF or Image)",
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"]
)
upload_status = gr.Textbox(
label="Upload Status",
interactive=False
)
with gr.Row():
user_prompt = gr.Textbox(
label="Enter your prompt/question",
placeholder="e.g., Explain this document...\nExtract key points...\nWhat is the main idea?",
lines=3
)
generate_btn = gr.Button("Generate")
clear_btn = gr.Button("Clear Document Context")
output_text = gr.Textbox(
label="Output",
interactive=False
)
file_upload.change(
fn=process_uploaded_file,
inputs=[file_upload],
outputs=[upload_status]
)
generate_btn.click(
fn=bot_streaming,
inputs=[user_prompt],
outputs=[output_text]
)
clear_btn.click(
fn=clear_context,
outputs=[upload_status]
)
# Launch the interface
demo.launch(debug=True)