wuhp commited on
Commit
3414d75
Β·
verified Β·
1 Parent(s): 69a8551

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -145
app.py CHANGED
@@ -6,21 +6,24 @@ from PIL import Image
6
  import shutil
7
  from ultralytics import YOLO
8
  import requests
 
 
9
 
10
  # Constants
11
  MODELS_DIR = "models"
12
  MODELS_INFO_FILE = "models_info.json"
13
  TEMP_DIR = "temp"
14
  OUTPUT_DIR = "outputs"
 
15
 
16
  def download_file(url, dest_path):
17
  """
18
  Download a file from a URL to the destination path.
19
-
20
  Args:
21
  url (str): The URL to download from.
22
  dest_path (str): The local path to save the file.
23
-
24
  Returns:
25
  bool: True if download succeeded, False otherwise.
26
  """
@@ -40,17 +43,17 @@ def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
40
  """
41
  Load YOLO models and their information from the specified directory and JSON file.
42
  Downloads models if they are not already present.
43
-
44
  Args:
45
  models_dir (str): Path to the models directory.
46
  info_file (str): Path to the JSON file containing model info.
47
-
48
  Returns:
49
  dict: A dictionary of models and their associated information.
50
  """
51
  with open(info_file, 'r') as f:
52
  models_info = json.load(f)
53
-
54
  models = {}
55
  for model_info in models_info:
56
  model_name = model_info['model_name']
@@ -59,7 +62,7 @@ def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
59
  os.makedirs(model_dir, exist_ok=True)
60
  model_path = os.path.join(model_dir, f"{model_name}.pt") # e.g., models/human/human.pt
61
  download_url = model_info['download_url']
62
-
63
  # Check if the model file exists
64
  if not os.path.isfile(model_path):
65
  print(f"Model '{display_name}' not found locally. Downloading from {download_url}...")
@@ -67,7 +70,7 @@ def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
67
  if not success:
68
  print(f"Skipping model '{display_name}' due to download failure.")
69
  continue # Skip loading this model
70
-
71
  try:
72
  # Load the YOLO model
73
  model = YOLO(model_path)
@@ -79,16 +82,16 @@ def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
79
  print(f"Loaded model '{display_name}' from '{model_path}'.")
80
  except Exception as e:
81
  print(f"Error loading model '{display_name}': {e}")
82
-
83
  return models
84
 
85
  def get_model_info(model_info):
86
  """
87
  Retrieve formatted model information for display.
88
-
89
  Args:
90
  model_info (dict): The model's information dictionary.
91
-
92
  Returns:
93
  str: A formatted string containing model details.
94
  """
@@ -96,11 +99,11 @@ def get_model_info(model_info):
96
  class_ids = info.get('class_ids', {})
97
  class_image_counts = info.get('class_image_counts', {})
98
  datasets_used = info.get('datasets_used', [])
99
-
100
  class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()])
101
  class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()])
102
  datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used])
103
-
104
  info_text = (
105
  f"**{info.get('display_name', 'Model Name')}**\n\n"
106
  f"**Architecture:** {info.get('architecture', 'N/A')}\n\n"
@@ -117,66 +120,41 @@ def get_model_info(model_info):
117
  )
118
  return info_text
119
 
120
- def predict_image(model_name, image, confidence, models):
121
  """
122
- Perform prediction on an uploaded image using the selected YOLO model.
123
-
124
  Args:
125
- model_name (str): The name of the selected model.
126
- image (PIL.Image.Image): The uploaded image.
127
- confidence (float): The confidence threshold for detections.
128
- models (dict): The dictionary containing models and their info.
129
-
130
  Returns:
131
- tuple: A status message, the processed image, and the path to the output image.
132
  """
