import gradio as gr import json import os from pathlib import Path from PIL import Image import shutil from ultralytics import YOLO def load_models(models_dir='models', info_file='models_info.json'): """ Load YOLO models and their information from the specified directory and JSON file. Args: models_dir (str): Path to the models directory. info_file (str): Path to the JSON file containing model info. Returns: dict: A dictionary of models and their associated information. """ with open(info_file, 'r') as f: models_info = json.load(f) models = {} for model_info in models_info: model_name = model_info['model_name'] model_path = os.path.join(models_dir, model_name, 'best.pt') # Assuming 'best.pt' as the weight file if os.path.isfile(model_path): try: # Load the YOLO model model = YOLO(model_path) models[model_name] = { 'model': model, 'mAP': model_info.get('mAP_score', 'N/A'), 'num_images': model_info.get('num_images', 'N/A') } print(f"Loaded model '{model_name}' from '{model_path}'.") except Exception as e: print(f"Error loading model '{model_name}': {e}") else: print(f"Model weight file for '{model_name}' not found at '{model_path}'. Skipping.") return models def get_model_info(model_name, models): """ Retrieve model information for the selected model. Args: model_name (str): The name of the model. models (dict): The dictionary containing models and their info. Returns: str: A formatted string containing model information. """ model_info = models.get(model_name, {}) if not model_info: return "Model information not available." info_text = ( f"**Model Name:** {model_name}\n\n" f"**mAP Score:** {model_info.get('mAP', 'N/A')}\n\n" f"**Number of Images Trained On:** {model_info.get('num_images', 'N/A')}" ) return info_text def predict_image(model_name, image, models): """ Perform prediction on an uploaded image using the selected YOLO model. Args: model_name (str): The name of the selected model. image (PIL.Image.Image): The uploaded image. models (dict): The dictionary containing models and their info. Returns: tuple: A status message, the processed image, and the path to the output image. """ model = models.get(model_name, {}).get('model', None) if not model: return "Error: Model not found.", None, None try: # Save the uploaded image to a temporary path input_image_path = f"temp/{model_name}_input_image.jpg" os.makedirs(os.path.dirname(input_image_path), exist_ok=True) image.save(input_image_path) # Perform prediction results = model(input_image_path, save=True, save_txt=False, conf=0.25) # Ultralytics saves the result images in 'runs/detect/predict' output_image_path = results[0].save()[0] # Get the path to the saved image # Open the output image output_image = Image.open(output_image_path) return "Prediction completed successfully.", output_image, output_image_path except Exception as e: return f"Error during prediction: {str(e)}", None, None def predict_video(model_name, video, models): """ Perform prediction on an uploaded video using the selected YOLO model. Args: model_name (str): The name of the selected model. video (str): Path to the uploaded video file. models (dict): The dictionary containing models and their info. Returns: tuple: A status message, the processed video, and the path to the output video. """ model = models.get(model_name, {}).get('model', None) if not model: return "Error: Model not found.", None, None try: # Ensure the video is saved in a temporary location input_video_path = video.name if not os.path.isfile(input_video_path): # If the video is a temp file provided by Gradio shutil.copy(video.name, input_video_path) # Perform prediction results = model(input_video_path, save=True, save_txt=False, conf=0.25) # Ultralytics saves the result videos in 'runs/detect/predict' output_video_path = results[0].save()[0] # Get the path to the saved video return "Prediction completed successfully.", output_video_path, output_video_path except Exception as e: return f"Error during prediction: {str(e)}", None, None def main(): # Load the models and their information models = load_models() # Initialize Gradio Blocks interface with gr.Blocks() as demo: gr.Markdown("# ๐Ÿงช YOLO Model Tester") gr.Markdown( """ Upload images or videos to test different YOLO models. Select a model from the dropdown to see its details. """ ) # Model selection and info with gr.Row(): model_dropdown = gr.Dropdown( choices=list(models.keys()), label="Select Model", value=None ) model_info = gr.Markdown("**Model Information will appear here.**") # Update model_info when a model is selected model_dropdown.change( fn=lambda model_name: get_model_info(model_name, models) if model_name else "Please select a model.", inputs=model_dropdown, outputs=model_info ) # Tabs for different input types with gr.Tabs(): # Image Prediction Tab with gr.Tab("๐Ÿ–ผ๏ธ Image"): with gr.Column(): image_input = gr.Image( type='pil', label="Upload Image for Prediction", tool="editor" ) image_predict_btn = gr.Button("๐Ÿ” Predict on Image") image_status = gr.Markdown("**Status will appear here.**") image_output = gr.Image(label="Predicted Image") image_download_btn = gr.File(label="โฌ‡๏ธ Download Predicted Image") # Define the image prediction function def process_image(model_name, image): return predict_image(model_name, image, models) # Connect the predict button image_predict_btn.click( fn=process_image, inputs=[model_dropdown, image_input], outputs=[image_status, image_output, image_download_btn] ) # Video Prediction Tab with gr.Tab("๐ŸŽฅ Video"): with gr.Column(): video_input = gr.Video( label="Upload Video for Prediction" ) video_predict_btn = gr.Button("๐Ÿ” Predict on Video") video_status = gr.Markdown("**Status will appear here.**") video_output = gr.Video(label="Predicted Video") video_download_btn = gr.File(label="โฌ‡๏ธ Download Predicted Video") # Define the video prediction function def process_video(model_name, video): return predict_video(model_name, video, models) # Connect the predict button video_predict_btn.click( fn=process_video, inputs=[model_dropdown, video_input], outputs=[video_status, video_output, video_download_btn] ) gr.Markdown( """ --- **Note:** Ensure that the YOLO models are correctly placed in the `models/` directory and that `models_info.json` is properly configured. """ ) # Launch the Gradio app demo.launch() if __name__ == "__main__": main()