UniquePratham commited on
Commit
93bd871
β€’
1 Parent(s): 0fbad84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -45
app.py CHANGED
@@ -1,9 +1,6 @@
 
1
  import streamlit as st
2
  from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
3
- from surya.ocr import run_ocr
4
- from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
5
- from surya.model.recognition.model import load_model as load_rec_model
6
- from surya.model.recognition.processor import load_processor as load_rec_processor
7
  from PIL import Image
8
  import torch
9
  import os
@@ -15,47 +12,56 @@ from groq import Groq
15
  from st_keyup import st_keyup
16
  from st_img_pastebutton import paste
17
 
18
-
19
  # Page configuration
20
- st.set_page_config(page_title="DualTextOCRFusion", page_icon="πŸ”", layout="wide")
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
- # Load Surya OCR Models (English + Hindi)
24
- det_processor, det_model = load_det_processor(), load_det_model()
25
- det_model.to(device)
26
- rec_model, rec_processor = load_rec_model(), load_rec_processor()
27
- rec_model.to(device)
28
-
29
  # Load GOT Models
 
 
30
  @st.cache_resource
31
  def init_got_model():
32
- tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
33
- model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
 
 
34
  return model.eval(), tokenizer
35
 
 
36
  @st.cache_resource
37
  def init_got_gpu_model():
38
- tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
39
- model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
 
 
40
  return model.eval().cuda(), tokenizer
41
 
42
  # Load Qwen Model
 
 
43
  @st.cache_resource
44
  def init_qwen_model():
45
- model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
 
46
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
47
  return model.eval(), processor
48
 
49
  # Text Cleaning AI - Clean spaces, handle dual languages
 
 
50
  def clean_extracted_text(text):
51
  cleaned_text = re.sub(r'\s+', ' ', text).strip()
52
  cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
53
  return cleaned_text
54
 
55
  # Polish the text using a model
 
 
56
  def polish_text_with_ai(cleaned_text):
57
  prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces, keeping it as close to the original as possible, along with relevant details or insights that an AI can provide about the extracted text. Extracted Text : {cleaned_text}"
58
- client = Groq(api_key="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")
 
