srikar-v05's picture
Upload folder using huggingface_hub
f563d24 verified
raw
history blame
3.57 kB
import numpy as np
import gradio as gr
import os
from transformers import AutoModel, AutoTokenizer
import torch
from PIL import Image
import warnings
import re
# Suppress warnings
warnings.simplefilter("ignore")
# Retrieve Hugging Face token
hf_token = os.getenv("HF_TOKEN")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_auth_token=hf_token)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
low_cpu_mem_usage=True,
device_map='cuda' if torch.cuda.is_available() else 'cpu',
use_safetensors=True,
pad_token_id=tokenizer.eos_token_id,
use_auth_token=hf_token)
model = model.eval()
# Global variable to store OCR result
ocr_result = ""
# Perform OCR function
def perform_ocr(image):
global ocr_result
# Convert the numpy array to a PIL image
pil_image = Image.fromarray(image)
# Save the image temporarily
image_file = "temp_image.png"
pil_image.save(image_file)
# Perform OCR with the model
with torch.no_grad():
ocr_result = model.chat(tokenizer, image_file, ocr_type='ocr')
# Optionally remove the temporary image file
os.remove(image_file)
return ocr_result
# Function to highlight search term with a different color (e.g., light blue)
def highlight_text(text, query):
# Use regex to wrap the search query with a span for styling
pattern = re.compile(re.escape(query), re.IGNORECASE)
highlighted_text = pattern.sub(f"<span style='background-color: #ADD8E6; color: black;'>{query}</span>", text)
return highlighted_text
# Search functionality to search within OCR result, highlight, and return the modified text
def search_text(query):
# If no query is provided, return the original OCR result
if not query:
return ocr_result, "No matches found."
# Highlight the searched term in the OCR text
highlighted_result = highlight_text(ocr_result, query)
# Split OCR result into lines and search for the query
lines = ocr_result.split('\n')
matching_lines = [line for line in lines if query.lower() in line.lower()]
if matching_lines:
return highlighted_result, '\n'.join(matching_lines) # Return highlighted text and matched lines
else:
return highlighted_result, "No matches found."
# Set up Gradio interface
with gr.Blocks() as demo:
# Section for uploading image and getting OCR results
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Upload Image")
ocr_output = gr.HTML(label="OCR Output") # Changed to HTML for displaying highlighted text
ocr_button = gr.Button("Run OCR")
# Section for searching within the OCR result
with gr.Row():
with gr.Column():
search_input = gr.Textbox(label="Search Text")
search_output = gr.HTML(label="Search Result") # Separate output for search matches
search_button = gr.Button("Search in OCR Text")
# Define button actions
ocr_button.click(perform_ocr, inputs=image_input, outputs=ocr_output)
search_button.click(search_text, inputs=search_input, outputs=[ocr_output, search_output])
# Launch the Gradio interface
demo.launch(share=True)