UniquePratham's picture
Update app.py
aa47a7c verified
raw
history blame
7.93 kB
import streamlit as st
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor
from surya.ocr import run_ocr
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from PIL import Image
import torch
import tempfile
import os
import re
import json
from groq import Groq
# Page configuration
st.set_page_config(page_title="DualTextOCRFusion", page_icon="πŸ”", layout="wide")
device = "cuda" if torch.cuda.is_available() else "cpu"
# Directories for images and results
IMAGES_DIR = "images"
RESULTS_DIR = "results"
os.makedirs(IMAGES_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
# Load Surya OCR Models (English + Hindi)
det_processor, det_model = load_det_processor(), load_det_model()
det_model.to(device)
rec_model, rec_processor = load_rec_model(), load_rec_processor()
rec_model.to(device)
# Load GOT Models
@st.cache_resource
def init_got_model():
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
return model.eval(), tokenizer
@st.cache_resource
def init_got_gpu_model():
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
return model.eval().cuda(), tokenizer
# Load Qwen Model
@st.cache_resource
def init_qwen_model():
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
return model.eval(), processor
# Text Cleaning AI - Clean spaces, handle dual languages
def clean_extracted_text(text):
cleaned_text = re.sub(r'\s+', ' ', text).strip()
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text)
return cleaned_text
# Polish the text using a model
def polish_text_with_ai(cleaned_text):
prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces."
client = Groq(api_key="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")
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a pedantic sentence corrector."},
{"role": "user", "content": prompt},
],
model="gemma2-9b-it",
)
polished_text = chat_completion.choices[0].message.content
return polished_text
# Extract text using GOT
def extract_text_got(image_file, model, tokenizer):
return model.chat(tokenizer, image_file, ocr_type='ocr')
# Extract text using Qwen
def extract_text_qwen(image_file, model, processor):
try:
image = Image.open(image_file).convert('RGB')
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}]
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt")
output_ids = model.generate(**inputs)
output_text = processor.batch_decode(output_ids, skip_special_tokens=True)
return output_text[0] if output_text else "No text extracted from the image."
except Exception as e:
return f"An error occurred: {str(e)}"
# Highlight keyword search
def highlight_text(text, search_term):
if not search_term:
return text
pattern = re.compile(re.escape(search_term), re.IGNORECASE)
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
# Title and UI
st.title("DualTextOCRFusion - πŸ”")
st.header("OCR Application - Multimodel Support")
st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.")
# Sidebar Configuration
st.sidebar.header("Configuration")
model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)"))
# Upload Section
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"])
clipboard_text = st.sidebar.text_area("Paste image path from clipboard:")
if uploaded_file or clipboard_text:
image_path = None
if uploaded_file:
image_path = os.path.join(IMAGES_DIR, uploaded_file.name)
with open(image_path, "wb") as f:
f.write(uploaded_file.getvalue())
elif clipboard_text:
image_path = clipboard_text.strip()
# Predict button
predict_button = st.sidebar.button("Predict")
# Main columns
col1, col2 = st.columns([2, 1])
# Check if result JSON already exists
result_json_path = os.path.join(RESULTS_DIR, f"{os.path.basename(image_path)}_result.json") if image_path else None
if predict_button and image_path:
if os.path.exists(result_json_path):
with open(result_json_path, "r") as json_file:
result_data = json.load(json_file)
polished_text = result_data.get("polished_text", "")
else:
with st.spinner("Processing..."):
image = Image.open(image_path).convert("RGB")
if model_choice == "GOT_CPU":
got_model, tokenizer = init_got_model()
extracted_text = extract_text_got(image_path, got_model, tokenizer)
elif model_choice == "GOT_GPU":
got_gpu_model, tokenizer = init_got_gpu_model()
extracted_text = extract_text_got(image_path, got_gpu_model, tokenizer)
elif model_choice == "Qwen":
qwen_model, qwen_processor = init_qwen_model()
extracted_text = extract_text_qwen(image_path, qwen_model, qwen_processor)
elif model_choice == "Surya (English+Hindi)":
langs = ["en", "hi"]
predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor)
text_list = re.findall(r"text='(.*?)'", str(predictions[0]))
extracted_text = ' '.join(text_list)
cleaned_text = clean_extracted_text(extracted_text)
polished_text = polish_text_with_ai(cleaned_text) if model_choice in ["GOT_CPU", "GOT_GPU"] else cleaned_text
# Save result to JSON
with open(result_json_path, "w") as json_file:
json.dump({"polished_text": polished_text}, json_file)
# Display image preview and text
if image_path:
with col1:
col1.image(image_path, caption='Uploaded Image', use_column_width=False, width=300)
st.subheader("Extracted Text (Cleaned & Polished)")
st.markdown(polished_text, unsafe_allow_html=True)
# Input box for real-time search
search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...", on_change=lambda: st.session_state.update(search_query) disabled=not uploaded_file)
# Highlight the search term in the text
if search_query:
highlighted_text = highlight_text(polished_text, search_query)
st.markdown("### Highlighted Search Results:")
st.markdown(highlighted_text, unsafe_allow_html=True)