Spaces:
Running
Running
UniquePratham
commited on
Commit
β’
93bd871
1
Parent(s):
0fbad84
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,6 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
|
3 |
-
from surya.ocr import run_ocr
|
4 |
-
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
|
5 |
-
from surya.model.recognition.model import load_model as load_rec_model
|
6 |
-
from surya.model.recognition.processor import load_processor as load_rec_processor
|
7 |
from PIL import Image
|
8 |
import torch
|
9 |
import os
|
@@ -15,47 +12,56 @@ from groq import Groq
|
|
15 |
from st_keyup import st_keyup
|
16 |
from st_img_pastebutton import paste
|
17 |
|
18 |
-
|
19 |
# Page configuration
|
20 |
-
st.set_page_config(page_title="DualTextOCRFusion",
|
|
|
21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
|
23 |
-
# Load Surya OCR Models (English + Hindi)
|
24 |
-
det_processor, det_model = load_det_processor(), load_det_model()
|
25 |
-
det_model.to(device)
|
26 |
-
rec_model, rec_processor = load_rec_model(), load_rec_processor()
|
27 |
-
rec_model.to(device)
|
28 |
-
|
29 |
# Load GOT Models
|
|
|
|
|
30 |
@st.cache_resource
|
31 |
def init_got_model():
|
32 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
33 |
-
|
|
|
|
|
34 |
return model.eval(), tokenizer
|
35 |
|
|
|
36 |
@st.cache_resource
|
37 |
def init_got_gpu_model():
|
38 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
39 |
-
|
|
|
|
|
40 |
return model.eval().cuda(), tokenizer
|
41 |
|
42 |
# Load Qwen Model
|
|
|
|
|
43 |
@st.cache_resource
|
44 |
def init_qwen_model():
|
45 |
-
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
|
46 |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
47 |
return model.eval(), processor
|
48 |
|
49 |
# Text Cleaning AI - Clean spaces, handle dual languages
|
|
|
|
|
50 |
def clean_extracted_text(text):
|
51 |
cleaned_text = re.sub(r'\s+', ' ', text).strip()
|
52 |
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
|
53 |
return cleaned_text
|
54 |
|
55 |
# Polish the text using a model
|
|
|
|
|
56 |
def polish_text_with_ai(cleaned_text):
|
57 |
prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces, keeping it as close to the original as possible, along with relevant details or insights that an AI can provide about the extracted text. Extracted Text : {cleaned_text}"
|
58 |
-
client = Groq(
|
|
|
59 |
chat_completion = client.chat.completions.create(
|
60 |
messages=[
|
61 |
{
|
@@ -80,16 +86,22 @@ def extract_text_got(image_file, model, tokenizer):
|
|
80 |
def extract_text_qwen(image_file, model, processor):
|
81 |
try:
|
82 |
image = Image.open(image_file).convert('RGB')
|
83 |
-
conversation = [{"role": "user", "content": [{"type": "image"}, {
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
86 |
output_ids = model.generate(**inputs)
|
87 |
-
output_text = processor.batch_decode(
|
|
|
88 |
return output_text[0] if output_text else "No text extracted from the image."
|
89 |
except Exception as e:
|
90 |
return f"An error occurred: {str(e)}"
|
91 |
|
92 |
# Function to highlight the keyword in the text
|
|
|
|
|
93 |
def highlight_text(text, search_term):
|
94 |
if not search_term: # If no search term is provided, return the original text
|
95 |
return text
|
@@ -98,6 +110,7 @@ def highlight_text(text, search_term):
|
|
98 |
# Highlight matched terms with yellow background
|
99 |
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
|
100 |
|
|
|
101 |
# Title and UI
|
102 |
st.title("DualTextOCRFusion - π")
|
103 |
st.header("OCR Application - Multimodel Support")
|
@@ -105,19 +118,22 @@ st.write("Upload an image for OCR using various models, with support for English
|
|
105 |
|
106 |
# Sidebar Configuration
|
107 |
st.sidebar.header("Configuration")
|
108 |
-
model_choice = st.sidebar.selectbox(
|
|
|
109 |
|
110 |
# Upload Section
|
111 |
-
uploaded_file = st.sidebar.file_uploader(
|
|
|
112 |
|
113 |
# Input from clipboard
|
114 |
# Paste image button
|
115 |
-
image_data = paste(label="paste from clipboard",key="image_clipboard")
|
116 |
if image_data is not None:
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
121 |
|
122 |
# Input from camera
|
123 |
camera_file = st.sidebar.camera_input("Capture from Camera")
|
@@ -134,19 +150,22 @@ col1, col2 = st.columns([2, 1])
|
|
134 |
if uploaded_file:
|
135 |
image = Image.open(uploaded_file)
|
136 |
with col1:
|
137 |
-
col1.image(image, caption='Uploaded Image',
|
|
|
138 |
|
139 |
# Save uploaded image to 'images' folder
|
140 |
images_dir = 'images'
|
141 |
os.makedirs(images_dir, exist_ok=True)
|
142 |
-
image_path = os.path.join(
|
|
|
143 |
with open(image_path, 'wb') as f:
|
144 |
f.write(uploaded_file.getvalue())
|
145 |
|
146 |
# Check if the result already exists
|
147 |
results_dir = 'results'
|
148 |
os.makedirs(results_dir, exist_ok=True)
|
149 |
-
result_path = os.path.join(
|
|
|
150 |
|
151 |
# Handle predictions
|
152 |
if predict_button:
|
@@ -158,28 +177,27 @@ if uploaded_file:
|
|
158 |
with st.spinner("Processing..."):
|
159 |
if model_choice == "GOT_CPU":
|
160 |
got_model, tokenizer = init_got_model()
|
161 |
-
extracted_text = extract_text_got(
|
|
|
162 |
|
163 |
elif model_choice == "GOT_GPU":
|
164 |
got_gpu_model, tokenizer = init_got_gpu_model()
|
165 |
-
extracted_text = extract_text_got(
|
|
|
166 |
|
167 |
elif model_choice == "Qwen":
|
168 |
qwen_model, qwen_processor = init_qwen_model()
|
169 |
-
extracted_text = extract_text_qwen(
|
170 |
-
|
171 |
-
elif model_choice == "Surya (English+Hindi)":
|
172 |
-
langs = ["en", "hi"]
|
173 |
-
predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
|
174 |
-
text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
|
175 |
-
extracted_text = ' '.join(text_list)
|
176 |
|
177 |
# Clean and polish extracted text
|
178 |
cleaned_text = clean_extracted_text(extracted_text)
|
179 |
-
polished_text = polish_text_with_ai(cleaned_text) if model_choice in [
|
|
|
180 |
|
181 |
# Save results to JSON file
|
182 |
-
result_data = {"extracted_text":extracted_text,
|
|
|
183 |
with open(result_path, 'w') as f:
|
184 |
json.dump(result_data, f)
|
185 |
|
@@ -197,9 +215,11 @@ if uploaded_file:
|
|
197 |
st.session_state["highlighted_result"] = extracted_text
|
198 |
|
199 |
# Input search term with real-time update on key press
|
200 |
-
search_query = st_keyup(
|
|
|
201 |
|
202 |
# Display highlighted results if they exist in session state
|
203 |
if "highlighted_result" in st.session_state:
|
204 |
st.markdown("### Highlighted Search Results:")
|
205 |
-
st.markdown(
|
|
|
|
1 |
+
import io
|
2 |
import streamlit as st
|
3 |
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
|
|
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
import torch
|
6 |
import os
|
|
|
12 |
from st_keyup import st_keyup
|
13 |
from st_img_pastebutton import paste
|
14 |
|
|
|
15 |
# Page configuration
|
16 |
+
st.set_page_config(page_title="DualTextOCRFusion",
|
17 |
+
page_icon="π", layout="wide")
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Load GOT Models
|
21 |
+
|
22 |
+
|
23 |
@st.cache_resource
|
24 |
def init_got_model():
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
26 |
+
'srimanth-d/GOT_CPU', trust_remote_code=True)
|
27 |
+
model = AutoModel.from_pretrained(
|
28 |
+
'srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
|
29 |
return model.eval(), tokenizer
|
30 |
|
31 |
+
|
32 |
@st.cache_resource
|
33 |
def init_got_gpu_model():
|
34 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
35 |
+
'ucaslcl/GOT-OCR2_0', trust_remote_code=True)
|
36 |
+
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True,
|
37 |
+
device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
|
38 |
return model.eval().cuda(), tokenizer
|
39 |
|
40 |
# Load Qwen Model
|
41 |
+
|
42 |
+
|
43 |
@st.cache_resource
|
44 |
def init_qwen_model():
|
45 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
46 |
+
"Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
|
47 |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
48 |
return model.eval(), processor
|
49 |
|
50 |
# Text Cleaning AI - Clean spaces, handle dual languages
|
51 |
+
|
52 |
+
|
53 |
def clean_extracted_text(text):
|
54 |
cleaned_text = re.sub(r'\s+', ' ', text).strip()
|
55 |
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
|
56 |
return cleaned_text
|
57 |
|
58 |
# Polish the text using a model
|
59 |
+
|
60 |
+
|
61 |
def polish_text_with_ai(cleaned_text):
|
62 |
prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces, keeping it as close to the original as possible, along with relevant details or insights that an AI can provide about the extracted text. Extracted Text : {cleaned_text}"
|
63 |
+
client = Groq(
|
64 |
+
api_key="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")
|
65 |
chat_completion = client.chat.completions.create(
|
66 |
messages=[
|
67 |
{
|
|
|
86 |
def extract_text_qwen(image_file, model, processor):
|
87 |
try:
|
88 |
image = Image.open(image_file).convert('RGB')
|
89 |
+
conversation = [{"role": "user", "content": [{"type": "image"}, {
|
90 |
+
"type": "text", "text": "Extract text from this image."}]}]
|
91 |
+
text_prompt = processor.apply_chat_template(
|
92 |
+
conversation, add_generation_prompt=True)
|
93 |
+
inputs = processor(text=[text_prompt], images=[
|
94 |
+
image], return_tensors="pt")
|
95 |
output_ids = model.generate(**inputs)
|
96 |
+
output_text = processor.batch_decode(
|
97 |
+
output_ids, skip_special_tokens=True)
|
98 |
return output_text[0] if output_text else "No text extracted from the image."
|
99 |
except Exception as e:
|
100 |
return f"An error occurred: {str(e)}"
|
101 |
|
102 |
# Function to highlight the keyword in the text
|
103 |
+
|
104 |
+
|
105 |
def highlight_text(text, search_term):
|
106 |
if not search_term: # If no search term is provided, return the original text
|
107 |
return text
|
|
|
110 |
# Highlight matched terms with yellow background
|
111 |
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
|
112 |
|
113 |
+
|
114 |
# Title and UI
|
115 |
st.title("DualTextOCRFusion - π")
|
116 |
st.header("OCR Application - Multimodel Support")
|
|
|
118 |
|
119 |
# Sidebar Configuration
|
120 |
st.sidebar.header("Configuration")
|
121 |
+
model_choice = st.sidebar.selectbox(
|
122 |
+
"Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen"))
|
123 |
|
124 |
# Upload Section
|
125 |
+
uploaded_file = st.sidebar.file_uploader(
|
126 |
+
"Choose an image...", type=["png", "jpg", "jpeg"])
|
127 |
|
128 |
# Input from clipboard
|
129 |
# Paste image button
|
130 |
+
image_data = paste(label="paste from clipboard", key="image_clipboard")
|
131 |
if image_data is not None:
|
132 |
+
clipboard_use = True
|
133 |
+
header, encoded = image_data.split(",", 1)
|
134 |
+
decoded_bytes = base64.b64decode(encoded)
|
135 |
+
img_stream = io.BytesIO(decoded_bytes)
|
136 |
+
uploaded_file = img_stream
|
137 |
|
138 |
# Input from camera
|
139 |
camera_file = st.sidebar.camera_input("Capture from Camera")
|
|
|
150 |
if uploaded_file:
|
151 |
image = Image.open(uploaded_file)
|
152 |
with col1:
|
153 |
+
col1.image(image, caption='Uploaded Image',
|
154 |
+
use_column_width=False, width=300)
|
155 |
|
156 |
# Save uploaded image to 'images' folder
|
157 |
images_dir = 'images'
|
158 |
os.makedirs(images_dir, exist_ok=True)
|
159 |
+
image_path = os.path.join(
|
160 |
+
images_dir, "temp_file.jpg" if clipboard_use else uploaded_file.name)
|
161 |
with open(image_path, 'wb') as f:
|
162 |
f.write(uploaded_file.getvalue())
|
163 |
|
164 |
# Check if the result already exists
|
165 |
results_dir = 'results'
|
166 |
os.makedirs(results_dir, exist_ok=True)
|
167 |
+
result_path = os.path.join(
|
168 |
+
results_dir, "temp_file_result.json" if clipboard_use else f"{uploaded_file.name}_result.json")
|
169 |
|
170 |
# Handle predictions
|
171 |
if predict_button:
|
|
|
177 |
with st.spinner("Processing..."):
|
178 |
if model_choice == "GOT_CPU":
|
179 |
got_model, tokenizer = init_got_model()
|
180 |
+
extracted_text = extract_text_got(
|
181 |
+
image_path, got_model, tokenizer)
|
182 |
|
183 |
elif model_choice == "GOT_GPU":
|
184 |
got_gpu_model, tokenizer = init_got_gpu_model()
|
185 |
+
extracted_text = extract_text_got(
|
186 |
+
image_path, got_gpu_model, tokenizer)
|
187 |
|
188 |
elif model_choice == "Qwen":
|
189 |
qwen_model, qwen_processor = init_qwen_model()
|
190 |
+
extracted_text = extract_text_qwen(
|
191 |
+
image_path, qwen_model, qwen_processor)
|
|
|
|
|
|
|
|
|
|
|
192 |
|
193 |
# Clean and polish extracted text
|
194 |
cleaned_text = clean_extracted_text(extracted_text)
|
195 |
+
polished_text = polish_text_with_ai(cleaned_text) if model_choice in [
|
196 |
+
"GOT_CPU", "GOT_GPU"] else cleaned_text
|
197 |
|
198 |
# Save results to JSON file
|
199 |
+
result_data = {"extracted_text": extracted_text,
|
200 |
+
"cleaner_text": cleaned_text, "polished_text": polished_text}
|
201 |
with open(result_path, 'w') as f:
|
202 |
json.dump(result_data, f)
|
203 |
|
|
|
215 |
st.session_state["highlighted_result"] = extracted_text
|
216 |
|
217 |
# Input search term with real-time update on key press
|
218 |
+
search_query = st_keyup(
|
219 |
+
"Search in extracted text:", key="search_key", on_change=update_search)
|
220 |
|
221 |
# Display highlighted results if they exist in session state
|
222 |
if "highlighted_result" in st.session_state:
|
223 |
st.markdown("### Highlighted Search Results:")
|
224 |
+
st.markdown(
|
225 |
+
st.session_state["highlighted_result"], unsafe_allow_html=True)
|