rolo-models / app.py
wuhp's picture
Update app.py
c9cbfde verified
import gradio as gr
import json
import os
from pathlib import Path
from PIL import Image
import shutil
from ultralytics import YOLO
import requests
MODELS_DIR = "models"
MODELS_INFO_FILE = "models_info.json"
TEMP_DIR = "temp"
OUTPUT_DIR = "outputs"
def download_file(url, dest_path):
"""
Download a file from a URL to the destination path.
Args:
url (str): The URL to download from.
dest_path (str): The local path to save the file.
Returns:
bool: True if download succeeded, False otherwise.
"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {url} to {dest_path}.")
return True
except Exception as e:
print(f"Failed to download {url}. Error: {e}")
return False
def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
"""
Load YOLO models and their information from the specified directory and JSON file.
Downloads models if they are not already present.
Args:
models_dir (str): Path to the models directory.
info_file (str): Path to the JSON file containing model info.
Returns:
dict: A dictionary of models and their associated information.
"""
with open(info_file, 'r') as f:
models_info = json.load(f)
models = {}
for model_info in models_info:
model_name = model_info['model_name']
display_name = model_info.get('display_name', model_name)
model_dir = os.path.join(models_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, f"{model_name}.pt")
download_url = model_info['download_url']
if not os.path.isfile(model_path):
print(f"Model '{display_name}' not found locally. Downloading from {download_url}...")
success = download_file(download_url, model_path)
if not success:
print(f"Skipping model '{display_name}' due to download failure.")
continue
try:
model = YOLO(model_path)
models[model_name] = {
'display_name': display_name,
'model': model,
'info': model_info
}
print(f"Loaded model '{display_name}' from '{model_path}'.")
except Exception as e:
print(f"Error loading model '{display_name}': {e}")
return models
def get_model_info(model_info):
"""
Retrieve formatted model information for display.
Args:
model_info (dict): The model's information dictionary.
Returns:
str: A formatted string containing model details.
"""
info = model_info
class_ids = info.get('class_ids', {})
class_image_counts = info.get('class_image_counts', {})
datasets_used = info.get('datasets_used', [])
class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()])
class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()])
datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used])
info_text = (
f"**{info.get('display_name', 'Model Name')}**\n\n"
f"**Architecture:** {info.get('architecture', 'N/A')}\n\n"
f"**Training Epochs:** {info.get('training_epochs', 'N/A')}\n\n"
f"**Batch Size:** {info.get('batch_size', 'N/A')}\n\n"
f"**Optimizer:** {info.get('optimizer', 'N/A')}\n\n"
f"**Learning Rate:** {info.get('learning_rate', 'N/A')}\n\n"
f"**Data Augmentation Level:** {info.get('data_augmentation_level', 'N/A')}\n\n"
f"**[email protected]:** {info.get('mAP_score', 'N/A')}\n\n"
f"**Number of Images Trained On:** {info.get('num_images', 'N/A')}\n\n"
f"**Class IDs:**\n{class_ids_formatted}\n\n"
f"**Datasets Used:**\n{datasets_used_formatted}\n\n"
f"**Class Image Counts:**\n{class_image_counts_formatted}"
)
return info_text
def predict_image(model_name, image, confidence, models):
"""
Perform prediction on an uploaded image using the selected YOLO model.
Args:
model_name (str): The name of the selected model.
image (PIL.Image.Image): The uploaded image.
confidence (float): The confidence threshold for detections.
models (dict): The dictionary containing models and their info.
Returns:
tuple: A status message, the processed image, and the path to the output image.
"""
model_entry = models.get(model_name, {})
model = model_entry.get('model', None)
if not model:
return "Error: Model not found.", None, None
try:
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
image.save(input_image_path)
results = model(input_image_path, save=True, save_txt=False, conf=confidence)
latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
output_image_path = os.path.join(latest_run, Path(input_image_path).name)
if not os.path.isfile(output_image_path):
output_image_path = results[0].save()[0]
final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg")
shutil.copy(output_image_path, final_output_path)
output_image = Image.open(final_output_path)
return "βœ… Prediction completed successfully.", output_image, final_output_path
except Exception as e:
return f"❌ Error during prediction: {str(e)}", None, None
def main():
models = load_models()
if not models:
print("No models loaded. Please check your models_info.json and model URLs.")
return
with gr.Blocks() as demo:
gr.Markdown("# πŸ§ͺ YOLOv11 Model Tester")
gr.Markdown(
"""
Upload images to test different YOLOv11 models. Select a model from the dropdown to see its details.
"""
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=[models[m]['display_name'] for m in models],
label="Select Model",
value=None
)
model_info = gr.Markdown("**Model Information will appear here.**")
display_to_name = {models[m]['display_name']: m for m in models}
def update_model_info(selected_display_name):
if not selected_display_name:
return "Please select a model."
model_name = display_to_name.get(selected_display_name)
if not model_name:
return "Model information not available."
model_entry = models[model_name]['info']
return get_model_info(model_entry)
model_dropdown.change(
fn=update_model_info,
inputs=model_dropdown,
outputs=model_info
)
with gr.Row():
confidence_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.25,
label="Confidence Threshold",
info="Adjust the minimum confidence required for detections to be displayed."
)
with gr.Tab("πŸ–ΌοΈ Image"):
with gr.Column():
image_input = gr.Image(
type='pil',
label="Upload Image for Prediction"
)
image_predict_btn = gr.Button("πŸ” Predict on Image")
image_status = gr.Markdown("**Status will appear here.**")
image_output = gr.Image(label="Predicted Image")
image_download_btn = gr.File(label="⬇️ Download Predicted Image")
def process_image(selected_display_name, image, confidence):
if not selected_display_name:
return "❌ Please select a model.", None, None
model_name = display_to_name.get(selected_display_name)
return predict_image(model_name, image, confidence, models)
image_predict_btn.click(
fn=process_image,
inputs=[model_dropdown, image_input, confidence_slider],
outputs=[image_status, image_output, image_download_btn]
)
gr.Markdown(
"""
---
**Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space.
"""
)
demo.launch()
if __name__ == "__main__":
main()