133
- model_entry = models.get(model_name, {})
134
- model = model_entry.get('model', None)
135
- if not model:
136
- return "Error: Model not found.", None, None
137
- try:
138
- # Ensure temporary and output directories exist
139
- os.makedirs(TEMP_DIR, exist_ok=True)
140
- os.makedirs(OUTPUT_DIR, exist_ok=True)
141
-
142
- # Save the uploaded image to a temporary path
143
- input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
144
- image.save(input_image_path)
145
-
146
- # Perform prediction with user-specified confidence
147
- results = model(input_image_path, save=True, save_txt=False, conf=confidence)
148
-
149
- # Determine the output path
150
- # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
151
- latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
152
- output_image_path = os.path.join(latest_run, Path(input_image_path).name)
153
- if not os.path.isfile(output_image_path):
154
- # Alternative method to get the output path
155
- output_image_path = results[0].save()[0]
156
-
157
- # Copy the output image to OUTPUT_DIR
158
- final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg")
159
- shutil.copy(output_image_path, final_output_path)
160
-
161
- # Open the output image
162
- output_image = Image.open(final_output_path)
163
-
164
- return "βœ… Prediction completed successfully.", output_image, final_output_path
165
- except Exception as e:
166
- return f"❌ Error during prediction: {str(e)}", None, None
167
 
168
- def predict_video(model_name, video, confidence, models):
 
 
 
169
  """
170
- Perform prediction on an uploaded video using the selected YOLO model.
171
-
172
  Args:
173
  model_name (str): The name of the selected model.
174
- video (str): Path to the uploaded video file.
175
  confidence (float): The confidence threshold for detections.
176
  models (dict): The dictionary containing models and their info.
177
-
178
  Returns:
179
- tuple: A status message, the processed video, and the path to the output video.
180
  """
181
  model_entry = models.get(model_name, {})
182
  model = model_entry.get('model', None)
@@ -186,28 +164,44 @@ def predict_video(model_name, video, confidence, models):
186
  # Ensure temporary and output directories exist
187
  os.makedirs(TEMP_DIR, exist_ok=True)
188
  os.makedirs(OUTPUT_DIR, exist_ok=True)
189
-
190
- # Save the uploaded video to a temporary path
191
- input_video_path = os.path.join(TEMP_DIR, f"{model_name}_input_video.mp4")
192
- shutil.copy(video, input_video_path)
193
-
194
- # Perform prediction with user-specified confidence and specify output format
195
- # Here, we set save_format to 'avi' to ensure compatibility
196
- results = model(input_video_path, save=True, save_txt=False, conf=confidence, save_format='avi')
197
-
198
- # Determine the output path
199
- # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
200
- latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
201
- output_video_path = os.path.join(latest_run, f"{model_name}_input_video.avi")
202
- if not os.path.isfile(output_video_path):
203
- # Alternative method to get the output path
204
- output_video_path = results[0].save()[0]
205
-
206
- # Copy the output video to OUTPUT_DIR
207
- final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_video.avi")
208
- shutil.copy(output_video_path, final_output_path)
209
-
210
- return "βœ… Prediction completed successfully.", final_output_path, final_output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  except Exception as e:
212
  return f"❌ Error during prediction: {str(e)}", None, None
213
 
@@ -217,16 +211,16 @@ def main():
217
  if not models:
218
  print("No models loaded. Please check your models_info.json and model URLs.")
219
  return
220
-
221
  # Initialize Gradio Blocks interface
222
  with gr.Blocks() as demo:
223
  gr.Markdown("# πŸ§ͺ YOLOv11 Model Tester")
224
  gr.Markdown(
225
  """
226
- Upload images or videos to test different YOLOv11 models. Select a model from the dropdown to see its details.
227
  """
228
  )
229
-
230
  # Model selection and info
231
  with gr.Row():
232
  model_dropdown = gr.Dropdown(
@@ -235,10 +229,10 @@ def main():
235
  value=None
236
  )
237
  model_info = gr.Markdown("**Model Information will appear here.**")
238
-
239
  # Mapping from display_name to model_name
240
  display_to_name = {models[m]['display_name']: m for m in models}
241
-
242
  # Update model_info when a model is selected
243
  def update_model_info(selected_display_name):
244
  if not selected_display_name:
@@ -248,13 +242,13 @@ def main():
248
  return "Model information not available."
