Qwen2-VL-2B / new-app.py
vykanand's picture
Create new-app.py
bd6f71c verified
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import uuid
import io
from PIL import Image
from threading import Thread
# Define model options (for the OCR model specifically)
MODEL_OPTIONS = {
"Latex OCR": "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
}
# Preload models and processors into CUDA
models = {}
processors = {}
for name, model_id in MODEL_OPTIONS.items():
print(f"Loading {name}...")
models[name] = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.float16
).to("cuda").eval()
processors[name] = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
image_extensions = Image.registered_extensions()
def identify_and_save_blob(blob_path):
"""Identifies if the blob is an image and saves it."""
try:
with open(blob_path, 'rb') as file:
blob_content = file.read()
try:
Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
extension = ".png" # Default to PNG for saving
media_type = "image"
except (IOError, SyntaxError):
raise ValueError("Unsupported media type. Please upload a valid image.")
filename = f"temp_{uuid.uuid4()}_media{extension}"
with open(filename, "wb") as f:
f.write(blob_content)
return filename, media_type
except FileNotFoundError:
raise ValueError(f"The file {blob_path} was not found.")
except Exception as e:
raise ValueError(f"An error occurred while processing the file: {e}")
def qwen_inference(model_name, media_input, text_input=None):
"""Handles inference for the selected model."""
model = models[model_name]
processor = processors[model_name]
if isinstance(media_input, str):
media_path = media_input
if media_path.endswith(tuple([i for i in image_extensions.keys()])):
media_type = "image"
else:
try:
media_path, media_type = identify_and_save_blob(media_input)
except Exception as e:
raise ValueError("Unsupported media type. Please upload a valid image.")
messages = [
{
"role": "user",
"content": [
{
"type": media_type,
media_type: media_path
},
{"type": "text", "text": text_input},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, _ = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
streamer = TextIteratorStreamer(
processor.tokenizer, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
# Remove <|im_end|> or similar tokens from the output
buffer = buffer.replace("<|im_end|>", "")
yield buffer
def ocr_endpoint(image, question):
"""This function will be exposed to the /ocr endpoint for OCR processing."""
return qwen_inference("Latex OCR", image, question)
# Gradio app setup for OCR endpoint
with gr.Blocks() as demo:
gr.Markdown("# Qwen2VL OCR Model - Latex OCR")
with gr.Row():
with gr.Column():
input_media = gr.File(label="Upload Image", type="filepath")
text_input = gr.Textbox(label="Question", placeholder="Ask a question about the image...")
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text", lines=10)
submit_btn.click(
ocr_endpoint, [input_media, text_input], [output_text]
)
# Launch the app on the /ocr endpoint
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)