Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import os | |
from pathlib import Path | |
from PIL import Image | |
import shutil | |
from ultralytics import YOLO | |
def load_models(models_dir='models', info_file='models_info.json'): | |
""" | |
Load YOLO models and their information from the specified directory and JSON file. | |
Args: | |
models_dir (str): Path to the models directory. | |
info_file (str): Path to the JSON file containing model info. | |
Returns: | |
dict: A dictionary of models and their associated information. | |
""" | |
with open(info_file, 'r') as f: | |
models_info = json.load(f) | |
models = {} | |
for model_info in models_info: | |
model_name = model_info['model_name'] | |
model_path = os.path.join(models_dir, model_name, 'best.pt') # Assuming 'best.pt' as the weight file | |
if os.path.isfile(model_path): | |
try: | |
# Load the YOLO model | |
model = YOLO(model_path) | |
models[model_name] = { | |
'model': model, | |
'mAP': model_info.get('mAP_score', 'N/A'), | |
'num_images': model_info.get('num_images', 'N/A') | |
} | |
print(f"Loaded model '{model_name}' from '{model_path}'.") | |
except Exception as e: | |
print(f"Error loading model '{model_name}': {e}") | |
else: | |
print(f"Model weight file for '{model_name}' not found at '{model_path}'. Skipping.") | |
return models | |
def get_model_info(model_name, models): | |
""" | |
Retrieve model information for the selected model. | |
Args: | |
model_name (str): The name of the model. | |
models (dict): The dictionary containing models and their info. | |
Returns: | |
str: A formatted string containing model information. | |
""" | |
model_info = models.get(model_name, {}) | |
if not model_info: | |
return "Model information not available." | |
info_text = ( | |
f"**Model Name:** {model_name}\n\n" | |
f"**mAP Score:** {model_info.get('mAP', 'N/A')}\n\n" | |
f"**Number of Images Trained On:** {model_info.get('num_images', 'N/A')}" | |
) | |
return info_text | |
def predict_image(model_name, image, models): | |
""" | |
Perform prediction on an uploaded image using the selected YOLO model. | |
Args: | |
model_name (str): The name of the selected model. | |
image (PIL.Image.Image): The uploaded image. | |
models (dict): The dictionary containing models and their info. | |
Returns: | |
tuple: A status message, the processed image, and the path to the output image. | |
""" | |
model = models.get(model_name, {}).get('model', None) | |
if not model: | |
return "Error: Model not found.", None, None | |
try: | |
# Save the uploaded image to a temporary path | |
input_image_path = f"temp/{model_name}_input_image.jpg" | |
os.makedirs(os.path.dirname(input_image_path), exist_ok=True) | |
image.save(input_image_path) | |
# Perform prediction | |
results = model(input_image_path, save=True, save_txt=False, conf=0.25) | |
# Ultralytics saves the result images in 'runs/detect/predict' | |
output_image_path = results[0].save()[0] # Get the path to the saved image | |
# Open the output image | |
output_image = Image.open(output_image_path) | |
return "Prediction completed successfully.", output_image, output_image_path | |
except Exception as e: | |
return f"Error during prediction: {str(e)}", None, None | |
def predict_video(model_name, video, models): | |
""" | |
Perform prediction on an uploaded video using the selected YOLO model. | |
Args: | |
model_name (str): The name of the selected model. | |
video (str): Path to the uploaded video file. | |
models (dict): The dictionary containing models and their info. | |
Returns: | |
tuple: A status message, the processed video, and the path to the output video. | |
""" | |
model = models.get(model_name, {}).get('model', None) | |
if not model: | |
return "Error: Model not found.", None, None | |
try: | |
# Ensure the video is saved in a temporary location | |
input_video_path = video.name | |
if not os.path.isfile(input_video_path): | |
# If the video is a temp file provided by Gradio | |
shutil.copy(video.name, input_video_path) | |
# Perform prediction | |
results = model(input_video_path, save=True, save_txt=False, conf=0.25) | |
# Ultralytics saves the result videos in 'runs/detect/predict' | |
output_video_path = results[0].save()[0] # Get the path to the saved video | |
return "Prediction completed successfully.", output_video_path, output_video_path | |
except Exception as e: | |
return f"Error during prediction: {str(e)}", None, None | |
def main(): | |
# Load the models and their information | |
models = load_models() | |
# Initialize Gradio Blocks interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# π§ͺ YOLO Model Tester") | |
gr.Markdown( | |
""" | |
Upload images or videos to test different YOLO models. Select a model from the dropdown to see its details. | |
""" | |
) | |
# Model selection and info | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
choices=list(models.keys()), | |
label="Select Model", | |
value=None | |
) | |
model_info = gr.Markdown("**Model Information will appear here.**") | |
# Update model_info when a model is selected | |
model_dropdown.change( | |
fn=lambda model_name: get_model_info(model_name, models) if model_name else "Please select a model.", | |
inputs=model_dropdown, | |
outputs=model_info | |
) | |
# Tabs for different input types | |
with gr.Tabs(): | |
# Image Prediction Tab | |
with gr.Tab("πΌοΈ Image"): | |
with gr.Column(): | |
image_input = gr.Image( | |
type='pil', | |
label="Upload Image for Prediction", | |
tool="editor" | |
) | |
image_predict_btn = gr.Button("π Predict on Image") | |
image_status = gr.Markdown("**Status will appear here.**") | |
image_output = gr.Image(label="Predicted Image") | |
image_download_btn = gr.File(label="β¬οΈ Download Predicted Image") | |
# Define the image prediction function | |
def process_image(model_name, image): | |
return predict_image(model_name, image, models) | |
# Connect the predict button | |
image_predict_btn.click( | |
fn=process_image, | |
inputs=[model_dropdown, image_input], | |
outputs=[image_status, image_output, image_download_btn] | |
) | |
# Video Prediction Tab | |
with gr.Tab("π₯ Video"): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Upload Video for Prediction" | |
) | |
video_predict_btn = gr.Button("π Predict on Video") | |
video_status = gr.Markdown("**Status will appear here.**") | |
video_output = gr.Video(label="Predicted Video") | |
video_download_btn = gr.File(label="β¬οΈ Download Predicted Video") | |
# Define the video prediction function | |
def process_video(model_name, video): | |
return predict_video(model_name, video, models) | |
# Connect the predict button | |
video_predict_btn.click( | |
fn=process_video, | |
inputs=[model_dropdown, video_input], | |
outputs=[video_status, video_output, video_download_btn] | |
) | |
gr.Markdown( | |
""" | |
--- | |
**Note:** Ensure that the YOLO models are correctly placed in the `models/` directory and that `models_info.json` is properly configured. | |
""" | |
) | |
# Launch the Gradio app | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |