import gradio as gr import spaces from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, Qwen2_5_VLForConditionalGeneration from qwen_vl_utils import process_vision_info import torch from PIL import Image import subprocess from datetime import datetime import numpy as np import os from gliner import GLiNER import json import tempfile import zipfile # Initialize GLiNER model gliner_model = GLiNER.from_pretrained("knowledgator/modern-gliner-bi-large-v1.0") DEFAULT_NER_LABELS = "person, organization, location, date, event" # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) # models = { # "Qwen/Qwen2-VL-7B-Instruct": AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval() # } class TextWithMetadata(list): def __init__(self, *args, **kwargs): super().__init__(*args) self.original_text = kwargs.get('original_text', '') self.entities = kwargs.get('entities', []) def array_to_image_path(image_array): # Convert numpy array to PIL Image img = Image.fromarray(np.uint8(image_array)) img.thumbnail((1024, 1024)) # Generate a unique filename using timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"image_{timestamp}.png" # Save the image img.save(filename) # Get the full path of the saved image full_path = os.path.abspath(filename) return full_path models = { "Qwen/Qwen2.5-VL-7B-Instruct": Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True, torch_dtype="auto").cuda().eval() } processors = { "Qwen/Qwen2.5-VL-7B-Instruct": AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True) } DESCRIPTION = "This demo uses[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)" kwargs = {} kwargs['torch_dtype'] = torch.bfloat16 user_prompt = '<|user|>\n' assistant_prompt = '<|assistant|>\n' prompt_suffix = "<|end|>\n" @spaces.GPU def run_example(image, model_id="Qwen/Qwen2.5-VL-7B-Instruct", run_ner=False, ner_labels=DEFAULT_NER_LABELS): # First get the OCR text text_input = "Convert the image to text." image_path = array_to_image_path(image) model = models[model_id] processor = processors[model_id] prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}" image = Image.fromarray(image).convert("RGB") messages = [ { "role": "user", "content": [ { "type": "image", "image": image_path, }, {"type": "text", "text": text_input}, ], } ] # Preparation for inference text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda") # Inference: Generation of the output generated_ids = model.generate(**inputs, max_new_tokens=1024) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) ocr_text = output_text[0] # If NER is enabled, process the OCR text if run_ner: ner_results = gliner_model.predict_entities( ocr_text, ner_labels.split(","), threshold=0.3 ) # Create a list of tuples (text, label) for highlighting highlighted_text = [] last_end = 0 # Sort entities by start position sorted_entities = sorted(ner_results, key=lambda x: x["start"]) # Process each entity and add non-entity text segments for entity in sorted_entities: # Add non-entity text before the current entity if last_end < entity["start"]: highlighted_text.append((ocr_text[last_end:entity["start"]], None)) # Add the entity text with its label highlighted_text.append(( ocr_text[entity["start"]:entity["end"]], entity["label"] )) last_end = entity["end"] # Add any remaining text after the last entity if last_end < len(ocr_text): highlighted_text.append((ocr_text[last_end:], None)) # Create TextWithMetadata instance with the highlighted text and metadata result = TextWithMetadata(highlighted_text, original_text=ocr_text, entities=ner_results) return result, result # Return twice: once for display, once for state # If NER is disabled, return the text without highlighting result = TextWithMetadata([(ocr_text, None)], original_text=ocr_text, entities=[]) return result, result # Return twice: once for display, once for state css = """ /* Overall app styling */ .gradio-container { max-width: 1200px !important; margin: 0 auto; padding: 20px; background-color: #f8f9fa; } /* Tabs styling */ .tabs { border-radius: 8px; background: white; padding: 20px; box-shadow: 0 2px 6px rgba(0, 0, 0, 0.1); } /* Input/Output containers */ .input-container, .output-container { background: white; border-radius: 8px; padding: 15px; margin: 10px 0; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); } /* Button styling */ .submit-btn { background-color: #2d31fa !important; border: none !important; padding: 8px 20px !important; border-radius: 6px !important; color: white !important; transition: all 0.3s ease !important; } .submit-btn:hover { background-color: #1f24c7 !important; transform: translateY(-1px); } /* Output text area */ #output { height: 500px; overflow: auto; border: 1px solid #e0e0e0; border-radius: 6px; padding: 15px; background: #ffffff; font-family: 'Arial', sans-serif; } /* Dropdown styling */ .gr-dropdown { border-radius: 6px !important; border: 1px solid #e0e0e0 !important; } /* Image upload area */ .gr-image-input { border: 2px dashed #ccc; border-radius: 8px; padding: 20px; transition: all 0.3s ease; } .gr-image-input:hover { border-color: #2d31fa; } """ with gr.Blocks(css=css) as demo: # Add state variables to store OCR results ocr_state = gr.State() gr.Image("Caracal.jpg", interactive=False) with gr.Tab(label="Image Input", elem_classes="tabs"): with gr.Row(): with gr.Column(elem_classes="input-container"): input_img = gr.Image(label="Input Picture", elem_classes="gr-image-input") model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="Qwen/Qwen2.5-VL-7B-Instruct", elem_classes="gr-dropdown") # Add NER controls with gr.Row(): ner_checkbox = gr.Checkbox(label="Run Named Entity Recognition", value=False) ner_labels = gr.Textbox( label="NER Labels (comma-separated)", value=DEFAULT_NER_LABELS, visible=False ) submit_btn = gr.Button(value="Submit", elem_classes="submit-btn") with gr.Column(elem_classes="output-container"): output_text = gr.HighlightedText(label="Output Text", elem_id="output") # Show/hide NER labels based on checkbox ner_checkbox.change( lambda x: gr.update(visible=x), inputs=[ner_checkbox], outputs=[ner_labels] ) # Modify the submit button click handler to update state submit_btn.click( run_example, inputs=[input_img, model_selector, ner_checkbox, ner_labels], outputs=[output_text, ocr_state] # Add ocr_state to outputs ) with gr.Row(): filename = gr.Textbox(label="Save filename (without extension)", placeholder="Enter filename to save") download_btn = gr.Button("Download Image & Text", elem_classes="submit-btn") download_output = gr.File(label="Download") # Modify create_zip to use the state data def create_zip(image, fname, ocr_result): # Validate inputs if not fname or image is None: # Changed the validation check return None try: # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) elif not isinstance(image, Image.Image): return None with tempfile.TemporaryDirectory() as temp_dir: # Save image img_path = os.path.join(temp_dir, f"{fname}.png") image.save(img_path) # Use the OCR result from state original_text = ocr_result.original_text if ocr_result else "" entities = ocr_result.entities if ocr_result else [] # Save text txt_path = os.path.join(temp_dir, f"{fname}.txt") with open(txt_path, 'w', encoding='utf-8') as f: f.write(original_text) # Create JSON with text and entities json_data = { "text": original_text, "entities": entities, "image_file": f"{fname}.png" } # Save JSON json_path = os.path.join(temp_dir, f"{fname}.json") with open(json_path, 'w', encoding='utf-8') as f: json.dump(json_data, f, indent=2, ensure_ascii=False) # Create zip file output_dir = "downloads" os.makedirs(output_dir, exist_ok=True) zip_path = os.path.join(output_dir, f"{fname}.zip") with zipfile.ZipFile(zip_path, 'w') as zipf: zipf.write(img_path, os.path.basename(img_path)) zipf.write(txt_path, os.path.basename(txt_path)) zipf.write(json_path, os.path.basename(json_path)) return zip_path except Exception as e: print(f"Error creating zip: {str(e)}") return None # Update the download button click handler to include state download_btn.click( create_zip, inputs=[input_img, filename, ocr_state], outputs=[download_output] ) demo.queue(api_open=False) demo.launch(debug=True)