import streamlit as st from transformers import AutoModel, AutoTokenizer, Qwen2VLForConditionalGeneration, AutoProcessor from surya.ocr import run_ocr from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor from surya.model.recognition.model import load_model as load_rec_model from surya.model.recognition.processor import load_processor as load_rec_processor from PIL import Image import torch import tempfile import os import re from groq import Groq # Page configuration st.set_page_config(page_title="DualTextOCRFusion", page_icon="🔍", layout="wide") device = "cuda" if torch.cuda.is_available() else "cpu" # Load Surya OCR Models (English + Hindi) det_processor, det_model = load_det_processor(), load_det_model() det_model.to(device) rec_model, rec_processor = load_rec_model(), load_rec_processor() rec_model.to(device) # Load GOT Models @st.cache_resource def init_got_model(): tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) return model.eval(), tokenizer @st.cache_resource def init_got_gpu_model(): tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id) return model.eval().cuda(), tokenizer # Load Qwen Model @st.cache_resource def init_qwen_model(): model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", device_map="cpu", torch_dtype=torch.float16) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") return model.eval(), processor # Text Cleaning AI - Clean spaces, handle dual languages def clean_extracted_text(text): # Remove extra spaces cleaned_text = re.sub(r'\s+', ' ', text).strip() cleaned_text = re.sub(r'\s([?.!,])', r'\1', cleaned_text) return cleaned_text # Polish the text using a model def polish_text_with_ai(cleaned_text): prompt = f"Correct and clean the following text: '{cleaned_text}' and make it meaningful." client = Groq(api_key="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") chat_completion = client.chat.completions.create( messages=[ { "role": "system", "content": "You are a meaningful sentence pedantic, you remove extra spaces in between words and word to make the sentence meaningful in English/Hindi/Hinglish according to the sentence." }, { "role": "user", "content": prompt, } ], model="gemma2-9b-it", ) polished_text=chat_completion.choices[0].message.content return polished_text # Extract text using GOT def extract_text_got(image_file, model, tokenizer): return model.chat(tokenizer, image_file, ocr_type='ocr') # Extract text using Qwen def extract_text_qwen(image_file, model, processor): try: image = Image.open(image_file).convert('RGB') conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Extract text from this image."}]}] text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(text=[text_prompt], images=[image], return_tensors="pt") output_ids = model.generate(**inputs) output_text = processor.batch_decode(output_ids, skip_special_tokens=True) return output_text[0] if output_text else "No text extracted from the image." except Exception as e: return f"An error occurred: {str(e)}" # Highlight keyword search def highlight_text(text, search_term): if not search_term: # If no search term is provided, return the original text return text # Use a regular expression to search for the term, case insensitive pattern = re.compile(re.escape(search_term), re.IGNORECASE) # Highlight matched terms with yellow background return pattern.sub(lambda m: f'{m.group()}', text) # Title and UI st.title("DualTextOCRFusion - 🔍") st.title("OCR Application - Multimodel Support") st.write("Upload an image for OCR using various models, with support for English, Hindi, and Hinglish.") # Sidebar Configuration st.sidebar.header("Configuration") model_choice = st.sidebar.selectbox("Select OCR Model:", ("GOT_CPU", "GOT_GPU", "Qwen", "Surya (English+Hindi)")) # Upload Section uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) # Predict button predict_button = st.sidebar.button("Predict") # Main columns col1, col2 = st.columns([2, 1]) # Display image preview if uploaded_file: image = Image.open(uploaded_file) with col1: col1.image(image, caption='Uploaded Image', use_column_width=False, width=300) # Handle predictions if predict_button and uploaded_file: with st.spinner("Processing..."): # Save uploaded image with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: temp_file.write(uploaded_file.getvalue()) temp_file_path = temp_file.name image = Image.open(temp_file_path) image = image.convert("RGB") if model_choice == "GOT_CPU": got_model, tokenizer = init_got_model() extracted_text = extract_text_got(temp_file_path, got_model, tokenizer) elif model_choice == "GOT_GPU": got_gpu_model, tokenizer = init_got_gpu_model() extracted_text = extract_text_got(temp_file_path, got_gpu_model, tokenizer) elif model_choice == "Qwen": qwen_model, qwen_processor = init_qwen_model() extracted_text = extract_text_qwen(temp_file_path, qwen_model, qwen_processor) elif model_choice == "Surya (English+Hindi)": langs = ["en", "hi"] predictions = run_ocr([image], [langs], det_model, det_processor, rec_model, rec_processor) text_list = re.findall(r"text='(.*?)'", str(predictions[0])) extracted_text = ' '.join(text_list) # Clean extracted text cleaned_text = clean_extracted_text(extracted_text) # Optionally, polish text with AI model for better language flow if model_choice in ["GOT_CPU", "GOT_GPU"]: polished_text = polish_text_with_ai(cleaned_text) else: polished_text = cleaned_text # Delete temp file if os.path.exists(temp_file_path): os.remove(temp_file_path) # Display extracted text and search functionality st.subheader("Extracted Text (Cleaned & Polished)") st.markdown(polished_text, unsafe_allow_html=True) # Input box for real-time search search_query = st.text_input("Search in extracted text:", key="search_query", placeholder="Type to search...") # Update results dynamically based on the search term if search_query: # Highlight the search term in the text highlighted_text = highlight_text(polished_text, search_query) st.markdown("### Highlighted Search Results:") # Render the highlighted text, allowing HTML rendering for the highlight st.markdown(highlighted_text, unsafe_allow_html=True) else: # If no search term is provided, display the original text st.markdown("### Extracted Text:") st.markdown(polished_text)