Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor | |
from surya.ocr import run_ocr | |
from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor | |
from surya.model.recognition.model import load_model as load_rec_model | |
from surya.model.recognition.processor import load_processor as load_rec_processor | |
from PIL import Image | |
import torch | |
import tempfile | |
import os | |
import re | |
import json | |
from groq import Groq | |
# Page configuration | |
st.set_page_config(page_title="DualTextOCRFusion", page_icon="π", layout="wide") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Directories for images and results | |
IMAGES_DIR = "images" | |
RESULTS_DIR = "results" | |
os.makedirs(IMAGES_DIR, exist_ok=True) | |
os.makedirs(RESULTS_DIR, exist_ok=True) | |
# Load Surya OCR Models (English + Hindi) | |
det_processor, det_model = load_det_processor(), load_det_model() | |
det_model.to(device) | |
rec_model, rec_processor = load_rec_model(), load_rec_processor() | |
rec_model.to(device) | |
# Load GOT Models | |
def init_got_model(): | |
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) | |
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
return model.eval(), tokenizer | |
def init_got_gpu_model(): | |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) | |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) | |
return model.eval().cuda(), tokenizer | |
# Load Qwen Model | |
def init_qwen_model(): | |
model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") | |
return model.eval(), processor | |
# Text Cleaning AI - Clean spaces, handle dual languages | |
def clean_extracted_text(text): | |
cleaned_text = re.sub(r'\s+', ' ', text).strip() | |
cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) | |
return cleaned_text | |
# Polish the text using a model | |
def polish_text_with_ai(cleaned_text): | |
prompt = f"Remove unwanted spaces between and inside words to join incomplete words, creating a meaningful sentence in either Hindi, English, or Hinglish without altering any words from the given extracted text. Then, return the corrected text with adjusted spaces." | |
client = Groq(api_key="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") | |
chat_completion = client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a pedantic sentence corrector."}, | |
{"role": "user", "content": prompt}, | |
], | |
model="gemma2-9b-it", | |
) | |
polished_text = chat_completion.choices[0].message.content | |
return polished_text | |
# Extract text using GOT | |
def extract_text_got(image_file, model, tokenizer): | |
return model.chat(tokenizer, image_file, ocr_type='ocr') | |
# Extract text using Qwen | |
def extract_text_qwen(image_file, model, processor): | |
try: | |
image = Image.open(image_file).convert('RGB') | |
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] | |
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
inputs = processor(text=[text_prompt], images=[image], return_tensors="pt") | |
output_ids = model.generate(**inputs) | |
output_text = processor.batch_decode(output_ids, skip_special_tokens=True) | |
return output_text[0] if output_text else "No text extracted from the image." | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Highlight keyword search | |
def highlight_text(text, search_term): | |
if not search_term: | |
return text | |
pattern = re.compile(re.escape(search_term), re.IGNORECASE) | |
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text) | |
# Title and UI | |
st.title("DualTextOCRFusion - π") | |
st.header("OCR Application - Multimodel Support") | |
st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.") | |
# Sidebar Configuration | |
st.sidebar.header("Configuration") | |
model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)")) | |
# Upload Section | |
uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) | |
clipboard_text = st.sidebar.text_area("Paste image path from clipboard:") | |
if uploaded_file or clipboard_text: | |
image_path = None | |
if uploaded_file: | |
image_path = os.path.join(IMAGES_DIR, uploaded_file.name) | |
with open(image_path, "wb") as f: | |
f.write(uploaded_file.getvalue()) | |
elif clipboard_text: | |
image_path = clipboard_text.strip() | |
# Predict button | |
predict_button = st.sidebar.button("Predict") | |
# Main columns | |
col1, col2 = st.columns([2, 1]) | |
# Check if result JSON already exists | |
result_json_path = os.path.join(RESULTS_DIR, f"{os.path.basename(image_path)}_result.json") if image_path else None | |
if predict_button and image_path: | |
if os.path.exists(result_json_path): | |
with open(result_json_path, "r") as json_file: | |
result_data = json.load(json_file) | |
polished_text = result_data.get("polished_text", "") | |
else: | |
with st.spinner("Processing..."): | |
image = Image.open(image_path).convert("RGB") | |
if model_choice == "GOT_CPU": | |
got_model, tokenizer = init_got_model() | |
extracted_text = extract_text_got(image_path, got_model, tokenizer) | |
elif model_choice == "GOT_GPU": | |
got_gpu_model, tokenizer = init_got_gpu_model() | |
extracted_text = extract_text_got(image_path, got_gpu_model, tokenizer) | |
elif model_choice == "Qwen": | |
qwen_model, qwen_processor = init_qwen_model() | |
extracted_text = extract_text_qwen(image_path, qwen_model, qwen_processor) | |
elif model_choice == "Surya (English+Hindi)": | |
langs = ["en", "hi"] | |
predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor) | |
text_list = re.findall(r"text='(.*?)'", str(predictions[0])) | |
extracted_text = ' '.join(text_list) | |
cleaned_text = clean_extracted_text(extracted_text) | |
polished_text = polish_text_with_ai(cleaned_text) if model_choice in ["GOT_CPU", "GOT_GPU"] else cleaned_text | |
# Save result to JSON | |
with open(result_json_path, "w") as json_file: | |
json.dump({"polished_text": polished_text}, json_file) | |
# Display image preview and text | |
if image_path: | |
with col1: | |
col1.image(image_path, caption='Uploaded Image', use_column_width=False, width=300) | |
st.subheader("Extracted Text (Cleaned & Polished)") | |
st.markdown(polished_text, unsafe_allow_html=True) | |
# Input box for real-time search | |
search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...", on_change=lambda: st.session_state.update(search_query) disabled=not uploaded_file) | |
# Highlight the search term in the text | |
if search_query: | |
highlighted_text = highlight_text(polished_text, search_query) | |
st.markdown("### Highlighted Search Results:") | |
st.markdown(highlighted_text, unsafe_allow_html=True) | |