249
  model_entry = models[model_name]['info']
250
  return get_model_info(model_entry)
251
-
252
  model_dropdown.change(
253
  fn=update_model_info,
254
  inputs=model_dropdown,
255
  outputs=model_info
256
  )
257
-
258
  # Confidence Threshold Slider
259
  with gr.Row():
260
  confidence_slider = gr.Slider(
@@ -265,68 +259,42 @@ def main():
265
  label="Confidence Threshold",
266
  info="Adjust the minimum confidence required for detections to be displayed."
267
  )
268
-
269
- # Tabs for different input types
270
- with gr.Tabs():
271
- # Image Prediction Tab
272
- with gr.Tab("πŸ–ΌοΈ Image"):
273
- with gr.Column():
274
- image_input = gr.Image(
275
- type='pil',
276
- label="Upload Image for Prediction"
277
- # Removed 'tool' parameter
278
- )
279
- image_predict_btn = gr.Button("πŸ” Predict on Image")
280
- image_status = gr.Markdown("**Status will appear here.**")
281
- image_output = gr.Image(label="Predicted Image")
282
- image_download_btn = gr.File(label="⬇️ Download Predicted Image")
283
-
284
- # Define the image prediction function
285
- def process_image(selected_display_name, image, confidence):
286
- if not selected_display_name:
287
- return "❌ Please select a model.", None, None
288
- model_name = display_to_name.get(selected_display_name)
289
- return predict_image(model_name, image, confidence, models)
290
-
291
- # Connect the predict button
292
- image_predict_btn.click(
293
- fn=process_image,
294
- inputs=[model_dropdown, image_input, confidence_slider],
295
- outputs=[image_status, image_output, image_download_btn]
296
- )
297
-
298
- # Video Prediction Tab
299
- with gr.Tab("πŸŽ₯ Video"):
300
- with gr.Column():
301
- video_input = gr.Video(
302
- label="Upload Video for Prediction"
303
- )
304
- video_predict_btn = gr.Button("πŸ” Predict on Video")
305
- video_status = gr.Markdown("**Status will appear here.**")
306
- video_output = gr.Video(label="Predicted Video")
307
- video_download_btn = gr.File(label="⬇️ Download Predicted Video")
308
-
309
- # Define the video prediction function
310
- def process_video(selected_display_name, video, confidence):
311
- if not selected_display_name:
312
- return "❌ Please select a model.", None, None
313
- model_name = display_to_name.get(selected_display_name)
314
- return predict_video(model_name, video, confidence, models)
315
-
316
- # Connect the predict button
317
- video_predict_btn.click(
318
- fn=process_video,
319
- inputs=[model_dropdown, video_input, confidence_slider],
320
- outputs=[video_status, video_output, video_download_btn]
321
  )
322
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  gr.Markdown(
324
  """
325
  ---
326
  **Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space.
327
  """
328
  )
329
-
330
  # Launch the Gradio app
331
  demo.launch()
332
 
 
6
  import shutil
7
  from ultralytics import YOLO
8
  import requests
9
+ import zipfile
10
+ import uuid
11
 
12
  # Constants
13
  MODELS_DIR = "models"
14
  MODELS_INFO_FILE = "models_info.json"
15
  TEMP_DIR = "temp"
16
  OUTPUT_DIR = "outputs"
17
+ ZIP_DIR = "zips"
18
 
19
  def download_file(url, dest_path):
20
  """
21
  Download a file from a URL to the destination path.
22
+
23
  Args:
24
  url (str): The URL to download from.
25
  dest_path (str): The local path to save the file.
26
+
27
  Returns:
28
  bool: True if download succeeded, False otherwise.
29
  """
 
43
  """
44
  Load YOLO models and their information from the specified directory and JSON file.
45
  Downloads models if they are not already present.
46
+
47
  Args:
48
  models_dir (str): Path to the models directory.
49
  info_file (str): Path to the JSON file containing model info.
50
+
51
  Returns:
52
  dict: A dictionary of models and their associated information.
53
  """
