UniquePratham commited on
Commit
b16db73
ยท
verified ยท
1 Parent(s): 47ab79c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -71
app.py CHANGED
@@ -1,73 +1,144 @@
1
  import streamlit as st
2
- from ocr_cpu import extract_text_got, extract_text_qwen, extract_text_llama, clean_extracted_text
3
- import json
4
-
5
- # Set up page layout and styling
6
- st.set_page_config(page_title="MultiModel OCR Fusion", layout="centered", page_icon="๐Ÿ“„")
7
-
8
- st.markdown(
9
- """
10
- <style>
11
- .reportview-container { background: #f4f4f4; }
12
- .sidebar .sidebar-content { background: #e0e0e0; }
13
- h1 { color: #007BFF; }
14
- .upload-btn { background-color: #007BFF; color: white; padding: 10px; border-radius: 5px; text-align: center; }
15
- </style>
16
- """, unsafe_allow_html=True
17
- )
18
-
19
- # --- Title Section ---
20
- st.title("๐Ÿ“„ MultiModel OCR Fusion")
21
- st.write("Upload an image to extract and clean text using multiple OCR models (GOT, Qwen, LLaMA).")
22
-
23
- # --- Image Upload Section ---
24
- uploaded_file = st.file_uploader("Upload an image file", type=["jpg", "jpeg", "png"])
25
-
26
- # Model selection
27
- st.sidebar.title("Model Selection")
28
- model_choice = st.sidebar.selectbox("Choose OCR Model", ("GOT", "Qwen", "LLaMA"))
29
-
30
- if uploaded_file is not None:
31
- st.image(uploaded_file, caption='Uploaded Image', use_column_width=True)
32
-
33
- # Extract text from the image based on selected model
34
- with st.spinner(f"Extracting text using the {model_choice} model..."):
35
- try:
36
- if model_choice == "GOT":
37
- extracted_text = extract_text_got(uploaded_file)
38
- elif model_choice == "Qwen":
39
- extracted_text = extract_text_qwen(uploaded_file)
40
- elif model_choice == "LLaMA":
41
- extracted_text = extract_text_llama(uploaded_file)
42
-
43
- # If no text extracted
44
- if not extracted_text.strip():
45
- st.warning(f"No text extracted using {model_choice}.")
46
- else:
47
- # Clean the extracted text
48
- cleaned_text = clean_extracted_text(extracted_text)
49
- except Exception as e:
50
- st.error(f"Error during text extraction: {str(e)}")
51
- extracted_text, cleaned_text = "", ""
52
-
53
- # --- Display Extracted and Cleaned Text ---
54
- st.subheader(f"Extracted Text using {model_choice}")
55
- st.text_area(f"Raw Text ({model_choice})", extracted_text, height=200)
56
-
57
- st.subheader("Cleaned Text (AI-processed)")
58
- st.text_area("Cleaned Text", cleaned_text, height=200)
59
-
60
- # Save extracted text for further use
61
- if extracted_text:
62
- with open("extracted_text.json", "w") as json_file:
63
- json.dump({"text": extracted_text}, json_file)
64
-
65
- # --- Keyword Search ---
66
- st.subheader("Search for Keywords")
67
- keyword = st.text_input("Enter a keyword to search in the extracted text")
68
-
69
- if keyword:
70
- if keyword.lower() in cleaned_text.lower():
71
- st.success(f"Keyword **'{keyword}'** found in the cleaned text!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  else:
73
- st.error(f"Keyword **'{keyword}'** not found.")
 
 
1
  import streamlit as st
2
+ from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
3
+ from surya.ocr import run_ocr
4
+ from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
5
+ from surya.model.recognition.model import load_model as load_rec_model
6
+ from surya.model.recognition.processor import load_processor as load_rec_processor
7
+ from PIL import Image
8
+ import torch
9
+ import tempfile
10
+ import os
11
+ import re
12
+
13
+ # Page configuration
14
+ st.set_page_config(page_title="OCR Application", page_icon="๐Ÿ–ผ๏ธ", layout="wide")
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ # Load Surya OCR Models (English + Hindi)
18
+ det_processor, det_model = load_det_processor(), load_det_model()
19
+ det_model.to(device)
20
+ rec_model, rec_processor = load_rec_model(), load_rec_processor()
21
+ rec_model.to(device)
22
+
23
+ # Load GOT Models
24
+ @st.cache_resource
25
+ def init_got_model():
26
+ tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
27
+ model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
28
+ return model.eval(), tokenizer
29
+
30
+ @st.cache_resource
31
+ def init_got_gpu_model():
32
+ tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
33
+ model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
34
+ return model.eval().cuda(), tokenizer
35
+
36
+ # Load Qwen Model
37
+ @st.cache_resource
38
+ def init_qwen_model():
39
+ model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
40
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
41
+ return model.eval(), processor
42
+
43
+ # Extract text using GOT
44
+ def extract_text_got(image_file, model, tokenizer):
45
+ return model.chat(tokenizer, image_file, ocr_type='ocr')
46
+
47
+ # Extract text using Qwen
48
+ def extract_text_qwen(image_file, model, processor):
49
+ try:
50
+ image = Image.open(image_file).convert('RGB')
51
+ conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
52
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
53
+ inputs = processor(text=[text_prompt], images=[image], return_tensors="pt")
54
+ output_ids = model.generate(**inputs)
55
+ output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
56
+ return output_text[0] if output_text else "No text extracted from the image."
57
+ except Exception as e:
58
+ return f"An error occurred: {str(e)}"
59
+
60
+ # Text Cleaning AI - Clean spaces, handle dual languages
61
+ def clean_extracted_text(text):
62
+ # Remove extra spaces
63
+ cleaned_text = re.sub(r'\s+', ' ', text).strip()
64
+ cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
65
+ return cleaned_text
66
+
67
+ # Highlight keyword search
68
+ def highlight_text(text, search_term):
69
+ if not search_term:
70
+ return text
71
+ pattern = re.compile(re.escape(search_term), re.IGNORECASE)
72
+ return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
73
+
74
+ # Title and UI
75
+ st.title("OCR Application - Multimodel Support")
76
+ st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.")
77
+
78
+ # Sidebar Configuration
79
+ st.sidebar.header("Configuration")
80
+ model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)"))
81
+
82
+ # Upload Section
83
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
84
+
85
+ # Predict button
86
+ predict_button = st.sidebar.button("Predict")
87
+
88
+ # Main columns
89
+ col1, col2 = st.columns([2, 1])
90
+
91
+ # Display image preview
92
+ if uploaded_file:
93
+ image = Image.open(uploaded_file)
94
+ with col1:
95
+ col1.image(image, caption='Uploaded Image', use_column_width=False, width=300)
96
+
97
+ # Handle predictions
98
+ if predict_button and uploaded_file:
99
+ with st.spinner("Processing..."):
100
+ # Save uploaded image
101
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
102
+ temp_file.write(uploaded_file.getvalue())
103
+ temp_file_path = temp_file.name
104
+
105
+ image = Image.open(temp_file_path)
106
+ image = image.convert("RGB")
107
+
108
+ if model_choice == "GOT_CPU":
109
+ got_model, tokenizer = init_got_model()
110
+ extracted_text = extract_text_got(temp_file_path, got_model, tokenizer)
111
+
112
+ elif model_choice == "GOT_GPU":
113
+ got_gpu_model, tokenizer = init_got_gpu_model()
114
+ extracted_text = extract_text_got(temp_file_path, got_gpu_model, tokenizer)
115
+
116
+ elif model_choice == "Qwen":
117
+ qwen_model, qwen_processor = init_qwen_model()
118
+ extracted_text = extract_text_qwen(temp_file_path, qwen_model, qwen_processor)
119
+
120
+ elif model_choice == "Surya (English+Hindi)":
121
+ langs = ["en", "hi"]
122
+ predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
123
+ text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
124
+ extracted_text = ' '.join(text_list)
125
+
126
+ # Clean extracted text
127
+ cleaned_text = clean_extracted_text(extracted_text)
128
+
129
+ # Delete temp file
130
+ if os.path.exists(temp_file_path):
131
+ os.remove(temp_file_path)
132
+
133
+ # Display extracted text and search functionality
134
+ st.subheader("Extracted Text (Cleaned)")
135
+ st.markdown(cleaned_text, unsafe_allow_html=True)
136
+
137
+ search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...")
138
+ if search_query:
139
+ highlighted_text = highlight_text(cleaned_text, search_query)
140
+ st.markdown("### Highlighted Search Results:")
141
+ st.markdown(highlighted_text, unsafe_allow_html=True)
142
  else:
143
+ st.markdown("### Extracted Text:")
144
+ st.markdown(cleaned_text, unsafe_allow_html=True)