Spaces:
Running
Running
import spaces | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer | |
from PIL import Image | |
import torch | |
from threading import Thread | |
import gradio as gr | |
import fitz # PyMuPDF | |
import io | |
import logging | |
import time | |
import numpy as np | |
from threading import Thread | |
import gradio as gr | |
from gradio import FileData | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load model and processor | |
ckpt = "Qwen/Qwen2.5-VL-7B-Instruct" | |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(ckpt, torch_dtype=torch.bfloat16, trust_remote_code=True).to("cuda") | |
processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True) | |
class DocumentState: | |
def __init__(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None | |
def clear(self): | |
self.current_doc_images = [] | |
self.current_doc_text = "" | |
self.doc_type = None | |
doc_state = DocumentState() | |
def process_pdf_file(file_path): | |
"""Convert PDF to images and extract text using PyMuPDF.""" | |
try: | |
doc = fitz.open(file_path) | |
images = [] | |
text = "" | |
for page_num in range(doc.page_count): | |
try: | |
page = doc[page_num] | |
page_text = page.get_text("text") | |
if page_text.strip(): | |
text += f"Page {page_num + 1}:\n{page_text}\n\n" | |
zoom = 3 | |
mat = fitz.Matrix(zoom, zoom) | |
pix = page.get_pixmap(matrix=mat, alpha=False) | |
img_data = pix.tobytes("png") | |
img = Image.open(io.BytesIO(img_data)) | |
img = img.convert("RGB") | |
max_size = 1600 | |
if max(img.size) > max_size: | |
ratio = max_size / max(img.size) | |
new_size = tuple(int(dim * ratio) for dim in img.size) | |
img = img.resize(new_size, Image.Resampling.LANCZOS) | |
images.append(img) | |
except Exception as e: | |
logger.error(f"Error processing page {page_num}: {str(e)}") | |
continue | |
doc.close() | |
if not images: | |
raise ValueError("No valid images could be extracted from the PDF") | |
return images, text | |
except Exception as e: | |
logger.error(f"Error processing PDF file: {str(e)}") | |
raise | |
def process_uploaded_file(file): | |
"""Process uploaded file and update document state.""" | |
try: | |
doc_state.clear() | |
if file is None: | |
return "No file uploaded. Please upload a file." | |
if isinstance(file, dict): | |
file_path = file["name"] | |
else: | |
file_path = file.name | |
file_ext = file_path.lower().split('.')[-1] | |
image_extensions = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp'} | |
if file_ext == 'pdf': | |
doc_state.doc_type = 'pdf' | |
try: | |
doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path) | |
return f"PDF processed successfully. Total pages: {len(doc_state.current_doc_images)}. You can now ask questions about the content." | |
except Exception as e: | |
return f"Error processing PDF: {str(e)}. Please try a different PDF file." | |
elif file_ext in image_extensions: | |
doc_state.doc_type = 'image' | |
try: | |
img = Image.open(file_path).convert("RGB") | |
max_size = 1600 | |
if max(img.size) > max_size: | |
ratio = max_size / max(img.size) | |
new_size = tuple(int(dim * ratio) for dim in img.size) | |
img = img.resize(new_size, Image.Resampling.LANCZOS) | |
doc_state.current_doc_images = [img] | |
return "Image loaded successfully. You can now ask questions about the content." | |
except Exception as e: | |
return f"Error processing image: {str(e)}. Please try a different image file." | |
else: | |
return f"Unsupported file type: {file_ext}. Please upload a PDF or image file." | |
except Exception as e: | |
logger.error(f"Error in process_file: {str(e)}") | |
return "An error occurred while processing the file. Please try again." | |
def bot_streaming(user_prompt, max_new_tokens=4096): | |
try: | |
if not user_prompt.strip(): | |
yield "Please enter a valid prompt/question." | |
return | |
messages = [] | |
# Include document context | |
if doc_state.current_doc_images: | |
context = f"\nDocument context:\n{doc_state.current_doc_text}" if doc_state.current_doc_text else "" | |
current_msg = f"{user_prompt}{context}" | |
messages.append({"role": "user", "content": [{"type": "text", "text": current_msg}, {"type": "image"}]}) | |
else: | |
messages.append({"role": "user", "content": [{"type": "text", "text": user_prompt}]}) | |
# Process inputs | |
texts = processor.apply_chat_template(messages, add_generation_prompt=True) | |
try: | |
if doc_state.current_doc_images: | |
inputs = processor( | |
text=texts, | |
images=doc_state.current_doc_images[0:1], | |
return_tensors="pt" | |
).to("cuda") | |
else: | |
inputs = processor(text=texts, return_tensors="pt").to("cuda") | |
streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True) | |
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
buffer = "" | |
for new_text in streamer: | |
buffer += new_text | |
time.sleep(0.01) | |
yield buffer | |
except Exception as e: | |
logger.error(f"Error in model processing: {str(e)}") | |
yield "An error occurred while processing your request. Please try again." | |
except Exception as e: | |
logger.error(f"Error in bot_streaming: {str(e)}") | |
yield "An error occurred. Please try again." | |
def clear_context(): | |
"""Clear the current document context.""" | |
doc_state.clear() | |
return "Document context cleared. You can upload a new document." | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Document Analyzer with Custom Prompts") | |
gr.Markdown("Upload a document and enter your custom prompt/question about its contents.") | |
with gr.Row(): | |
file_upload = gr.File( | |
label="Upload Document (PDF or Image)", | |
file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"] | |
) | |
upload_status = gr.Textbox( | |
label="Upload Status", | |
interactive=False | |
) | |
with gr.Row(): | |
user_prompt = gr.Textbox( | |
label="Enter your prompt/question", | |
placeholder="e.g., Explain this document...\nExtract key points...\nWhat is the main idea?", | |
lines=3 | |
) | |
generate_btn = gr.Button("Generate") | |
clear_btn = gr.Button("Clear Document Context") | |
output_text = gr.Textbox( | |
label="Output", | |
interactive=False | |
) | |
file_upload.change( | |
fn=process_uploaded_file, | |
inputs=[file_upload], | |
outputs=[upload_status] | |
) | |
generate_btn.click( | |
fn=bot_streaming, | |
inputs=[user_prompt], | |
outputs=[output_text] | |
) | |
clear_btn.click( | |
fn=clear_context, | |
outputs=[upload_status] | |
) | |
# Launch the interface | |
demo.launch(debug=True) |