54
  with open(info_file, 'r') as f:
55
  models_info = json.load(f)
56
+
57
  models = {}
58
  for model_info in models_info:
59
  model_name = model_info['model_name']
 
62
  os.makedirs(model_dir, exist_ok=True)
63
  model_path = os.path.join(model_dir, f"{model_name}.pt") # e.g., models/human/human.pt
64
  download_url = model_info['download_url']
65
+
66
  # Check if the model file exists
67
  if not os.path.isfile(model_path):
68
  print(f"Model '{display_name}' not found locally. Downloading from {download_url}...")
 
70
  if not success:
71
  print(f"Skipping model '{display_name}' due to download failure.")
72
  continue # Skip loading this model
73
+
74
  try:
75
  # Load the YOLO model
76
  model = YOLO(model_path)
 
82
  print(f"Loaded model '{display_name}' from '{model_path}'.")
83
  except Exception as e:
84
  print(f"Error loading model '{display_name}': {e}")
85
+
86
  return models
87
 
88
  def get_model_info(model_info):
89
  """
90
  Retrieve formatted model information for display.
91
+
92
  Args:
93
  model_info (dict): The model's information dictionary.
94
+
95
  Returns:
96
  str: A formatted string containing model details.
97
  """
 
99
  class_ids = info.get('class_ids', {})
100
  class_image_counts = info.get('class_image_counts', {})
101
  datasets_used = info.get('datasets_used', [])
102
+
103
  class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()])
104
  class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()])
105
  datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used])
106
+
107
  info_text = (
108
  f"**{info.get('display_name', 'Model Name')}**\n\n"
109
  f"**Architecture:** {info.get('architecture', 'N/A')}\n\n"
 
120
  )
121
  return info_text
122
 
123
+ def zip_processed_images(processed_image_paths, model_name):
124
  """
125
+ Create a ZIP file containing all processed images.
126
+
127
  Args:
128
+ processed_image_paths (list): List of file paths to processed images.
129
+ model_name (str): Name of the model used for processing.
130
+
 
 
131
  Returns:
132
+ str: Path to the created ZIP file.
133
  """
134
+ os.makedirs(ZIP_DIR, exist_ok=True)
135
+ zip_filename = f"{model_name}_processed_images_{uuid.uuid4().hex}.zip"
136
+ zip_path = os.path.join(ZIP_DIR, zip_filename)
137
+
138
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
139
+ for img_path in processed_image_paths:
140
+ arcname = os.path.basename(img_path)
141
+ zipf.write(img_path, arcname)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ print(f"Created ZIP file at {zip_path}.")
144
+ return zip_path
145
+
146
+ def predict_image(model_name, images, confidence, models):
147
  """
148
+ Perform prediction on uploaded images using the selected YOLO model.
149
+
150
  Args:
151
  model_name (str): The name of the selected model.
152
+ images (list): List of uploaded PIL.Image.Image objects.
153
  confidence (float): The confidence threshold for detections.
154
  models (dict): The dictionary containing models and their info.
155
+
156
  Returns:
157
+ tuple: A status message, list of processed images, and a ZIP file for download.
158
  """
159
  model_entry = models.get(model_name, {})
160
  model = model_entry.get('model', None)
 
164
  # Ensure temporary and output directories exist
165
  os.makedirs(TEMP_DIR, exist_ok=True)
166
  os.makedirs(OUTPUT_DIR, exist_ok=True)
