vykanand commited on
Commit
bd6f71c
Β·
verified Β·
1 Parent(s): 73d58c2

Create new-app.py

Browse files
Files changed (1) hide show
  1. new-app.py +128 -0
new-app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
3
+ from qwen_vl_utils import process_vision_info
4
+ import torch
5
+ import uuid
6
+ import io
7
+ from PIL import Image
8
+ from threading import Thread
9
+
10
+ # Define model options (for the OCR model specifically)
11
+ MODEL_OPTIONS = {
12
+ "Latex OCR": "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
13
+ }
14
+
15
+ # Preload models and processors into CUDA
16
+ models = {}
17
+ processors = {}
18
+ for name, model_id in MODEL_OPTIONS.items():
19
+ print(f"Loading {name}...")
20
+ models[name] = Qwen2VLForConditionalGeneration.from_pretrained(
21
+ model_id,
22
+ trust_remote_code=True,
23
+ torch_dtype=torch.float16
24
+ ).to("cuda").eval()
25
+ processors[name] = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
26
+
27
+ image_extensions = Image.registered_extensions()
28
+
29
+ def identify_and_save_blob(blob_path):
30
+ """Identifies if the blob is an image and saves it."""
31
+ try:
32
+ with open(blob_path, 'rb') as file:
33
+ blob_content = file.read()
34
+ try:
35
+ Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
36
+ extension = ".png" # Default to PNG for saving
37
+ media_type = "image"
38
+ except (IOError, SyntaxError):
39
+ raise ValueError("Unsupported media type. Please upload a valid image.")
40
+
41
+ filename = f"temp_{uuid.uuid4()}_media{extension}"
42
+ with open(filename, "wb") as f:
43
+ f.write(blob_content)
44
+
45
+ return filename, media_type
46
+
47
+ except FileNotFoundError:
48
+ raise ValueError(f"The file {blob_path} was not found.")
49
+ except Exception as e:
50
+ raise ValueError(f"An error occurred while processing the file: {e}")
51
+
52
+ def qwen_inference(model_name, media_input, text_input=None):
53
+ """Handles inference for the selected model."""
54
+ model = models[model_name]
55
+ processor = processors[model_name]
56
+
57
+ if isinstance(media_input, str):
58
+ media_path = media_input
59
+ if media_path.endswith(tuple([i for i in image_extensions.keys()])):
60
+ media_type = "image"
61
+ else:
62
+ try:
63
+ media_path, media_type = identify_and_save_blob(media_input)
64
+ except Exception as e:
65
+ raise ValueError("Unsupported media type. Please upload a valid image.")
66
+
67
+ messages = [
68
+ {
69
+ "role": "user",
70
+ "content": [
71
+ {
72
+ "type": media_type,
73
+ media_type: media_path
74
+ },
75
+ {"type": "text", "text": text_input},
76
+ ],
77
+ }
78
+ ]
79
+
80
+ text = processor.apply_chat_template(
81
+ messages, tokenize=False, add_generation_prompt=True
82
+ )
83
+ image_inputs, _ = process_vision_info(messages)
84
+ inputs = processor(
85
+ text=[text],
86
+ images=image_inputs,
87
+ padding=True,
88
+ return_tensors="pt",
89
+ ).to("cuda")
90
+
91
+ streamer = TextIteratorStreamer(
92
+ processor.tokenizer, skip_prompt=True, skip_special_tokens=True
93
+ )
94
+ generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
95
+
96
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
97
+ thread.start()
98
+
99
+ buffer = ""
100
+ for new_text in streamer:
101
+ buffer += new_text
102
+ # Remove <|im_end|> or similar tokens from the output
103
+ buffer = buffer.replace("<|im_end|>", "")
104
+ yield buffer
105
+
106
+ def ocr_endpoint(image, question):
107
+ """This function will be exposed to the /ocr endpoint for OCR processing."""
108
+ return qwen_inference("Latex OCR", image, question)
109
+
110
+ # Gradio app setup for OCR endpoint
111
+ with gr.Blocks() as demo:
112
+ gr.Markdown("# Qwen2VL OCR Model - Latex OCR")
113
+
114
+ with gr.Row():
115
+ with gr.Column():
116
+ input_media = gr.File(label="Upload Image", type="filepath")
117
+ text_input = gr.Textbox(label="Question", placeholder="Ask a question about the image...")
118
+ submit_btn = gr.Button(value="Submit")
119
+
120
+ with gr.Column():
121
+ output_text = gr.Textbox(label="Output Text", lines=10)
122
+
123
+ submit_btn.click(
124
+ ocr_endpoint, [input_media, text_input], [output_text]
125
+ )
126
+
127
+ # Launch the app on the /ocr endpoint
128
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)