59
  chat_completion = client.chat.completions.create(
60
  messages=[
61
  {
@@ -80,16 +86,22 @@ def extract_text_got(image_file, model, tokenizer):
80
  def extract_text_qwen(image_file, model, processor):
81
  try:
82
  image = Image.open(image_file).convert('RGB')
83
- conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
84
- text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
85
- inputs = processor(text=[text_prompt], images=[image], return_tensors="pt")
 
 
 
86
  output_ids = model.generate(**inputs)
87
- output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
 
88
  return output_text[0] if output_text else "No text extracted from the image."
89
  except Exception as e:
90
  return f"An error occurred: {str(e)}"
91
 
92
  # Function to highlight the keyword in the text
 
 
93
  def highlight_text(text, search_term):
94
  if not search_term: # If no search term is provided, return the original text
95
  return text
@@ -98,6 +110,7 @@ def highlight_text(text, search_term):
98
  # Highlight matched terms with yellow background
99
  return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
100
 
 
101
  # Title and UI
102
  st.title("DualTextOCRFusion - πŸ”")
103
  st.header("OCR Application - Multimodel Support")
@@ -105,19 +118,22 @@ st.write("Upload an image for OCR using various models, with support for English
105
 
106
  # Sidebar Configuration
107
  st.sidebar.header("Configuration")
108
- model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)"))
 
109
 
110
  # Upload Section
111
- uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
 
112
 
113
  # Input from clipboard
114
  # Paste image button
115
- image_data = paste(label="paste from clipboard",key="image_clipboard")
116
  if image_data is not None:
117
- header, encoded = image_data.split(",", 1)
118
- decoded_bytes = base64.b64decode(encoded)
119
- img_stream = io.BytesIO(decoded_bytes)
120
- uploaded_file=img_stream
 
121
 
122
  # Input from camera
123
  camera_file = st.sidebar.camera_input("Capture from Camera")
@@ -134,19 +150,22 @@ col1, col2 = st.columns([2, 1])
134
  if uploaded_file:
135
  image = Image.open(uploaded_file)
136
  with col1:
137
- col1.image(image, caption='Uploaded Image', use_column_width=False, width=300)
 
138
 
139
  # Save uploaded image to 'images' folder
140
  images_dir = 'images'
141
  os.makedirs(images_dir, exist_ok=True)
142
- image_path = os.path.join(images_dir, uploaded_file.name)
 
143
  with open(image_path, 'wb') as f:
144
  f.write(uploaded_file.getvalue())
145
 
146
  # Check if the result already exists
147
  results_dir = 'results'
148
  os.makedirs(results_dir, exist_ok=True)
149
- result_path = os.path.join(results_dir, f"{uploaded_file.name}_result.json")
 
150
 
151
  # Handle predictions
152
  if predict_button:
@@ -158,28 +177,27 @@ if uploaded_file:
158
  with st.spinner("Processing..."):
159
  if model_choice == "GOT_CPU":
160
  got_model, tokenizer = init_got_model()
161
- extracted_text = extract_text_got(image_path, got_model, tokenizer)
 
162
 
163
  elif model_choice == "GOT_GPU":
164
  got_gpu_model, tokenizer = init_got_gpu_model()
165
- extracted_text = extract_text_got(image_path, got_gpu_model, tokenizer)
 
166
 
167
  elif model_choice == "Qwen":
168
  qwen_model, qwen_processor = init_qwen_model()
169
- extracted_text = extract_text_qwen(image_path, qwen_model, qwen_processor)
170
-
171
- elif model_choice == "Surya (English+Hindi)":
172
- langs = ["en", "hi"]
173
- predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
174
- text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
175
- extracted_text = ' '.join(text_list)
176
 
177
  # Clean and polish extracted text
178
  cleaned_text = clean_extracted_text(extracted_text)
179
- polished_text = polish_text_with_ai(cleaned_text) if model_choice in ["GOT_CPU", "GOT_GPU"] else cleaned_text
 
180
 
181
  # Save results to JSON file
182
- result_data = {"extracted_text":extracted_text,"cleaner_text":cleaned_text,"polished_text": polished_text}
 
183
  with open(result_path, 'w') as f:
184
  json.dump(result_data, f)
185
 
@@ -197,9 +215,11 @@ if uploaded_file:
197
  st.session_state["highlighted_result"] = extracted_text
198
 
199
  # Input search term with real-time update on key press
200
- search_query = st_keyup("Search in extracted text:", key="search_key", on_change=update_search)
 
201
 
202
  # Display highlighted results if they exist in session state
203
  if "highlighted_result" in st.session_state:
204
  st.markdown("### Highlighted Search Results:")
205
- st.markdown(st.session_state["highlighted_result"], unsafe_allow_html=True)
 
 
1
+ import io
2
  import streamlit as st
3
  from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
 
 
 
 
4
  from PIL import Image
5
  import torch
6
  import os
 
12
  from st_keyup import st_keyup
13
  from st_img_pastebutton import paste
14
 
 
15
  # Page configuration
16
+ st.set_page_config(page_title="DualTextOCRFusion",
17
+ page_icon="πŸ”", layout="wide")
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
 
 
 
 
 
 
20
  # Load GOT Models
21
+
22
+
23
  @st.cache_resource
24
  def init_got_model():
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ 'srimanth-d/GOT_CPU', trust_remote_code=True)
27
+ model = AutoModel.from_pretrained(
28
+ 'srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
29
  return model.eval(), tokenizer
30
 
31
+
32
  @st.cache_resource
33
  def init_got_gpu_model():
34
+ tokenizer = AutoTokenizer.from_pretrained(
35
+ 'ucaslcl/GOT-OCR2_0', trust_remote_code=True)
36
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True,
37
+ device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
38
  return model.eval().cuda(), tokenizer
39
 
40
  # Load Qwen Model
41
+
42
+
43
  @st.cache_resource
44
  def init_qwen_model():
45
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
46
+ "Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
47
  processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
48
  return model.eval(), processor
49
 
50
  # Text Cleaning AI - Clean spaces, handle dual languages
51
+
52
+
53
  def clean_extracted_text(text):
54
  cleaned_text = re.sub(r'\s+', ' ', text).strip()
55
  cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
56
  return cleaned_text
57
 
58
  # Polish the text using a model
59
+
60
+
61
  def polish_text_with_ai(cleaned_text):
62
  prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces, keeping it as close to the original as possible, along with relevant details or insights that an AI can provide about the extracted text. Extracted Text : {cleaned_text}"
63
+ client = Groq(
64
+ api_key="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")
65
  chat_completion = client.chat.completions.create(
66
  messages=[
67
  {
 
86
  def extract_text_qwen(image_file, model, processor):
87
  try:
88
  image = Image.open(image_file).convert('RGB')
89
+ conversation = [{"role": "user", "content": [{"type": "image"}, {
90
+ "type": "text", "text": "Extract text from this image."}]}]
91
+ text_prompt = processor.apply_chat_template(
92
+ conversation, add_generation_prompt=True)
93
+ inputs = processor(text=[text_prompt], images=[
94
+ image], return_tensors="pt")
95
  output_ids = model.generate(**inputs)
96
+ output_text = processor.batch_decode(
97
+ output_ids, skip_special_tokens=True)
98
  return output_text[0] if output_text else "No text extracted from the image."
99
  except Exception as e:
100
  return f"An error occurred: {str(e)}"
101
 
102
  # Function to highlight the keyword in the text
103
+
104
+
105
  def highlight_text(text, search_term):
106
  if not search_term: # If no search term is provided, return the original text
107
  return text
 
110
  # Highlight matched terms with yellow background
111
  return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
112
 
113
+
114
  # Title and UI
115
  st.title("DualTextOCRFusion - πŸ”")
116
  st.header("OCR Application - Multimodel Support")
 
118
 
119
  # Sidebar Configuration
120
  st.sidebar.header("Configuration")
121
+ model_choice = st.sidebar.selectbox(
122
+ "Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen"))
123
 
124
  # Upload Section
125
+ uploaded_file = st.sidebar.file_uploader(
126
+ "Choose an image...", type=["png", "jpg", "jpeg"])
127
 
128
  # Input from clipboard
129
  # Paste image button
130
+ image_data = paste(label="paste from clipboard", key="image_clipboard")
131
  if image_data is not None:
132
+ clipboard_use = True
133
+ header, encoded = image_data.split(",", 1)
134
+ decoded_bytes = base64.b64decode(encoded)
135
+ img_stream = io.BytesIO(decoded_bytes)
136
+ uploaded_file = img_stream
137
 
138
  # Input from camera
139
  camera_file = st.sidebar.camera_input("Capture from Camera")
 
150
  if uploaded_file:
151
  image = Image.open(uploaded_file)
152
  with col1:
153
+ col1.image(image, caption='Uploaded Image',
154
+ use_column_width=False, width=300)
155
 
156
  # Save uploaded image to 'images' folder
157
  images_dir = 'images'
158
  os.makedirs(images_dir, exist_ok=True)
159
+ image_path = os.path.join(
160
+ images_dir, "temp_file.jpg" if clipboard_use else uploaded_file.name)
161
  with open(image_path, 'wb') as f:
162
  f.write(uploaded_file.getvalue())
163
 
164
  # Check if the result already exists
165
  results_dir = 'results'
166
  os.makedirs(results_dir, exist_ok=True)
167
+ result_path = os.path.join(
168
+ results_dir, "temp_file_result.json" if clipboard_use else f"{uploaded_file.name}_result.json")
169
 
170
  # Handle predictions
171
  if predict_button:
 
177
  with st.spinner("Processing..."):
178
  if model_choice == "GOT_CPU":
179
  got_model, tokenizer = init_got_model()
180
+ extracted_text = extract_text_got(
181
+ image_path, got_model, tokenizer)
182
 
183
  elif model_choice == "GOT_GPU":
184
  got_gpu_model, tokenizer = init_got_gpu_model()
185
+ extracted_text = extract_text_got(
186
+ image_path, got_gpu_model, tokenizer)
187
 
188
  elif model_choice == "Qwen":
189
  qwen_model, qwen_processor = init_qwen_model()
190
+ extracted_text = extract_text_qwen(
191
+ image_path, qwen_model, qwen_processor)
 
 
 
 
 
192
 
193
  # Clean and polish extracted text
194
  cleaned_text = clean_extracted_text(extracted_text)
195
+ polished_text = polish_text_with_ai(cleaned_text) if model_choice in [
196
+ "GOT_CPU", "GOT_GPU"] else cleaned_text
197
 
198
  # Save results to JSON file
199
+ result_data = {"extracted_text": extracted_text,
200
+ "cleaner_text": cleaned_text, "polished_text": polished_text}
201
  with open(result_path, 'w') as f:
202
  json.dump(result_data, f)
203
 
 
215
  st.session_state["highlighted_result"] = extracted_text
216
 
217
  # Input search term with real-time update on key press
218
+ search_query = st_keyup(
219
+ "Search in extracted text:", key="search_key", on_change=update_search)
220
 
221
  # Display highlighted results if they exist in session state
222
  if "highlighted_result" in st.session_state:
223
  st.markdown("### Highlighted Search Results:")
224
+ st.markdown(
225
+ st.session_state["highlighted_result"], unsafe_allow_html=True)