167
+
168
+ processed_image_paths = []
169
+ processed_images = []
170
+
171
+ for idx, image in enumerate(images):
172
+ # Generate unique filenames to avoid conflicts
173
+ unique_id = uuid.uuid4().hex
174
+ input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image_{unique_id}.jpg")
175
+ output_image_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image_{unique_id}.jpg")
176
+
177
+ # Save the uploaded image to a temporary path
178
+ image.save(input_image_path)
179
+
180
+ # Perform prediction with user-specified confidence
181
+ results = model(input_image_path, save=True, save_txt=False, conf=confidence)
182
+
183
+ # Determine the output path
184
+ # Ultralytics YOLO saves the results in 'runs/detect/predict' by default
185
+ latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
186
+ detected_image_path = os.path.join(latest_run, Path(input_image_path).name)
187
+
188
+ if not os.path.isfile(detected_image_path):
189
+ # Alternative method to get the output path
190
+ detected_image_path = results[0].save()[0]
191
+
192
+ # Copy the output image to OUTPUT_DIR with a unique name
193
+ shutil.copy(detected_image_path, output_image_path)
194
+ processed_image_paths.append(output_image_path)
195
+
196
+ # Open the processed image for display
197
+ processed_image = Image.open(output_image_path)
198
+ processed_images.append(processed_image)
199
+
200
+ # Create a ZIP file containing all processed images
201
+ zip_path = zip_processed_images(processed_image_paths, model_name)
202
+
203
+ return "βœ… Prediction completed successfully.", processed_images, zip_path
204
+
205
  except Exception as e:
206
  return f"❌ Error during prediction: {str(e)}", None, None
207
 
 
211
  if not models:
212
  print("No models loaded. Please check your models_info.json and model URLs.")
213
  return
214
+
215
  # Initialize Gradio Blocks interface
216
  with gr.Blocks() as demo:
217
  gr.Markdown("# πŸ§ͺ YOLOv11 Model Tester")
218
  gr.Markdown(
219
  """
220
+ Upload one or multiple images to test different YOLOv11 models. Select a model from the dropdown to see its details.
221
  """
222
  )
223
+
224
  # Model selection and info
225
  with gr.Row():
226
  model_dropdown = gr.Dropdown(
 
229
  value=None
230
  )
231
  model_info = gr.Markdown("**Model Information will appear here.**")
232
+
233
  # Mapping from display_name to model_name
234
  display_to_name = {models[m]['display_name']: m for m in models}
235
+
236
  # Update model_info when a model is selected
237
  def update_model_info(selected_display_name):
238
  if not selected_display_name:
 
242
  return "Model information not available."
243
  model_entry = models[model_name]['info']
244
  return get_model_info(model_entry)
245
+
246
  model_dropdown.change(
247
  fn=update_model_info,
248
  inputs=model_dropdown,
249
  outputs=model_info
250
  )
251
+
252
  # Confidence Threshold Slider
253
  with gr.Row():
254
  confidence_slider = gr.Slider(
 
259
  label="Confidence Threshold",
260
  info="Adjust the minimum confidence required for detections to be displayed."
261
  )
262
+
263
+ # Image Prediction Tab (now supporting multiple images)
264
+ with gr.Tab("πŸ–ΌοΈ Image"):
265
+ with gr.Column():
266
+ image_input = gr.Images(
267
+ label="Upload Images for Prediction",
268
+ type='pil'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  )
270
+ image_predict_btn = gr.Button("πŸ” Predict on Images")
271
+ image_status = gr.Markdown("**Status will appear here.**")
272
+ image_gallery = gr.Gallery(label="Predicted Images").style(grid=[2], height="auto")
273
+ image_download_btn = gr.File(label="⬇️ Download All Processed Images (ZIP)")
274
+
275
+ # Define the image prediction function
276
+ def process_image(selected_display_name, images, confidence):
277
+ if not selected_display_name:
278
+ return "❌ Please select a model.", None, None
279
+ if not images:
280
+ return "❌ Please upload at least one image.", None, None
281
+ model_name = display_to_name.get(selected_display_name)
282
+ return predict_image(model_name, images, confidence, models)
283
+
284
+ # Connect the predict button
285
+ image_predict_btn.click(
286
+ fn=process_image,
287
+ inputs=[model_dropdown, image_input, confidence_slider],
288
+ outputs=[image_status, image_gallery, image_download_btn]
289
+ )
290
+
291
  gr.Markdown(
292
  """
293
  ---
294
  **Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space.
295
  """
296
  )
297
+
298
  # Launch the Gradio app
299
  demo.launch()
300