diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..019042addff9c4d6537762e20510c79125c03e8a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +ultralytics/assets/bus.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 367e0a0ac90363c031c110305db9bf1d72e7852b..cc63d741a68484f2e1efe8fcf122ad452f809b14 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,136 @@ ---- -title: Yolov12 -emoji: 🌖 -colorFrom: purple -colorTo: blue -sdk: gradio -sdk_version: 5.16.1 -app_file: app.py -pinned: false -license: apache-2.0 -short_description: 'YOLOv12: Attention-Centric Real-Time Object Detectors' ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + + +
+

YOLOv12

+

YOLOv12: Attention-Centric Real-Time Object Detectors

+ +[Yunjie Tian](https://sunsmarterjie.github.io/)1, [Qixiang Ye](https://people.ucas.ac.cn/~qxye?language=en)2, [David Doermann](https://cse.buffalo.edu/~doermann/)1 + +1 University at Buffalo, SUNY, 2 University of Chinese Academy of Sciences. + + +

+
+ Comparison with popular methods in terms of latency-accuracy (left) and FLOPs-accuracy (right) trade-offs +

+ +
+ +[![arXiv](https://img.shields.io/badge/arXiv-2502.12524-b31b1b.svg)](https://arxiv.org/abs/2502.12524) + +## Updates +- 2025/02/19: [arXiv version](https://arxiv.org/abs/2502.12524) is public. + + +
+ + Abstract + +Enhancing the network architecture of the YOLO framework has been crucial for a long time but has focused on CNN-based improvements despite the proven superiority of attention mechanisms in modeling capabilities. This is because attention-based models cannot match the speed of CNN-based models. This paper proposes an attention-centric YOLO framework, namely YOLOv12, that matches the speed of previous CNN-based ones while harnessing the performance benefits of attention mechanisms. + +YOLOv12 surpasses all popular real-time object detectors in accuracy with competitive speed. For example, YOLOv12-N achieves 40.6% mAP with an inference latency of 1.64 ms on a T4 GPU, outperforming advanced YOLOv10-N / YOLOv11-N by 2.1%/1.2% mAP with a comparable speed. This advantage extends to other model scales. YOLOv12 also surpasses end-to-end real-time detectors that improve DETR, such as RT-DETR / RT-DETRv2: YOLOv12-S beats RT-DETR-R18 / RT-DETRv2-R18 while running 42% faster, using only 36% of the computation and 45% of the parameters. +
+ + +## Main Results +COCO + +| Model | size
(pixels) | mAPval
50-95 | Speed
T4 TensorRT10
| params
(M) | FLOPs
(G) | +| :----------------------------------------------------------------------------------- | :-------------------: | :-------------------:| :------------------------------:| :-----------------:| :---------------:| +| [YOLO12n](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12n.pt) | 640 | 40.6 | 1.64 | 2.6 | 6.5 | +| [YOLO12s](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12s.pt) | 640 | 48.0 | 2.61 | 9.3 | 21.4 | +| [YOLO12m](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12m.pt) | 640 | 52.5 | 4.86 | 20.2 | 67.5 | +| [YOLO12l](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12l.pt) | 640 | 53.7 | 6.77 | 26.4 | 88.9 | +| [YOLO12x](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12x.pt) | 640 | 55.2 | 11.79 | 59.1 | 199.0 | + +## Installation +``` +wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl +conda create -n yolov12 python=3.11 +conda activate yolov12 +pip install -r requirements.txt +pip install -e . +``` + +## Validation +[`yolov12n`](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12n.pt) +[`yolov12s`](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12s.pt) +[`yolov12m`](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12m.pt) +[`yolov12l`](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12l.pt) +[`yolov12x`](https://github.com/sunsmarterjie/yolov12/releases/download/v1.0/yolov12x.pt) + +```python +from ultralytics import YOLO + +model = YOLO('yolov12{n/s/m/l/x}.pt') +model.val(data='coco.yaml', save_json=True) +``` + +## Training +```python +from ultralytics import YOLO + +model = YOLO('yolov12n.yaml') + +# Train the model +results = model.train( + data='coco.yaml', + epochs=600, + batch=256, + imgsz=640, + scale=0.5, # S:0.9; M:0.9; L:0.9; X:0.9 + mosaic=1.0, + mixup=0.0, # S:0.05; M:0.15; L:0.15; X:0.2 + copy_paste=0.1, # S:0.15; M:0.4; L:0.5; X:0.6 + device="0,1,2,3", +) + +# Evaluate model performance on the validation set +metrics = model.val() + +# Perform object detection on an image +results = model("path/to/image.jpg") +results[0].show() + +``` + +## Prediction +```python +from ultralytics import YOLO + +model = YOLO('yolov12{n/s/m/l/x}.pt') +model.predict() +``` + +## Export +```python +from ultralytics import YOLO + +model = YOLO('yolov12{n/s/m/l/x}.pt') +model.export(format="engine", half=True) # or format="onnx" +``` + + +## Demo + +``` +python app.py +# Please visit http://127.0.0.1:7860 +``` + + +## Acknowledgement + +The code is based on [ultralytics](https://github.com/ultralytics/ultralytics). Thanks for their excellent work! + +## Citation + +```BibTeX +@article{tian2025yolov12, + title={YOLOv12: Attention-Centric Real-Time Object Detectors}, + author={Tian, Yunjie and Ye, Qixiang and Doermann, David}, + journal={arXiv preprint arXiv:2502.12524}, + year={2025} +} +``` + diff --git a/app.py b/app.py index 5f92ff8ab351503e3ac4956597b6ff02666d22be..6d770057ec48df1a9e3abc13334a351ef3079b35 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,7 @@ +# -------------------------------------------------------- +# Based on yolov10 +# https://github.com/THU-MIG/yolov10/app.py +# --------------------------------------------------------' import gradio as gr import cv2 diff --git a/assets/tradeoff.svg b/assets/tradeoff.svg new file mode 100644 index 0000000000000000000000000000000000000000..0525b58a059513148cc22e079711b5f4c9cc72ef --- /dev/null +++ b/assets/tradeoff.svg @@ -0,0 +1,3328 @@ + + + + + + + + + + + + + + + + + + + + + FLOPs (G) + + + MS COCO mAP (%) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 2.5 + + + 5.0 + + + 7.5 + + + 10.0 + + + 12.5 + + + 15.0 + + + 17.5 + + + 20.0 + + + Latency (ms) + + + MS COCO mAP (%) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 38 + + + 40 + + + 42 + + + 44 + + + 46 + + + 48 + + + 52 + + + 54 + + + 50 + + + 56 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 38 + + + 40 + + + 42 + + + 44 + + + 46 + + + 48 + + + 52 + + + 54 + + + 50 + + + 56 + + + + + + 0 + + + + + + 20 + + + + + + 40 + + + + + + 60 + + + + + + 140 + + + + + + 160 + + + + + + 100 + + + + + + 120 + + + + + + 80 + + + + + + 180 + + + + + + 200 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + YOLOv6-3.0 + + + + + + + + + + + + + + + YOLOv7 + + + + + + + + + + + + + + + YOLOv8 + + + + + + + + + + + + + + + RT-DETRv2 + + + + + + + + + + + + + + + YOLO-MS + + + + + + + + + + + + + + + YOLOv11 + + + + + + + + + + + + + + + Gold-YOLO + + + + + + + + + + + + + + + YOLOv10 + + + + + + + + + + + + + + + YOLOv12 (ours) + + + + + + + + + + + + + + + RT-DETR + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + YOLOv6-3.0 + + + + + + + + + + + + + + + YOLOv7 + + + + + + + + + + + + + + + YOLOv8 + + + + + + + + + + + + + + + RT-DETRv2 + + + + + + + + + + + + + + + YOLO-MS + + + + + + + + + + + + + + + YOLOv11 + + + + + + + + + + + + + + + Gold-YOLO + + + + + + + + + + + + + + + YOLOv10 + + + + + + + + + + + + + + + YOLOv12 (ours) + + + + + + + + + + + + + + + RT-DETR + + + + + + diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..2cfbfd352e9ce13bdcfc772f8a381af84a852b2a --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,92 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:latest image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Image is CUDA-optimized for YOLO11 single/multi-GPU training and inference + +# Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch or nvcr.io/nvidia/pytorch:23.03-py3 +FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime + +# Set environment variables +# Avoid DDP error "MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library" https://github.com/pytorch/pytorch/issues/37377 +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 \ + MKL_THREADING_LAYER=GNU \ + OMP_NUM_THREADS=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Install linux packages +# g++ required to build 'tflite_support' and 'lap' packages, libusb-1.0-0 required for 'tflite_support' package +# libsm6 required by libqxcb to create QT-based windows for visualization; set 'QT_DEBUG_PLUGINS=1' to test in docker +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + gcc git zip unzip wget curl htop libgl1 libglib2.0-0 libpython3-dev gnupg g++ libusb-1.0-0 libsm6 \ + && rm -rf /var/lib/apt/lists/* + +# Security updates +# https://security.snyk.io/vuln/SNYK-UBUNTU1804-OPENSSL-3314796 +RUN apt upgrade --no-install-recommends -y openssl tar + +# Create working directory +WORKDIR /ultralytics + +# Copy contents and configure git +COPY . . +RUN sed -i '/^\[http "https:\/\/github\.com\/"\]/,+1d' .git/config +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Install pip packages +RUN pip install uv +# Note -cu12 must be used with tensorrt +RUN uv pip install --system -e ".[export]" tensorrt-cu12 "albumentations>=1.4.6" comet pycocotools + +# Run exports to AutoInstall packages +# Edge TPU export fails the first time so is run twice here +RUN yolo export model=tmp/yolo11n.pt format=edgetpu imgsz=32 || yolo export model=tmp/yolo11n.pt format=edgetpu imgsz=32 +RUN yolo export model=tmp/yolo11n.pt format=ncnn imgsz=32 +# Requires <= Python 3.10, bug with paddlepaddle==2.5.0 https://github.com/PaddlePaddle/X2Paddle/issues/991 +RUN uv pip install --system "paddlepaddle>=2.6.0" x2paddle +# Fix error: `np.bool` was a deprecated alias for the builtin `bool` segmentation error in Tests +RUN uv pip install --system numpy==1.23.5 + +# Remove extra build files +RUN rm -rf tmp /root/.config/Ultralytics/persistent_cache.json + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest && sudo docker build -f docker/Dockerfile -t $t . && sudo docker push $t + +# Pull and Run with access to all GPUs +# t=ultralytics/ultralytics:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t + +# Pull and Run with access to GPUs 2 and 3 (inside container CUDA devices will appear as 0 and 1) +# t=ultralytics/ultralytics:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus '"device=2,3"' $t + +# Pull and Run with local directory access +# t=ultralytics/ultralytics:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/shared/datasets:/datasets $t + +# Kill all +# sudo docker kill $(sudo docker ps -q) + +# Kill all image-based +# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/ultralytics:latest) + +# DockerHub tag update +# t=ultralytics/ultralytics:latest tnew=ultralytics/ultralytics:v6.2 && sudo docker pull $t && sudo docker tag $t $tnew && sudo docker push $tnew + +# Clean up +# sudo docker system prune -a --volumes + +# Update Ubuntu drivers +# https://www.maketecheasier.com/install-nvidia-drivers-ubuntu/ + +# DDP test +# python -m torch.distributed.run --nproc_per_node 2 --master_port 1 train.py --epochs 3 + +# GCP VM from Image +# docker.io/ultralytics/ultralytics:latest diff --git a/docker/Dockerfile-arm64 b/docker/Dockerfile-arm64 new file mode 100644 index 0000000000000000000000000000000000000000..dce273203810a5d36cd7d35c22f9c3c532e8b59d --- /dev/null +++ b/docker/Dockerfile-arm64 @@ -0,0 +1,58 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:latest-arm64 image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Image is aarch64-compatible for Apple M1, M2, M3, Raspberry Pi and other ARM architectures + +# Start FROM Ubuntu image https://hub.docker.com/_/ubuntu with "FROM arm64v8/ubuntu:22.04" (deprecated) +# Start FROM Debian image for arm64v8 https://hub.docker.com/r/arm64v8/debian (new) +FROM arm64v8/debian:bookworm-slim + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Install linux packages +# g++ required to build 'tflite_support' and 'lap' packages, libusb-1.0-0 required for 'tflite_support' package +# pkg-config and libhdf5-dev (not included) are needed to build 'h5py==3.11.0' aarch64 wheel required by 'tensorflow' +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + python3-pip git zip unzip wget curl htop gcc libgl1 libglib2.0-0 libpython3-dev gnupg g++ libusb-1.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Create working directory +WORKDIR /ultralytics + +# Copy contents and configure git +COPY . . +RUN sed -i '/^\[http "https:\/\/github\.com\/"\]/,+1d' .git/config +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Install pip packages +RUN pip install uv +RUN uv pip install --system -e ".[export]" --break-system-packages + +# Creates a symbolic link to make 'python' point to 'python3' +RUN ln -sf /usr/bin/python3 /usr/bin/python + +# Remove extra build files +RUN rm -rf /root/.config/Ultralytics/persistent_cache.json + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-arm64 && sudo docker build --platform linux/arm64 -f docker/Dockerfile-arm64 -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-arm64 && sudo docker run -it --ipc=host $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-arm64 && sudo docker pull $t && sudo docker run -it --ipc=host $t + +# Pull and Run with local volume mounted +# t=ultralytics/ultralytics:latest-arm64 && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/shared/datasets:/datasets $t diff --git a/docker/Dockerfile-conda b/docker/Dockerfile-conda new file mode 100644 index 0000000000000000000000000000000000000000..aa1dff53bf08aba2bf04f2385bcdd770308446ad --- /dev/null +++ b/docker/Dockerfile-conda @@ -0,0 +1,50 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:latest-conda image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Image is optimized for Ultralytics Anaconda (https://anaconda.org/conda-forge/ultralytics) installation and usage + +# Start FROM miniconda3 image https://hub.docker.com/r/continuumio/miniconda3 +FROM continuumio/miniconda3:latest + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Install linux packages +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + libgl1 \ + && rm -rf /var/lib/apt/lists/* + +# Copy contents +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Install conda packages +# mkl required to fix 'OSError: libmkl_intel_lp64.so.2: cannot open shared object file: No such file or directory' +RUN conda config --set solver libmamba && \ + conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia && \ + conda install -c conda-forge ultralytics mkl + # conda install -c pytorch -c nvidia -c conda-forge pytorch torchvision pytorch-cuda=12.1 ultralytics mkl + +# Remove extra build files +RUN rm -rf /root/.config/Ultralytics/persistent_cache.json + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-conda && sudo docker build -f docker/Dockerfile-cpu -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-conda && sudo docker run -it --ipc=host $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-conda && sudo docker pull $t && sudo docker run -it --ipc=host $t + +# Pull and Run with local volume mounted +# t=ultralytics/ultralytics:latest-conda && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/shared/datasets:/datasets $t diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu new file mode 100644 index 0000000000000000000000000000000000000000..79d5d50b707f98d845204a91bbffb3e296c84bf0 --- /dev/null +++ b/docker/Dockerfile-cpu @@ -0,0 +1,62 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:latest-cpu image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLO11 deployments + +# Use official Python base image for reproducibility (3.11.10 for export and 3.12.6 for inference) +FROM python:3.11.10-slim-bookworm + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Install linux packages +# g++ required to build 'tflite_support' and 'lap' packages, libusb-1.0-0 required for 'tflite_support' package +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + python3-pip git zip unzip wget curl htop libgl1 libglib2.0-0 libpython3-dev gnupg g++ libusb-1.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Create working directory +WORKDIR /ultralytics + +# Copy contents and configure git +COPY . . +RUN sed -i '/^\[http "https:\/\/github\.com\/"\]/,+1d' .git/config +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Install pip packages +RUN pip install uv +RUN uv pip install --system -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match + +# Run exports to AutoInstall packages +RUN yolo export model=tmp/yolo11n.pt format=edgetpu imgsz=32 +RUN yolo export model=tmp/yolo11n.pt format=ncnn imgsz=32 +# Requires Python<=3.10, bug with paddlepaddle==2.5.0 https://github.com/PaddlePaddle/X2Paddle/issues/991 +RUN uv pip install --system "paddlepaddle>=2.6.0" x2paddle + +# Remove extra build files +RUN rm -rf tmp /root/.config/Ultralytics/persistent_cache.json + +# Set default command to bash +CMD ["/bin/bash"] + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-cpu && sudo docker build -f docker/Dockerfile-cpu -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-cpu && sudo docker run -it --ipc=host --name NAME $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-cpu && sudo docker pull $t && sudo docker run -it --ipc=host --name NAME $t + +# Pull and Run with local volume mounted +# t=ultralytics/ultralytics:latest-cpu && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/shared/datasets:/datasets $t diff --git a/docker/Dockerfile-jetson-jetpack4 b/docker/Dockerfile-jetson-jetpack4 new file mode 100644 index 0000000000000000000000000000000000000000..e11279dad90804528e09029e18bbad971ced2652 --- /dev/null +++ b/docker/Dockerfile-jetson-jetpack4 @@ -0,0 +1,70 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:jetson-jetpack4 image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Supports JetPack4.x for YOLO11 on Jetson Nano, TX2, Xavier NX, AGX Xavier + +# Start FROM https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-cuda +FROM nvcr.io/nvidia/l4t-cuda:10.2.460-runtime + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Add NVIDIA repositories for TensorRT dependencies +RUN wget -q -O - https://repo.download.nvidia.com/jetson/jetson-ota-public.asc | apt-key add - && \ + echo "deb https://repo.download.nvidia.com/jetson/common r32.7 main" > /etc/apt/sources.list.d/nvidia-l4t-apt-source.list && \ + echo "deb https://repo.download.nvidia.com/jetson/t194 r32.7 main" >> /etc/apt/sources.list.d/nvidia-l4t-apt-source.list + +# Install dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + git python3.8 python3.8-dev python3-pip python3-libnvinfer libopenmpi-dev libopenblas-base libomp-dev gcc \ + && rm -rf /var/lib/apt/lists/* + +# Create symbolic links for python3.8 and pip3 +RUN ln -sf /usr/bin/python3.8 /usr/bin/python3 +RUN ln -s /usr/bin/pip3 /usr/bin/pip + +# Create working directory +WORKDIR /ultralytics + +# Copy contents and configure git +COPY . . +RUN sed -i '/^\[http "https:\/\/github\.com\/"\]/,+1d' .git/config +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Download onnxruntime-gpu 1.8.0 and tensorrt 8.2.0.6 +# Other versions can be seen in https://elinux.org/Jetson_Zoo and https://forums.developer.nvidia.com/t/pytorch-for-jetson/72048 +ADD https://nvidia.box.com/shared/static/gjqofg7rkg97z3gc8jeyup6t8n9j8xjw.whl onnxruntime_gpu-1.8.0-cp38-cp38-linux_aarch64.whl +ADD https://forums.developer.nvidia.com/uploads/short-url/hASzFOm9YsJx6VVFrDW1g44CMmv.whl tensorrt-8.2.0.6-cp38-none-linux_aarch64.whl + +# Install pip packages +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install uv +RUN uv pip install --system \ + onnxruntime_gpu-1.8.0-cp38-cp38-linux_aarch64.whl \ + tensorrt-8.2.0.6-cp38-none-linux_aarch64.whl \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-1.11.0a0+gitbc2c6ed-cp38-cp38-linux_aarch64.whl \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/torchvision-0.12.0a0+9b5a3fe-cp38-cp38-linux_aarch64.whl +RUN uv pip install --system -e ".[export]" + +# Remove extra build files +RUN rm -rf *.whl /root/.config/Ultralytics/persistent_cache.json + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-jetson-jetpack4 && sudo docker build --platform linux/arm64 -f docker/Dockerfile-jetson-jetpack4 -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-jetson-jetpack4 && sudo docker run -it --ipc=host $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-jetson-jetpack4 && sudo docker pull $t && sudo docker run -it --ipc=host $t + +# Pull and Run with NVIDIA runtime +# t=ultralytics/ultralytics:latest-jetson-jetpack4 && sudo docker pull $t && sudo docker run -it --ipc=host --runtime=nvidia $t diff --git a/docker/Dockerfile-jetson-jetpack5 b/docker/Dockerfile-jetson-jetpack5 new file mode 100644 index 0000000000000000000000000000000000000000..bfedb6e0cf2bfa44b7d315f8b0814819e6892245 --- /dev/null +++ b/docker/Dockerfile-jetson-jetpack5 @@ -0,0 +1,57 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:jetson-jetson-jetpack5 image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Supports JetPack5.1.2 for YOLO11 on Jetson Xavier NX, AGX Xavier, AGX Orin, Orin Nano and Orin NX + +# Start FROM https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-jetpack +FROM nvcr.io/nvidia/l4t-jetpack:r35.4.1 + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Install dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + git python3-pip libopenmpi-dev libopenblas-base libomp-dev \ + && rm -rf /var/lib/apt/lists/* + +# Create working directory +WORKDIR /ultralytics + +# Copy contents and configure git +COPY . . +RUN sed -i '/^\[http "https:\/\/github\.com\/"\]/,+1d' .git/config +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Pip install onnxruntime-gpu, torch, torchvision and ultralytics +RUN python3 -m pip install --upgrade pip uv +RUN uv pip install --system \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.18.0-cp38-cp38-linux_aarch64.whl \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.1.0a0+41361538.nv23.06-cp38-cp38-linux_aarch64.whl \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/torchvision-0.16.2+c6f3977-cp38-cp38-linux_aarch64.whl + +RUN uv pip install --system -e ".[export]" + +# Remove extra build files +RUN rm -rf *.whl /root/.config/Ultralytics/persistent_cache.json + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-jetson-jetpack5 && sudo docker build --platform linux/arm64 -f docker/Dockerfile-jetson-jetpack5 -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-jetson-jetpack5 && sudo docker run -it --ipc=host $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-jetson-jetpack5 && sudo docker pull $t && sudo docker run -it --ipc=host $t + +# Pull and Run with NVIDIA runtime +# t=ultralytics/ultralytics:latest-jetson-jetpack5 && sudo docker pull $t && sudo docker run -it --ipc=host --runtime=nvidia $t diff --git a/docker/Dockerfile-jetson-jetpack6 b/docker/Dockerfile-jetson-jetpack6 new file mode 100644 index 0000000000000000000000000000000000000000..fa6ec651b0a64ed23d88feade5b9df018dd0f8ab --- /dev/null +++ b/docker/Dockerfile-jetson-jetpack6 @@ -0,0 +1,58 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:jetson-jetpack6 image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Supports JetPack6.1 for YOLO11 on Jetson AGX Orin, Orin NX and Orin Nano Series + +# Start FROM https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-jetpack +FROM nvcr.io/nvidia/l4t-jetpack:r36.4.0 + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Install dependencies +ADD https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/arm64/cuda-keyring_1.1-1_all.deb . +RUN dpkg -i cuda-keyring_1.1-1_all.deb && \ + apt-get update && \ + apt-get install -y --no-install-recommends \ + git python3-pip libopenmpi-dev libopenblas-base libomp-dev libcusparselt0 libcusparselt-dev \ + && rm -rf /var/lib/apt/lists/* + +# Create working directory +WORKDIR /ultralytics + +# Copy contents and configure git +COPY . . +RUN sed -i '/^\[http "https:\/\/github\.com\/"\]/,+1d' .git/config +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Pip install onnxruntime-gpu, torch, torchvision and ultralytics +RUN python3 -m pip install --upgrade pip uv +RUN uv pip install --system \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/onnxruntime_gpu-1.20.0-cp310-cp310-linux_aarch64.whl \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/torch-2.5.0a0+872d972e41.nv24.08-cp310-cp310-linux_aarch64.whl \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/torchvision-0.20.0a0+afc54f7-cp310-cp310-linux_aarch64.whl +RUN uv pip install --system -e ".[export]" + +# Remove extra build files +RUN rm -rf *.whl /root/.config/Ultralytics/persistent_cache.json + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-jetson-jetpack6 && sudo docker build --platform linux/arm64 -f docker/Dockerfile-jetson-jetpack6 -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-jetson-jetpack6 && sudo docker run -it --ipc=host $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-jetson-jetpack6 && sudo docker pull $t && sudo docker run -it --ipc=host $t + +# Pull and Run with NVIDIA runtime +# t=ultralytics/ultralytics:latest-jetson-jetpack6 && sudo docker pull $t && sudo docker run -it --ipc=host --runtime=nvidia $t diff --git a/docker/Dockerfile-jupyter b/docker/Dockerfile-jupyter new file mode 100644 index 0000000000000000000000000000000000000000..c458ff8848056d44566c67bbdb3ebb9108bcb508 --- /dev/null +++ b/docker/Dockerfile-jupyter @@ -0,0 +1,33 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:latest-jupyter image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Image provides JupyterLab interface for interactive YOLO development and includes tutorial notebooks + +# Start from Python-based Ultralytics image for full Python environment +FROM ultralytics/ultralytics:latest-python + +# Install JupyterLab for interactive development +RUN uv pip install --system jupyterlab + +# Create persistent data directory structure +RUN mkdir /data + +# Configure YOLO directories +RUN mkdir /data/{datasets,weights,runs} && \ + yolo settings datasets_dir="/data/datasets" weights_dir="/data/weights" runs_dir="/data/runs" + +# Start JupyterLab with tutorial notebook +ENTRYPOINT ["/usr/local/bin/jupyter", "lab", "--allow-root", "--ip=*", "/ultralytics/examples/tutorial.ipynb"] + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-jupyter && sudo docker build -f docker/Dockerfile-jupyter -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-jupyter && sudo docker run -it --ipc=host -p 8888:8888 $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-jupyter && sudo docker pull $t && sudo docker run -it --ipc=host -p 8888:8888 $t + +# Pull and Run with local volume mounted +# t=ultralytics/ultralytics:latest-jupyter && sudo docker pull $t && sudo docker run -it --ipc=host -p 8888:8888 -v "$(pwd)"/datasets:/data/datasets $t diff --git a/docker/Dockerfile-python b/docker/Dockerfile-python new file mode 100644 index 0000000000000000000000000000000000000000..796d18879243c1313850245fd97acc6f4b32e37b --- /dev/null +++ b/docker/Dockerfile-python @@ -0,0 +1,59 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds ultralytics/ultralytics:latest-cpu image on DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Image is CPU-optimized for ONNX, OpenVINO and PyTorch YOLO11 deployments + +# Use official Python base image for reproducibility (3.11.10 for export and 3.12.6 for inference) +FROM python:3.11.10-slim-bookworm + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 + +# Downloads to user config dir +ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.ttf \ + https://github.com/ultralytics/assets/releases/download/v0.0.0/Arial.Unicode.ttf \ + /root/.config/Ultralytics/ + +# Install linux packages +# g++ required to build 'tflite_support' and 'lap' packages, libusb-1.0-0 required for 'tflite_support' package +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + python3-pip git zip unzip wget curl htop libgl1 libglib2.0-0 libpython3-dev gnupg g++ libusb-1.0-0 \ + && rm -rf /var/lib/apt/lists/* + +# Create working directory +WORKDIR /ultralytics + +# Copy contents and configure git +COPY . . +RUN sed -i '/^\[http "https:\/\/github\.com\/"\]/,+1d' .git/config +ADD https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt . + +# Install pip packages +RUN pip install uv +RUN uv pip install --system -e ".[export]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match + +# Run exports to AutoInstall packages +RUN yolo export model=tmp/yolo11n.pt format=edgetpu imgsz=32 +RUN yolo export model=tmp/yolo11n.pt format=ncnn imgsz=32 +# Requires Python<=3.10, bug with paddlepaddle==2.5.0 https://github.com/PaddlePaddle/X2Paddle/issues/991 +RUN uv pip install --system "paddlepaddle>=2.6.0" x2paddle + +# Remove extra build files +RUN rm -rf tmp /root/.config/Ultralytics/persistent_cache.json + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-python && sudo docker build -f docker/Dockerfile-python -t $t . && sudo docker push $t + +# Run +# t=ultralytics/ultralytics:latest-python && sudo docker run -it --ipc=host $t + +# Pull and Run +# t=ultralytics/ultralytics:latest-python && sudo docker pull $t && sudo docker run -it --ipc=host $t + +# Pull and Run with local volume mounted +# t=ultralytics/ultralytics:latest-python && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/shared/datasets:/datasets $t diff --git a/docker/Dockerfile-runner b/docker/Dockerfile-runner new file mode 100644 index 0000000000000000000000000000000000000000..5de5ee06507bc249872acc612ade577a3afcf000 --- /dev/null +++ b/docker/Dockerfile-runner @@ -0,0 +1,44 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Builds GitHub actions CI runner image for deployment to DockerHub https://hub.docker.com/r/ultralytics/ultralytics +# Image is CUDA-optimized for YOLO11 single/multi-GPU training and inference tests + +# Start FROM Ultralytics GPU image +FROM ultralytics/ultralytics:latest + +# Set environment variables +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PIP_NO_CACHE_DIR=1 \ + PIP_BREAK_SYSTEM_PACKAGES=1 \ + RUNNER_ALLOW_RUNASROOT=1 \ + DEBIAN_FRONTEND=noninteractive + +# Set the working directory +WORKDIR /actions-runner + +# Download and unpack the latest runner from https://github.com/actions/runner +RUN FILENAME=actions-runner-linux-x64-2.320.0.tar.gz && \ + curl -o $FILENAME -L https://github.com/actions/runner/releases/download/v2.320.0/$FILENAME && \ + tar xzf $FILENAME && \ + rm $FILENAME + +# Install runner dependencies +RUN uv pip install --system pytest-cov +RUN ./bin/installdependencies.sh && \ + apt-get -y install libicu-dev + +# Inline ENTRYPOINT command to configure and start runner with default TOKEN and NAME +ENTRYPOINT sh -c './config.sh --url https://github.com/ultralytics/ultralytics \ + --token ${GITHUB_RUNNER_TOKEN:-TOKEN} \ + --name ${GITHUB_RUNNER_NAME:-NAME} \ + --labels gpu-latest \ + --replace && \ + ./run.sh' + +# Usage Examples ------------------------------------------------------------------------------------------------------- + +# Build and Push +# t=ultralytics/ultralytics:latest-runner && sudo docker build -f docker/Dockerfile-runner -t $t . && sudo docker push $t + +# Pull and Run in detached mode with access to GPUs 0 and 1 +# t=ultralytics/ultralytics:latest-runner && sudo docker run -d -e GITHUB_RUNNER_TOKEN=TOKEN -e GITHUB_RUNNER_NAME=NAME --ipc=host --gpus '"device=0,1"' $t diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ee06d337b62c6432e8944b838c0fc05f74e62562 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,40 @@ +## Ultralytics Examples + +This directory features a collection of real-world applications and walkthroughs, provided as either Python files or notebooks. Explore the examples below to see how YOLO can be integrated into various applications. + +### Ultralytics YOLO Example Applications + +| Title | Format | Contributor | +| ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------ | ----------------------------------------------------------------------------------------- | +| [YOLO ONNX Detection Inference with C++](./YOLOv8-CPP-Inference) | C++/ONNX | [Justas Bartnykas](https://github.com/JustasBart) | +| [YOLO OpenCV ONNX Detection Python](./YOLOv8-OpenCV-ONNX-Python) | OpenCV/Python/ONNX | [Farid Inawan](https://github.com/frdteknikelektro) | +| [YOLO C# ONNX-Runtime](https://github.com/dme-compunet/YoloSharp) | .NET/ONNX-Runtime | [Compunet](https://github.com/dme-compunet) | +| [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) | +| [YOLOv8 on NVIDIA Jetson(TensorRT and DeepStream)](https://wiki.seeedstudio.com/YOLOv8-DeepStream-TRT-Jetson/) | Python | [Lakshantha](https://github.com/lakshanthad) | +| [YOLOv8 ONNXRuntime Python](./YOLOv8-ONNXRuntime) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) | +| [RTDETR ONNXRuntime Python](./RTDETR-ONNXRuntime-Python) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) | +| [YOLOv8 ONNXRuntime CPP](./YOLOv8-ONNXRuntime-CPP) | C++/ONNXRuntime | [DennisJcy](https://github.com/DennisJcy), [Onuralp Sezer](https://github.com/onuralpszr) | +| [RTDETR ONNXRuntime C#](https://github.com/Kayzwer/yolo-cs/blob/master/RTDETR.cs) | C#/ONNX | [Kayzwer](https://github.com/Kayzwer) | +| [YOLOv8 SAHI Video Inference](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) | +| [YOLOv8 Region Counter](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-Region-Counter/yolov8_region_counter.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) | +| [YOLOv8 Segmentation ONNXRuntime Python](./YOLOv8-Segmentation-ONNXRuntime-Python) | Python/ONNXRuntime | [jamjamjon](https://github.com/jamjamjon) | +| [YOLOv8 LibTorch CPP](./YOLOv8-LibTorch-CPP-Inference) | C++/LibTorch | [Myyura](https://github.com/Myyura) | +| [YOLOv8 OpenCV INT8 TFLite Python](./YOLOv8-TFLite-Python) | Python | [Wamiq Raza](https://github.com/wamiqraza) | +| [YOLOv8 All Tasks ONNXRuntime Rust](./YOLOv8-ONNXRuntime-Rust) | Rust/ONNXRuntime | [jamjamjon](https://github.com/jamjamjon) | +| [YOLOv8 OpenVINO CPP](./YOLOv8-OpenVINO-CPP-Inference) | C++/OpenVINO | [Erlangga Yudi Pradana](https://github.com/rlggyp) | +| [YOLOv5-YOLO11 ONNXRuntime Rust](./YOLO-Series-ONNXRuntime-Rust) | Rust/ONNXRuntime | [jamjamjon](https://github.com/jamjamjon) | + +### How to Contribute + +We greatly appreciate contributions from the community, including examples, applications, and guides. If you'd like to contribute, please follow these guidelines: + +1. **Create a pull request (PR)** with the title prefix `[Example]`, adding your new example folder to the `examples/` directory within the repository. +2. **Ensure your project adheres to the following standards:** + - Makes use of the `ultralytics` package. + - Includes a `README.md` with clear instructions for setting up and running the example. + - Avoids adding large files or dependencies unless they are absolutely necessary for the example. + - Contributors should be willing to provide support for their examples and address related issues. + +For more detailed information and guidance on contributing, please visit our [contribution documentation](https://docs.ultralytics.com/help/contributing/). + +If you encounter any questions or concerns regarding these guidelines, feel free to open a PR or an issue in the repository, and we will assist you in the contribution process. diff --git a/examples/RTDETR-ONNXRuntime-Python/README.md b/examples/RTDETR-ONNXRuntime-Python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1861da8295dcb7e1eb34392006839fbae693db42 --- /dev/null +++ b/examples/RTDETR-ONNXRuntime-Python/README.md @@ -0,0 +1,43 @@ +# RTDETR - ONNX Runtime + +This project implements RTDETR using ONNX Runtime. + +## Installation + +To run this project, you need to install the required dependencies. The following instructions will guide you through the installation process. + +### Installing Required Dependencies + +You can install the required dependencies by running the following command: + +```bash +pip install -r requirements.txt +``` + +### Installing `onnxruntime-gpu` + +If you have an NVIDIA GPU and want to leverage GPU acceleration, you can install the onnxruntime-gpu package using the following command: + +```bash +pip install onnxruntime-gpu +``` + +Note: Make sure you have the appropriate GPU drivers installed on your system. + +### Installing `onnxruntime` (CPU version) + +If you don't have an NVIDIA GPU or prefer to use the CPU version of onnxruntime, you can install the onnxruntime package using the following command: + +```bash +pip install onnxruntime +``` + +### Usage + +After successfully installing the required packages, you can run the RTDETR implementation using the following command: + +```bash +python main.py --model rtdetr-l.onnx --img image.jpg --conf-thres 0.5 --iou-thres 0.5 +``` + +Make sure to replace rtdetr-l.onnx with the path to your RTDETR ONNX model file, image.jpg with the path to your input image, and adjust the confidence threshold (conf-thres) and IoU threshold (iou-thres) values as needed. diff --git a/examples/RTDETR-ONNXRuntime-Python/main.py b/examples/RTDETR-ONNXRuntime-Python/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d794a7d648b8ce107015142c11ca2a7b1935590b --- /dev/null +++ b/examples/RTDETR-ONNXRuntime-Python/main.py @@ -0,0 +1,222 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse + +import cv2 +import numpy as np +import onnxruntime as ort +import torch + +from ultralytics.utils import ASSETS, yaml_load +from ultralytics.utils.checks import check_requirements, check_yaml + + +class RTDETR: + """RTDETR object detection model class for handling inference and visualization.""" + + def __init__(self, model_path, img_path, conf_thres=0.5, iou_thres=0.5): + """ + Initializes the RTDETR object with the specified parameters. + + Args: + model_path: Path to the ONNX model file. + img_path: Path to the input image. + conf_thres: Confidence threshold for object detection. + iou_thres: IoU threshold for non-maximum suppression + """ + self.model_path = model_path + self.img_path = img_path + self.conf_thres = conf_thres + self.iou_thres = iou_thres + + # Set up the ONNX runtime session with CUDA and CPU execution providers + self.session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + self.model_input = self.session.get_inputs() + self.input_width = self.model_input[0].shape[2] + self.input_height = self.model_input[0].shape[3] + + # Load class names from the COCO dataset YAML file + self.classes = yaml_load(check_yaml("coco8.yaml"))["names"] + + # Generate a color palette for drawing bounding boxes + self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) + + def draw_detections(self, box, score, class_id): + """ + Draws bounding boxes and labels on the input image based on the detected objects. + + Args: + box: Detected bounding box. + score: Corresponding detection score. + class_id: Class ID for the detected object. + + Returns: + None + """ + # Extract the coordinates of the bounding box + x1, y1, x2, y2 = box + + # Retrieve the color for the class ID + color = self.color_palette[class_id] + + # Draw the bounding box on the image + cv2.rectangle(self.img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) + + # Create the label text with class name and score + label = f"{self.classes[class_id]}: {score:.2f}" + + # Calculate the dimensions of the label text + (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Calculate the position of the label text + label_x = x1 + label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 + + # Draw a filled rectangle as the background for the label text + cv2.rectangle( + self.img, + (int(label_x), int(label_y - label_height)), + (int(label_x + label_width), int(label_y + label_height)), + color, + cv2.FILLED, + ) + + # Draw the label text on the image + cv2.putText( + self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA + ) + + def preprocess(self): + """ + Preprocesses the input image before performing inference. + + Returns: + image_data: Preprocessed image data ready for inference. + """ + # Read the input image using OpenCV + self.img = cv2.imread(self.img_path) + + # Get the height and width of the input image + self.img_height, self.img_width = self.img.shape[:2] + + # Convert the image color space from BGR to RGB + img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB) + + # Resize the image to match the input shape + img = cv2.resize(img, (self.input_width, self.input_height)) + + # Normalize the image data by dividing it by 255.0 + image_data = np.array(img) / 255.0 + + # Transpose the image to have the channel dimension as the first dimension + image_data = np.transpose(image_data, (2, 0, 1)) # Channel first + + # Expand the dimensions of the image data to match the expected input shape + image_data = np.expand_dims(image_data, axis=0).astype(np.float32) + + # Return the preprocessed image data + return image_data + + def bbox_cxcywh_to_xyxy(self, boxes): + """ + Converts bounding boxes from (center x, center y, width, height) format to (x_min, y_min, x_max, y_max) format. + + Args: + boxes (numpy.ndarray): An array of shape (N, 4) where each row represents + a bounding box in (cx, cy, w, h) format. + + Returns: + numpy.ndarray: An array of shape (N, 4) where each row represents + a bounding box in (x_min, y_min, x_max, y_max) format. + """ + # Calculate half width and half height of the bounding boxes + half_width = boxes[:, 2] / 2 + half_height = boxes[:, 3] / 2 + + # Calculate the coordinates of the bounding boxes + x_min = boxes[:, 0] - half_width + y_min = boxes[:, 1] - half_height + x_max = boxes[:, 0] + half_width + y_max = boxes[:, 1] + half_height + + # Return the bounding boxes in (x_min, y_min, x_max, y_max) format + return np.column_stack((x_min, y_min, x_max, y_max)) + + def postprocess(self, model_output): + """ + Postprocesses the model output to extract detections and draw them on the input image. + + Args: + model_output: Output of the model inference. + + Returns: + np.array: Annotated image with detections. + """ + # Squeeze the model output to remove unnecessary dimensions + outputs = np.squeeze(model_output[0]) + + # Extract bounding boxes and scores from the model output + boxes = outputs[:, :4] + scores = outputs[:, 4:] + + # Get the class labels and scores for each detection + labels = np.argmax(scores, axis=1) + scores = np.max(scores, axis=1) + + # Apply confidence threshold to filter out low-confidence detections + mask = scores > self.conf_thres + boxes, scores, labels = boxes[mask], scores[mask], labels[mask] + + # Convert bounding boxes to (x_min, y_min, x_max, y_max) format + boxes = self.bbox_cxcywh_to_xyxy(boxes) + + # Scale bounding boxes to match the original image dimensions + boxes[:, 0::2] *= self.img_width + boxes[:, 1::2] *= self.img_height + + # Draw detections on the image + for box, score, label in zip(boxes, scores, labels): + self.draw_detections(box, score, label) + + # Return the annotated image + return self.img + + def main(self): + """ + Executes the detection on the input image using the ONNX model. + + Returns: + np.array: Output image with annotations. + """ + # Preprocess the image for model input + image_data = self.preprocess() + + # Run the model inference + model_output = self.session.run(None, {self.model_input[0].name: image_data}) + + # Process and return the model output + return self.postprocess(model_output) + + +if __name__ == "__main__": + # Set up argument parser for command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="rtdetr-l.onnx", help="Path to the ONNX model file.") + parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to the input image.") + parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold for object detection.") + parser.add_argument("--iou-thres", type=float, default=0.5, help="IoU threshold for non-maximum suppression.") + args = parser.parse_args() + + # Check for dependencies and set up ONNX runtime + check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime") + + # Create the detector instance with specified parameters + detection = RTDETR(args.model, args.img, args.conf_thres, args.iou_thres) + + # Perform detection and get the output image + output_image = detection.main() + + # Display the annotated output image + cv2.namedWindow("Output", cv2.WINDOW_NORMAL) + cv2.imshow("Output", output_image) + cv2.waitKey(0) diff --git a/examples/YOLO-Series-ONNXRuntime-Rust/Cargo.toml b/examples/YOLO-Series-ONNXRuntime-Rust/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..048ece887dfb31850fb0d5079e922e21486ba8c1 --- /dev/null +++ b/examples/YOLO-Series-ONNXRuntime-Rust/Cargo.toml @@ -0,0 +1,14 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +[package] +name = "YOLO-ONNXRuntime-Rust" +version = "0.1.0" +edition = "2021" +authors = ["Jamjamjon "] + +[dependencies] +anyhow = "1.0.92" +clap = "4.5.20" +tracing = "0.1.40" +tracing-subscriber = "0.3.18" +usls = { version = "0.0.19", features = ["auto"] } diff --git a/examples/YOLO-Series-ONNXRuntime-Rust/README.md b/examples/YOLO-Series-ONNXRuntime-Rust/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0b6fabe20d979916a89ecc5a2af993b8cafe9faa --- /dev/null +++ b/examples/YOLO-Series-ONNXRuntime-Rust/README.md @@ -0,0 +1,94 @@ +# YOLO-Series ONNXRuntime Rust Demo for Core YOLO Tasks + +This repository provides a Rust demo for key YOLO-Series tasks such as `Classification`, `Segmentation`, `Detection`, `Pose Detection`, and `OBB` using ONNXRuntime. It supports various YOLO models (v5 - 11) across multiple vision tasks. + +## Introduction + +- This example leverages the latest versions of both ONNXRuntime and YOLO models. +- We utilize the [usls](https://github.com/jamjamjon/usls/tree/main) crate to streamline YOLO model inference, providing efficient data loading, visualization, and optimized inference performance. + +## Features + +- **Extensive Model Compatibility**: Supports `YOLOv5`, `YOLOv6`, `YOLOv7`, `YOLOv8`, `YOLOv9`, `YOLOv10`, `YOLO11`, `YOLO-world`, `RTDETR`, and others, covering a wide range of YOLO versions. +- **Versatile Task Coverage**: Includes `Classification`, `Segmentation`, `Detection`, `Pose`, and `OBB`. +- **Precision Flexibility**: Works with `FP16` and `FP32` ONNX models. +- **Execution Providers**: Accelerated support for `CPU`, `CUDA`, `CoreML`, and `TensorRT`. +- **Dynamic Input Shapes**: Dynamically adjusts to variable `batch`, `width`, and `height` dimensions for flexible model input. +- **Flexible Data Loading**: The `DataLoader` handles images, folders, videos, and video streams. +- **Real-Time Display and Video Export**: `Viewer` provides real-time frame visualization and video export functions, similar to OpenCV’s `imshow()` and `imwrite()`. +- **Enhanced Annotation and Visualization**: The `Annotator` facilitates comprehensive result rendering, with support for bounding boxes (HBB), oriented bounding boxes (OBB), polygons, masks, keypoints, and text labels. + +## Setup Instructions + +### 1. ONNXRuntime Linking + +
+You have two options to link the ONNXRuntime library: + +- **Option 1: Manual Linking** + + - For detailed setup, consult the [ONNX Runtime linking documentation](https://ort.pyke.io/setup/linking). + - **Linux or macOS**: + 1. Download the ONNX Runtime package from the [Releases page](https://github.com/microsoft/onnxruntime/releases). + 2. Set up the library path by exporting the `ORT_DYLIB_PATH` environment variable: + ```shell + export ORT_DYLIB_PATH=/path/to/onnxruntime/lib/libonnxruntime.so.1.19.0 + ``` + +- **Option 2: Automatic Download** + - Use the `--features auto` flag to handle downloading automatically: + ```shell + cargo run -r --example yolo --features auto + ``` + +
+ +### 2. \[Optional\] Install CUDA, CuDNN, and TensorRT + +- The CUDA execution provider requires CUDA version `12.x`. +- The TensorRT execution provider requires both CUDA `12.x` and TensorRT `10.x`. + +### 3. \[Optional\] Install ffmpeg + +To view video frames and save video inferences, install `rust-ffmpeg`. For instructions, see: +[https://github.com/zmwangx/rust-ffmpeg/wiki/Notes-on-building#dependencies](https://github.com/zmwangx/rust-ffmpeg/wiki/Notes-on-building#dependencies) + +## Get Started + +```Shell +# customized +cargo run -r -- --task detect --ver v8 --nc 6 --model xxx.onnx # YOLOv8 + +# Classify +cargo run -r -- --task classify --ver v5 --scale s --width 224 --height 224 --nc 1000 # YOLOv5 +cargo run -r -- --task classify --ver v8 --scale n --width 224 --height 224 --nc 1000 # YOLOv8 +cargo run -r -- --task classify --ver v11 --scale n --width 224 --height 224 --nc 1000 # YOLO11 + +# Detect +cargo run -r -- --task detect --ver v5 --scale n # YOLOv5 +cargo run -r -- --task detect --ver v6 --scale n # YOLOv6 +cargo run -r -- --task detect --ver v7 --scale t # YOLOv7 +cargo run -r -- --task detect --ver v8 --scale n # YOLOv8 +cargo run -r -- --task detect --ver v9 --scale t # YOLOv9 +cargo run -r -- --task detect --ver v10 --scale n # YOLOv10 +cargo run -r -- --task detect --ver v11 --scale n # YOLO11 +cargo run -r -- --task detect --ver rtdetr --scale l # RTDETR + +# Pose +cargo run -r -- --task pose --ver v8 --scale n # YOLOv8-Pose +cargo run -r -- --task pose --ver v11 --scale n # YOLO11-Pose + +# Segment +cargo run -r -- --task segment --ver v5 --scale n # YOLOv5-Segment +cargo run -r -- --task segment --ver v8 --scale n # YOLOv8-Segment +cargo run -r -- --task segment --ver v11 --scale n # YOLOv8-Segment +cargo run -r -- --task segment --ver v8 --model yolo/FastSAM-s-dyn-f16.onnx # FastSAM + +# OBB +cargo run -r -- --ver v8 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLOv8-Obb +cargo run -r -- --ver v11 --task obb --scale n --width 1024 --height 1024 --source images/dota.png # YOLO11-Obb +``` + +**`cargo run -- --help` for more options** + +For more details, please refer to [usls-yolo](https://github.com/jamjamjon/usls/tree/main/examples/yolo). diff --git a/examples/YOLO-Series-ONNXRuntime-Rust/src/main.rs b/examples/YOLO-Series-ONNXRuntime-Rust/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..3c71a253108414198462e84a5ae3854a9535486e --- /dev/null +++ b/examples/YOLO-Series-ONNXRuntime-Rust/src/main.rs @@ -0,0 +1,236 @@ +use anyhow::Result; +use clap::Parser; + +use usls::{ + models::YOLO, Annotator, DataLoader, Device, Options, Viewer, Vision, YOLOScale, YOLOTask, + YOLOVersion, COCO_SKELETONS_16, +}; + +#[derive(Parser, Clone)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// Path to the ONNX model + #[arg(long)] + pub model: Option, + + /// Input source path + #[arg(long, default_value_t = String::from("../../ultralytics/assets/bus.jpg"))] + pub source: String, + + /// YOLO Task + #[arg(long, value_enum, default_value_t = YOLOTask::Detect)] + pub task: YOLOTask, + + /// YOLO Version + #[arg(long, value_enum, default_value_t = YOLOVersion::V8)] + pub ver: YOLOVersion, + + /// YOLO Scale + #[arg(long, value_enum, default_value_t = YOLOScale::N)] + pub scale: YOLOScale, + + /// Batch size + #[arg(long, default_value_t = 1)] + pub batch_size: usize, + + /// Minimum input width + #[arg(long, default_value_t = 224)] + pub width_min: isize, + + /// Input width + #[arg(long, default_value_t = 640)] + pub width: isize, + + /// Maximum input width + #[arg(long, default_value_t = 1024)] + pub width_max: isize, + + /// Minimum input height + #[arg(long, default_value_t = 224)] + pub height_min: isize, + + /// Input height + #[arg(long, default_value_t = 640)] + pub height: isize, + + /// Maximum input height + #[arg(long, default_value_t = 1024)] + pub height_max: isize, + + /// Number of classes + #[arg(long, default_value_t = 80)] + pub nc: usize, + + /// Class confidence + #[arg(long)] + pub confs: Vec, + + /// Enable TensorRT support + #[arg(long)] + pub trt: bool, + + /// Enable CUDA support + #[arg(long)] + pub cuda: bool, + + /// Enable CoreML support + #[arg(long)] + pub coreml: bool, + + /// Use TensorRT half precision + #[arg(long)] + pub half: bool, + + /// Device ID to use + #[arg(long, default_value_t = 0)] + pub device_id: usize, + + /// Enable performance profiling + #[arg(long)] + pub profile: bool, + + /// Disable contour drawing, for saving time + #[arg(long)] + pub no_contours: bool, + + /// Show result + #[arg(long)] + pub view: bool, + + /// Do not save output + #[arg(long)] + pub nosave: bool, +} + +fn main() -> Result<()> { + let args = Args::parse(); + + // logger + if args.profile { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + } + + // model path + let path = match &args.model { + None => format!( + "yolo/{}-{}-{}.onnx", + args.ver.name(), + args.scale.name(), + args.task.name() + ), + Some(x) => x.to_string(), + }; + + // saveout + let saveout = match &args.model { + None => format!( + "{}-{}-{}", + args.ver.name(), + args.scale.name(), + args.task.name() + ), + Some(x) => { + let p = std::path::PathBuf::from(&x); + p.file_stem().unwrap().to_str().unwrap().to_string() + } + }; + + // device + let device = if args.cuda { + Device::Cuda(args.device_id) + } else if args.trt { + Device::Trt(args.device_id) + } else if args.coreml { + Device::CoreML(args.device_id) + } else { + Device::Cpu(args.device_id) + }; + + // build options + let options = Options::new() + .with_model(&path)? + .with_yolo_version(args.ver) + .with_yolo_task(args.task) + .with_device(device) + .with_trt_fp16(args.half) + .with_ixx(0, 0, (1, args.batch_size as _, 4).into()) + .with_ixx(0, 2, (args.height_min, args.height, args.height_max).into()) + .with_ixx(0, 3, (args.width_min, args.width, args.width_max).into()) + .with_confs(if args.confs.is_empty() { + &[0.2, 0.15] + } else { + &args.confs + }) + .with_nc(args.nc) + .with_find_contours(!args.no_contours) // find contours or not + // .with_names(&COCO_CLASS_NAMES_80) // detection class names + // .with_names2(&COCO_KEYPOINTS_17) // keypoints class names + // .exclude_classes(&[0]) + // .retain_classes(&[0, 5]) + .with_profile(args.profile); + + // build model + let mut model = YOLO::new(options)?; + + // build dataloader + let dl = DataLoader::new(&args.source)? + .with_batch(model.batch() as _) + .build()?; + + // build annotator + let annotator = Annotator::default() + .with_skeletons(&COCO_SKELETONS_16) + .without_masks(true) // no masks plotting when doing segment task + .with_bboxes_thickness(3) + .with_keypoints_name(false) // enable keypoints names + .with_saveout_subs(&["YOLO"]) + .with_saveout(&saveout); + + // build viewer + let mut viewer = if args.view { + Some(Viewer::new().with_delay(5).with_scale(1.).resizable(true)) + } else { + None + }; + + // run & annotate + for (xs, _paths) in dl { + let ys = model.forward(&xs, args.profile)?; + let images_plotted = annotator.plot(&xs, &ys, !args.nosave)?; + + // show image + match &mut viewer { + Some(viewer) => viewer.imshow(&images_plotted)?, + None => continue, + } + + // check out window and key event + match &mut viewer { + Some(viewer) => { + if !viewer.is_open() || viewer.is_key_pressed(usls::Key::Escape) { + break; + } + } + None => continue, + } + + // write video + if !args.nosave { + match &mut viewer { + Some(viewer) => viewer.write_batch(&images_plotted)?, + None => continue, + } + } + } + + // finish video write + if !args.nosave { + if let Some(viewer) = &mut viewer { + viewer.finish_write()?; + } + } + + Ok(()) +} diff --git a/examples/YOLOv8-Action-Recognition/action_recognition.py b/examples/YOLOv8-Action-Recognition/action_recognition.py new file mode 100644 index 0000000000000000000000000000000000000000..38b6a2526931ad34963403a3e9ab1bcbba737bba --- /dev/null +++ b/examples/YOLOv8-Action-Recognition/action_recognition.py @@ -0,0 +1,464 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse +import time +from collections import defaultdict +from typing import List, Optional, Tuple +from urllib.parse import urlparse + +import cv2 +import numpy as np +import torch +from transformers import AutoModel, AutoProcessor + +from ultralytics import YOLO +from ultralytics.data.loaders import get_best_youtube_url +from ultralytics.utils.plotting import Annotator +from ultralytics.utils.torch_utils import select_device + + +class TorchVisionVideoClassifier: + """Classifies videos using pretrained TorchVision models; see https://pytorch.org/vision/stable/.""" + + from torchvision.models.video import ( + MViT_V1_B_Weights, + MViT_V2_S_Weights, + R3D_18_Weights, + S3D_Weights, + Swin3D_B_Weights, + Swin3D_T_Weights, + mvit_v1_b, + mvit_v2_s, + r3d_18, + s3d, + swin3d_b, + swin3d_t, + ) + + model_name_to_model_and_weights = { + "s3d": (s3d, S3D_Weights.DEFAULT), + "r3d_18": (r3d_18, R3D_18_Weights.DEFAULT), + "swin3d_t": (swin3d_t, Swin3D_T_Weights.DEFAULT), + "swin3d_b": (swin3d_b, Swin3D_B_Weights.DEFAULT), + "mvit_v1_b": (mvit_v1_b, MViT_V1_B_Weights.DEFAULT), + "mvit_v2_s": (mvit_v2_s, MViT_V2_S_Weights.DEFAULT), + } + + def __init__(self, model_name: str, device: str or torch.device = ""): + """ + Initialize the VideoClassifier with the specified model name and device. + + Args: + model_name (str): The name of the model to use. + device (str or torch.device, optional): The device to run the model on. Defaults to "". + + Raises: + ValueError: If an invalid model name is provided. + """ + if model_name not in self.model_name_to_model_and_weights: + raise ValueError(f"Invalid model name '{model_name}'. Available models: {self.available_model_names()}") + model, self.weights = self.model_name_to_model_and_weights[model_name] + self.device = select_device(device) + self.model = model(weights=self.weights).to(self.device).eval() + + @staticmethod + def available_model_names() -> List[str]: + """ + Get the list of available model names. + + Returns: + list: List of available model names. + """ + return list(TorchVisionVideoClassifier.model_name_to_model_and_weights.keys()) + + def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = None) -> torch.Tensor: + """ + Preprocess a list of crops for video classification. + + Args: + crops (List[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C) + input_size (tuple, optional): The target input size for the model. Defaults to (224, 224). + + Returns: + torch.Tensor: Preprocessed crops as a tensor with dimensions (1, T, C, H, W). + """ + if input_size is None: + input_size = [224, 224] + from torchvision.transforms import v2 + + transform = v2.Compose( + [ + v2.ToDtype(torch.float32, scale=True), + v2.Resize(input_size, antialias=True), + v2.Normalize(mean=self.weights.transforms().mean, std=self.weights.transforms().std), + ] + ) + + processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] + return torch.stack(processed_crops).unsqueeze(0).permute(0, 2, 1, 3, 4).to(self.device) + + def __call__(self, sequences: torch.Tensor): + """ + Perform inference on the given sequences. + + Args: + sequences (torch.Tensor): The input sequences for the model. The expected input dimensions are + (B, T, C, H, W) for batched video frames or (T, C, H, W) for single video frames. + + Returns: + torch.Tensor: The model's output. + """ + with torch.inference_mode(): + return self.model(sequences) + + def postprocess(self, outputs: torch.Tensor) -> Tuple[List[str], List[float]]: + """ + Postprocess the model's batch output. + + Args: + outputs (torch.Tensor): The model's output. + + Returns: + List[str]: The predicted labels. + List[float]: The predicted confidences. + """ + pred_labels = [] + pred_confs = [] + for output in outputs: + pred_class = output.argmax(0).item() + pred_label = self.weights.meta["categories"][pred_class] + pred_labels.append(pred_label) + pred_conf = output.softmax(0)[pred_class].item() + pred_confs.append(pred_conf) + + return pred_labels, pred_confs + + +class HuggingFaceVideoClassifier: + """Zero-shot video classifier using Hugging Face models for various devices.""" + + def __init__( + self, + labels: List[str], + model_name: str = "microsoft/xclip-base-patch16-zero-shot", + device: str or torch.device = "", + fp16: bool = False, + ): + """ + Initialize the HuggingFaceVideoClassifier with the specified model name. + + Args: + labels (List[str]): List of labels for zero-shot classification. + model_name (str): The name of the model to use. Defaults to "microsoft/xclip-base-patch16-zero-shot". + device (str or torch.device, optional): The device to run the model on. Defaults to "". + fp16 (bool, optional): Whether to use FP16 for inference. Defaults to False. + """ + self.fp16 = fp16 + self.labels = labels + self.device = select_device(device) + self.processor = AutoProcessor.from_pretrained(model_name) + model = AutoModel.from_pretrained(model_name).to(self.device) + if fp16: + model = model.half() + self.model = model.eval() + + def preprocess_crops_for_video_cls(self, crops: List[np.ndarray], input_size: list = None) -> torch.Tensor: + """ + Preprocess a list of crops for video classification. + + Args: + crops (List[np.ndarray]): List of crops to preprocess. Each crop should have dimensions (H, W, C) + input_size (tuple, optional): The target input size for the model. Defaults to (224, 224). + + Returns: + torch.Tensor: Preprocessed crops as a tensor (1, T, C, H, W). + """ + if input_size is None: + input_size = [224, 224] + from torchvision import transforms + + transform = transforms.Compose( + [ + transforms.Lambda(lambda x: x.float() / 255.0), + transforms.Resize(input_size), + transforms.Normalize( + mean=self.processor.image_processor.image_mean, std=self.processor.image_processor.image_std + ), + ] + ) + + processed_crops = [transform(torch.from_numpy(crop).permute(2, 0, 1)) for crop in crops] # (T, C, H, W) + output = torch.stack(processed_crops).unsqueeze(0).to(self.device) # (1, T, C, H, W) + if self.fp16: + output = output.half() + return output + + def __call__(self, sequences: torch.Tensor) -> torch.Tensor: + """ + Perform inference on the given sequences. + + Args: + sequences (torch.Tensor): The input sequences for the model. Batched video frames with shape (B, T, H, W, C). + + Returns: + torch.Tensor: The model's output. + """ + input_ids = self.processor(text=self.labels, return_tensors="pt", padding=True)["input_ids"].to(self.device) + + inputs = {"pixel_values": sequences, "input_ids": input_ids} + + with torch.inference_mode(): + outputs = self.model(**inputs) + + return outputs.logits_per_video + + def postprocess(self, outputs: torch.Tensor) -> Tuple[List[List[str]], List[List[float]]]: + """ + Postprocess the model's batch output. + + Args: + outputs (torch.Tensor): The model's output. + + Returns: + List[List[str]]: The predicted top3 labels. + List[List[float]]: The predicted top3 confidences. + """ + pred_labels = [] + pred_confs = [] + + with torch.no_grad(): + logits_per_video = outputs # Assuming outputs is already the logits tensor + probs = logits_per_video.softmax(dim=-1) # Use softmax to convert logits to probabilities + + for prob in probs: + top2_indices = prob.topk(2).indices.tolist() + top2_labels = [self.labels[idx] for idx in top2_indices] + top2_confs = prob[top2_indices].tolist() + pred_labels.append(top2_labels) + pred_confs.append(top2_confs) + + return pred_labels, pred_confs + + +def crop_and_pad(frame, box, margin_percent): + """Crop box with margin and take square crop from frame.""" + x1, y1, x2, y2 = map(int, box) + w, h = x2 - x1, y2 - y1 + + # Add margin + margin_x, margin_y = int(w * margin_percent / 100), int(h * margin_percent / 100) + x1, y1 = max(0, x1 - margin_x), max(0, y1 - margin_y) + x2, y2 = min(frame.shape[1], x2 + margin_x), min(frame.shape[0], y2 + margin_y) + + # Take square crop from frame + size = max(y2 - y1, x2 - x1) + center_y, center_x = (y1 + y2) // 2, (x1 + x2) // 2 + half_size = size // 2 + square_crop = frame[ + max(0, center_y - half_size) : min(frame.shape[0], center_y + half_size), + max(0, center_x - half_size) : min(frame.shape[1], center_x + half_size), + ] + + return cv2.resize(square_crop, (224, 224), interpolation=cv2.INTER_LINEAR) + + +def run( + weights: str = "yolo11n.pt", + device: str = "", + source: str = "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + output_path: Optional[str] = None, + crop_margin_percentage: int = 10, + num_video_sequence_samples: int = 8, + skip_frame: int = 2, + video_cls_overlap_ratio: float = 0.25, + fp16: bool = False, + video_classifier_model: str = "microsoft/xclip-base-patch32", + labels: List[str] = None, +) -> None: + """ + Run action recognition on a video source using YOLO for object detection and a video classifier. + + Args: + weights (str): Path to the YOLO model weights. Defaults to "yolo11n.pt". + device (str): Device to run the model on. Use 'cuda' for NVIDIA GPU, 'mps' for Apple Silicon, or 'cpu'. Defaults to auto-detection. + source (str): Path to mp4 video file or YouTube URL. Defaults to a sample YouTube video. + output_path (Optional[str], optional): Path to save the output video. Defaults to None. + crop_margin_percentage (int, optional): Percentage of margin to add around detected objects. Defaults to 10. + num_video_sequence_samples (int, optional): Number of video frames to use for classification. Defaults to 8. + skip_frame (int, optional): Number of frames to skip between detections. Defaults to 4. + video_cls_overlap_ratio (float, optional): Overlap ratio between video sequences. Defaults to 0.25. + fp16 (bool, optional): Whether to use half-precision floating point. Defaults to False. + video_classifier_model (str, optional): Name or path of the video classifier model. Defaults to "microsoft/xclip-base-patch32". + labels (List[str], optional): List of labels for zero-shot classification. Defaults to predefined list. + + Returns: + None + """ + if labels is None: + labels = [ + "walking", + "running", + "brushing teeth", + "looking into phone", + "weight lifting", + "cooking", + "sitting", + ] + # Initialize models and device + device = select_device(device) + yolo_model = YOLO(weights).to(device) + if video_classifier_model in TorchVisionVideoClassifier.available_model_names(): + print("'fp16' is not supported for TorchVisionVideoClassifier. Setting fp16 to False.") + print( + "'labels' is not used for TorchVisionVideoClassifier. Ignoring the provided labels and using Kinetics-400 labels." + ) + video_classifier = TorchVisionVideoClassifier(video_classifier_model, device=device) + else: + video_classifier = HuggingFaceVideoClassifier( + labels, model_name=video_classifier_model, device=device, fp16=fp16 + ) + + # Initialize video capture + if source.startswith("http") and urlparse(source).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: + source = get_best_youtube_url(source) + elif not source.endswith(".mp4"): + raise ValueError("Invalid source. Supported sources are YouTube URLs and MP4 files.") + cap = cv2.VideoCapture(source) + + # Get video properties + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + + # Initialize VideoWriter + if output_path is not None: + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) + + # Initialize track history + track_history = defaultdict(list) + frame_counter = 0 + + track_ids_to_infer = [] + crops_to_infer = [] + pred_labels = [] + pred_confs = [] + + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + + frame_counter += 1 + + # Run YOLO tracking + results = yolo_model.track(frame, persist=True, classes=[0]) # Track only person class + + if results[0].boxes.id is not None: + boxes = results[0].boxes.xyxy.cpu().numpy() + track_ids = results[0].boxes.id.cpu().numpy() + + # Visualize prediction + annotator = Annotator(frame, line_width=3, font_size=10, pil=False) + + if frame_counter % skip_frame == 0: + crops_to_infer = [] + track_ids_to_infer = [] + + for box, track_id in zip(boxes, track_ids): + if frame_counter % skip_frame == 0: + crop = crop_and_pad(frame, box, crop_margin_percentage) + track_history[track_id].append(crop) + + if len(track_history[track_id]) > num_video_sequence_samples: + track_history[track_id].pop(0) + + if len(track_history[track_id]) == num_video_sequence_samples and frame_counter % skip_frame == 0: + start_time = time.time() + crops = video_classifier.preprocess_crops_for_video_cls(track_history[track_id]) + end_time = time.time() + preprocess_time = end_time - start_time + print(f"video cls preprocess time: {preprocess_time:.4f} seconds") + crops_to_infer.append(crops) + track_ids_to_infer.append(track_id) + + if crops_to_infer and ( + not pred_labels + or frame_counter % int(num_video_sequence_samples * skip_frame * (1 - video_cls_overlap_ratio)) == 0 + ): + crops_batch = torch.cat(crops_to_infer, dim=0) + + start_inference_time = time.time() + output_batch = video_classifier(crops_batch) + end_inference_time = time.time() + inference_time = end_inference_time - start_inference_time + print(f"video cls inference time: {inference_time:.4f} seconds") + + pred_labels, pred_confs = video_classifier.postprocess(output_batch) + + if track_ids_to_infer and crops_to_infer: + for box, track_id, pred_label, pred_conf in zip(boxes, track_ids_to_infer, pred_labels, pred_confs): + top2_preds = sorted(zip(pred_label, pred_conf), key=lambda x: x[1], reverse=True) + label_text = " | ".join([f"{label} ({conf:.2f})" for label, conf in top2_preds]) + annotator.box_label(box, label_text, color=(0, 0, 255)) + + # Write the annotated frame to the output video + if output_path is not None: + out.write(frame) + + # Display the annotated frame + cv2.imshow("YOLOv8 Tracking with S3D Classification", frame) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + cap.release() + if output_path is not None: + out.release() + cv2.destroyAllWindows() + + +def parse_opt(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--weights", type=str, default="yolo11n.pt", help="ultralytics detector model path") + parser.add_argument("--device", default="", help='cuda device, i.e. 0 or 0,1,2,3 or cpu/mps, "" for auto-detection') + parser.add_argument( + "--source", + type=str, + default="https://www.youtube.com/watch?v=dQw4w9WgXcQ", + help="video file path or youtube URL", + ) + parser.add_argument("--output-path", type=str, default="output_video.mp4", help="output video file path") + parser.add_argument( + "--crop-margin-percentage", type=int, default=10, help="percentage of margin to add around detected objects" + ) + parser.add_argument( + "--num-video-sequence-samples", type=int, default=8, help="number of video frames to use for classification" + ) + parser.add_argument("--skip-frame", type=int, default=2, help="number of frames to skip between detections") + parser.add_argument( + "--video-cls-overlap-ratio", type=float, default=0.25, help="overlap ratio between video sequences" + ) + parser.add_argument("--fp16", action="store_true", help="use FP16 for inference") + parser.add_argument( + "--video-classifier-model", type=str, default="microsoft/xclip-base-patch32", help="video classifier model name" + ) + parser.add_argument( + "--labels", + nargs="+", + type=str, + default=["dancing", "singing a song"], + help="labels for zero-shot video classification", + ) + return parser.parse_args() + + +def main(opt): + """Main function.""" + run(**vars(opt)) + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/examples/YOLOv8-Action-Recognition/readme.md b/examples/YOLOv8-Action-Recognition/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..6cceb594425326743695cb786254fb9a211b7ad8 --- /dev/null +++ b/examples/YOLOv8-Action-Recognition/readme.md @@ -0,0 +1,116 @@ +# Zero-shot Action Recognition with YOLOv8 (Inference on Video) + +- Action recognition is a technique used to identify and classify actions performed by individuals in a video. This process enables more advanced analyses when multiple actions are considered. The actions can be detected and classified in real time. +- The system can be customized to recognize specific actions based on the user's preferences and requirements. + +## Table of Contents + +- [Step 1: Install the Required Libraries](#step-1-install-the-required-libraries) +- [Step 2: Run the Action Recognition Using Ultralytics YOLOv8](#step-2-run-the-action-recognition-using-ultralytics-yolov8) +- [Usage Options](#usage-options) +- [FAQ](#faq) + +## Step 1: Install the Required Libraries + +Clone the repository, install dependencies and `cd` to this local directory for commands in Step 2. + +```bash +# Clone ultralytics repo +git clone https://github.com/ultralytics/ultralytics + +# cd to local directory +cd examples/YOLOv8-Action-Recognition + +# Install dependencies +pip install -U -r requirements.txt +``` + +## Step 2: Run the Action Recognition Using Ultralytics YOLOv8 + +Here are the basic commands for running the inference: + +### Note + +The action recognition model will automatically detect and track people in the video, and classify their actions based on the specified labels. The results will be displayed in real-time on the video output. You can customize the action labels by modifying the `--labels` argument when running the script. + +```bash +# Quick start +python action_recognition.py + +# Basic usage +python action_recognition.py --source "https://www.youtube.com/watch?v=dQw4w9WgXcQ" --labels "dancing" "singing a song" + +# Use local video file +python action_recognition.py --source path/to/video.mp4 + +# Better detector performance +python action_recognition.py --weights yolov8m.pt + +# Run on CPU +python action_recognition.py --device cpu + +# Use a different video classifier model +python action_recognition.py --video-classifier-model "s3d" + +# Use FP16 for inference (only for HuggingFace models) +python action_recognition.py --fp16 + +# Export output as mp4 +python action_recognition.py --output-path output.mp4 + +# Combine multiple options +python action_recognition.py --source "https://www.youtube.com/watch?v=dQw4w9WgXcQ" --device 0 --video-classifier-model "microsoft/xclip-base-patch32" --labels "dancing" "singing a song" --fp16 +``` + +## Usage Options + +- `--weights`: Path to the YOLO model weights (default: "yolov8n.pt") +- `--device`: Cuda device, i.e. 0 or 0,1,2,3 or cpu (default: auto-detect) +- `--source`: Video file path or YouTube URL (default: "[rickroll](https://www.youtube.com/watch?v=dQw4w9WgXcQ)") +- `--output-path`: Output video file path +- `--crop-margin-percentage`: Percentage of margin to add around detected objects (default: 10) +- `--num-video-sequence-samples`: Number of video frames to use for classification (default: 8) +- `--skip-frame`: Number of frames to skip between detections (default: 1) +- `--video-cls-overlap-ratio`: Overlap ratio between video sequences (default: 0.25) +- `--fp16`: Use FP16 for inference (only for HuggingFace models) +- `--video-classifier-model`: Video classifier model name or path (default: "microsoft/xclip-base-patch32") +- `--labels`: Labels for zero-shot video classification (default: \["dancing" "singing a song"\]) + +## FAQ + +**1. What Does Action Recognition Involve?** + +Action recognition is a computational method used to identify and classify actions or activities performed by individuals in recorded video or real-time streams. This technique is widely used in video analysis, surveillance, and human-computer interaction, enabling the detection and understanding of human behaviors based on their motion patterns and context. + +**2. Is Custom Action Labels Supported by the Action Recognition?** + +Yes, custom action labels are supported by the action recognition system. The `action_recognition.py` script allows users to specify their own custom labels for zero-shot video classification. This can be done using the `--labels` argument when running the script. For example: + +```bash +python action_recognition.py --source https://www.youtube.com/watch?v=dQw4w9WgXcQ --labels "dancing" "singing" "jumping" +``` + +You can adjust these labels to match the specific actions you want to recognize in your video. The system will then attempt to classify the detected actions based on these custom labels. + +Additionally, you can choose between different video classification models: + +1. For Hugging Face models, you can use any compatible video classification model. The default is set to: + + - "microsoft/xclip-base-patch32" + +2. For TorchVision models (no support for zero-shot labels), you can select from the following options: + + - "s3d" + - "r3d_18" + - "swin3d_t" + - "swin3d_b" + - "mvit_v1_b" + - "mvit_v2_s" + +**3. Why Combine Action Recognition with YOLOv8?** + +YOLOv8 specializes in the detection and tracking of objects in video streams. Action recognition complements this by enabling the identification and classification of actions performed by individuals, making it a valuable application of YOLOv8. + +**4. Can I Employ Other YOLO Versions?** + +Certainly, you have the flexibility to specify different YOLO model weights using the `--weights` option. diff --git a/examples/YOLOv8-Action-Recognition/requirements.txt b/examples/YOLOv8-Action-Recognition/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3a93595908c54027d17ac0d0cd976ed4321371b3 --- /dev/null +++ b/examples/YOLOv8-Action-Recognition/requirements.txt @@ -0,0 +1,4 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +ultralytics +transformers diff --git a/examples/YOLOv8-CPP-Inference/CMakeLists.txt b/examples/YOLOv8-CPP-Inference/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..bc2f33fffd447ffef995a570f01e0aefdd071acb --- /dev/null +++ b/examples/YOLOv8-CPP-Inference/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.5) + +project(Yolov8CPPInference VERSION 0.1) + +set(CMAKE_INCLUDE_CURRENT_DIR ON) + +# CUDA +set(CUDA_TOOLKIT_ROOT_DIR "/usr/local/cuda") +find_package(CUDA 11 REQUIRED) + +set(CMAKE_CUDA_STANDARD 11) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) +# !CUDA + +# OpenCV +find_package(OpenCV REQUIRED) +include_directories(${OpenCV_INCLUDE_DIRS}) +# !OpenCV + +set(PROJECT_SOURCES + main.cpp + + inference.h + inference.cpp +) + +add_executable(Yolov8CPPInference ${PROJECT_SOURCES}) +target_link_libraries(Yolov8CPPInference ${OpenCV_LIBS}) diff --git a/examples/YOLOv8-CPP-Inference/README.md b/examples/YOLOv8-CPP-Inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..243d448e366d700d5e41fea99789a2a673098dad --- /dev/null +++ b/examples/YOLOv8-CPP-Inference/README.md @@ -0,0 +1,50 @@ +# YOLOv8/YOLOv5 Inference C++ + +This example demonstrates how to perform inference using YOLOv8 and YOLOv5 models in C++ with OpenCV DNN API. + +## Usage + +```bash +git clone ultralytics +cd ultralytics +pip install . +cd examples/YOLOv8-CPP-Inference + +# Add a **yolov8\_.onnx** and/or **yolov5\_.onnx** model(s) to the ultralytics folder. +# Edit the **main.cpp** to change the **projectBasePath** to match your user. + +# Note that by default the CMake file will try to import the CUDA library to be used with the OpenCVs dnn (cuDNN) GPU Inference. +# If your OpenCV build does not use CUDA/cuDNN you can remove that import call and run the example on CPU. + +mkdir build +cd build +cmake .. +make +./Yolov8CPPInference +``` + +## Exporting YOLOv8 and YOLOv5 Models + +To export YOLOv8 models: + +```bash +yolo export model=yolov8s.pt imgsz=480,640 format=onnx opset=12 +``` + +To export YOLOv5 models: + +```bash +python3 export.py --weights yolov5s.pt --img 480 640 --include onnx --opset 12 +``` + +yolov8s.onnx: + +![image](https://user-images.githubusercontent.com/40023722/217356132-a4cecf2e-2729-4acb-b80a-6559022d7707.png) + +yolov5s.onnx: + +![image](https://user-images.githubusercontent.com/40023722/217357005-07464492-d1da-42e3-98a7-fc753f87d5e6.png) + +This repository utilizes OpenCV DNN API to run ONNX exported models of YOLOv5 and YOLOv8. In theory, it should work for YOLOv6 and YOLOv7 as well, but they have not been tested. Note that the example networks are exported with rectangular (640x480) resolutions, but any exported resolution will work. You may want to use the letterbox approach for square images, depending on your use case. + +The **main** branch version uses Qt as a GUI wrapper. The primary focus here is the **Inference** class file, which demonstrates how to transpose YOLOv8 models to work as YOLOv5 models. diff --git a/examples/YOLOv8-CPP-Inference/inference.cpp b/examples/YOLOv8-CPP-Inference/inference.cpp new file mode 100644 index 0000000000000000000000000000000000000000..12c26079bcbf1b69b92e2305830dce2474a37288 --- /dev/null +++ b/examples/YOLOv8-CPP-Inference/inference.cpp @@ -0,0 +1,185 @@ +#include "inference.h" + +Inference::Inference(const std::string &onnxModelPath, const cv::Size &modelInputShape, const std::string &classesTxtFile, const bool &runWithCuda) +{ + modelPath = onnxModelPath; + modelShape = modelInputShape; + classesPath = classesTxtFile; + cudaEnabled = runWithCuda; + + loadOnnxNetwork(); + // loadClassesFromFile(); The classes are hard-coded for this example +} + +std::vector Inference::runInference(const cv::Mat &input) +{ + cv::Mat modelInput = input; + if (letterBoxForSquare && modelShape.width == modelShape.height) + modelInput = formatToSquare(modelInput); + + cv::Mat blob; + cv::dnn::blobFromImage(modelInput, blob, 1.0/255.0, modelShape, cv::Scalar(), true, false); + net.setInput(blob); + + std::vector outputs; + net.forward(outputs, net.getUnconnectedOutLayersNames()); + + int rows = outputs[0].size[1]; + int dimensions = outputs[0].size[2]; + + bool yolov8 = false; + // yolov5 has an output of shape (batchSize, 25200, 85) (Num classes + box[x,y,w,h] + confidence[c]) + // yolov8 has an output of shape (batchSize, 84, 8400) (Num classes + box[x,y,w,h]) + if (dimensions > rows) // Check if the shape[2] is more than shape[1] (yolov8) + { + yolov8 = true; + rows = outputs[0].size[2]; + dimensions = outputs[0].size[1]; + + outputs[0] = outputs[0].reshape(1, dimensions); + cv::transpose(outputs[0], outputs[0]); + } + float *data = (float *)outputs[0].data; + + float x_factor = modelInput.cols / modelShape.width; + float y_factor = modelInput.rows / modelShape.height; + + std::vector class_ids; + std::vector confidences; + std::vector boxes; + + for (int i = 0; i < rows; ++i) + { + if (yolov8) + { + float *classes_scores = data+4; + + cv::Mat scores(1, classes.size(), CV_32FC1, classes_scores); + cv::Point class_id; + double maxClassScore; + + minMaxLoc(scores, 0, &maxClassScore, 0, &class_id); + + if (maxClassScore > modelScoreThreshold) + { + confidences.push_back(maxClassScore); + class_ids.push_back(class_id.x); + + float x = data[0]; + float y = data[1]; + float w = data[2]; + float h = data[3]; + + int left = int((x - 0.5 * w) * x_factor); + int top = int((y - 0.5 * h) * y_factor); + + int width = int(w * x_factor); + int height = int(h * y_factor); + + boxes.push_back(cv::Rect(left, top, width, height)); + } + } + else // yolov5 + { + float confidence = data[4]; + + if (confidence >= modelConfidenceThreshold) + { + float *classes_scores = data+5; + + cv::Mat scores(1, classes.size(), CV_32FC1, classes_scores); + cv::Point class_id; + double max_class_score; + + minMaxLoc(scores, 0, &max_class_score, 0, &class_id); + + if (max_class_score > modelScoreThreshold) + { + confidences.push_back(confidence); + class_ids.push_back(class_id.x); + + float x = data[0]; + float y = data[1]; + float w = data[2]; + float h = data[3]; + + int left = int((x - 0.5 * w) * x_factor); + int top = int((y - 0.5 * h) * y_factor); + + int width = int(w * x_factor); + int height = int(h * y_factor); + + boxes.push_back(cv::Rect(left, top, width, height)); + } + } + } + + data += dimensions; + } + + std::vector nms_result; + cv::dnn::NMSBoxes(boxes, confidences, modelScoreThreshold, modelNMSThreshold, nms_result); + + std::vector detections{}; + for (unsigned long i = 0; i < nms_result.size(); ++i) + { + int idx = nms_result[i]; + + Detection result; + result.class_id = class_ids[idx]; + result.confidence = confidences[idx]; + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(100, 255); + result.color = cv::Scalar(dis(gen), + dis(gen), + dis(gen)); + + result.className = classes[result.class_id]; + result.box = boxes[idx]; + + detections.push_back(result); + } + + return detections; +} + +void Inference::loadClassesFromFile() +{ + std::ifstream inputFile(classesPath); + if (inputFile.is_open()) + { + std::string classLine; + while (std::getline(inputFile, classLine)) + classes.push_back(classLine); + inputFile.close(); + } +} + +void Inference::loadOnnxNetwork() +{ + net = cv::dnn::readNetFromONNX(modelPath); + if (cudaEnabled) + { + std::cout << "\nRunning on CUDA" << std::endl; + net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA); + net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA); + } + else + { + std::cout << "\nRunning on CPU" << std::endl; + net.setPreferableBackend(cv::dnn::DNN_BACKEND_OPENCV); + net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU); + } +} + +cv::Mat Inference::formatToSquare(const cv::Mat &source) +{ + int col = source.cols; + int row = source.rows; + int _max = MAX(col, row); + cv::Mat result = cv::Mat::zeros(_max, _max, CV_8UC3); + source.copyTo(result(cv::Rect(0, 0, col, row))); + return result; +} diff --git a/examples/YOLOv8-CPP-Inference/inference.h b/examples/YOLOv8-CPP-Inference/inference.h new file mode 100644 index 0000000000000000000000000000000000000000..dc6149f1875654bf52ccc7497deb5ea8b06f57ca --- /dev/null +++ b/examples/YOLOv8-CPP-Inference/inference.h @@ -0,0 +1,52 @@ +#ifndef INFERENCE_H +#define INFERENCE_H + +// Cpp native +#include +#include +#include +#include + +// OpenCV / DNN / Inference +#include +#include +#include + +struct Detection +{ + int class_id{0}; + std::string className{}; + float confidence{0.0}; + cv::Scalar color{}; + cv::Rect box{}; +}; + +class Inference +{ +public: + Inference(const std::string &onnxModelPath, const cv::Size &modelInputShape = {640, 640}, const std::string &classesTxtFile = "", const bool &runWithCuda = true); + std::vector runInference(const cv::Mat &input); + +private: + void loadClassesFromFile(); + void loadOnnxNetwork(); + cv::Mat formatToSquare(const cv::Mat &source); + + std::string modelPath{}; + std::string classesPath{}; + bool cudaEnabled{}; + + std::vector classes{"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"}; + + cv::Size2f modelShape{}; + + float modelConfidenceThreshold {0.25}; + float modelScoreThreshold {0.45}; + float modelNMSThreshold {0.50}; + + bool letterBoxForSquare = true; + + cv::dnn::Net net; +}; + +#endif // INFERENCE_H diff --git a/examples/YOLOv8-CPP-Inference/main.cpp b/examples/YOLOv8-CPP-Inference/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fe040c8634a69fa85cec0b8ab1d8decd954974f8 --- /dev/null +++ b/examples/YOLOv8-CPP-Inference/main.cpp @@ -0,0 +1,70 @@ +#include +#include +#include + +#include + +#include "inference.h" + +using namespace std; +using namespace cv; + +int main(int argc, char **argv) +{ + std::string projectBasePath = "/home/user/ultralytics"; // Set your ultralytics base path + + bool runOnGPU = true; + + // + // Pass in either: + // + // "yolov8s.onnx" or "yolov5s.onnx" + // + // To run Inference with yolov8/yolov5 (ONNX) + // + + // Note that in this example the classes are hard-coded and 'classes.txt' is a place holder. + Inference inf(projectBasePath + "/yolov8s.onnx", cv::Size(640, 640), "classes.txt", runOnGPU); + + std::vector imageNames; + imageNames.push_back(projectBasePath + "/ultralytics/assets/bus.jpg"); + imageNames.push_back(projectBasePath + "/ultralytics/assets/zidane.jpg"); + + for (int i = 0; i < imageNames.size(); ++i) + { + cv::Mat frame = cv::imread(imageNames[i]); + + // Inference starts here... + std::vector output = inf.runInference(frame); + + int detections = output.size(); + std::cout << "Number of detections:" << detections << std::endl; + + for (int i = 0; i < detections; ++i) + { + Detection detection = output[i]; + + cv::Rect box = detection.box; + cv::Scalar color = detection.color; + + // Detection box + cv::rectangle(frame, box, color, 2); + + // Detection box text + std::string classString = detection.className + ' ' + std::to_string(detection.confidence).substr(0, 4); + cv::Size textSize = cv::getTextSize(classString, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0); + cv::Rect textBox(box.x, box.y - 40, textSize.width + 10, textSize.height + 20); + + cv::rectangle(frame, textBox, color, cv::FILLED); + cv::putText(frame, classString, cv::Point(box.x + 5, box.y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0); + } + // Inference ends here... + + // This is only for preview purposes + float scale = 0.8; + cv::resize(frame, frame, cv::Size(frame.cols*scale, frame.rows*scale)); + cv::imshow("Inference", frame); + + cv::waitKey(-1); + } +} diff --git a/examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt b/examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2cbd796c45df36b81c6882dfa57e067f129d9685 --- /dev/null +++ b/examples/YOLOv8-LibTorch-CPP-Inference/CMakeLists.txt @@ -0,0 +1,47 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(yolov8_libtorch_example) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + + +# -------------- OpenCV -------------- +set(OpenCV_DIR "/path/to/opencv/lib/cmake/opencv4") +find_package(OpenCV REQUIRED) + +message(STATUS "OpenCV library status:") +message(STATUS " config: ${OpenCV_DIR}") +message(STATUS " version: ${OpenCV_VERSION}") +message(STATUS " libraries: ${OpenCV_LIBS}") +message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}") + +include_directories(${OpenCV_INCLUDE_DIRS}) + +# -------------- libtorch -------------- +list(APPEND CMAKE_PREFIX_PATH "/path/to/libtorch") +set(Torch_DIR "/path/to/libtorch/share/cmake/Torch") + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +message("${TORCH_LIBRARIES}") +message("${TORCH_INCLUDE_DIRS}") + +# The following code block is suggested to be used on Windows. +# According to https://github.com/pytorch/pytorch/issues/25457, +# the DLLs need to be copied to avoid memory errors. +# if (MSVC) +# file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") +# add_custom_command(TARGET yolov8_libtorch_example +# POST_BUILD +# COMMAND ${CMAKE_COMMAND} -E copy_if_different +# ${TORCH_DLLS} +# $) +# endif (MSVC) + +include_directories(${TORCH_INCLUDE_DIRS}) + +add_executable(yolov8_libtorch_inference "${CMAKE_CURRENT_SOURCE_DIR}/main.cc") +target_link_libraries(yolov8_libtorch_inference ${TORCH_LIBRARIES} ${OpenCV_LIBS}) +set_property(TARGET yolov8_libtorch_inference PROPERTY CXX_STANDARD 17) diff --git a/examples/YOLOv8-LibTorch-CPP-Inference/README.md b/examples/YOLOv8-LibTorch-CPP-Inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1380071ee4a9d7067750c9cdaf2e6ec32fb9aa9b --- /dev/null +++ b/examples/YOLOv8-LibTorch-CPP-Inference/README.md @@ -0,0 +1,35 @@ +# YOLOv8 LibTorch Inference C++ + +This example demonstrates how to perform inference using YOLOv8 models in C++ with LibTorch API. + +## Dependencies + +| Dependency | Version | +| ------------ | -------- | +| OpenCV | >=4.0.0 | +| C++ Standard | >=17 | +| Cmake | >=3.18 | +| Libtorch | >=1.12.1 | + +## Usage + +```bash +git clone ultralytics +cd ultralytics +pip install . +cd examples/YOLOv8-LibTorch-CPP-Inference + +mkdir build +cd build +cmake .. +make +./yolov8_libtorch_inference +``` + +## Exporting YOLOv8 + +To export YOLOv8 models: + +```bash +yolo export model=yolov8s.pt imgsz=640 format=torchscript +``` diff --git a/examples/YOLOv8-LibTorch-CPP-Inference/main.cc b/examples/YOLOv8-LibTorch-CPP-Inference/main.cc new file mode 100644 index 0000000000000000000000000000000000000000..0937b56828e9ca33dd67ebaca7d88ed892881808 --- /dev/null +++ b/examples/YOLOv8-LibTorch-CPP-Inference/main.cc @@ -0,0 +1,260 @@ +#include + +#include +#include +#include +#include +#include + +using torch::indexing::Slice; +using torch::indexing::None; + + +float generate_scale(cv::Mat& image, const std::vector& target_size) { + int origin_w = image.cols; + int origin_h = image.rows; + + int target_h = target_size[0]; + int target_w = target_size[1]; + + float ratio_h = static_cast(target_h) / static_cast(origin_h); + float ratio_w = static_cast(target_w) / static_cast(origin_w); + float resize_scale = std::min(ratio_h, ratio_w); + return resize_scale; +} + + +float letterbox(cv::Mat &input_image, cv::Mat &output_image, const std::vector &target_size) { + if (input_image.cols == target_size[1] && input_image.rows == target_size[0]) { + if (input_image.data == output_image.data) { + return 1.; + } else { + output_image = input_image.clone(); + return 1.; + } + } + + float resize_scale = generate_scale(input_image, target_size); + int new_shape_w = std::round(input_image.cols * resize_scale); + int new_shape_h = std::round(input_image.rows * resize_scale); + float padw = (target_size[1] - new_shape_w) / 2.; + float padh = (target_size[0] - new_shape_h) / 2.; + + int top = std::round(padh - 0.1); + int bottom = std::round(padh + 0.1); + int left = std::round(padw - 0.1); + int right = std::round(padw + 0.1); + + cv::resize(input_image, output_image, + cv::Size(new_shape_w, new_shape_h), + 0, 0, cv::INTER_AREA); + + cv::copyMakeBorder(output_image, output_image, top, bottom, left, right, + cv::BORDER_CONSTANT, cv::Scalar(114.)); + return resize_scale; +} + + +torch::Tensor xyxy2xywh(const torch::Tensor& x) { + auto y = torch::empty_like(x); + y.index_put_({"...", 0}, (x.index({"...", 0}) + x.index({"...", 2})).div(2)); + y.index_put_({"...", 1}, (x.index({"...", 1}) + x.index({"...", 3})).div(2)); + y.index_put_({"...", 2}, x.index({"...", 2}) - x.index({"...", 0})); + y.index_put_({"...", 3}, x.index({"...", 3}) - x.index({"...", 1})); + return y; +} + + +torch::Tensor xywh2xyxy(const torch::Tensor& x) { + auto y = torch::empty_like(x); + auto dw = x.index({"...", 2}).div(2); + auto dh = x.index({"...", 3}).div(2); + y.index_put_({"...", 0}, x.index({"...", 0}) - dw); + y.index_put_({"...", 1}, x.index({"...", 1}) - dh); + y.index_put_({"...", 2}, x.index({"...", 0}) + dw); + y.index_put_({"...", 3}, x.index({"...", 1}) + dh); + return y; +} + + +// Reference: https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/nms_kernel.cpp +torch::Tensor nms(const torch::Tensor& bboxes, const torch::Tensor& scores, float iou_threshold) { + if (bboxes.numel() == 0) + return torch::empty({0}, bboxes.options().dtype(torch::kLong)); + + auto x1_t = bboxes.select(1, 0).contiguous(); + auto y1_t = bboxes.select(1, 1).contiguous(); + auto x2_t = bboxes.select(1, 2).contiguous(); + auto y2_t = bboxes.select(1, 3).contiguous(); + + torch::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); + + auto order_t = std::get<1>( + scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); + + auto ndets = bboxes.size(0); + torch::Tensor suppressed_t = torch::zeros({ndets}, bboxes.options().dtype(torch::kByte)); + torch::Tensor keep_t = torch::zeros({ndets}, bboxes.options().dtype(torch::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + auto x1 = x1_t.data_ptr(); + auto y1 = y1_t.data_ptr(); + auto x2 = x2_t.data_ptr(); + auto y2 = y2_t.data_ptr(); + auto areas = areas_t.data_ptr(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) + continue; + keep[num_to_keep++] = i; + auto ix1 = x1[i]; + auto iy1 = y1[i]; + auto ix2 = x2[i]; + auto iy2 = y2[i]; + auto iarea = areas[i]; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) + continue; + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1); + auto h = std::max(static_cast(0), yy2 - yy1); + auto inter = w * h; + auto ovr = inter / (iarea + areas[j] - inter); + if (ovr > iou_threshold) + suppressed[j] = 1; + } + } + return keep_t.narrow(0, 0, num_to_keep); +} + + +torch::Tensor non_max_suppression(torch::Tensor& prediction, float conf_thres = 0.25, float iou_thres = 0.45, int max_det = 300) { + auto bs = prediction.size(0); + auto nc = prediction.size(1) - 4; + auto nm = prediction.size(1) - nc - 4; + auto mi = 4 + nc; + auto xc = prediction.index({Slice(), Slice(4, mi)}).amax(1) > conf_thres; + + prediction = prediction.transpose(-1, -2); + prediction.index_put_({"...", Slice({None, 4})}, xywh2xyxy(prediction.index({"...", Slice(None, 4)}))); + + std::vector output; + for (int i = 0; i < bs; i++) { + output.push_back(torch::zeros({0, 6 + nm}, prediction.device())); + } + + for (int xi = 0; xi < prediction.size(0); xi++) { + auto x = prediction[xi]; + x = x.index({xc[xi]}); + auto x_split = x.split({4, nc, nm}, 1); + auto box = x_split[0], cls = x_split[1], mask = x_split[2]; + auto [conf, j] = cls.max(1, true); + x = torch::cat({box, conf, j.toType(torch::kFloat), mask}, 1); + x = x.index({conf.view(-1) > conf_thres}); + int n = x.size(0); + if (!n) { continue; } + + // NMS + auto c = x.index({Slice(), Slice{5, 6}}) * 7680; + auto boxes = x.index({Slice(), Slice(None, 4)}) + c; + auto scores = x.index({Slice(), 4}); + auto i = nms(boxes, scores, iou_thres); + i = i.index({Slice(None, max_det)}); + output[xi] = x.index({i}); + } + + return torch::stack(output); +} + + +torch::Tensor clip_boxes(torch::Tensor& boxes, const std::vector& shape) { + boxes.index_put_({"...", 0}, boxes.index({"...", 0}).clamp(0, shape[1])); + boxes.index_put_({"...", 1}, boxes.index({"...", 1}).clamp(0, shape[0])); + boxes.index_put_({"...", 2}, boxes.index({"...", 2}).clamp(0, shape[1])); + boxes.index_put_({"...", 3}, boxes.index({"...", 3}).clamp(0, shape[0])); + return boxes; +} + + +torch::Tensor scale_boxes(const std::vector& img1_shape, torch::Tensor& boxes, const std::vector& img0_shape) { + auto gain = (std::min)((float)img1_shape[0] / img0_shape[0], (float)img1_shape[1] / img0_shape[1]); + auto pad0 = std::round((float)(img1_shape[1] - img0_shape[1] * gain) / 2. - 0.1); + auto pad1 = std::round((float)(img1_shape[0] - img0_shape[0] * gain) / 2. - 0.1); + + boxes.index_put_({"...", 0}, boxes.index({"...", 0}) - pad0); + boxes.index_put_({"...", 2}, boxes.index({"...", 2}) - pad0); + boxes.index_put_({"...", 1}, boxes.index({"...", 1}) - pad1); + boxes.index_put_({"...", 3}, boxes.index({"...", 3}) - pad1); + boxes.index_put_({"...", Slice(None, 4)}, boxes.index({"...", Slice(None, 4)}).div(gain)); + return boxes; +} + + +int main() { + // Device + torch::Device device(torch::cuda::is_available() ? torch::kCUDA :torch::kCPU); + + // Note that in this example the classes are hard-coded + std::vector classes {"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", + "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", + "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", + "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", + "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", + "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", + "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"}; + + try { + // Load the model (e.g. yolov8s.torchscript) + std::string model_path = "/path/to/yolov8s.torchscript"; + torch::jit::script::Module yolo_model; + yolo_model = torch::jit::load(model_path); + yolo_model.eval(); + yolo_model.to(device, torch::kFloat32); + + // Load image and preprocess + cv::Mat image = cv::imread("/path/to/bus.jpg"); + cv::Mat input_image; + letterbox(image, input_image, {640, 640}); + cv::cvtColor(input_image, input_image, cv::COLOR_BGR2RGB); + + torch::Tensor image_tensor = torch::from_blob(input_image.data, {input_image.rows, input_image.cols, 3}, torch::kByte).to(device); + image_tensor = image_tensor.toType(torch::kFloat32).div(255); + image_tensor = image_tensor.permute({2, 0, 1}); + image_tensor = image_tensor.unsqueeze(0); + std::vector inputs {image_tensor}; + + // Inference + torch::Tensor output = yolo_model.forward(inputs).toTensor().cpu(); + + // NMS + auto keep = non_max_suppression(output)[0]; + auto boxes = keep.index({Slice(), Slice(None, 4)}); + keep.index_put_({Slice(), Slice(None, 4)}, scale_boxes({input_image.rows, input_image.cols}, boxes, {image.rows, image.cols})); + + // Show the results + for (int i = 0; i < keep.size(0); i++) { + int x1 = keep[i][0].item().toFloat(); + int y1 = keep[i][1].item().toFloat(); + int x2 = keep[i][2].item().toFloat(); + int y2 = keep[i][3].item().toFloat(); + float conf = keep[i][4].item().toFloat(); + int cls = keep[i][5].item().toInt(); + std::cout << "Rect: [" << x1 << "," << y1 << "," << x2 << "," << y2 << "] Conf: " << conf << " Class: " << classes[cls] << std::endl; + } + } catch (const c10::Error& e) { + std::cout << e.msg() << std::endl; + } + + return 0; +} diff --git a/examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt b/examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a99662a51194c140c6c10a613060bf86dda82072 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-CPP/CMakeLists.txt @@ -0,0 +1,99 @@ +cmake_minimum_required(VERSION 3.5) + +set(PROJECT_NAME Yolov8OnnxRuntimeCPPInference) +project(${PROJECT_NAME} VERSION 0.0.1 LANGUAGES CXX) + + +# -------------- Support C++17 for using filesystem ------------------# +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS ON) +set(CMAKE_INCLUDE_CURRENT_DIR ON) + + +# -------------- OpenCV ------------------# +find_package(OpenCV REQUIRED) +include_directories(${OpenCV_INCLUDE_DIRS}) + + +# -------------- Compile CUDA for FP16 inference if needed ------------------# +option(USE_CUDA "Enable CUDA support" ON) +if (NOT APPLE AND USE_CUDA) + find_package(CUDA REQUIRED) + include_directories(${CUDA_INCLUDE_DIRS}) + add_definitions(-DUSE_CUDA) +else () + set(USE_CUDA OFF) +endif () + +# -------------- ONNXRUNTIME ------------------# + +# Set ONNXRUNTIME_VERSION +set(ONNXRUNTIME_VERSION 1.15.1) + +if (WIN32) + if (USE_CUDA) + set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-win-x64-gpu-${ONNXRUNTIME_VERSION}") + else () + set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-win-x64-${ONNXRUNTIME_VERSION}") + endif () +elseif (LINUX) + if (USE_CUDA) + set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-linux-x64-gpu-${ONNXRUNTIME_VERSION}") + else () + set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-linux-x64-${ONNXRUNTIME_VERSION}") + endif () +elseif (APPLE) + set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-osx-arm64-${ONNXRUNTIME_VERSION}") + # Apple X64 binary + # set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-osx-x64-${ONNXRUNTIME_VERSION}") + # Apple Universal binary + # set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-osx-universal2-${ONNXRUNTIME_VERSION}") +else () + message(SEND_ERROR "Variable ONNXRUNTIME_ROOT is not set properly. Please check if your cmake project \ + is not compiled with `-D WIN32=TRUE`, `-D LINUX=TRUE`, or `-D APPLE=TRUE`!") +endif () + +include_directories(${PROJECT_NAME} ${ONNXRUNTIME_ROOT}/include) + +set(PROJECT_SOURCES + main.cpp + inference.h + inference.cpp +) + +add_executable(${PROJECT_NAME} ${PROJECT_SOURCES}) + +if (WIN32) + target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/onnxruntime.lib) + if (USE_CUDA) + target_link_libraries(${PROJECT_NAME} ${CUDA_LIBRARIES}) + endif () +elseif (LINUX) + target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.so) + if (USE_CUDA) + target_link_libraries(${PROJECT_NAME} ${CUDA_LIBRARIES}) + endif () +elseif (APPLE) + target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.dylib) +endif () + +# For windows system, copy onnxruntime.dll to the same folder of the executable file +if (WIN32) + add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${ONNXRUNTIME_ROOT}/lib/onnxruntime.dll" + $) +endif () + +# Download https://raw.githubusercontent.com/ultralytics/ultralytics/main/ultralytics/cfg/datasets/coco.yaml +# and put it in the same folder of the executable file +configure_file(coco.yaml ${CMAKE_CURRENT_BINARY_DIR}/coco.yaml COPYONLY) + +# Copy yolov8n.onnx file to the same folder of the executable file +configure_file(yolov8n.onnx ${CMAKE_CURRENT_BINARY_DIR}/yolov8n.onnx COPYONLY) + +# Create folder name images in the same folder of the executable file +add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/images +) diff --git a/examples/YOLOv8-ONNXRuntime-CPP/README.md b/examples/YOLOv8-ONNXRuntime-CPP/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ed9d9364292da365e51ddfb3f630c743002099b0 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-CPP/README.md @@ -0,0 +1,120 @@ +# YOLOv8 OnnxRuntime C++ + +C++ Onnx-runtime + +This example demonstrates how to perform inference using YOLOv8 in C++ with ONNX Runtime and OpenCV's API. + +## Benefits ✨ + +- Friendly for deployment in the industrial sector. +- Faster than OpenCV's DNN inference on both CPU and GPU. +- Supports FP32 and FP16 CUDA acceleration. + +## Note ☕ + +1. Benefit for Ultralytics' latest release, a `Transpose` op is added to the YOLOv8 model, while make v8 and v5 has the same output shape. Therefore, you can run inference with YOLOv5/v7/v8 via this project. + +## Exporting YOLOv8 Models 📦 + +To export YOLOv8 models, use the following Python script: + +```python +from ultralytics import YOLO + +# Load a YOLOv8 model +model = YOLO("yolov8n.pt") + +# Export the model +model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640) +``` + +Alternatively, you can use the following command for exporting the model in the terminal + +```bash +yolo export model=yolov8n.pt opset=12 simplify=True dynamic=False format=onnx imgsz=640,640 +``` + +## Exporting YOLOv8 FP16 Models 📦 + +```python +import onnx +from onnxconverter_common import float16 + +model = onnx.load(R"YOUR_ONNX_PATH") +model_fp16 = float16.convert_float_to_float16(model) +onnx.save(model_fp16, R"YOUR_FP16_ONNX_PATH") +``` + +## Download COCO.yaml file 📂 + +In order to run example, you also need to download coco.yaml. You can download the file manually from [here](https://raw.githubusercontent.com/ultralytics/ultralytics/main/ultralytics/cfg/datasets/coco.yaml) + +## Dependencies ⚙️ + +| Dependency | Version | +| -------------------------------- | ------------- | +| Onnxruntime(linux,windows,macos) | >=1.14.1 | +| OpenCV | >=4.0.0 | +| C++ Standard | >=17 | +| Cmake | >=3.5 | +| Cuda (Optional) | >=11.4 \<12.0 | +| cuDNN (Cuda required) | =8 | + +Note: The dependency on C++17 is due to the usage of the C++17 filesystem feature. + +Note (2): Due to ONNX Runtime, we need to use CUDA 11 and cuDNN 8. Keep in mind that this requirement might change in the future. + +## Build 🛠️ + +1. Clone the repository to your local machine. + +2. Navigate to the root directory of the repository. + +3. Create a build directory and navigate to it: + + ```console + mkdir build && cd build + ``` + +4. Run CMake to generate the build files: + + ```console + cmake .. + ``` + + **Notice**: + + If you encounter an error indicating that the `ONNXRUNTIME_ROOT` variable is not set correctly, you can resolve this by building the project using the appropriate command tailored to your system. + + ```console + # compiled in a win32 system + cmake -D WIN32=TRUE .. + # compiled in a linux system + cmake -D LINUX=TRUE .. + # compiled in an apple system + cmake -D APPLE=TRUE .. + ``` + +5. Build the project: + + ```console + make + ``` + +6. The built executable should now be located in the `build` directory. + +## Usage 🚀 + +```c++ +//change your param as you like +//Pay attention to your device and the onnx model type(fp32 or fp16) +DL_INIT_PARAM params; +params.rectConfidenceThreshold = 0.1; +params.iouThreshold = 0.5; +params.modelPath = "yolov8n.onnx"; +params.imgSize = { 640, 640 }; +params.cudaEnable = true; +params.modelType = YOLO_DETECT_V8; +yoloDetector->CreateSession(params); +Detector(yoloDetector); +``` diff --git a/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp b/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a65391f5d7d116befac6f1cb2151891cca5a1c5d --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp @@ -0,0 +1,375 @@ +#include "inference.h" +#include + +#define benchmark +#define min(a,b) (((a) < (b)) ? (a) : (b)) +YOLO_V8::YOLO_V8() { + +} + + +YOLO_V8::~YOLO_V8() { + delete session; +} + +#ifdef USE_CUDA +namespace Ort +{ + template<> + struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; +} +#endif + + +template +char* BlobFromImage(cv::Mat& iImg, T& iBlob) { + int channels = iImg.channels(); + int imgHeight = iImg.rows; + int imgWidth = iImg.cols; + + for (int c = 0; c < channels; c++) + { + for (int h = 0; h < imgHeight; h++) + { + for (int w = 0; w < imgWidth; w++) + { + iBlob[c * imgWidth * imgHeight + h * imgWidth + w] = typename std::remove_pointer::type( + (iImg.at(h, w)[c]) / 255.0f); + } + } + } + return RET_OK; +} + + +char* YOLO_V8::PreProcess(cv::Mat& iImg, std::vector iImgSize, cv::Mat& oImg) +{ + if (iImg.channels() == 3) + { + oImg = iImg.clone(); + cv::cvtColor(oImg, oImg, cv::COLOR_BGR2RGB); + } + else + { + cv::cvtColor(iImg, oImg, cv::COLOR_GRAY2RGB); + } + + switch (modelType) + { + case YOLO_DETECT_V8: + case YOLO_POSE: + case YOLO_DETECT_V8_HALF: + case YOLO_POSE_V8_HALF://LetterBox + { + if (iImg.cols >= iImg.rows) + { + resizeScales = iImg.cols / (float)iImgSize.at(0); + cv::resize(oImg, oImg, cv::Size(iImgSize.at(0), int(iImg.rows / resizeScales))); + } + else + { + resizeScales = iImg.rows / (float)iImgSize.at(0); + cv::resize(oImg, oImg, cv::Size(int(iImg.cols / resizeScales), iImgSize.at(1))); + } + cv::Mat tempImg = cv::Mat::zeros(iImgSize.at(0), iImgSize.at(1), CV_8UC3); + oImg.copyTo(tempImg(cv::Rect(0, 0, oImg.cols, oImg.rows))); + oImg = tempImg; + break; + } + case YOLO_CLS://CenterCrop + { + int h = iImg.rows; + int w = iImg.cols; + int m = min(h, w); + int top = (h - m) / 2; + int left = (w - m) / 2; + cv::resize(oImg(cv::Rect(left, top, m, m)), oImg, cv::Size(iImgSize.at(0), iImgSize.at(1))); + break; + } + } + return RET_OK; +} + + +char* YOLO_V8::CreateSession(DL_INIT_PARAM& iParams) { + char* Ret = RET_OK; + std::regex pattern("[\u4e00-\u9fa5]"); + bool result = std::regex_search(iParams.modelPath, pattern); + if (result) + { + Ret = "[YOLO_V8]:Your model path is error.Change your model path without chinese characters."; + std::cout << Ret << std::endl; + return Ret; + } + try + { + rectConfidenceThreshold = iParams.rectConfidenceThreshold; + iouThreshold = iParams.iouThreshold; + imgSize = iParams.imgSize; + modelType = iParams.modelType; + env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "Yolo"); + Ort::SessionOptions sessionOption; + if (iParams.cudaEnable) + { + cudaEnable = iParams.cudaEnable; + OrtCUDAProviderOptions cudaOption; + cudaOption.device_id = 0; + sessionOption.AppendExecutionProvider_CUDA(cudaOption); + } + sessionOption.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + sessionOption.SetIntraOpNumThreads(iParams.intraOpNumThreads); + sessionOption.SetLogSeverityLevel(iParams.logSeverityLevel); + +#ifdef _WIN32 + int ModelPathSize = MultiByteToWideChar(CP_UTF8, 0, iParams.modelPath.c_str(), static_cast(iParams.modelPath.length()), nullptr, 0); + wchar_t* wide_cstr = new wchar_t[ModelPathSize + 1]; + MultiByteToWideChar(CP_UTF8, 0, iParams.modelPath.c_str(), static_cast(iParams.modelPath.length()), wide_cstr, ModelPathSize); + wide_cstr[ModelPathSize] = L'\0'; + const wchar_t* modelPath = wide_cstr; +#else + const char* modelPath = iParams.modelPath.c_str(); +#endif // _WIN32 + + session = new Ort::Session(env, modelPath, sessionOption); + Ort::AllocatorWithDefaultOptions allocator; + size_t inputNodesNum = session->GetInputCount(); + for (size_t i = 0; i < inputNodesNum; i++) + { + Ort::AllocatedStringPtr input_node_name = session->GetInputNameAllocated(i, allocator); + char* temp_buf = new char[50]; + strcpy(temp_buf, input_node_name.get()); + inputNodeNames.push_back(temp_buf); + } + size_t OutputNodesNum = session->GetOutputCount(); + for (size_t i = 0; i < OutputNodesNum; i++) + { + Ort::AllocatedStringPtr output_node_name = session->GetOutputNameAllocated(i, allocator); + char* temp_buf = new char[10]; + strcpy(temp_buf, output_node_name.get()); + outputNodeNames.push_back(temp_buf); + } + options = Ort::RunOptions{ nullptr }; + WarmUpSession(); + return RET_OK; + } + catch (const std::exception& e) + { + const char* str1 = "[YOLO_V8]:"; + const char* str2 = e.what(); + std::string result = std::string(str1) + std::string(str2); + char* merged = new char[result.length() + 1]; + std::strcpy(merged, result.c_str()); + std::cout << merged << std::endl; + delete[] merged; + return "[YOLO_V8]:Create session failed."; + } + +} + + +char* YOLO_V8::RunSession(cv::Mat& iImg, std::vector& oResult) { +#ifdef benchmark + clock_t starttime_1 = clock(); +#endif // benchmark + + char* Ret = RET_OK; + cv::Mat processedImg; + PreProcess(iImg, imgSize, processedImg); + if (modelType < 4) + { + float* blob = new float[processedImg.total() * 3]; + BlobFromImage(processedImg, blob); + std::vector inputNodeDims = { 1, 3, imgSize.at(0), imgSize.at(1) }; + TensorProcess(starttime_1, iImg, blob, inputNodeDims, oResult); + } + else + { +#ifdef USE_CUDA + half* blob = new half[processedImg.total() * 3]; + BlobFromImage(processedImg, blob); + std::vector inputNodeDims = { 1,3,imgSize.at(0),imgSize.at(1) }; + TensorProcess(starttime_1, iImg, blob, inputNodeDims, oResult); +#endif + } + + return Ret; +} + + +template +char* YOLO_V8::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std::vector& inputNodeDims, + std::vector& oResult) { + Ort::Value inputTensor = Ort::Value::CreateTensor::type>( + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), + inputNodeDims.data(), inputNodeDims.size()); +#ifdef benchmark + clock_t starttime_2 = clock(); +#endif // benchmark + auto outputTensor = session->Run(options, inputNodeNames.data(), &inputTensor, 1, outputNodeNames.data(), + outputNodeNames.size()); +#ifdef benchmark + clock_t starttime_3 = clock(); +#endif // benchmark + + Ort::TypeInfo typeInfo = outputTensor.front().GetTypeInfo(); + auto tensor_info = typeInfo.GetTensorTypeAndShapeInfo(); + std::vector outputNodeDims = tensor_info.GetShape(); + auto output = outputTensor.front().GetTensorMutableData::type>(); + delete[] blob; + switch (modelType) + { + case YOLO_DETECT_V8: + case YOLO_DETECT_V8_HALF: + { + int signalResultNum = outputNodeDims[1];//84 + int strideNum = outputNodeDims[2];//8400 + std::vector class_ids; + std::vector confidences; + std::vector boxes; + cv::Mat rawData; + if (modelType == YOLO_DETECT_V8) + { + // FP32 + rawData = cv::Mat(signalResultNum, strideNum, CV_32F, output); + } + else + { + // FP16 + rawData = cv::Mat(signalResultNum, strideNum, CV_16F, output); + rawData.convertTo(rawData, CV_32F); + } + // Note: + // ultralytics add transpose operator to the output of yolov8 model.which make yolov8/v5/v7 has same shape + // https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt + rawData = rawData.t(); + + float* data = (float*)rawData.data; + + for (int i = 0; i < strideNum; ++i) + { + float* classesScores = data + 4; + cv::Mat scores(1, this->classes.size(), CV_32FC1, classesScores); + cv::Point class_id; + double maxClassScore; + cv::minMaxLoc(scores, 0, &maxClassScore, 0, &class_id); + if (maxClassScore > rectConfidenceThreshold) + { + confidences.push_back(maxClassScore); + class_ids.push_back(class_id.x); + float x = data[0]; + float y = data[1]; + float w = data[2]; + float h = data[3]; + + int left = int((x - 0.5 * w) * resizeScales); + int top = int((y - 0.5 * h) * resizeScales); + + int width = int(w * resizeScales); + int height = int(h * resizeScales); + + boxes.push_back(cv::Rect(left, top, width, height)); + } + data += signalResultNum; + } + std::vector nmsResult; + cv::dnn::NMSBoxes(boxes, confidences, rectConfidenceThreshold, iouThreshold, nmsResult); + for (int i = 0; i < nmsResult.size(); ++i) + { + int idx = nmsResult[i]; + DL_RESULT result; + result.classId = class_ids[idx]; + result.confidence = confidences[idx]; + result.box = boxes[idx]; + oResult.push_back(result); + } + +#ifdef benchmark + clock_t starttime_4 = clock(); + double pre_process_time = (double)(starttime_2 - starttime_1) / CLOCKS_PER_SEC * 1000; + double process_time = (double)(starttime_3 - starttime_2) / CLOCKS_PER_SEC * 1000; + double post_process_time = (double)(starttime_4 - starttime_3) / CLOCKS_PER_SEC * 1000; + if (cudaEnable) + { + std::cout << "[YOLO_V8(CUDA)]: " << pre_process_time << "ms pre-process, " << process_time << "ms inference, " << post_process_time << "ms post-process." << std::endl; + } + else + { + std::cout << "[YOLO_V8(CPU)]: " << pre_process_time << "ms pre-process, " << process_time << "ms inference, " << post_process_time << "ms post-process." << std::endl; + } +#endif // benchmark + + break; + } + case YOLO_CLS: + case YOLO_CLS_HALF: + { + cv::Mat rawData; + if (modelType == YOLO_CLS) { + // FP32 + rawData = cv::Mat(1, this->classes.size(), CV_32F, output); + } else { + // FP16 + rawData = cv::Mat(1, this->classes.size(), CV_16F, output); + rawData.convertTo(rawData, CV_32F); + } + float *data = (float *) rawData.data; + + DL_RESULT result; + for (int i = 0; i < this->classes.size(); i++) + { + result.classId = i; + result.confidence = data[i]; + oResult.push_back(result); + } + break; + } + default: + std::cout << "[YOLO_V8]: " << "Not support model type." << std::endl; + } + return RET_OK; + +} + + +char* YOLO_V8::WarmUpSession() { + clock_t starttime_1 = clock(); + cv::Mat iImg = cv::Mat(cv::Size(imgSize.at(0), imgSize.at(1)), CV_8UC3); + cv::Mat processedImg; + PreProcess(iImg, imgSize, processedImg); + if (modelType < 4) + { + float* blob = new float[iImg.total() * 3]; + BlobFromImage(processedImg, blob); + std::vector YOLO_input_node_dims = { 1, 3, imgSize.at(0), imgSize.at(1) }; + Ort::Value input_tensor = Ort::Value::CreateTensor( + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), + YOLO_input_node_dims.data(), YOLO_input_node_dims.size()); + auto output_tensors = session->Run(options, inputNodeNames.data(), &input_tensor, 1, outputNodeNames.data(), + outputNodeNames.size()); + delete[] blob; + clock_t starttime_4 = clock(); + double post_process_time = (double)(starttime_4 - starttime_1) / CLOCKS_PER_SEC * 1000; + if (cudaEnable) + { + std::cout << "[YOLO_V8(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl; + } + } + else + { +#ifdef USE_CUDA + half* blob = new half[iImg.total() * 3]; + BlobFromImage(processedImg, blob); + std::vector YOLO_input_node_dims = { 1,3,imgSize.at(0),imgSize.at(1) }; + Ort::Value input_tensor = Ort::Value::CreateTensor(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), YOLO_input_node_dims.data(), YOLO_input_node_dims.size()); + auto output_tensors = session->Run(options, inputNodeNames.data(), &input_tensor, 1, outputNodeNames.data(), outputNodeNames.size()); + delete[] blob; + clock_t starttime_4 = clock(); + double post_process_time = (double)(starttime_4 - starttime_1) / CLOCKS_PER_SEC * 1000; + if (cudaEnable) + { + std::cout << "[YOLO_V8(CUDA)]: " << "Cuda warm-up cost " << post_process_time << " ms. " << std::endl; + } +#endif + } + return RET_OK; +} diff --git a/examples/YOLOv8-ONNXRuntime-CPP/inference.h b/examples/YOLOv8-ONNXRuntime-CPP/inference.h new file mode 100644 index 0000000000000000000000000000000000000000..3a9d029ccf434e68fb482ce12d83bdba2405dd1b --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-CPP/inference.h @@ -0,0 +1,94 @@ +#pragma once + +#define RET_OK nullptr + +#ifdef _WIN32 +#include +#include +#include +#endif + +#include +#include +#include +#include +#include "onnxruntime_cxx_api.h" + +#ifdef USE_CUDA +#include +#endif + + +enum MODEL_TYPE +{ + //FLOAT32 MODEL + YOLO_DETECT_V8 = 1, + YOLO_POSE = 2, + YOLO_CLS = 3, + + //FLOAT16 MODEL + YOLO_DETECT_V8_HALF = 4, + YOLO_POSE_V8_HALF = 5, + YOLO_CLS_HALF = 6 +}; + + +typedef struct _DL_INIT_PARAM +{ + std::string modelPath; + MODEL_TYPE modelType = YOLO_DETECT_V8; + std::vector imgSize = { 640, 640 }; + float rectConfidenceThreshold = 0.6; + float iouThreshold = 0.5; + int keyPointsNum = 2;//Note:kpt number for pose + bool cudaEnable = false; + int logSeverityLevel = 3; + int intraOpNumThreads = 1; +} DL_INIT_PARAM; + + +typedef struct _DL_RESULT +{ + int classId; + float confidence; + cv::Rect box; + std::vector keyPoints; +} DL_RESULT; + + +class YOLO_V8 +{ +public: + YOLO_V8(); + + ~YOLO_V8(); + +public: + char* CreateSession(DL_INIT_PARAM& iParams); + + char* RunSession(cv::Mat& iImg, std::vector& oResult); + + char* WarmUpSession(); + + template + char* TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std::vector& inputNodeDims, + std::vector& oResult); + + char* PreProcess(cv::Mat& iImg, std::vector iImgSize, cv::Mat& oImg); + + std::vector classes{}; + +private: + Ort::Env env; + Ort::Session* session; + bool cudaEnable; + Ort::RunOptions options; + std::vector inputNodeNames; + std::vector outputNodeNames; + + MODEL_TYPE modelType; + std::vector imgSize; + float rectConfidenceThreshold; + float iouThreshold; + float resizeScales;//letterbox scale +}; diff --git a/examples/YOLOv8-ONNXRuntime-CPP/main.cpp b/examples/YOLOv8-ONNXRuntime-CPP/main.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e4ef1ddffbf3d0527dd62928596d4d12f81018c --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-CPP/main.cpp @@ -0,0 +1,193 @@ +#include +#include +#include "inference.h" +#include +#include +#include + +void Detector(YOLO_V8*& p) { + std::filesystem::path current_path = std::filesystem::current_path(); + std::filesystem::path imgs_path = current_path / "images"; + for (auto& i : std::filesystem::directory_iterator(imgs_path)) + { + if (i.path().extension() == ".jpg" || i.path().extension() == ".png" || i.path().extension() == ".jpeg") + { + std::string img_path = i.path().string(); + cv::Mat img = cv::imread(img_path); + std::vector res; + p->RunSession(img, res); + + for (auto& re : res) + { + cv::RNG rng(cv::getTickCount()); + cv::Scalar color(rng.uniform(0, 256), rng.uniform(0, 256), rng.uniform(0, 256)); + + cv::rectangle(img, re.box, color, 3); + + float confidence = floor(100 * re.confidence) / 100; + std::cout << std::fixed << std::setprecision(2); + std::string label = p->classes[re.classId] + " " + + std::to_string(confidence).substr(0, std::to_string(confidence).size() - 4); + + cv::rectangle( + img, + cv::Point(re.box.x, re.box.y - 25), + cv::Point(re.box.x + label.length() * 15, re.box.y), + color, + cv::FILLED + ); + + cv::putText( + img, + label, + cv::Point(re.box.x, re.box.y - 5), + cv::FONT_HERSHEY_SIMPLEX, + 0.75, + cv::Scalar(0, 0, 0), + 2 + ); + + + } + std::cout << "Press any key to exit" << std::endl; + cv::imshow("Result of Detection", img); + cv::waitKey(0); + cv::destroyAllWindows(); + } + } +} + + +void Classifier(YOLO_V8*& p) +{ + std::filesystem::path current_path = std::filesystem::current_path(); + std::filesystem::path imgs_path = current_path;// / "images" + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, 255); + for (auto& i : std::filesystem::directory_iterator(imgs_path)) + { + if (i.path().extension() == ".jpg" || i.path().extension() == ".png") + { + std::string img_path = i.path().string(); + //std::cout << img_path << std::endl; + cv::Mat img = cv::imread(img_path); + std::vector res; + char* ret = p->RunSession(img, res); + + float positionY = 50; + for (int i = 0; i < res.size(); i++) + { + int r = dis(gen); + int g = dis(gen); + int b = dis(gen); + cv::putText(img, std::to_string(i) + ":", cv::Point(10, positionY), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2); + cv::putText(img, std::to_string(res.at(i).confidence), cv::Point(70, positionY), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2); + positionY += 50; + } + + cv::imshow("TEST_CLS", img); + cv::waitKey(0); + cv::destroyAllWindows(); + //cv::imwrite("E:\\output\\" + std::to_string(k) + ".png", img); + } + + } +} + + + +int ReadCocoYaml(YOLO_V8*& p) { + // Open the YAML file + std::ifstream file("coco.yaml"); + if (!file.is_open()) + { + std::cerr << "Failed to open file" << std::endl; + return 1; + } + + // Read the file line by line + std::string line; + std::vector lines; + while (std::getline(file, line)) + { + lines.push_back(line); + } + + // Find the start and end of the names section + std::size_t start = 0; + std::size_t end = 0; + for (std::size_t i = 0; i < lines.size(); i++) + { + if (lines[i].find("names:") != std::string::npos) + { + start = i + 1; + } + else if (start > 0 && lines[i].find(':') == std::string::npos) + { + end = i; + break; + } + } + + // Extract the names + std::vector names; + for (std::size_t i = start; i < end; i++) + { + std::stringstream ss(lines[i]); + std::string name; + std::getline(ss, name, ':'); // Extract the number before the delimiter + std::getline(ss, name); // Extract the string after the delimiter + names.push_back(name); + } + + p->classes = names; + return 0; +} + + +void DetectTest() +{ + YOLO_V8* yoloDetector = new YOLO_V8; + ReadCocoYaml(yoloDetector); + DL_INIT_PARAM params; + params.rectConfidenceThreshold = 0.1; + params.iouThreshold = 0.5; + params.modelPath = "yolov8n.onnx"; + params.imgSize = { 640, 640 }; +#ifdef USE_CUDA + params.cudaEnable = true; + + // GPU FP32 inference + params.modelType = YOLO_DETECT_V8; + // GPU FP16 inference + //Note: change fp16 onnx model + //params.modelType = YOLO_DETECT_V8_HALF; + +#else + // CPU inference + params.modelType = YOLO_DETECT_V8; + params.cudaEnable = false; + +#endif + yoloDetector->CreateSession(params); + Detector(yoloDetector); +} + + +void ClsTest() +{ + YOLO_V8* yoloDetector = new YOLO_V8; + std::string model_path = "cls.onnx"; + ReadCocoYaml(yoloDetector); + DL_INIT_PARAM params{ model_path, YOLO_CLS, {224, 224} }; + yoloDetector->CreateSession(params); + Classifier(yoloDetector); +} + + +int main() +{ + //DetectTest(); + ClsTest(); +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml b/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..8eb421a86a13496d53cb88887f05b67e02a0375a --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +[package] +name = "yolov8-rs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +clap = { version = "4.2.4", features = ["derive"] } +image = { version = "0.25.2"} +imageproc = { version = "0.25.0"} +ndarray = { version = "0.16" } +ort = { version = "2.0.0-rc.5", features = ["cuda", "tensorrt", "load-dynamic", "copy-dylibs", "half"]} +rusttype = { version = "0.9.3" } +anyhow = { version = "1.0.75" } +regex = { version = "1.5.4" } +rand = { version = "0.8.5" } +chrono = { version = "0.4.30" } +half = { version = "2.3.1" } +dirs = { version = "5.0.1" } +ureq = { version = "2.9.1" } +ab_glyph = "0.2.29" diff --git a/examples/YOLOv8-ONNXRuntime-Rust/README.md b/examples/YOLOv8-ONNXRuntime-Rust/README.md new file mode 100644 index 0000000000000000000000000000000000000000..ec09edbf6557f869ebed045da77d7477a3f90073 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/README.md @@ -0,0 +1,212 @@ +# YOLOv8-ONNXRuntime-Rust for All the Key YOLO Tasks + +This repository provides a Rust demo for performing YOLOv8 tasks like `Classification`, `Segmentation`, `Detection`, `Pose Detection` and `OBB` using ONNXRuntime. + +## Recently Updated + +- Add YOLOv8-OBB demo +- Update ONNXRuntime to 1.19.x + +Newly updated YOLOv8 example code is located in [this repository](https://github.com/jamjamjon/usls/tree/main/examples/yolo) + +## Features + +- Support `Classification`, `Segmentation`, `Detection`, `Pose(Keypoints)-Detection`, `OBB` tasks. +- Support `FP16` & `FP32` ONNX models. +- Support `CPU`, `CUDA` and `TensorRT` execution provider to accelerate computation. +- Support dynamic input shapes(`batch`, `width`, `height`). + +## Installation + +### 1. Install Rust + +Please follow the Rust official installation. (https://www.rust-lang.org/tools/install) + +### 2. ONNXRuntime Linking + +- #### For detailed setup instructions, refer to the [ORT documentation](https://ort.pyke.io/setup/linking). + +- #### For Linux or macOS Users: + - Download the ONNX Runtime package from the [Releases page](https://github.com/microsoft/onnxruntime/releases). + - Set up the library path by exporting the `ORT_DYLIB_PATH` environment variable: + ```shell + export ORT_DYLIB_PATH=/path/to/onnxruntime/lib/libonnxruntime.so.1.19.0 + ``` + +### 3. \[Optional\] Install CUDA & CuDNN & TensorRT + +- CUDA execution provider requires CUDA v11.6+. +- TensorRT execution provider requires CUDA v11.4+ and TensorRT v8.4+. + +## Get Started + +### 1. Export the YOLOv8 ONNX Models + +```bash +pip install -U ultralytics + +# export onnx model with dynamic shapes +yolo export model=yolov8m.pt format=onnx simplify dynamic +yolo export model=yolov8m-cls.pt format=onnx simplify dynamic +yolo export model=yolov8m-pose.pt format=onnx simplify dynamic +yolo export model=yolov8m-seg.pt format=onnx simplify dynamic + + +# export onnx model with constant shapes +yolo export model=yolov8m.pt format=onnx simplify +yolo export model=yolov8m-cls.pt format=onnx simplify +yolo export model=yolov8m-pose.pt format=onnx simplify +yolo export model=yolov8m-seg.pt format=onnx simplify +``` + +### 2. Run Inference + +It will perform inference with the ONNX model on the source image. + +```bash +cargo run --release -- --model --source +``` + +Set `--cuda` to use CUDA execution provider to speed up inference. + +```bash +cargo run --release -- --cuda --model --source +``` + +Set `--trt` to use TensorRT execution provider, and you can set `--fp16` at the same time to use TensorRT FP16 engine. + +```bash +cargo run --release -- --trt --fp16 --model --source +``` + +Set `--device_id` to select which device to run. When you have only one GPU, and you set `device_id` to 1 will not cause program panic, the `ort` would automatically fall back to `CPU` EP. + +```bash +cargo run --release -- --cuda --device_id 0 --model --source +``` + +Set `--batch` to do multi-batch-size inference. + +If you're using `--trt`, you can also set `--batch-min` and `--batch-max` to explicitly specify min/max/opt batch for dynamic batch input.(https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#explicit-shape-range-for-dynamic-shape-input).(Note that the ONNX model should be exported with dynamic shapes.) + +```bash +cargo run --release -- --cuda --batch 2 --model --source +``` + +Set `--height` and `--width` to do dynamic image size inference. (Note that the ONNX model should be exported with dynamic shapes.) + +```bash +cargo run --release -- --cuda --width 480 --height 640 --model --source +``` + +Set `--profile` to check time consumed in each stage.(Note that the model usually needs to take 1~3 times dry run to warmup. Make sure to run enough times to evaluate the result.) + +```bash +cargo run --release -- --trt --fp16 --profile --model --source +``` + +Results: (yolov8m.onnx, batch=1, 3 times, trt, fp16, RTX 3060Ti) + +```bash +==> 0 +[Model Preprocess]: 12.75788ms +[ORT H2D]: 237.118µs +[ORT Inference]: 507.895469ms +[ORT D2H]: 191.655µs +[Model Inference]: 508.34589ms +[Model Postprocess]: 1.061122ms +==> 1 +[Model Preprocess]: 13.658655ms +[ORT H2D]: 209.975µs +[ORT Inference]: 5.12372ms +[ORT D2H]: 182.389µs +[Model Inference]: 5.530022ms +[Model Postprocess]: 1.04851ms +==> 2 +[Model Preprocess]: 12.475332ms +[ORT H2D]: 246.127µs +[ORT Inference]: 5.048432ms +[ORT D2H]: 187.117µs +[Model Inference]: 5.493119ms +[Model Postprocess]: 1.040906ms +``` + +And also: + +`--conf`: confidence threshold \[default: 0.3\] + +`--iou`: iou threshold in NMS \[default: 0.45\] + +`--kconf`: confidence threshold of keypoint \[default: 0.55\] + +`--plot`: plot inference result with random RGB color and save + +you can check out all CLI arguments by: + +```bash +git clone https://github.com/ultralytics/ultralytics +cd ultralytics/examples/YOLOv8-ONNXRuntime-Rust +cargo run --release -- --help +``` + +## Examples + +![Ultralytics YOLO Tasks](https://raw.githubusercontent.com/ultralytics/assets/main/im/banner-tasks.png) + +### Classification + +Running dynamic shape ONNX model on `CPU` with image size `--height 224 --width 224`. Saving plotted image in `runs` directory. + +```bash +cargo run --release -- --model ../assets/weights/yolov8m-cls-dyn.onnx --source ../assets/images/dog.jpg --height 224 --width 224 --plot --profile +``` + +You will see result like: + +```bash +Summary: +> Task: Classify (Ultralytics 8.0.217) +> EP: Cpu +> Dtype: Float32 +> Batch: 1 (Dynamic), Height: 224 (Dynamic), Width: 224 (Dynamic) +> nc: 1000 nk: 0, nm: 0, conf: 0.3, kconf: 0.55, iou: 0.45 + +[Model Preprocess]: 16.363477ms +[ORT H2D]: 50.722µs +[ORT Inference]: 16.295808ms +[ORT D2H]: 8.37µs +[Model Inference]: 16.367046ms +[Model Postprocess]: 3.527µs +[ + YOLOResult { + Probs(top5): Some([(208, 0.6950566), (209, 0.13823675), (178, 0.04849795), (215, 0.019029364), (212, 0.016506357)]), + Bboxes: None, + Keypoints: None, + Masks: None, + }, +] +``` + +### Object Detection + +Using `CUDA` EP and dynamic image size `--height 640 --width 480` + +```bash +cargo run --release -- --cuda --model ../assets/weights/yolov8m-dynamic.onnx --source ../assets/images/bus.jpg --plot --height 640 --width 480 +``` + +### Pose Detection + +using `TensorRT` EP + +```bash +cargo run --release -- --trt --model ../assets/weights/yolov8m-pose.onnx --source ../assets/images/bus.jpg --plot +``` + +### Instance Segmentation + +using `TensorRT` EP and FP16 model `--fp16` + +```bash +cargo run --release -- --trt --fp16 --model ../assets/weights/yolov8m-seg.onnx --source ../assets/images/0172.jpg --plot +``` diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs new file mode 100644 index 0000000000000000000000000000000000000000..b5bc05a585a0fc4b52dd671a101e66f716b45f7c --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs @@ -0,0 +1,87 @@ +use clap::Parser; + +use crate::YOLOTask; + +#[derive(Parser, Clone)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// ONNX model path + #[arg(long, required = true)] + pub model: String, + + /// input path + #[arg(long, required = true)] + pub source: String, + + /// device id + #[arg(long, default_value_t = 0)] + pub device_id: i32, + + /// using TensorRT EP + #[arg(long)] + pub trt: bool, + + /// using CUDA EP + #[arg(long)] + pub cuda: bool, + + /// input batch size + #[arg(long, default_value_t = 1)] + pub batch: u32, + + /// trt input min_batch size + #[arg(long, default_value_t = 1)] + pub batch_min: u32, + + /// trt input max_batch size + #[arg(long, default_value_t = 32)] + pub batch_max: u32, + + /// using TensorRT --fp16 + #[arg(long)] + pub fp16: bool, + + /// specify YOLO task + #[arg(long, value_enum)] + pub task: Option, + + /// num_classes + #[arg(long)] + pub nc: Option, + + /// num_keypoints + #[arg(long)] + pub nk: Option, + + /// num_masks + #[arg(long)] + pub nm: Option, + + /// input image width + #[arg(long)] + pub width: Option, + + /// input image height + #[arg(long)] + pub height: Option, + + /// confidence threshold + #[arg(long, required = false, default_value_t = 0.3)] + pub conf: f32, + + /// iou threshold in NMS + #[arg(long, required = false, default_value_t = 0.45)] + pub iou: f32, + + /// confidence threshold of keypoint + #[arg(long, required = false, default_value_t = 0.55)] + pub kconf: f32, + + /// plot inference result and save + #[arg(long)] + pub plot: bool, + + /// check time consumed in each stage + #[arg(long)] + pub profile: bool, +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..0084535ee577384e81357a8b35a36463e10387b6 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs @@ -0,0 +1,160 @@ +#![allow(clippy::type_complexity)] + +use std::io::{Read, Write}; + +pub mod cli; +pub mod model; +pub mod ort_backend; +pub mod yolo_result; +pub use crate::cli::Args; +pub use crate::model::YOLOv8; +pub use crate::ort_backend::{Batch, OrtBackend, OrtConfig, OrtEP, YOLOTask}; +pub use crate::yolo_result::{Bbox, Embedding, Point2, YOLOResult}; + +pub fn non_max_suppression( + xs: &mut Vec<(Bbox, Option>, Option>)>, + iou_threshold: f32, +) { + xs.sort_by(|b1, b2| b2.0.confidence().partial_cmp(&b1.0.confidence()).unwrap()); + + let mut current_index = 0; + for index in 0..xs.len() { + let mut drop = false; + for prev_index in 0..current_index { + let iou = xs[prev_index].0.iou(&xs[index].0); + if iou > iou_threshold { + drop = true; + break; + } + } + if !drop { + xs.swap(current_index, index); + current_index += 1; + } + } + xs.truncate(current_index); +} + +pub fn gen_time_string(delimiter: &str) -> String { + let offset = chrono::FixedOffset::east_opt(8 * 60 * 60).unwrap(); // Beijing + let t_now = chrono::Utc::now().with_timezone(&offset); + let fmt = format!( + "%Y{}%m{}%d{}%H{}%M{}%S{}%f", + delimiter, delimiter, delimiter, delimiter, delimiter, delimiter + ); + t_now.format(&fmt).to_string() +} + +pub const SKELETON: [(usize, usize); 16] = [ + (0, 1), + (0, 2), + (1, 3), + (2, 4), + (5, 6), + (5, 11), + (6, 12), + (11, 12), + (5, 7), + (6, 8), + (7, 9), + (8, 10), + (11, 13), + (12, 14), + (13, 15), + (14, 16), +]; + +pub fn check_font(font: &str) -> rusttype::Font<'static> { + // check then load font + + // ultralytics font path + let font_path_config = match dirs::config_dir() { + Some(mut d) => { + d.push("Ultralytics"); + d.push(font); + d + } + None => panic!("Unsupported operating system. Now support Linux, MacOS, Windows."), + }; + + // current font path + let font_path_current = std::path::PathBuf::from(font); + + // check font + let font_path = if font_path_config.exists() { + font_path_config + } else if font_path_current.exists() { + font_path_current + } else { + println!("Downloading font..."); + let source_url = "https://ultralytics.com/assets/Arial.ttf"; + let resp = ureq::get(source_url) + .timeout(std::time::Duration::from_secs(500)) + .call() + .unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}")); + + // read to buffer + let mut buffer = vec![]; + let total_size = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + let _reader = resp + .into_reader() + .take(total_size) + .read_to_end(&mut buffer) + .unwrap(); + + // save + let _path = std::fs::File::create(font).unwrap(); + let mut writer = std::io::BufWriter::new(_path); + writer.write_all(&buffer).unwrap(); + println!("Font saved at: {:?}", font_path_current.display()); + font_path_current + }; + + // load font + let buffer = std::fs::read(font_path).unwrap(); + rusttype::Font::try_from_vec(buffer).unwrap() +} + +use ab_glyph::FontArc; +pub fn load_font() -> FontArc { + use std::path::Path; + let font_path = Path::new("./font/Arial.ttf"); + match font_path.try_exists() { + Ok(true) => { + let buffer = std::fs::read(font_path).unwrap(); + FontArc::try_from_vec(buffer).unwrap() + } + Ok(false) => { + std::fs::create_dir_all("./font").unwrap(); + println!("Downloading font..."); + let source_url = "https://ultralytics.com/assets/Arial.ttf"; + let resp = ureq::get(source_url) + .timeout(std::time::Duration::from_secs(500)) + .call() + .unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}")); + + // read to buffer + let mut buffer = vec![]; + let total_size = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + let _reader = resp + .into_reader() + .take(total_size) + .read_to_end(&mut buffer) + .unwrap(); + // save + let mut fd = std::fs::File::create(font_path).unwrap(); + fd.write_all(&buffer).unwrap(); + println!("Font saved at: {:?}", font_path.display()); + FontArc::try_from_vec(buffer).unwrap() + } + Err(e) => { + panic!("Failed to load font {}", e); + } + } +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs new file mode 100644 index 0000000000000000000000000000000000000000..fd3845ced08aebc8bf94e514e117db6cfcd5c6bc --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs @@ -0,0 +1,28 @@ +use clap::Parser; + +use yolov8_rs::{Args, YOLOv8}; + +fn main() -> Result<(), Box> { + let args = Args::parse(); + + // 1. load image + let x = image::ImageReader::open(&args.source)? + .with_guessed_format()? + .decode()?; + + // 2. model support dynamic batch inference, so input should be a Vec + let xs = vec![x]; + + // You can test `--batch 2` with this + // let xs = vec![x.clone(), x]; + + // 3. build yolov8 model + let mut model = YOLOv8::new(args)?; + model.summary(); // model info + + // 4. run + let ys = model.run(&xs)?; + println!("{:?}", ys); + + Ok(()) +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs new file mode 100644 index 0000000000000000000000000000000000000000..95b2bdfffaaec5ab9d2fe37f4302f0e6a785d10b --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs @@ -0,0 +1,651 @@ +#![allow(clippy::type_complexity)] + +use ab_glyph::FontArc; +use anyhow::Result; +use image::{DynamicImage, GenericImageView, ImageBuffer}; +use ndarray::{s, Array, Axis, IxDyn}; +use rand::{thread_rng, Rng}; +use std::path::PathBuf; + +use crate::{ + gen_time_string, load_font, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend, + OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON, +}; + +pub struct YOLOv8 { + // YOLOv8 model for all yolo-tasks + engine: OrtBackend, + nc: u32, + nk: u32, + nm: u32, + height: u32, + width: u32, + batch: u32, + task: YOLOTask, + conf: f32, + kconf: f32, + iou: f32, + names: Vec, + color_palette: Vec<(u8, u8, u8)>, + profile: bool, + plot: bool, +} + +impl YOLOv8 { + pub fn new(config: Args) -> Result { + // execution provider + let ep = if config.trt { + OrtEP::Trt(config.device_id) + } else if config.cuda { + OrtEP::CUDA(config.device_id) + } else { + OrtEP::CPU + }; + + // batch + let batch = Batch { + opt: config.batch, + min: config.batch_min, + max: config.batch_max, + }; + + // build ort engine + let ort_args = OrtConfig { + ep, + batch, + f: config.model, + task: config.task, + trt_fp16: config.fp16, + image_size: (config.height, config.width), + }; + let engine = OrtBackend::build(ort_args)?; + + // get batch, height, width, tasks, nc, nk, nm + let (batch, height, width, task) = ( + engine.batch(), + engine.height(), + engine.width(), + engine.task(), + ); + let nc = engine.nc().or(config.nc).unwrap_or_else(|| { + panic!("Failed to get num_classes, make it explicit with `--nc`"); + }); + let (nk, nm) = match task { + YOLOTask::Pose => { + let nk = engine.nk().or(config.nk).unwrap_or_else(|| { + panic!("Failed to get num_keypoints, make it explicit with `--nk`"); + }); + (nk, 0) + } + YOLOTask::Segment => { + let nm = engine.nm().or(config.nm).unwrap_or_else(|| { + panic!("Failed to get num_masks, make it explicit with `--nm`"); + }); + (0, nm) + } + _ => (0, 0), + }; + + // class names + let names = engine.names().unwrap_or(vec!["Unknown".to_string()]); + + // color palette + let mut rng = thread_rng(); + let color_palette: Vec<_> = names + .iter() + .map(|_| { + ( + rng.gen_range(0..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + ) + }) + .collect(); + + Ok(Self { + engine, + names, + conf: config.conf, + kconf: config.kconf, + iou: config.iou, + color_palette, + profile: config.profile, + plot: config.plot, + nc, + nk, + nm, + height, + width, + batch, + task, + }) + } + + pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) { + let r = (w1 / w0).min(h1 / h0); + (r, (w0 * r).round(), (h0 * r).round()) + } + + pub fn preprocess(&mut self, xs: &Vec) -> Result> { + let mut ys = + Array::ones((xs.len(), 3, self.height() as usize, self.width() as usize)).into_dyn(); + ys.fill(144.0 / 255.0); + for (idx, x) in xs.iter().enumerate() { + let img = match self.task() { + YOLOTask::Classify => x.resize_exact( + self.width(), + self.height(), + image::imageops::FilterType::Triangle, + ), + _ => { + let (w0, h0) = x.dimensions(); + let w0 = w0 as f32; + let h0 = h0 as f32; + let (_, w_new, h_new) = + self.scale_wh(w0, h0, self.width() as f32, self.height() as f32); // f32 round + x.resize_exact( + w_new as u32, + h_new as u32, + if let YOLOTask::Segment = self.task() { + image::imageops::FilterType::CatmullRom + } else { + image::imageops::FilterType::Triangle + }, + ) + } + }; + + for (x, y, rgb) in img.pixels() { + let x = x as usize; + let y = y as usize; + let [r, g, b, _] = rgb.0; + ys[[idx, 0, y, x]] = (r as f32) / 255.0; + ys[[idx, 1, y, x]] = (g as f32) / 255.0; + ys[[idx, 2, y, x]] = (b as f32) / 255.0; + } + } + + Ok(ys) + } + + pub fn run(&mut self, xs: &Vec) -> Result> { + // pre-process + let t_pre = std::time::Instant::now(); + let xs_ = self.preprocess(xs)?; + if self.profile { + println!("[Model Preprocess]: {:?}", t_pre.elapsed()); + } + + // run + let t_run = std::time::Instant::now(); + let ys = self.engine.run(xs_, self.profile)?; + if self.profile { + println!("[Model Inference]: {:?}", t_run.elapsed()); + } + + // post-process + let t_post = std::time::Instant::now(); + let ys = self.postprocess(ys, xs)?; + if self.profile { + println!("[Model Postprocess]: {:?}", t_post.elapsed()); + } + + // plot and save + if self.plot { + self.plot_and_save(&ys, xs, Some(&SKELETON)); + } + Ok(ys) + } + + pub fn postprocess( + &self, + xs: Vec>, + xs0: &[DynamicImage], + ) -> Result> { + if let YOLOTask::Classify = self.task() { + let mut ys = Vec::new(); + let preds = &xs[0]; + for batch in preds.axis_iter(Axis(0)) { + ys.push(YOLOResult::new( + Some(Embedding::new(batch.into_owned())), + None, + None, + None, + )); + } + Ok(ys) + } else { + const CXYWH_OFFSET: usize = 4; // cxcywh + const KPT_STEP: usize = 3; // xyconf + let preds = &xs[0]; + let protos = { + if xs.len() > 1 { + Some(&xs[1]) + } else { + None + } + }; + let mut ys = Vec::new(); + for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() { + // [bs, 4 + nc + nm, anchors] + // input image + let width_original = xs0[idx].width() as f32; + let height_original = xs0[idx].height() as f32; + let ratio = (self.width() as f32 / width_original) + .min(self.height() as f32 / height_original); + + // save each result + let mut data: Vec<(Bbox, Option>, Option>)> = Vec::new(); + for pred in anchor.axis_iter(Axis(1)) { + // split preds for different tasks + let bbox = pred.slice(s![0..CXYWH_OFFSET]); + let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc() as usize]); + let kpts = { + if let YOLOTask::Pose = self.task() { + Some(pred.slice(s![pred.len() - KPT_STEP * self.nk() as usize..])) + } else { + None + } + }; + let coefs = { + if let YOLOTask::Segment = self.task() { + Some(pred.slice(s![pred.len() - self.nm() as usize..]).to_vec()) + } else { + None + } + }; + + // confidence and id + let (id, &confidence) = clss + .into_iter() + .enumerate() + .reduce(|max, x| if x.1 > max.1 { x } else { max }) + .unwrap(); // definitely will not panic! + + // confidence filter + if confidence < self.conf { + continue; + } + + // bbox re-scale + let cx = bbox[0] / ratio; + let cy = bbox[1] / ratio; + let w = bbox[2] / ratio; + let h = bbox[3] / ratio; + let x = cx - w / 2.; + let y = cy - h / 2.; + let y_bbox = Bbox::new( + x.max(0.0f32).min(width_original), + y.max(0.0f32).min(height_original), + w, + h, + id, + confidence, + ); + + // kpts + let y_kpts = { + if let Some(kpts) = kpts { + let mut kpts_ = Vec::new(); + // rescale + for i in 0..self.nk() as usize { + let kx = kpts[KPT_STEP * i] / ratio; + let ky = kpts[KPT_STEP * i + 1] / ratio; + let kconf = kpts[KPT_STEP * i + 2]; + if kconf < self.kconf { + kpts_.push(Point2::default()); + } else { + kpts_.push(Point2::new_with_conf( + kx.max(0.0f32).min(width_original), + ky.max(0.0f32).min(height_original), + kconf, + )); + } + } + Some(kpts_) + } else { + None + } + }; + + // data merged + data.push((y_bbox, y_kpts, coefs)); + } + + // nms + non_max_suppression(&mut data, self.iou); + + // decode + let mut y_bboxes: Vec = Vec::new(); + let mut y_kpts: Vec> = Vec::new(); + let mut y_masks: Vec> = Vec::new(); + for elem in data.into_iter() { + if let Some(kpts) = elem.1 { + y_kpts.push(kpts) + } + + // decode masks + if let Some(coefs) = elem.2 { + let proto = protos.unwrap().slice(s![idx, .., .., ..]); + let (nm, nh, nw) = proto.dim(); + + // coefs * proto -> mask + let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm) + + let proto = proto.to_owned(); + let proto = proto.to_shape((nm, nh * nw))?; // (nm, nh*nw) + let mask = coefs.dot(&proto); // (nh, nw, n) + let mask = mask.to_shape((nh, nw, 1))?; + + // build image from ndarray + let mask_im: ImageBuffer, Vec> = + match ImageBuffer::from_raw( + nw as u32, + nh as u32, + mask.to_owned().into_raw_vec_and_offset().0, + ) { + Some(image) => image, + None => panic!("can not create image from ndarray"), + }; + let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn + + // rescale masks + let (_, w_mask, h_mask) = + self.scale_wh(width_original, height_original, nw as f32, nh as f32); + let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32); + let mask_original = mask_cropped.resize_exact( + // resize_to_fill + width_original as u32, + height_original as u32, + match self.task() { + YOLOTask::Segment => image::imageops::FilterType::CatmullRom, + _ => image::imageops::FilterType::Triangle, + }, + ); + + // crop-mask with bbox + let mut mask_original_cropped = mask_original.into_luma8(); + for y in 0..height_original as usize { + for x in 0..width_original as usize { + if x < elem.0.xmin() as usize + || x > elem.0.xmax() as usize + || y < elem.0.ymin() as usize + || y > elem.0.ymax() as usize + { + mask_original_cropped.put_pixel( + x as u32, + y as u32, + image::Luma([0u8]), + ); + } + } + } + y_masks.push(mask_original_cropped.into_raw()); + } + y_bboxes.push(elem.0); + } + + // save each result + let y = YOLOResult { + probs: None, + bboxes: if !y_bboxes.is_empty() { + Some(y_bboxes) + } else { + None + }, + keypoints: if !y_kpts.is_empty() { + Some(y_kpts) + } else { + None + }, + masks: if !y_masks.is_empty() { + Some(y_masks) + } else { + None + }, + }; + ys.push(y); + } + + Ok(ys) + } + } + + pub fn plot_and_save( + &self, + ys: &[YOLOResult], + xs0: &[DynamicImage], + skeletons: Option<&[(usize, usize)]>, + ) { + // check font then load + let font: FontArc = load_font(); + for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() { + let mut img = img0.to_rgb8(); + + // draw for classifier + if let Some(probs) = y.probs() { + for (i, k) in probs.topk(5).iter().enumerate() { + let legend = format!("{} {:.2}%", self.names[k.0], k.1); + let scale = 32; + let legend_size = img.width().max(img.height()) / scale; + let x = img.width() / 20; + let y = img.height() / 20 + i as u32 * legend_size; + + imageproc::drawing::draw_text_mut( + &mut img, + image::Rgb([0, 255, 0]), + x as i32, + y as i32, + legend_size as f32, + &font, + &legend, + ); + } + } + + // draw bboxes & keypoints + if let Some(bboxes) = y.bboxes() { + for (_idx, bbox) in bboxes.iter().enumerate() { + // rect + imageproc::drawing::draw_hollow_rect_mut( + &mut img, + imageproc::rect::Rect::at(bbox.xmin() as i32, bbox.ymin() as i32) + .of_size(bbox.width() as u32, bbox.height() as u32), + image::Rgb(self.color_palette[bbox.id()].into()), + ); + + // text + let legend = format!("{} {:.2}%", self.names[bbox.id()], bbox.confidence()); + let scale = 40; + let legend_size = img.width().max(img.height()) / scale; + imageproc::drawing::draw_text_mut( + &mut img, + image::Rgb(self.color_palette[bbox.id()].into()), + bbox.xmin() as i32, + (bbox.ymin() - legend_size as f32) as i32, + legend_size as f32, + &font, + &legend, + ); + } + } + + // draw kpts + if let Some(keypoints) = y.keypoints() { + for kpts in keypoints.iter() { + for kpt in kpts.iter() { + // filter + if kpt.confidence() < self.kconf { + continue; + } + + // draw point + imageproc::drawing::draw_filled_circle_mut( + &mut img, + (kpt.x() as i32, kpt.y() as i32), + 2, + image::Rgb([0, 255, 0]), + ); + } + + // draw skeleton if has + if let Some(skeletons) = skeletons { + for &(idx1, idx2) in skeletons.iter() { + let kpt1 = &kpts[idx1]; + let kpt2 = &kpts[idx2]; + if kpt1.confidence() < self.kconf || kpt2.confidence() < self.kconf { + continue; + } + imageproc::drawing::draw_line_segment_mut( + &mut img, + (kpt1.x(), kpt1.y()), + (kpt2.x(), kpt2.y()), + image::Rgb([233, 14, 57]), + ); + } + } + } + } + + // draw mask + if let Some(masks) = y.masks() { + for (mask, _bbox) in masks.iter().zip(y.bboxes().unwrap().iter()) { + let mask_nd: ImageBuffer, Vec> = + match ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec()) { + Some(image) => image, + None => panic!("can not crate image from ndarray"), + }; + + for _x in 0..img.width() { + for _y in 0..img.height() { + let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y); + if mask_p.0[0] > 0 { + let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, _x, _y); + // img_p.0[2] = self.color_palette[bbox.id()].2 / 2; + // img_p.0[1] = self.color_palette[bbox.id()].1 / 2; + // img_p.0[0] = self.color_palette[bbox.id()].0 / 2; + img_p.0[2] /= 2; + img_p.0[1] = 255 - (255 - img_p.0[2]) / 2; + img_p.0[0] /= 2; + imageproc::drawing::Canvas::draw_pixel(&mut img, _x, _y, img_p) + } + } + } + } + } + + // mkdir and save + let mut runs = PathBuf::from("runs"); + if !runs.exists() { + std::fs::create_dir_all(&runs).unwrap(); + } + runs.push(gen_time_string("-")); + let saveout = format!("{}.jpg", runs.to_str().unwrap()); + let _ = img.save(saveout); + } + } + + pub fn summary(&self) { + println!( + "\nSummary:\n\ + > Task: {:?}{}\n\ + > EP: {:?} {}\n\ + > Dtype: {:?}\n\ + > Batch: {} ({}), Height: {} ({}), Width: {} ({})\n\ + > nc: {} nk: {}, nm: {}, conf: {}, kconf: {}, iou: {}\n\ + ", + self.task(), + match self.engine.author().zip(self.engine.version()) { + Some((author, ver)) => format!(" ({} {})", author, ver), + None => String::from(""), + }, + self.engine.ep(), + if let OrtEP::CPU = self.engine.ep() { + "" + } else { + "(May still fall back to CPU)" + }, + self.engine.dtype(), + self.batch(), + if self.engine.is_batch_dynamic() { + "Dynamic" + } else { + "Const" + }, + self.height(), + if self.engine.is_height_dynamic() { + "Dynamic" + } else { + "Const" + }, + self.width(), + if self.engine.is_width_dynamic() { + "Dynamic" + } else { + "Const" + }, + self.nc(), + self.nk(), + self.nm(), + self.conf, + self.kconf, + self.iou, + ); + } + + pub fn engine(&self) -> &OrtBackend { + &self.engine + } + + pub fn conf(&self) -> f32 { + self.conf + } + + pub fn set_conf(&mut self, val: f32) { + self.conf = val; + } + + pub fn conf_mut(&mut self) -> &mut f32 { + &mut self.conf + } + + pub fn kconf(&self) -> f32 { + self.kconf + } + + pub fn iou(&self) -> f32 { + self.iou + } + + pub fn task(&self) -> &YOLOTask { + &self.task + } + + pub fn batch(&self) -> u32 { + self.batch + } + + pub fn width(&self) -> u32 { + self.width + } + + pub fn height(&self) -> u32 { + self.height + } + + pub fn nc(&self) -> u32 { + self.nc + } + + pub fn nk(&self) -> u32 { + self.nk + } + + pub fn nm(&self) -> u32 { + self.nm + } + + pub fn names(&self) -> &Vec { + &self.names + } +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs new file mode 100644 index 0000000000000000000000000000000000000000..d88208dead34f24d3fa9143cf7c1d1a8e93ffa62 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs @@ -0,0 +1,553 @@ +use anyhow::Result; +use clap::ValueEnum; +use half::f16; +use ndarray::{Array, CowArray, IxDyn}; +use ort::{ + CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch, + TensorRTExecutionProvider, +}; +use ort::{Session, SessionBuilder}; +use ort::{TensorElementType, ValueType}; +use regex::Regex; +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +pub enum YOLOTask { + // YOLO tasks + Classify, + Detect, + Pose, + Segment, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum OrtEP { + // ONNXRuntime execution provider + CPU, + CUDA(i32), + Trt(i32), +} + +#[derive(Debug)] +pub struct Batch { + pub opt: u32, + pub min: u32, + pub max: u32, +} + +impl Default for Batch { + fn default() -> Self { + Self { + opt: 1, + min: 1, + max: 1, + } + } +} + +#[derive(Debug, Default)] +pub struct OrtInputs { + // ONNX model inputs attrs + pub shapes: Vec>, + //pub dtypes: Vec, + pub dtypes: Vec, + pub names: Vec, + pub sizes: Vec>, +} + +impl OrtInputs { + pub fn new(session: &Session) -> Self { + let mut shapes = Vec::new(); + let mut dtypes = Vec::new(); + let mut names = Vec::new(); + for i in session.inputs.iter() { + /* let shape: Vec = i + .dimensions() + .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) + .collect(); + shapes.push(shape); */ + if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type { + dtypes.push(ty.clone()); + let shape = dimensions.clone(); + shapes.push(shape); + } else { + panic!("不支持的数据格式, {} - {}", file!(), line!()); + } + //dtypes.push(i.input_type); + names.push(i.name.clone()); + } + Self { + shapes, + dtypes, + names, + ..Default::default() + } + } +} + +#[derive(Debug)] +pub struct OrtConfig { + // ORT config + pub f: String, + pub task: Option, + pub ep: OrtEP, + pub trt_fp16: bool, + pub batch: Batch, + pub image_size: (Option, Option), +} + +#[derive(Debug)] +pub struct OrtBackend { + // ORT engine + session: Session, + task: YOLOTask, + ep: OrtEP, + batch: Batch, + inputs: OrtInputs, +} + +impl OrtBackend { + pub fn build(args: OrtConfig) -> Result { + // build env & session + // in version 2.x environment is removed + /* let env = ort::EnvironmentBuilder + ::with_name("YOLOv8") + .build()? + .into_arc(); */ + let sessionbuilder = SessionBuilder::new()?; + let session = sessionbuilder.commit_from_file(&args.f)?; + //let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?; + + // get inputs + let mut inputs = OrtInputs::new(&session); + + // batch size + let mut batch = args.batch; + let batch = if inputs.shapes[0][0] == -1 { + batch + } else { + assert_eq!( + inputs.shapes[0][0] as u32, batch.opt, + "Expected batch size: {}, got {}. Try using `--batch {}`.", + inputs.shapes[0][0] as u32, batch.opt, inputs.shapes[0][0] as u32 + ); + batch.opt = inputs.shapes[0][0] as u32; + batch + }; + + // input size: height and width + let height = if inputs.shapes[0][2] == -1 { + match args.image_size.0 { + Some(height) => height, + None => panic!("Failed to get model height. Make it explicit with `--height`"), + } + } else { + inputs.shapes[0][2] as u32 + }; + let width = if inputs.shapes[0][3] == -1 { + match args.image_size.1 { + Some(width) => width, + None => panic!("Failed to get model width. Make it explicit with `--width`"), + } + } else { + inputs.shapes[0][3] as u32 + }; + inputs.sizes.push(vec![height, width]); + + // build provider + let (ep, provider) = match args.ep { + OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id), + OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs), + _ => ( + OrtEP::CPU, + ExecutionProviderDispatch::from(CPUExecutionProvider::default()), + ), + }; + + // build session again with the new provider + let session = SessionBuilder::new()? + // .with_optimization_level(ort::GraphOptimizationLevel::Level3)? + .with_execution_providers([provider])? + .commit_from_file(args.f)?; + + // task: using given one or guessing + let task = match args.task { + Some(task) => task, + None => match session.metadata() { + Err(_) => panic!("No metadata found. Try making it explicit by `--task`"), + Ok(metadata) => match metadata.custom("task") { + Err(_) => panic!("Can not get custom value. Try making it explicit by `--task`"), + Ok(value) => match value { + None => panic!("No corresponding value of `task` found in metadata. Make it explicit by `--task`"), + Some(task) => match task.as_str() { + "classify" => YOLOTask::Classify, + "detect" => YOLOTask::Detect, + "pose" => YOLOTask::Pose, + "segment" => YOLOTask::Segment, + x => todo!("{:?} is not supported for now!", x), + }, + }, + }, + }, + }; + + Ok(Self { + session, + task, + ep, + batch, + inputs, + }) + } + + pub fn fetch_inputs_from_session( + session: &Session, + ) -> (Vec>, Vec, Vec) { + // get inputs attrs from ONNX model + let mut shapes = Vec::new(); + let mut dtypes = Vec::new(); + let mut names = Vec::new(); + for i in session.inputs.iter() { + if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type { + dtypes.push(ty.clone()); + let shape = dimensions.clone(); + shapes.push(shape); + } else { + panic!("不支持的数据格式, {} - {}", file!(), line!()); + } + names.push(i.name.clone()); + } + (shapes, dtypes, names) + } + + pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) { + let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id); + if let Ok(true) = cuda_provider.is_available() { + ( + OrtEP::CUDA(device_id), + ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider) + ) + } else { + println!("> CUDA is not available! Using CPU."); + ( + OrtEP::CPU, + ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default()) + ) + } + } + + pub fn set_ep_trt( + device_id: i32, + fp16: bool, + batch: &Batch, + inputs: &OrtInputs, + ) -> (OrtEP, ExecutionProviderDispatch) { + // set TensorRT + let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id); + + //trt_provider. + if let Ok(true) = trt_provider.is_available() { + let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]); + if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 { + panic!( + "Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`", + inputs.dtypes[0] + ); + } + // dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,... + let mut opt_string = String::new(); + let mut min_string = String::new(); + let mut max_string = String::new(); + for name in inputs.names.iter() { + let s_opt = format!("{}:{}x3x{}x{},", name, batch.opt, height, width); + let s_min = format!("{}:{}x3x{}x{},", name, batch.min, height, width); + let s_max = format!("{}:{}x3x{}x{},", name, batch.max, height, width); + opt_string.push_str(s_opt.as_str()); + min_string.push_str(s_min.as_str()); + max_string.push_str(s_max.as_str()); + } + let _ = opt_string.pop(); + let _ = min_string.pop(); + let _ = max_string.pop(); + + let trt_provider = trt_provider + .with_profile_opt_shapes(opt_string) + .with_profile_min_shapes(min_string) + .with_profile_max_shapes(max_string) + .with_fp16(fp16) + .with_timing_cache(true); + ( + OrtEP::Trt(device_id), + ExecutionProviderDispatch::from(trt_provider), + ) + } else { + println!("> TensorRT is not available! Try using CUDA..."); + Self::set_ep_cuda(device_id) + } + } + + pub fn fetch_from_metadata(&self, key: &str) -> Option { + // fetch value from onnx model file by key + match self.session.metadata() { + Err(_) => None, + Ok(metadata) => match metadata.custom(key) { + Err(_) => None, + Ok(value) => value, + }, + } + } + + pub fn run(&self, xs: Array, profile: bool) -> Result>> { + // ORT inference + match self.dtype() { + TensorElementType::Float16 => self.run_fp16(xs, profile), + TensorElementType::Float32 => self.run_fp32(xs, profile), + _ => todo!(), + } + } + + pub fn run_fp16(&self, xs: Array, profile: bool) -> Result>> { + // f32->f16 + let t = std::time::Instant::now(); + let xs = xs.mapv(f16::from_f32); + if profile { + println!("[ORT f32->f16]: {:?}", t.elapsed()); + } + + // h2d + let t = std::time::Instant::now(); + let xs = CowArray::from(xs); + if profile { + println!("[ORT H2D]: {:?}", t.elapsed()); + } + + // run + let t = std::time::Instant::now(); + let ys = self.session.run(ort::inputs![xs.view()]?)?; + if profile { + println!("[ORT Inference]: {:?}", t.elapsed()); + } + + // d2h + Ok(ys + .iter() + .map(|(_k, v)| { + // d2h + let t = std::time::Instant::now(); + let v = v.try_extract_tensor().unwrap(); + //let v = v.try_extract::<_>().unwrap().view().clone().into_owned(); + if profile { + println!("[ORT D2H]: {:?}", t.elapsed()); + } + + // f16->f32 + let t_ = std::time::Instant::now(); + let v = v.mapv(f16::to_f32); + if profile { + println!("[ORT f16->f32]: {:?}", t_.elapsed()); + } + v + }) + .collect::>>()) + } + + pub fn run_fp32(&self, xs: Array, profile: bool) -> Result>> { + // h2d + let t = std::time::Instant::now(); + let xs = CowArray::from(xs); + if profile { + println!("[ORT H2D]: {:?}", t.elapsed()); + } + + // run + let t = std::time::Instant::now(); + let ys = self.session.run(ort::inputs![xs.view()]?)?; + if profile { + println!("[ORT Inference]: {:?}", t.elapsed()); + } + + // d2h + Ok(ys + .iter() + .map(|(_k, v)| { + let t = std::time::Instant::now(); + let v = v.try_extract_tensor::().unwrap().into_owned(); + //let x = x.try_extract::<_>().unwrap().view().clone().into_owned(); + if profile { + println!("[ORT D2H]: {:?}", t.elapsed()); + } + v + }) + .collect::>>()) + } + + pub fn output_shapes(&self) -> Vec> { + let mut shapes = Vec::new(); + for output in &self.session.outputs { + if let ValueType::Tensor { ty: _, dimensions } = &output.output_type { + let shape = dimensions.clone(); + shapes.push(shape); + } else { + panic!("not support data format, {} - {}", file!(), line!()); + } + } + shapes + } + + pub fn output_dtypes(&self) -> Vec { + let mut dtypes = Vec::new(); + for output in &self.session.outputs { + if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type { + dtypes.push(ty.clone()); + } else { + panic!("not support data format, {} - {}", file!(), line!()); + } + } + dtypes + } + + pub fn input_shapes(&self) -> &Vec> { + &self.inputs.shapes + } + + pub fn input_names(&self) -> &Vec { + &self.inputs.names + } + + pub fn input_dtypes(&self) -> &Vec { + &self.inputs.dtypes + } + + pub fn dtype(&self) -> TensorElementType { + self.input_dtypes()[0] + } + + pub fn height(&self) -> u32 { + self.inputs.sizes[0][0] + } + + pub fn width(&self) -> u32 { + self.inputs.sizes[0][1] + } + + pub fn is_height_dynamic(&self) -> bool { + self.input_shapes()[0][2] == -1 + } + + pub fn is_width_dynamic(&self) -> bool { + self.input_shapes()[0][3] == -1 + } + + pub fn batch(&self) -> u32 { + self.batch.opt + } + + pub fn is_batch_dynamic(&self) -> bool { + self.input_shapes()[0][0] == -1 + } + + pub fn ep(&self) -> &OrtEP { + &self.ep + } + + pub fn task(&self) -> YOLOTask { + self.task.clone() + } + + pub fn names(&self) -> Option> { + // class names, metadata parsing + // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}` + match self.fetch_from_metadata("names") { + Some(names) => { + let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap(); + let mut names_ = vec![]; + for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) { + names_.push(name.to_string()); + } + Some(names_) + } + None => None, + } + } + + pub fn nk(&self) -> Option { + // num_keypoints, metadata parsing: String `nk` in onnx model: `[17, 3]` + match self.fetch_from_metadata("kpt_shape") { + None => None, + Some(kpt_string) => { + let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap(); + let caps = re.captures(&kpt_string).unwrap(); + Some(caps.get(1).unwrap().as_str().parse::().unwrap()) + } + } + } + + pub fn nc(&self) -> Option { + // num_classes + match self.names() { + // by names + Some(names) => Some(names.len() as u32), + None => match self.task() { + // by task calculation + YOLOTask::Classify => Some(self.output_shapes()[0][1] as u32), + YOLOTask::Detect => { + if self.output_shapes()[0][1] == -1 { + None + } else { + // cxywhclss + Some(self.output_shapes()[0][1] as u32 - 4) + } + } + YOLOTask::Pose => { + match self.nk() { + None => None, + Some(nk) => { + if self.output_shapes()[0][1] == -1 { + None + } else { + // cxywhclss3*kpt + Some(self.output_shapes()[0][1] as u32 - 4 - 3 * nk) + } + } + } + } + YOLOTask::Segment => { + if self.output_shapes()[0][1] == -1 { + None + } else { + // cxywhclssnm + Some((self.output_shapes()[0][1] - self.output_shapes()[1][1]) as u32 - 4) + } + } + }, + } + } + + pub fn nm(&self) -> Option { + // num_masks + match self.task() { + YOLOTask::Segment => Some(self.output_shapes()[1][1] as u32), + _ => None, + } + } + + pub fn na(&self) -> Option { + // num_anchors + match self.task() { + YOLOTask::Segment | YOLOTask::Detect | YOLOTask::Pose => { + if self.output_shapes()[0][2] == -1 { + None + } else { + Some(self.output_shapes()[0][2] as u32) + } + } + _ => None, + } + } + + pub fn author(&self) -> Option { + self.fetch_from_metadata("author") + } + + pub fn version(&self) -> Option { + self.fetch_from_metadata("version") + } +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/yolo_result.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/yolo_result.rs new file mode 100644 index 0000000000000000000000000000000000000000..2fcc6d860274174e107addb71df6d4f104486439 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/yolo_result.rs @@ -0,0 +1,235 @@ +use ndarray::{Array, Axis, IxDyn}; + +#[derive(Clone, PartialEq, Default)] +pub struct YOLOResult { + // YOLO tasks results of an image + pub probs: Option, + pub bboxes: Option>, + pub keypoints: Option>>, + pub masks: Option>>, +} + +impl std::fmt::Debug for YOLOResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("YOLOResult") + .field( + "Probs(top5)", + &format_args!("{:?}", self.probs().map(|probs| probs.topk(5))), + ) + .field("Bboxes", &self.bboxes) + .field("Keypoints", &self.keypoints) + .field( + "Masks", + &format_args!("{:?}", self.masks().map(|masks| masks.len())), + ) + .finish() + } +} + +impl YOLOResult { + pub fn new( + probs: Option, + bboxes: Option>, + keypoints: Option>>, + masks: Option>>, + ) -> Self { + Self { + probs, + bboxes, + keypoints, + masks, + } + } + + pub fn probs(&self) -> Option<&Embedding> { + self.probs.as_ref() + } + + pub fn keypoints(&self) -> Option<&Vec>> { + self.keypoints.as_ref() + } + + pub fn masks(&self) -> Option<&Vec>> { + self.masks.as_ref() + } + + pub fn bboxes(&self) -> Option<&Vec> { + self.bboxes.as_ref() + } + + pub fn bboxes_mut(&mut self) -> Option<&mut Vec> { + self.bboxes.as_mut() + } +} + +#[derive(Debug, PartialEq, Clone, Default)] +pub struct Point2 { + // A point2d with x, y, conf + x: f32, + y: f32, + confidence: f32, +} + +impl Point2 { + pub fn new_with_conf(x: f32, y: f32, confidence: f32) -> Self { + Self { x, y, confidence } + } + + pub fn new(x: f32, y: f32) -> Self { + Self { + x, + y, + ..Default::default() + } + } + + pub fn x(&self) -> f32 { + self.x + } + + pub fn y(&self) -> f32 { + self.y + } + + pub fn confidence(&self) -> f32 { + self.confidence + } +} + +#[derive(Debug, Clone, PartialEq, Default)] +pub struct Embedding { + // An float32 n-dims tensor + data: Array, +} + +impl Embedding { + pub fn new(data: Array) -> Self { + Self { data } + } + + pub fn data(&self) -> &Array { + &self.data + } + + pub fn topk(&self, k: usize) -> Vec<(usize, f32)> { + let mut probs = self + .data + .iter() + .enumerate() + .map(|(a, b)| (a, *b)) + .collect::>(); + probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + let mut topk = Vec::new(); + for &(id, confidence) in probs.iter().take(k) { + topk.push((id, confidence)); + } + topk + } + + pub fn norm(&self) -> Array { + let std_ = self.data.mapv(|x| x * x).sum_axis(Axis(0)).mapv(f32::sqrt); + self.data.clone() / std_ + } + + pub fn top1(&self) -> (usize, f32) { + self.topk(1)[0] + } +} + +#[derive(Debug, Clone, PartialEq, Default)] +pub struct Bbox { + // a bounding box around an object + xmin: f32, + ymin: f32, + width: f32, + height: f32, + id: usize, + confidence: f32, +} + +impl Bbox { + pub fn new_from_xywh(xmin: f32, ymin: f32, width: f32, height: f32) -> Self { + Self { + xmin, + ymin, + width, + height, + ..Default::default() + } + } + + pub fn new(xmin: f32, ymin: f32, width: f32, height: f32, id: usize, confidence: f32) -> Self { + Self { + xmin, + ymin, + width, + height, + id, + confidence, + } + } + + pub fn width(&self) -> f32 { + self.width + } + + pub fn height(&self) -> f32 { + self.height + } + + pub fn xmin(&self) -> f32 { + self.xmin + } + + pub fn ymin(&self) -> f32 { + self.ymin + } + + pub fn xmax(&self) -> f32 { + self.xmin + self.width + } + + pub fn ymax(&self) -> f32 { + self.ymin + self.height + } + + pub fn tl(&self) -> Point2 { + Point2::new(self.xmin, self.ymin) + } + + pub fn br(&self) -> Point2 { + Point2::new(self.xmax(), self.ymax()) + } + + pub fn cxcy(&self) -> Point2 { + Point2::new(self.xmin + self.width / 2., self.ymin + self.height / 2.) + } + + pub fn id(&self) -> usize { + self.id + } + + pub fn confidence(&self) -> f32 { + self.confidence + } + + pub fn area(&self) -> f32 { + self.width * self.height + } + + pub fn intersection_area(&self, another: &Bbox) -> f32 { + let l = self.xmin.max(another.xmin); + let r = (self.xmin + self.width).min(another.xmin + another.width); + let t = self.ymin.max(another.ymin); + let b = (self.ymin + self.height).min(another.ymin + another.height); + (r - l + 1.).max(0.) * (b - t + 1.).max(0.) + } + + pub fn union(&self, another: &Bbox) -> f32 { + self.area() + another.area() - self.intersection_area(another) + } + + pub fn iou(&self, another: &Bbox) -> f32 { + self.intersection_area(another) / self.union(another) + } +} diff --git a/examples/YOLOv8-ONNXRuntime/README.md b/examples/YOLOv8-ONNXRuntime/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b206b2e2be97aec2255915c13ef3358ce450a4ba --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime/README.md @@ -0,0 +1,43 @@ +# YOLOv8 - ONNX Runtime + +This project implements YOLOv8 using ONNX Runtime. + +## Installation + +To run this project, you need to install the required dependencies. The following instructions will guide you through the installation process. + +### Installing Required Dependencies + +You can install the required dependencies by running the following command: + +```bash +pip install -r requirements.txt +``` + +### Installing `onnxruntime-gpu` + +If you have an NVIDIA GPU and want to leverage GPU acceleration, you can install the onnxruntime-gpu package using the following command: + +```bash +pip install onnxruntime-gpu +``` + +Note: Make sure you have the appropriate GPU drivers installed on your system. + +### Installing `onnxruntime` (CPU version) + +If you don't have an NVIDIA GPU or prefer to use the CPU version of onnxruntime, you can install the onnxruntime package using the following command: + +```bash +pip install onnxruntime +``` + +### Usage + +After successfully installing the required packages, you can run the YOLOv8 implementation using the following command: + +```bash +python main.py --model yolov8n.onnx --img image.jpg --conf-thres 0.5 --iou-thres 0.5 +``` + +Make sure to replace yolov8n.onnx with the path to your YOLOv8 ONNX model file, image.jpg with the path to your input image, and adjust the confidence threshold (conf-thres) and IoU threshold (iou-thres) values as needed. diff --git a/examples/YOLOv8-ONNXRuntime/main.py b/examples/YOLOv8-ONNXRuntime/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e18a404c1812606a143f3d38c7ebefab915d88 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime/main.py @@ -0,0 +1,229 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse + +import cv2 +import numpy as np +import onnxruntime as ort +import torch + +from ultralytics.utils import ASSETS, yaml_load +from ultralytics.utils.checks import check_requirements, check_yaml + + +class YOLOv8: + """YOLOv8 object detection model class for handling inference and visualization.""" + + def __init__(self, onnx_model, input_image, confidence_thres, iou_thres): + """ + Initializes an instance of the YOLOv8 class. + + Args: + onnx_model: Path to the ONNX model. + input_image: Path to the input image. + confidence_thres: Confidence threshold for filtering detections. + iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression. + """ + self.onnx_model = onnx_model + self.input_image = input_image + self.confidence_thres = confidence_thres + self.iou_thres = iou_thres + + # Load the class names from the COCO dataset + self.classes = yaml_load(check_yaml("coco8.yaml"))["names"] + + # Generate a color palette for the classes + self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) + + def draw_detections(self, img, box, score, class_id): + """ + Draws bounding boxes and labels on the input image based on the detected objects. + + Args: + img: The input image to draw detections on. + box: Detected bounding box. + score: Corresponding detection score. + class_id: Class ID for the detected object. + + Returns: + None + """ + # Extract the coordinates of the bounding box + x1, y1, w, h = box + + # Retrieve the color for the class ID + color = self.color_palette[class_id] + + # Draw the bounding box on the image + cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) + + # Create the label text with class name and score + label = f"{self.classes[class_id]}: {score:.2f}" + + # Calculate the dimensions of the label text + (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Calculate the position of the label text + label_x = x1 + label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 + + # Draw a filled rectangle as the background for the label text + cv2.rectangle( + img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED + ) + + # Draw the label text on the image + cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) + + def preprocess(self): + """ + Preprocesses the input image before performing inference. + + Returns: + image_data: Preprocessed image data ready for inference. + """ + # Read the input image using OpenCV + self.img = cv2.imread(self.input_image) + + # Get the height and width of the input image + self.img_height, self.img_width = self.img.shape[:2] + + # Convert the image color space from BGR to RGB + img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB) + + # Resize the image to match the input shape + img = cv2.resize(img, (self.input_width, self.input_height)) + + # Normalize the image data by dividing it by 255.0 + image_data = np.array(img) / 255.0 + + # Transpose the image to have the channel dimension as the first dimension + image_data = np.transpose(image_data, (2, 0, 1)) # Channel first + + # Expand the dimensions of the image data to match the expected input shape + image_data = np.expand_dims(image_data, axis=0).astype(np.float32) + + # Return the preprocessed image data + return image_data + + def postprocess(self, input_image, output): + """ + Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs. + + Args: + input_image (numpy.ndarray): The input image. + output (numpy.ndarray): The output of the model. + + Returns: + numpy.ndarray: The input image with detections drawn on it. + """ + # Transpose and squeeze the output to match the expected shape + outputs = np.transpose(np.squeeze(output[0])) + + # Get the number of rows in the outputs array + rows = outputs.shape[0] + + # Lists to store the bounding boxes, scores, and class IDs of the detections + boxes = [] + scores = [] + class_ids = [] + + # Calculate the scaling factors for the bounding box coordinates + x_factor = self.img_width / self.input_width + y_factor = self.img_height / self.input_height + + # Iterate over each row in the outputs array + for i in range(rows): + # Extract the class scores from the current row + classes_scores = outputs[i][4:] + + # Find the maximum score among the class scores + max_score = np.amax(classes_scores) + + # If the maximum score is above the confidence threshold + if max_score >= self.confidence_thres: + # Get the class ID with the highest score + class_id = np.argmax(classes_scores) + + # Extract the bounding box coordinates from the current row + x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3] + + # Calculate the scaled coordinates of the bounding box + left = int((x - w / 2) * x_factor) + top = int((y - h / 2) * y_factor) + width = int(w * x_factor) + height = int(h * y_factor) + + # Add the class ID, score, and box coordinates to the respective lists + class_ids.append(class_id) + scores.append(max_score) + boxes.append([left, top, width, height]) + + # Apply non-maximum suppression to filter out overlapping bounding boxes + indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres) + + # Iterate over the selected indices after non-maximum suppression + for i in indices: + # Get the box, score, and class ID corresponding to the index + box = boxes[i] + score = scores[i] + class_id = class_ids[i] + + # Draw the detection on the input image + self.draw_detections(input_image, box, score, class_id) + + # Return the modified input image + return input_image + + def main(self): + """ + Performs inference using an ONNX model and returns the output image with drawn detections. + + Returns: + output_img: The output image with drawn detections. + """ + # Create an inference session using the ONNX model and specify execution providers + session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + + # Get the model inputs + model_inputs = session.get_inputs() + + # Store the shape of the input for later use + input_shape = model_inputs[0].shape + self.input_width = input_shape[2] + self.input_height = input_shape[3] + + # Preprocess the image data + img_data = self.preprocess() + + # Run inference using the preprocessed image data + outputs = session.run(None, {model_inputs[0].name: img_data}) + + # Perform post-processing on the outputs to obtain output image. + return self.postprocess(self.img, outputs) # output image + + +if __name__ == "__main__": + # Create an argument parser to handle command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="yolov8n.onnx", help="Input your ONNX model.") + parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.") + parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold") + parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold") + args = parser.parse_args() + + # Check the requirements and select the appropriate backend (CPU or GPU) + check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime") + + # Create an instance of the YOLOv8 class with the specified arguments + detection = YOLOv8(args.model, args.img, args.conf_thres, args.iou_thres) + + # Perform object detection and obtain the output image + output_image = detection.main() + + # Display the output image in a window + cv2.namedWindow("Output", cv2.WINDOW_NORMAL) + cv2.imshow("Output", output_image) + + # Wait for a key press to exit + cv2.waitKey(0) diff --git a/examples/YOLOv8-OpenCV-ONNX-Python/README.md b/examples/YOLOv8-OpenCV-ONNX-Python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c9076fa55df4c874ac2a8f3705d11917c667d82b --- /dev/null +++ b/examples/YOLOv8-OpenCV-ONNX-Python/README.md @@ -0,0 +1,19 @@ +# YOLOv8 - OpenCV + +Implementation YOLOv8 on OpenCV using ONNX Format. + +Just simply clone and run + +```bash +pip install -r requirements.txt +python main.py --model yolov8n.onnx --img image.jpg +``` + +If you start from scratch: + +```bash +pip install ultralytics +yolo export model=yolov8n.pt imgsz=640 format=onnx opset=12 +``` + +_\*Make sure to include "opset=12"_ diff --git a/examples/YOLOv8-OpenCV-ONNX-Python/main.py b/examples/YOLOv8-OpenCV-ONNX-Python/main.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e095dd46231711b54d5ec734af888f15ea742e --- /dev/null +++ b/examples/YOLOv8-OpenCV-ONNX-Python/main.py @@ -0,0 +1,130 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse + +import cv2.dnn +import numpy as np + +from ultralytics.utils import ASSETS, yaml_load +from ultralytics.utils.checks import check_yaml + +CLASSES = yaml_load(check_yaml("coco8.yaml"))["names"] +colors = np.random.uniform(0, 255, size=(len(CLASSES), 3)) + + +def draw_bounding_box(img, class_id, confidence, x, y, x_plus_w, y_plus_h): + """ + Draws bounding boxes on the input image based on the provided arguments. + + Args: + img (numpy.ndarray): The input image to draw the bounding box on. + class_id (int): Class ID of the detected object. + confidence (float): Confidence score of the detected object. + x (int): X-coordinate of the top-left corner of the bounding box. + y (int): Y-coordinate of the top-left corner of the bounding box. + x_plus_w (int): X-coordinate of the bottom-right corner of the bounding box. + y_plus_h (int): Y-coordinate of the bottom-right corner of the bounding box. + """ + label = f"{CLASSES[class_id]} ({confidence:.2f})" + color = colors[class_id] + cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) + cv2.putText(img, label, (x - 10, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) + + +def main(onnx_model, input_image): + """ + Main function to load ONNX model, perform inference, draw bounding boxes, and display the output image. + + Args: + onnx_model (str): Path to the ONNX model. + input_image (str): Path to the input image. + + Returns: + list: List of dictionaries containing detection information such as class_id, class_name, confidence, etc. + """ + # Load the ONNX model + model: cv2.dnn.Net = cv2.dnn.readNetFromONNX(onnx_model) + + # Read the input image + original_image: np.ndarray = cv2.imread(input_image) + [height, width, _] = original_image.shape + + # Prepare a square image for inference + length = max((height, width)) + image = np.zeros((length, length, 3), np.uint8) + image[0:height, 0:width] = original_image + + # Calculate scale factor + scale = length / 640 + + # Preprocess the image and prepare blob for model + blob = cv2.dnn.blobFromImage(image, scalefactor=1 / 255, size=(640, 640), swapRB=True) + model.setInput(blob) + + # Perform inference + outputs = model.forward() + + # Prepare output array + outputs = np.array([cv2.transpose(outputs[0])]) + rows = outputs.shape[1] + + boxes = [] + scores = [] + class_ids = [] + + # Iterate through output to collect bounding boxes, confidence scores, and class IDs + for i in range(rows): + classes_scores = outputs[0][i][4:] + (minScore, maxScore, minClassLoc, (x, maxClassIndex)) = cv2.minMaxLoc(classes_scores) + if maxScore >= 0.25: + box = [ + outputs[0][i][0] - (0.5 * outputs[0][i][2]), + outputs[0][i][1] - (0.5 * outputs[0][i][3]), + outputs[0][i][2], + outputs[0][i][3], + ] + boxes.append(box) + scores.append(maxScore) + class_ids.append(maxClassIndex) + + # Apply NMS (Non-maximum suppression) + result_boxes = cv2.dnn.NMSBoxes(boxes, scores, 0.25, 0.45, 0.5) + + detections = [] + + # Iterate through NMS results to draw bounding boxes and labels + for i in range(len(result_boxes)): + index = result_boxes[i] + box = boxes[index] + detection = { + "class_id": class_ids[index], + "class_name": CLASSES[class_ids[index]], + "confidence": scores[index], + "box": box, + "scale": scale, + } + detections.append(detection) + draw_bounding_box( + original_image, + class_ids[index], + scores[index], + round(box[0] * scale), + round(box[1] * scale), + round((box[0] + box[2]) * scale), + round((box[1] + box[3]) * scale), + ) + + # Display the image with bounding boxes + cv2.imshow("image", original_image) + cv2.waitKey(0) + cv2.destroyAllWindows() + + return detections + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", default="yolov8n.onnx", help="Input your ONNX model.") + parser.add_argument("--img", default=str(ASSETS / "bus.jpg"), help="Path to input image.") + args = parser.parse_args() + main(args.model, args.img) diff --git a/examples/YOLOv8-OpenVINO-CPP-Inference/CMakeLists.txt b/examples/YOLOv8-OpenVINO-CPP-Inference/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d34ea96f3e8c817171c12c0693ce72ea333104b6 --- /dev/null +++ b/examples/YOLOv8-OpenVINO-CPP-Inference/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.12) +project(yolov8_openvino_example) + +set(CMAKE_CXX_STANDARD 14) + +find_package(OpenCV REQUIRED) + +include_directories( + ${OpenCV_INCLUDE_DIRS} + /path/to/intel/openvino/runtime/include +) + +add_executable(detect + main.cc + inference.cc +) + +target_link_libraries(detect + ${OpenCV_LIBS} + /path/to/intel/openvino/runtime/lib/intel64/libopenvino.so +) diff --git a/examples/YOLOv8-OpenVINO-CPP-Inference/README.md b/examples/YOLOv8-OpenVINO-CPP-Inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6c6c794dea337dccae6f1aa889c7e167b913dda4 --- /dev/null +++ b/examples/YOLOv8-OpenVINO-CPP-Inference/README.md @@ -0,0 +1,69 @@ +# YOLOv8 OpenVINO Inference in C++ 🦾 + +Welcome to the YOLOv8 OpenVINO Inference example in C++! This guide will help you get started with leveraging the powerful YOLOv8 models using OpenVINO and OpenCV API in your C++ projects. Whether you're looking to enhance performance or add flexibility to your applications, this example has got you covered. + +## 🌟 Features + +- 🚀 **Model Format Support**: Compatible with `ONNX` and `OpenVINO IR` formats. +- ⚡ **Precision Options**: Run models in `FP32`, `FP16`, and `INT8` precisions. +- 🔄 **Dynamic Shape Loading**: Easily handle models with dynamic input shapes. + +## 📋 Dependencies + +To ensure smooth execution, please make sure you have the following dependencies installed: + +| Dependency | Version | +| ---------- | -------- | +| OpenVINO | >=2023.3 | +| OpenCV | >=4.5.0 | +| C++ | >=14 | +| CMake | >=3.12.0 | + +## ⚙️ Build Instructions + +Follow these steps to build the project: + +1. Clone the repository: + + ```bash + git clone https://github.com/ultralytics/ultralytics.git + cd ultralytics/YOLOv8-OpenVINO-CPP-Inference + ``` + +2. Create a build directory and compile the project: + ```bash + mkdir build + cd build + cmake .. + make + ``` + +## 🛠️ Usage + +Once built, you can run inference on an image using the following command: + +```bash +./detect +``` + +## 🔄 Exporting YOLOv8 Models + +To use your YOLOv8 model with OpenVINO, you need to export it first. Use the command below to export the model: + +```bash +yolo export model=yolov8s.pt imgsz=640 format=openvino +``` + +## 📸 Screenshots + +### Running Using OpenVINO Model + +![Running OpenVINO Model](https://github.com/ultralytics/ultralytics/assets/76827698/2d7cf201-3def-4357-824c-12446ccf85a9) + +### Running Using ONNX Model + +![Running ONNX Model](https://github.com/ultralytics/ultralytics/assets/76827698/9b90031c-cc81-4cfb-8b34-c619e09035a7) + +## ❤️ Contributions + +We hope this example helps you integrate YOLOv8 with OpenVINO and OpenCV into your C++ projects effortlessly. Happy coding! 🚀 diff --git a/examples/YOLOv8-OpenVINO-CPP-Inference/inference.cc b/examples/YOLOv8-OpenVINO-CPP-Inference/inference.cc new file mode 100644 index 0000000000000000000000000000000000000000..dbabd2a98929678af8da038cea3624df6a3db8a6 --- /dev/null +++ b/examples/YOLOv8-OpenVINO-CPP-Inference/inference.cc @@ -0,0 +1,175 @@ +#include "inference.h" + +#include +#include +#include + +namespace yolo { + +// Constructor to initialize the model with default input shape +Inference::Inference(const std::string &model_path, const float &model_confidence_threshold, const float &model_NMS_threshold) { + model_input_shape_ = cv::Size(640, 640); // Set the default size for models with dynamic shapes to prevent errors. + model_confidence_threshold_ = model_confidence_threshold; + model_NMS_threshold_ = model_NMS_threshold; + InitializeModel(model_path); +} + +// Constructor to initialize the model with specified input shape +Inference::Inference(const std::string &model_path, const cv::Size model_input_shape, const float &model_confidence_threshold, const float &model_NMS_threshold) { + model_input_shape_ = model_input_shape; + model_confidence_threshold_ = model_confidence_threshold; + model_NMS_threshold_ = model_NMS_threshold; + InitializeModel(model_path); +} + +void Inference::InitializeModel(const std::string &model_path) { + ov::Core core; // OpenVINO core object + std::shared_ptr model = core.read_model(model_path); // Read the model from file + + // If the model has dynamic shapes, reshape it to the specified input shape + if (model->is_dynamic()) { + model->reshape({1, 3, static_cast(model_input_shape_.height), static_cast(model_input_shape_.width)}); + } + + // Preprocessing setup for the model + ov::preprocess::PrePostProcessor ppp = ov::preprocess::PrePostProcessor(model); + ppp.input().tensor().set_element_type(ov::element::u8).set_layout("NHWC").set_color_format(ov::preprocess::ColorFormat::BGR); + ppp.input().preprocess().convert_element_type(ov::element::f32).convert_color(ov::preprocess::ColorFormat::RGB).scale({255, 255, 255}); + ppp.input().model().set_layout("NCHW"); + ppp.output().tensor().set_element_type(ov::element::f32); + model = ppp.build(); // Build the preprocessed model + + // Compile the model for inference + compiled_model_ = core.compile_model(model, "AUTO"); + inference_request_ = compiled_model_.create_infer_request(); // Create inference request + + short width, height; + + // Get input shape from the model + const std::vector> inputs = model->inputs(); + const ov::Shape input_shape = inputs[0].get_shape(); + height = input_shape[1]; + width = input_shape[2]; + model_input_shape_ = cv::Size2f(width, height); + + // Get output shape from the model + const std::vector> outputs = model->outputs(); + const ov::Shape output_shape = outputs[0].get_shape(); + height = output_shape[1]; + width = output_shape[2]; + model_output_shape_ = cv::Size(width, height); +} + +// Method to run inference on an input frame +void Inference::RunInference(cv::Mat &frame) { + Preprocessing(frame); // Preprocess the input frame + inference_request_.infer(); // Run inference + PostProcessing(frame); // Postprocess the inference results +} + +// Method to preprocess the input frame +void Inference::Preprocessing(const cv::Mat &frame) { + cv::Mat resized_frame; + cv::resize(frame, resized_frame, model_input_shape_, 0, 0, cv::INTER_AREA); // Resize the frame to match the model input shape + + // Calculate scaling factor + scale_factor_.x = static_cast(frame.cols / model_input_shape_.width); + scale_factor_.y = static_cast(frame.rows / model_input_shape_.height); + + float *input_data = (float *)resized_frame.data; // Get pointer to resized frame data + const ov::Tensor input_tensor = ov::Tensor(compiled_model_.input().get_element_type(), compiled_model_.input().get_shape(), input_data); // Create input tensor + inference_request_.set_input_tensor(input_tensor); // Set input tensor for inference +} + +// Method to postprocess the inference results +void Inference::PostProcessing(cv::Mat &frame) { + std::vector class_list; + std::vector confidence_list; + std::vector box_list; + + // Get the output tensor from the inference request + const float *detections = inference_request_.get_output_tensor().data(); + const cv::Mat detection_outputs(model_output_shape_, CV_32F, (float *)detections); // Create OpenCV matrix from output tensor + + // Iterate over detections and collect class IDs, confidence scores, and bounding boxes + for (int i = 0; i < detection_outputs.cols; ++i) { + const cv::Mat classes_scores = detection_outputs.col(i).rowRange(4, detection_outputs.rows); + + cv::Point class_id; + double score; + cv::minMaxLoc(classes_scores, nullptr, &score, nullptr, &class_id); // Find the class with the highest score + + // Check if the detection meets the confidence threshold + if (score > model_confidence_threshold_) { + class_list.push_back(class_id.y); + confidence_list.push_back(score); + + const float x = detection_outputs.at(0, i); + const float y = detection_outputs.at(1, i); + const float w = detection_outputs.at(2, i); + const float h = detection_outputs.at(3, i); + + cv::Rect box; + box.x = static_cast(x); + box.y = static_cast(y); + box.width = static_cast(w); + box.height = static_cast(h); + box_list.push_back(box); + } + } + + // Apply Non-Maximum Suppression (NMS) to filter overlapping bounding boxes + std::vector NMS_result; + cv::dnn::NMSBoxes(box_list, confidence_list, model_confidence_threshold_, model_NMS_threshold_, NMS_result); + + // Collect final detections after NMS + for (int i = 0; i < NMS_result.size(); ++i) { + Detection result; + const unsigned short id = NMS_result[i]; + + result.class_id = class_list[id]; + result.confidence = confidence_list[id]; + result.box = GetBoundingBox(box_list[id]); + + DrawDetectedObject(frame, result); + } +} + +// Method to get the bounding box in the correct scale +cv::Rect Inference::GetBoundingBox(const cv::Rect &src) const { + cv::Rect box = src; + box.x = (box.x - box.width / 2) * scale_factor_.x; + box.y = (box.y - box.height / 2) * scale_factor_.y; + box.width *= scale_factor_.x; + box.height *= scale_factor_.y; + return box; +} + +void Inference::DrawDetectedObject(cv::Mat &frame, const Detection &detection) const { + const cv::Rect &box = detection.box; + const float &confidence = detection.confidence; + const int &class_id = detection.class_id; + + // Generate a random color for the bounding box + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(120, 255); + const cv::Scalar &color = cv::Scalar(dis(gen), dis(gen), dis(gen)); + + // Draw the bounding box around the detected object + cv::rectangle(frame, cv::Point(box.x, box.y), cv::Point(box.x + box.width, box.y + box.height), color, 3); + + // Prepare the class label and confidence text + std::string classString = classes_[class_id] + std::to_string(confidence).substr(0, 4); + + // Get the size of the text box + cv::Size textSize = cv::getTextSize(classString, cv::FONT_HERSHEY_DUPLEX, 0.75, 2, 0); + cv::Rect textBox(box.x, box.y - 40, textSize.width + 10, textSize.height + 20); + + // Draw the text box + cv::rectangle(frame, textBox, color, cv::FILLED); + + // Put the class label and confidence text above the bounding box + cv::putText(frame, classString, cv::Point(box.x + 5, box.y - 10), cv::FONT_HERSHEY_DUPLEX, 0.75, cv::Scalar(0, 0, 0), 2, 0); +} +} // namespace yolo diff --git a/examples/YOLOv8-OpenVINO-CPP-Inference/inference.h b/examples/YOLOv8-OpenVINO-CPP-Inference/inference.h new file mode 100644 index 0000000000000000000000000000000000000000..7bcb20df8f259670ebe9729a44d1f2c8faa60bef --- /dev/null +++ b/examples/YOLOv8-OpenVINO-CPP-Inference/inference.h @@ -0,0 +1,59 @@ +#ifndef YOLO_INFERENCE_H_ +#define YOLO_INFERENCE_H_ + +#include +#include +#include +#include + +namespace yolo { + +struct Detection { + short class_id; + float confidence; + cv::Rect box; +}; + +class Inference { + public: + Inference() {} + // Constructor to initialize the model with default input shape + Inference(const std::string &model_path, const float &model_confidence_threshold, const float &model_NMS_threshold); + // Constructor to initialize the model with specified input shape + Inference(const std::string &model_path, const cv::Size model_input_shape, const float &model_confidence_threshold, const float &model_NMS_threshold); + + void RunInference(cv::Mat &frame); + + private: + void InitializeModel(const std::string &model_path); + void Preprocessing(const cv::Mat &frame); + void PostProcessing(cv::Mat &frame); + cv::Rect GetBoundingBox(const cv::Rect &src) const; + void DrawDetectedObject(cv::Mat &frame, const Detection &detections) const; + + cv::Point2f scale_factor_; // Scaling factor for the input frame + cv::Size2f model_input_shape_; // Input shape of the model + cv::Size model_output_shape_; // Output shape of the model + + ov::InferRequest inference_request_; // OpenVINO inference request + ov::CompiledModel compiled_model_; // OpenVINO compiled model + + float model_confidence_threshold_; // Confidence threshold for detections + float model_NMS_threshold_; // Non-Maximum Suppression threshold + + std::vector classes_ { + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", + "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", + "scissors", "teddy bear", "hair drier", "toothbrush" + }; +}; + +} // namespace yolo + +#endif // YOLO_INFERENCE_H_ diff --git a/examples/YOLOv8-OpenVINO-CPP-Inference/main.cc b/examples/YOLOv8-OpenVINO-CPP-Inference/main.cc new file mode 100644 index 0000000000000000000000000000000000000000..2031af6d97e2377c4cdfb12964df7c15761e814e --- /dev/null +++ b/examples/YOLOv8-OpenVINO-CPP-Inference/main.cc @@ -0,0 +1,41 @@ +#include "inference.h" + +#include +#include + +int main(int argc, char **argv) { + // Check if the correct number of arguments is provided + if (argc != 3) { + std::cerr << "usage: " << argv[0] << " " << std::endl; + return 1; + } + + // Get the model and image paths from the command-line arguments + const std::string model_path = argv[1]; + const std::string image_path = argv[2]; + + // Read the input image + cv::Mat image = cv::imread(image_path); + + // Check if the image was successfully loaded + if (image.empty()) { + std::cerr << "ERROR: image is empty" << std::endl; + return 1; + } + + // Define the confidence and NMS thresholds + const float confidence_threshold = 0.5; + const float NMS_threshold = 0.5; + + // Initialize the YOLO inference with the specified model and parameters + yolo::Inference inference(model_path, cv::Size(640, 640), confidence_threshold, NMS_threshold); + + // Run inference on the input image + inference.RunInference(image); + + // Display the image with the detections + cv::imshow("image", image); + cv::waitKey(0); + + return 0; +} diff --git a/examples/YOLOv8-Region-Counter/readme.md b/examples/YOLOv8-Region-Counter/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..3ed0679910747ae9c68a770a2cef8e46263089dc --- /dev/null +++ b/examples/YOLOv8-Region-Counter/readme.md @@ -0,0 +1,128 @@ +# Regions Counting Using YOLOv8 (Inference on Video) + +> **Region Counter** is now part of **[Ultralytics Solutions](https://docs.ultralytics.com/solutions/)**, offering improved features and regular updates. Enjoy improved features and regular updates! + +🔗 **[Explore Object Counting in Regions Here](https://docs.ultralytics.com/guides/region-counting/)** + +> 🔔 **Notice:** + +> The GitHub example will remain available but **will no longer be actively maintained**. For the latest updates and improvements, please use the official [link](https://docs.ultralytics.com/guides/region-counting/). Thank you! + +Region counting is a method employed to tally the objects within a specified area, allowing for more sophisticated analyses when multiple regions are considered. These regions can be adjusted interactively using a Left Mouse Click, and the counting process occurs in real time. Regions can be adjusted to suit the user's preferences and requirements. + +
+

+ YOLOv8 region counting visual 1 + YOLOv8 region counting visual 2 +

+
+ +## Table of Contents + +- [Step 1: Install the Required Libraries](#step-1-install-the-required-libraries) +- [Step 2: Run the Region Counting Using Ultralytics YOLOv8](#step-2-run-the-region-counting-using-ultralytics-yolov8) +- [Usage Options](#usage-options) +- [FAQ](#faq) + +## Step 1: Install the Required Libraries + +Clone the repository, install dependencies and `cd` to this local directory for commands in Step 2. + +```bash +# Clone ultralytics repo +git clone https://github.com/ultralytics/ultralytics + +# cd to local directory +cd ultralytics/examples/YOLOv8-Region-Counter +``` + +## Step 2: Run the Region Counting Using Ultralytics YOLOv8 + +Here are the basic commands for running the inference: + +### Note + +After the video begins playing, you can freely move the region anywhere within the video by simply clicking and dragging using the left mouse button. + +```bash +# If you want to save results +python yolov8_region_counter.py --source "path/to/video.mp4" --save-img --view-img + +# If you want to run model on CPU +python yolov8_region_counter.py --source "path/to/video.mp4" --save-img --view-img --device cpu + +# If you want to change model file +python yolov8_region_counter.py --source "path/to/video.mp4" --save-img --weights "path/to/model.pt" + +# If you want to detect specific class (first class and third class) +python yolov8_region_counter.py --source "path/to/video.mp4" --classes 0 2 --weights "path/to/model.pt" + +# If you don't want to save results +python yolov8_region_counter.py --source "path/to/video.mp4" --view-img +``` + +## Usage Options + +- `--source`: Specifies the path to the video file you want to run inference on. +- `--device`: Specifies the device `cpu` or `0` +- `--save-img`: Flag to save the detection results as images. +- `--weights`: Specifies a different YOLOv8 model file (e.g., `yolov8n.pt`, `yolov8s.pt`, `yolov8m.pt`, `yolov8l.pt`, `yolov8x.pt`). +- `--classes`: Specifies the class to be detected +- `--line-thickness`: Specifies the bounding box thickness +- `--region-thickness`: Specifies the region boxes thickness +- `--track-thickness`: Specifies the track line thickness + +## FAQ + +**1. What Does Region Counting Involve?** + +Region counting is a computational method utilized to ascertain the quantity of objects within a specific area in recorded video or real-time streams. This technique finds frequent application in image processing, computer vision, and pattern recognition, facilitating the analysis and segmentation of objects or features based on their spatial relationships. + +**2. Is Friendly Region Plotting Supported by the Region Counter?** + +The Region Counting offers the capability to create regions in various formats, such as polygons and rectangles. You have the flexibility to modify region attributes, including coordinates, colors, and other details, as demonstrated in the following code: + +```python +from shapely.geometry import Polygon + +counting_regions = [ + { + "name": "YOLOv8 Polygon Region", + "polygon": Polygon( + [(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)] + ), # Polygon with five points (Pentagon) + "counts": 0, + "dragging": False, + "region_color": (255, 42, 4), # BGR Value + "text_color": (255, 255, 255), # Region Text Color + }, + { + "name": "YOLOv8 Rectangle Region", + "polygon": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Rectangle with four points + "counts": 0, + "dragging": False, + "region_color": (37, 255, 225), # BGR Value + "text_color": (0, 0, 0), # Region Text Color + }, +] +``` + +**3. Why Combine Region Counting with YOLOv8?** + +YOLOv8 specializes in the detection and tracking of objects in video streams. Region counting complements this by enabling object counting within designated areas, making it a valuable application of YOLOv8. + +**4. How Can I Troubleshoot Issues?** + +To gain more insights during inference, you can include the `--debug` flag in your command: + +```bash +python yolov8_region_counter.py --source "path to video file" --debug +``` + +**5. Can I Employ Other YOLO Versions?** + +Certainly, you have the flexibility to specify different YOLO model weights using the `--weights` option. + +**6. Where Can I Access Additional Information?** + +For a comprehensive guide on using YOLOv8 with Object Tracking, please refer to [Multi-Object Tracking with Ultralytics YOLO](https://docs.ultralytics.com/modes/track/). diff --git a/examples/YOLOv8-Region-Counter/yolov8_region_counter.py b/examples/YOLOv8-Region-Counter/yolov8_region_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce28fb4505404837ba3b1bd2d57576de274a013 --- /dev/null +++ b/examples/YOLOv8-Region-Counter/yolov8_region_counter.py @@ -0,0 +1,253 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse +from collections import defaultdict +from pathlib import Path + +import cv2 +import numpy as np +from shapely.geometry import Polygon +from shapely.geometry.point import Point + +from ultralytics import YOLO +from ultralytics.utils.files import increment_path +from ultralytics.utils.plotting import Annotator, colors + +track_history = defaultdict(list) + +current_region = None +counting_regions = [ + { + "name": "YOLOv8 Polygon Region", + "polygon": Polygon([(50, 80), (250, 20), (450, 80), (400, 350), (100, 350)]), # Polygon points + "counts": 0, + "dragging": False, + "region_color": (255, 42, 4), # BGR Value + "text_color": (255, 255, 255), # Region Text Color + }, + { + "name": "YOLOv8 Rectangle Region", + "polygon": Polygon([(200, 250), (440, 250), (440, 550), (200, 550)]), # Polygon points + "counts": 0, + "dragging": False, + "region_color": (37, 255, 225), # BGR Value + "text_color": (0, 0, 0), # Region Text Color + }, +] + + +def mouse_callback(event, x, y, flags, param): + """ + Handles mouse events for region manipulation. + + Args: + event (int): The mouse event type (e.g., cv2.EVENT_LBUTTONDOWN). + x (int): The x-coordinate of the mouse pointer. + y (int): The y-coordinate of the mouse pointer. + flags (int): Additional flags passed by OpenCV. + param: Additional parameters passed to the callback (not used in this function). + + Global Variables: + current_region (dict): A dictionary representing the current selected region. + + Mouse Events: + - LBUTTONDOWN: Initiates dragging for the region containing the clicked point. + - MOUSEMOVE: Moves the selected region if dragging is active. + - LBUTTONUP: Ends dragging for the selected region. + + Notes: + - This function is intended to be used as a callback for OpenCV mouse events. + - Requires the existence of the 'counting_regions' list and the 'Polygon' class. + + Example: + >>> cv2.setMouseCallback(window_name, mouse_callback) + """ + global current_region + + # Mouse left button down event + if event == cv2.EVENT_LBUTTONDOWN: + for region in counting_regions: + if region["polygon"].contains(Point((x, y))): + current_region = region + current_region["dragging"] = True + current_region["offset_x"] = x + current_region["offset_y"] = y + + # Mouse move event + elif event == cv2.EVENT_MOUSEMOVE: + if current_region is not None and current_region["dragging"]: + dx = x - current_region["offset_x"] + dy = y - current_region["offset_y"] + current_region["polygon"] = Polygon( + [(p[0] + dx, p[1] + dy) for p in current_region["polygon"].exterior.coords] + ) + current_region["offset_x"] = x + current_region["offset_y"] = y + + # Mouse left button up event + elif event == cv2.EVENT_LBUTTONUP: + if current_region is not None and current_region["dragging"]: + current_region["dragging"] = False + + +def run( + weights="yolov8n.pt", + source=None, + device="cpu", + view_img=False, + save_img=False, + exist_ok=False, + classes=None, + line_thickness=2, + track_thickness=2, + region_thickness=2, +): + """ + Run Region counting on a video using YOLOv8 and ByteTrack. + + Supports movable region for real time counting inside specific area. + Supports multiple regions counting. + Regions can be Polygons or rectangle in shape + + Args: + weights (str): Model weights path. + source (str): Video file path. + device (str): processing device cpu, 0, 1 + view_img (bool): Show results. + save_img (bool): Save results. + exist_ok (bool): Overwrite existing files. + classes (list): classes to detect and track + line_thickness (int): Bounding box thickness. + track_thickness (int): Tracking line thickness + region_thickness (int): Region thickness. + """ + vid_frame_count = 0 + + # Check source path + if not Path(source).exists(): + raise FileNotFoundError(f"Source path '{source}' does not exist.") + + # Setup Model + model = YOLO(f"{weights}") + model.to("cuda") if device == "0" else model.to("cpu") + + # Extract classes names + names = model.names + + # Video setup + videocapture = cv2.VideoCapture(source) + frame_width = int(videocapture.get(3)) + frame_height = int(videocapture.get(4)) + fps = int(videocapture.get(5)) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + + # Output setup + save_dir = increment_path(Path("ultralytics_rc_output") / "exp", exist_ok) + save_dir.mkdir(parents=True, exist_ok=True) + video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.avi"), fourcc, fps, (frame_width, frame_height)) + + # Iterate over video frames + while videocapture.isOpened(): + success, frame = videocapture.read() + if not success: + break + vid_frame_count += 1 + + # Extract the results + results = model.track(frame, persist=True, classes=classes) + + if results[0].boxes.id is not None: + boxes = results[0].boxes.xyxy.cpu() + track_ids = results[0].boxes.id.int().cpu().tolist() + clss = results[0].boxes.cls.cpu().tolist() + + annotator = Annotator(frame, line_width=line_thickness, example=str(names)) + + for box, track_id, cls in zip(boxes, track_ids, clss): + annotator.box_label(box, str(names[cls]), color=colors(cls, True)) + bbox_center = (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 # Bbox center + + track = track_history[track_id] # Tracking Lines plot + track.append((float(bbox_center[0]), float(bbox_center[1]))) + if len(track) > 30: + track.pop(0) + points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(frame, [points], isClosed=False, color=colors(cls, True), thickness=track_thickness) + + # Check if detection inside region + for region in counting_regions: + if region["polygon"].contains(Point((bbox_center[0], bbox_center[1]))): + region["counts"] += 1 + + # Draw regions (Polygons/Rectangles) + for region in counting_regions: + region_label = str(region["counts"]) + region_color = region["region_color"] + region_text_color = region["text_color"] + + polygon_coordinates = np.array(region["polygon"].exterior.coords, dtype=np.int32) + centroid_x, centroid_y = int(region["polygon"].centroid.x), int(region["polygon"].centroid.y) + + text_size, _ = cv2.getTextSize( + region_label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.7, thickness=line_thickness + ) + text_x = centroid_x - text_size[0] // 2 + text_y = centroid_y + text_size[1] // 2 + cv2.rectangle( + frame, + (text_x - 5, text_y - text_size[1] - 5), + (text_x + text_size[0] + 5, text_y + 5), + region_color, + -1, + ) + cv2.putText( + frame, region_label, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, region_text_color, line_thickness + ) + cv2.polylines(frame, [polygon_coordinates], isClosed=True, color=region_color, thickness=region_thickness) + + if view_img: + if vid_frame_count == 1: + cv2.namedWindow("Ultralytics YOLOv8 Region Counter Movable") + cv2.setMouseCallback("Ultralytics YOLOv8 Region Counter Movable", mouse_callback) + cv2.imshow("Ultralytics YOLOv8 Region Counter Movable", frame) + + if save_img: + video_writer.write(frame) + + for region in counting_regions: # Reinitialize count for each region + region["counts"] = 0 + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + del vid_frame_count + video_writer.release() + videocapture.release() + cv2.destroyAllWindows() + + +def parse_opt(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path") + parser.add_argument("--device", default="", help="cuda device, i.e. 0 or 0,1,2,3 or cpu") + parser.add_argument("--source", type=str, required=True, help="video file path") + parser.add_argument("--view-img", action="store_true", help="show results") + parser.add_argument("--save-img", action="store_true", help="save results") + parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") + parser.add_argument("--classes", nargs="+", type=int, help="filter by class: --classes 0, or --classes 0 2 3") + parser.add_argument("--line-thickness", type=int, default=2, help="bounding box thickness") + parser.add_argument("--track-thickness", type=int, default=2, help="Tracking line thickness") + parser.add_argument("--region-thickness", type=int, default=4, help="Region thickness") + + return parser.parse_args() + + +def main(options): + """Main function.""" + run(**vars(options)) + + +if __name__ == "__main__": + opt = parse_opt() + main(opt) diff --git a/examples/YOLOv8-SAHI-Inference-Video/readme.md b/examples/YOLOv8-SAHI-Inference-Video/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..4dc169b3e17030b936f2bd4f9c82a4d8546c88e6 --- /dev/null +++ b/examples/YOLOv8-SAHI-Inference-Video/readme.md @@ -0,0 +1,69 @@ +# YOLO11 with SAHI (Inference on Video) + +[SAHI](https://docs.ultralytics.com/guides/sahi-tiled-inference/) is designed to optimize object detection algorithms for large-scale and high-resolution imagery. It partitions images into manageable slices, performs object detection on each slice, and then stitches the results back together. This tutorial will guide you through the process of running YOLO11 inference on video files with the aid of SAHI. + +## Table of Contents + +- [Step 1: Install the Required Libraries](#step-1-install-the-required-libraries) +- [Step 2: Run the Inference with SAHI using Ultralytics YOLO11](#step-2-run-the-inference-with-sahi-using-ultralytics-yolo11) +- [Usage Options](#usage-options) +- [FAQ](#faq) + +## Step 1: Install the Required Libraries + +Clone the repository, install dependencies and `cd` to this local directory for commands in Step 2. + +```bash +# Clone ultralytics repo +git clone https://github.com/ultralytics/ultralytics + +# Install dependencies +pip install -U sahi ultralytics + +# cd to local directory +cd ultralytics/examples/YOLOv8-SAHI-Inference-Video +``` + +## Step 2: Run the Inference with SAHI using Ultralytics YOLO11 + +Here are the basic commands for running the inference: + +```bash +#if you want to save results +python yolov8_sahi.py --source "path/to/video.mp4" --save-img + +#if you want to change model file +python yolov8_sahi.py --source "path/to/video.mp4" --save-img --weights "yolo11n.pt" +``` + +## Usage Options + +- `--source`: Specifies the path to the video file you want to run inference on. +- `--save-img`: Flag to save the detection results as images. +- `--weights`: Specifies a different YOLO11 model file (e.g., `yolo11n.pt`, `yolov8s.pt`, `yolo11m.pt`, `yolo11l.pt`, `yolo11x.pt`). + +## FAQ + +**1. What is SAHI?** + +SAHI stands for Slicing Aided Hyper Inference. It is a library designed to optimize object detection algorithms for large-scale and high-resolution images. The library source code is available on [GitHub](https://github.com/obss/sahi). + +**2. Why use SAHI with YOLO11?** + +SAHI can handle large-scale images by slicing them into smaller, more manageable sizes without compromising the detection quality. This makes it a great companion to YOLO11, especially when working with high-resolution videos. + +**3. How do I debug issues?** + +You can add the `--debug` flag to your command to print out more information during inference: + +```bash +python yolov8_sahi.py --source "path to video file" --debug +``` + +**4. Can I use other YOLO versions?** + +Yes, you can specify different YOLO model weights using the `--weights` option. + +**5. Where can I find more information?** + +For a full guide to YOLO11 with SAHI see [https://docs.ultralytics.com/guides/sahi-tiled-inference](https://docs.ultralytics.com/guides/sahi-tiled-inference/). diff --git a/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py new file mode 100644 index 0000000000000000000000000000000000000000..69872dcc9e48dc4e195049b7823d19989eb5ce59 --- /dev/null +++ b/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py @@ -0,0 +1,108 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse +from pathlib import Path + +import cv2 +from sahi import AutoDetectionModel +from sahi.predict import get_sliced_prediction +from sahi.utils.ultralytics import download_yolo11n_model + +from ultralytics.utils.files import increment_path +from ultralytics.utils.plotting import Annotator, colors + + +class SAHIInference: + """Runs Ultralytics YOLO11 and SAHI for object detection on video with options to view, save, and track results.""" + + def __init__(self): + """Initializes the SAHIInference class for performing sliced inference using SAHI with YOLO11 models.""" + self.detection_model = None + + def load_model(self, weights): + """Loads a YOLO11 model with specified weights for object detection using SAHI.""" + yolo11_model_path = f"models/{weights}" + download_yolo11n_model(yolo11_model_path) + self.detection_model = AutoDetectionModel.from_pretrained( + model_type="ultralytics", model_path=yolo11_model_path, device="cpu" + ) + + def inference( + self, + weights="yolo11n.pt", + source="test.mp4", + view_img=False, + save_img=False, + exist_ok=False, + ): + """ + Run object detection on a video using YOLO11 and SAHI. + + Args: + weights (str): Model weights path. + source (str): Video file path. + view_img (bool): Show results. + save_img (bool): Save results. + exist_ok (bool): Overwrite existing files. + """ + # Video setup + cap = cv2.VideoCapture(source) + assert cap.isOpened(), "Error reading video file" + frame_width, frame_height = int(cap.get(3)), int(cap.get(4)) + + # Output setup + save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok) + save_dir.mkdir(parents=True, exist_ok=True) + video_writer = cv2.VideoWriter( + str(save_dir / f"{Path(source).stem}.avi"), + cv2.VideoWriter_fourcc(*"MJPG"), + int(cap.get(5)), + (frame_width, frame_height), + ) + + # Load model + self.load_model(weights) + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + annotator = Annotator(frame) # Initialize annotator for plotting detection and tracking results + results = get_sliced_prediction( + frame[..., ::-1], + self.detection_model, + slice_height=512, + slice_width=512, + ) + detection_data = [ + (det.category.name, det.category.id, (det.bbox.minx, det.bbox.miny, det.bbox.maxx, det.bbox.maxy)) + for det in results.object_prediction_list + ] + + for det in detection_data: + annotator.box_label(det[2], label=str(det[0]), color=colors(int(det[1]), True)) + + if view_img: + cv2.imshow(Path(source).stem, frame) + if save_img: + video_writer.write(frame) + + if cv2.waitKey(1) & 0xFF == ord("q"): + break + video_writer.release() + cap.release() + cv2.destroyAllWindows() + + def parse_opt(self): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--weights", type=str, default="yolo11n.pt", help="initial weights path") + parser.add_argument("--source", type=str, required=True, help="video file path") + parser.add_argument("--view-img", action="store_true", help="show results") + parser.add_argument("--save-img", action="store_true", help="save results") + parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment") + return parser.parse_args() + + +if __name__ == "__main__": + inference = SAHIInference() + inference.inference(**vars(inference.parse_opt())) diff --git a/examples/YOLOv8-Segmentation-ONNXRuntime-Python/README.md b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b647700c01995af89672527dc6b1e753fc1f53f9 --- /dev/null +++ b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/README.md @@ -0,0 +1,63 @@ +# YOLOv8-Segmentation-ONNXRuntime-Python Demo + +This repository provides a Python demo for performing segmentation with YOLOv8 using ONNX Runtime, highlighting the interoperability of YOLOv8 models without the need for the full PyTorch stack. + +## Features + +- **Framework Agnostic**: Runs segmentation inference purely on ONNX Runtime without importing PyTorch. +- **Efficient Inference**: Supports both FP32 and FP16 precision for ONNX models, catering to different computational needs. +- **Ease of Use**: Utilizes simple command-line arguments for model execution. +- **Broad Compatibility**: Leverages Numpy and OpenCV for image processing, ensuring broad compatibility with various environments. + +## Installation + +Install the required packages using pip. You will need `ultralytics` for exporting YOLOv8-seg ONNX model and using some utility functions, `onnxruntime-gpu` for GPU-accelerated inference, and `opencv-python` for image processing. + +```bash +pip install ultralytics +pip install onnxruntime-gpu # For GPU support +# pip install onnxruntime # Use this instead if you don't have an NVIDIA GPU +pip install numpy +pip install opencv-python +``` + +## Getting Started + +### 1. Export the YOLOv8 ONNX Model + +Export the YOLOv8 segmentation model to ONNX format using the provided `ultralytics` package. + +```bash +yolo export model=yolov8s-seg.pt imgsz=640 format=onnx opset=12 simplify +``` + +### 2. Run Inference + +Perform inference with the exported ONNX model on your images. + +```bash +python main.py --model --source +``` + +### Example Output + +After running the command, you should see segmentation results similar to this: + +Segmentation Demo + +## Advanced Usage + +For more advanced usage, including real-time video processing, please refer to the `main.py` script's command-line arguments. + +## Contributing + +We welcome contributions to improve this demo! Please submit issues and pull requests for bug reports, feature requests, or submitting a new algorithm enhancement. + +## License + +This project is licensed under the AGPL-3.0 License - see the [LICENSE](https://github.com/ultralytics/ultralytics/blob/main/LICENSE) file for details. + +## Acknowledgments + +- The YOLOv8-Segmentation-ONNXRuntime-Python demo is contributed by GitHub user [jamjamjon](https://github.com/jamjamjon). +- Thanks to the ONNX Runtime community for providing a robust and efficient inference engine. diff --git a/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e2e7d55d80ba1a60d37ffb1dca7b86216ba90e --- /dev/null +++ b/examples/YOLOv8-Segmentation-ONNXRuntime-Python/main.py @@ -0,0 +1,338 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse + +import cv2 +import numpy as np +import onnxruntime as ort + +from ultralytics.utils import ASSETS, yaml_load +from ultralytics.utils.checks import check_yaml +from ultralytics.utils.plotting import Colors + + +class YOLOv8Seg: + """YOLOv8 segmentation model.""" + + def __init__(self, onnx_model): + """ + Initialization. + + Args: + onnx_model (str): Path to the ONNX model. + """ + # Build Ort session + self.session = ort.InferenceSession( + onnx_model, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + if ort.get_device() == "GPU" + else ["CPUExecutionProvider"], + ) + + # Numpy dtype: support both FP32 and FP16 onnx model + self.ndtype = np.half if self.session.get_inputs()[0].type == "tensor(float16)" else np.single + + # Get model width and height(YOLOv8-seg only has one input) + self.model_height, self.model_width = [x.shape for x in self.session.get_inputs()][0][-2:] + + # Load COCO class names + self.classes = yaml_load(check_yaml("coco8.yaml"))["names"] + + # Create color palette + self.color_palette = Colors() + + def __call__(self, im0, conf_threshold=0.4, iou_threshold=0.45, nm=32): + """ + The whole pipeline: pre-process -> inference -> post-process. + + Args: + im0 (Numpy.ndarray): original input image. + conf_threshold (float): confidence threshold for filtering predictions. + iou_threshold (float): iou threshold for NMS. + nm (int): the number of masks. + + Returns: + boxes (List): list of bounding boxes. + segments (List): list of segments. + masks (np.ndarray): [N, H, W], output masks. + """ + # Pre-process + im, ratio, (pad_w, pad_h) = self.preprocess(im0) + + # Ort inference + preds = self.session.run(None, {self.session.get_inputs()[0].name: im}) + + # Post-process + boxes, segments, masks = self.postprocess( + preds, + im0=im0, + ratio=ratio, + pad_w=pad_w, + pad_h=pad_h, + conf_threshold=conf_threshold, + iou_threshold=iou_threshold, + nm=nm, + ) + return boxes, segments, masks + + def preprocess(self, img): + """ + Pre-processes the input image. + + Args: + img (Numpy.ndarray): image about to be processed. + + Returns: + img_process (Numpy.ndarray): image preprocessed for inference. + ratio (tuple): width, height ratios in letterbox. + pad_w (float): width padding in letterbox. + pad_h (float): height padding in letterbox. + """ + # Resize and pad input image using letterbox() (Borrowed from Ultralytics) + shape = img.shape[:2] # original image shape + new_shape = (self.model_height, self.model_width) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + ratio = r, r + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + pad_w, pad_h = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(pad_h - 0.1)), int(round(pad_h + 0.1)) + left, right = int(round(pad_w - 0.1)), int(round(pad_w + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) + + # Transforms: HWC to CHW -> BGR to RGB -> div(255) -> contiguous -> add axis(optional) + img = np.ascontiguousarray(np.einsum("HWC->CHW", img)[::-1], dtype=self.ndtype) / 255.0 + img_process = img[None] if len(img.shape) == 3 else img + return img_process, ratio, (pad_w, pad_h) + + def postprocess(self, preds, im0, ratio, pad_w, pad_h, conf_threshold, iou_threshold, nm=32): + """ + Post-process the prediction. + + Args: + preds (Numpy.ndarray): predictions come from ort.session.run(). + im0 (Numpy.ndarray): [h, w, c] original input image. + ratio (tuple): width, height ratios in letterbox. + pad_w (float): width padding in letterbox. + pad_h (float): height padding in letterbox. + conf_threshold (float): conf threshold. + iou_threshold (float): iou threshold. + nm (int): the number of masks. + + Returns: + boxes (List): list of bounding boxes. + segments (List): list of segments. + masks (np.ndarray): [N, H, W], output masks. + """ + x, protos = preds[0], preds[1] # Two outputs: predictions and protos + + # Transpose dim 1: (Batch_size, xywh_conf_cls_nm, Num_anchors) -> (Batch_size, Num_anchors, xywh_conf_cls_nm) + x = np.einsum("bcn->bnc", x) + + # Predictions filtering by conf-threshold + x = x[np.amax(x[..., 4:-nm], axis=-1) > conf_threshold] + + # Create a new matrix which merge these(box, score, cls, nm) into one + # For more details about `numpy.c_()`: https://numpy.org/doc/1.26/reference/generated/numpy.c_.html + x = np.c_[x[..., :4], np.amax(x[..., 4:-nm], axis=-1), np.argmax(x[..., 4:-nm], axis=-1), x[..., -nm:]] + + # NMS filtering + x = x[cv2.dnn.NMSBoxes(x[:, :4], x[:, 4], conf_threshold, iou_threshold)] + + # Decode and return + if len(x) > 0: + # Bounding boxes format change: cxcywh -> xyxy + x[..., [0, 1]] -= x[..., [2, 3]] / 2 + x[..., [2, 3]] += x[..., [0, 1]] + + # Rescales bounding boxes from model shape(model_height, model_width) to the shape of original image + x[..., :4] -= [pad_w, pad_h, pad_w, pad_h] + x[..., :4] /= min(ratio) + + # Bounding boxes boundary clamp + x[..., [0, 2]] = x[:, [0, 2]].clip(0, im0.shape[1]) + x[..., [1, 3]] = x[:, [1, 3]].clip(0, im0.shape[0]) + + # Process masks + masks = self.process_mask(protos[0], x[:, 6:], x[:, :4], im0.shape) + + # Masks -> Segments(contours) + segments = self.masks2segments(masks) + return x[..., :6], segments, masks # boxes, segments, masks + else: + return [], [], [] + + @staticmethod + def masks2segments(masks): + """ + Takes a list of masks(n,h,w) and returns a list of segments(n,xy), from + https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. + + Args: + masks (numpy.ndarray): the output of the model, which is a tensor of shape (batch_size, 160, 160). + + Returns: + segments (List): list of segment masks. + """ + segments = [] + for x in masks.astype("uint8"): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[0] # CHAIN_APPROX_SIMPLE + if c: + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) # no segments found + segments.append(c.astype("float32")) + return segments + + @staticmethod + def crop_mask(masks, boxes): + """ + Takes a mask and a bounding box, and returns a mask that is cropped to the bounding box, from + https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. + + Args: + masks (Numpy.ndarray): [n, h, w] tensor of masks. + boxes (Numpy.ndarray): [n, 4] tensor of bbox coordinates in relative point form. + + Returns: + (Numpy.ndarray): The masks are being cropped to the bounding box. + """ + n, h, w = masks.shape + x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1) + r = np.arange(w, dtype=x1.dtype)[None, None, :] + c = np.arange(h, dtype=x1.dtype)[None, :, None] + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + def process_mask(self, protos, masks_in, bboxes, im0_shape): + """ + Takes the output of the mask head, and applies the mask to the bounding boxes. This produces masks of higher + quality but is slower, from https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. + + Args: + protos (numpy.ndarray): [mask_dim, mask_h, mask_w]. + masks_in (numpy.ndarray): [n, mask_dim], n is number of masks after nms. + bboxes (numpy.ndarray): bboxes re-scaled to original image shape. + im0_shape (tuple): the size of the input image (h,w,c). + + Returns: + (numpy.ndarray): The upsampled masks. + """ + c, mh, mw = protos.shape + masks = np.matmul(masks_in, protos.reshape((c, -1))).reshape((-1, mh, mw)).transpose(1, 2, 0) # HWN + masks = np.ascontiguousarray(masks) + masks = self.scale_mask(masks, im0_shape) # re-scale mask from P3 shape to original input image shape + masks = np.einsum("HWN -> NHW", masks) # HWN -> NHW + masks = self.crop_mask(masks, bboxes) + return np.greater(masks, 0.5) + + @staticmethod + def scale_mask(masks, im0_shape, ratio_pad=None): + """ + Takes a mask, and resizes it to the original image size, from + https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py. + + Args: + masks (np.ndarray): resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): the original image shape. + ratio_pad (tuple): the ratio of the padding to the original image. + + Returns: + masks (np.ndarray): The masks that are being returned. + """ + im1_shape = masks.shape[:2] + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + pad = ratio_pad[1] + + # Calculate tlbr of mask + top, left = int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1)) # y, x + bottom, right = int(round(im1_shape[0] - pad[1] + 0.1)), int(round(im1_shape[1] - pad[0] + 0.1)) + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize( + masks, (im0_shape[1], im0_shape[0]), interpolation=cv2.INTER_LINEAR + ) # INTER_CUBIC would be better + if len(masks.shape) == 2: + masks = masks[:, :, None] + return masks + + def draw_and_visualize(self, im, bboxes, segments, vis=False, save=True): + """ + Draw and visualize results. + + Args: + im (np.ndarray): original image, shape [h, w, c]. + bboxes (numpy.ndarray): [n, 4], n is number of bboxes. + segments (List): list of segment masks. + vis (bool): imshow using OpenCV. + save (bool): save image annotated. + + Returns: + None + """ + # Draw rectangles and polygons + im_canvas = im.copy() + for (*box, conf, cls_), segment in zip(bboxes, segments): + # draw contour and fill mask + cv2.polylines(im, np.int32([segment]), True, (255, 255, 255), 2) # white borderline + cv2.fillPoly(im_canvas, np.int32([segment]), self.color_palette(int(cls_), bgr=True)) + + # draw bbox rectangle + cv2.rectangle( + im, + (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), + self.color_palette(int(cls_), bgr=True), + 1, + cv2.LINE_AA, + ) + cv2.putText( + im, + f"{self.classes[cls_]}: {conf:.3f}", + (int(box[0]), int(box[1] - 9)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + self.color_palette(int(cls_), bgr=True), + 2, + cv2.LINE_AA, + ) + + # Mix image + im = cv2.addWeighted(im_canvas, 0.3, im, 0.7, 0) + + # Show image + if vis: + cv2.imshow("demo", im) + cv2.waitKey(0) + cv2.destroyAllWindows() + + # Save image + if save: + cv2.imwrite("demo.jpg", im) + + +if __name__ == "__main__": + # Create an argument parser to handle command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True, help="Path to ONNX model") + parser.add_argument("--source", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image") + parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") + parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold") + args = parser.parse_args() + + # Build model + model = YOLOv8Seg(args.model) + + # Read image by OpenCV + img = cv2.imread(args.source) + + # Inference + boxes, segments, _ = model(img, conf_threshold=args.conf, iou_threshold=args.iou) + + # Draw bboxes and polygons + if len(boxes) > 0: + model.draw_and_visualize(img, boxes, segments, vis=False, save=True) diff --git a/examples/YOLOv8-TFLite-Python/README.md b/examples/YOLOv8-TFLite-Python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0156759fdba626189e3c42f3c00556f0f7f1bbe6 --- /dev/null +++ b/examples/YOLOv8-TFLite-Python/README.md @@ -0,0 +1,55 @@ +# YOLOv8 - TFLite Runtime + +This example shows how to run inference with YOLOv8 TFLite model. It supports FP32, FP16 and INT8 models. + +## Installation + +### Installing `tflite-runtime` + +To load TFLite models, install the `tflite-runtime` package using: + +```bash +pip install tflite-runtime +``` + +### Installing `tensorflow-gpu` (For NVIDIA GPU Users) + +Leverage GPU acceleration with NVIDIA GPUs by installing `tensorflow-gpu`: + +```bash +pip install tensorflow-gpu +``` + +**Note:** Ensure you have compatible GPU drivers installed on your system. + +### Installing `tensorflow` (CPU Version) + +For CPU usage or non-NVIDIA GPUs, install TensorFlow with: + +```bash +pip install tensorflow +``` + +## Usage + +Follow these instructions to run YOLOv8 after successful installation. + +Convert the YOLOv8 model to TFLite format: + +```bash +yolo export model=yolov8n.pt imgsz=640 format=tflite int8 +``` + +Locate the TFLite model in `yolov8n_saved_model`. Then, execute the following in your terminal: + +```bash +python main.py --model yolov8n_full_integer_quant.tflite --img image.jpg --conf 0.25 --iou 0.45 --metadata "metadata.yaml" +``` + +Replace `best_full_integer_quant.tflite` with the TFLite model path, `image.jpg` with the input image path, `metadata.yaml` with the one generated by `ultralytics` during export, and adjust the confidence (conf) and IoU thresholds (iou) as necessary. + +### Output + +The output would show the detections along with the class labels and confidences of each detected object. + +![image](https://github.com/wamiqraza/Attribute-recognition-and-reidentification-Market1501-dataset/blob/main/img/bus.jpg) diff --git a/examples/YOLOv8-TFLite-Python/main.py b/examples/YOLOv8-TFLite-Python/main.py new file mode 100644 index 0000000000000000000000000000000000000000..00c403032859477a5a8a49260080214af92d920f --- /dev/null +++ b/examples/YOLOv8-TFLite-Python/main.py @@ -0,0 +1,221 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import argparse +from typing import Tuple, Union + +import cv2 +import numpy as np +import tensorflow as tf +import yaml + +from ultralytics.utils import ASSETS + +try: + from tflite_runtime.interpreter import Interpreter +except ImportError: + import tensorflow as tf + + Interpreter = tf.lite.Interpreter + + +class YOLOv8TFLite: + """ + YOLOv8TFLite. + + A class for performing object detection using the YOLOv8 model with TensorFlow Lite. + + Attributes: + model (str): Path to the TensorFlow Lite model file. + conf (float): Confidence threshold for filtering detections. + iou (float): Intersection over Union threshold for non-maximum suppression. + metadata (Optional[str]): Path to the metadata file, if any. + + Methods: + detect(img_path: str) -> np.ndarray: + Performs inference and returns the output image with drawn detections. + """ + + def __init__(self, model: str, conf: float = 0.25, iou: float = 0.45, metadata: Union[str, None] = None): + """ + Initializes an instance of the YOLOv8TFLite class. + + Args: + model (str): Path to the TFLite model. + conf (float, optional): Confidence threshold for filtering detections. Defaults to 0.25. + iou (float, optional): IoU (Intersection over Union) threshold for non-maximum suppression. Defaults to 0.45. + metadata (Union[str, None], optional): Path to the metadata file or None if not used. Defaults to None. + """ + self.conf = conf + self.iou = iou + if metadata is None: + self.classes = {i: i for i in range(1000)} + else: + with open(metadata) as f: + self.classes = yaml.safe_load(f)["names"] + np.random.seed(42) + self.color_palette = np.random.uniform(128, 255, size=(len(self.classes), 3)) + + self.model = Interpreter(model_path=model) + self.model.allocate_tensors() + + input_details = self.model.get_input_details()[0] + + self.in_width, self.in_height = input_details["shape"][1:3] + self.in_index = input_details["index"] + self.in_scale, self.in_zero_point = input_details["quantization"] + self.int8 = input_details["dtype"] == np.int8 + + output_details = self.model.get_output_details()[0] + self.out_index = output_details["index"] + self.out_scale, self.out_zero_point = output_details["quantization"] + + def letterbox(self, img: np.ndarray, new_shape: Tuple = (640, 640)) -> Tuple[np.ndarray, Tuple[float, float]]: + """Resizes and reshapes images while maintaining aspect ratio by adding padding, suitable for YOLO models.""" + shape = img.shape[:2] # current shape [height, width] + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + + # Compute padding + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) + + return img, (top / img.shape[0], left / img.shape[1]) + + def draw_detections(self, img: np.ndarray, box: np.ndarray, score: np.float32, class_id: int) -> None: + """ + Draws bounding boxes and labels on the input image based on the detected objects. + + Args: + img (np.ndarray): The input image to draw detections on. + box (np.ndarray): Detected bounding box in the format [x1, y1, width, height]. + score (np.float32): Corresponding detection score. + class_id (int): Class ID for the detected object. + + Returns: + None + """ + x1, y1, w, h = box + color = self.color_palette[class_id] + + cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2) + + label = f"{self.classes[class_id]}: {score:.2f}" + + (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + label_x = x1 + label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 + + cv2.rectangle( + img, + (int(label_x), int(label_y - label_height)), + (int(label_x + label_width), int(label_y + label_height)), + color, + cv2.FILLED, + ) + + cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) + + def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float]]: + """ + Preprocesses the input image before performing inference. + + Args: + img (np.ndarray): The input image to be preprocessed. + + Returns: + Tuple[np.ndarray, Tuple[float, float]]: A tuple containing: + - The preprocessed image (np.ndarray). + - A tuple of two float values representing the padding applied (top/bottom, left/right). + """ + img, pad = self.letterbox(img, (self.in_width, self.in_height)) + img = img[..., ::-1][None] # N,H,W,C for TFLite + img = np.ascontiguousarray(img) + img = img.astype(np.float32) + return img / 255, pad + + def postprocess(self, img: np.ndarray, outputs: np.ndarray, pad: Tuple[float, float]) -> np.ndarray: + """ + Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs. + + Args: + img (numpy.ndarray): The input image. + outputs (numpy.ndarray): The output of the model. + pad (Tuple[float, float]): Padding used by letterbox. + + Returns: + numpy.ndarray: The input image with detections drawn on it. + """ + outputs[:, 0] -= pad[1] + outputs[:, 1] -= pad[0] + outputs[:, :4] *= max(img.shape) + + outputs = outputs.transpose(0, 2, 1) + outputs[..., 0] -= outputs[..., 2] / 2 + outputs[..., 1] -= outputs[..., 3] / 2 + + for out in outputs: + scores = out[:, 4:].max(-1) + keep = scores > self.conf + boxes = out[keep, :4] + scores = scores[keep] + class_ids = out[keep, 4:].argmax(-1) + + indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf, self.iou).flatten() + + [self.draw_detections(img, boxes[i], scores[i], class_ids[i]) for i in indices] + + return img + + def detect(self, img_path: str) -> np.ndarray: + """ + Performs inference using a TFLite model and returns the output image with drawn detections. + + Args: + img_path (str): The path to the input image file. + + Returns: + np.ndarray: The output image with drawn detections. + """ + img = cv2.imread(img_path) + x, pad = self.preprocess(img) + if self.int8: + x = (x / self.in_scale + self.in_zero_point).astype(np.int8) + self.model.set_tensor(self.in_index, x) + + self.model.invoke() + + y = self.model.get_tensor(self.out_index) + + if self.int8: + y = (y.astype(np.float32) - self.out_zero_point) * self.out_scale + + return self.postprocess(img, y, pad) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="yolov8n_saved_model/yolov8n_full_integer_quant.tflite", + help="Path to TFLite model.", + ) + parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image") + parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold") + parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold") + parser.add_argument("--metadata", type=str, default="yolov8n_saved_model/metadata.yaml", help="Metadata yaml") + args = parser.parse_args() + + detector = YOLOv8TFLite(args.model, args.conf, args.iou, args.metadata) + result = detector.detect(str(ASSETS / "bus.jpg")) + + cv2.imshow("Output", result) + cv2.waitKey(0) diff --git a/examples/heatmaps.ipynb b/examples/heatmaps.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..4f34da35a4c31efa03ecfa593db24a7dae116665 --- /dev/null +++ b/examples/heatmaps.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PN1cAxdvd61e" + }, + "source": [ + "
\n", + "\n", + " \n", + " \n", + "\n", + " [中文](https://docs.ultralytics.com/zh/) | [한국어](https://docs.ultralytics.com/ko/) | [日本語](https://docs.ultralytics.com/ja/) | [Русский](https://docs.ultralytics.com/ru/) | [Deutsch](https://docs.ultralytics.com/de/) | [Français](https://docs.ultralytics.com/fr/) | [Español](https://docs.ultralytics.com/es/) | [Português](https://docs.ultralytics.com/pt/) | [Türkçe](https://docs.ultralytics.com/tr/) | [Tiếng Việt](https://docs.ultralytics.com/vi/) | [العربية](https://docs.ultralytics.com/ar/)\n", + "\n", + " \"Ultralytics\n", + " \"Run\n", + " \"Open\n", + " \"Open\n", + " \"Discord\"\n", + "\n", + "Welcome to the Ultralytics YOLO11 🚀 notebook! YOLO11 is the latest version of the YOLO (You Only Look Once) AI models developed by Ultralytics. This notebook serves as the starting point for exploring the various resources available to help you get started with YOLO11 and understand its features and capabilities.\n", + "\n", + "YOLO11 models are fast, accurate, and easy to use, making them ideal for various object detection and image segmentation tasks. They can be trained on large datasets and run on diverse hardware platforms, from CPUs to GPUs.\n", + "\n", + "We hope that the resources in this notebook will help you get the most out of YOLO11. Please browse the YOLO11 Heatmap Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o68Sg1oOeZm2" + }, + "source": [ + "# Setup\n", + "\n", + "Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware.\n", + "\n", + "[![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9dSwz_uOReMI", + "outputId": "99866c77-e210-41e1-d581-8508371ce634" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ultralytics 8.2.17 🚀 Python-3.10.12 torch-2.2.1+cu121 CUDA:0 (T4, 15102MiB)\n", + "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 29.8/78.2 GB disk)\n" + ] + } + ], + "source": [ + "%pip install ultralytics\n", + "import ultralytics\n", + "\n", + "ultralytics.checks()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m7VkxQ2aeg7k" + }, + "source": [ + "# Introduction to Heatmaps\n", + "\n", + "A heatmap generated with [Ultralytics YOLO11](https://github.com/ultralytics/ultralytics/) transforms complex data into a vibrant, color-coded matrix. This visual tool employs a spectrum of colors to represent varying data values, where warmer hues indicate higher intensities and cooler tones signify lower values. Heatmaps excel in visualizing intricate data patterns, correlations, and anomalies, offering an accessible and engaging approach to data interpretation across diverse domains.\n", + "\n", + "## Real World Applications\n", + "\n", + "| Transportation | Retail |\n", + "|:-----------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------:|\n", + "| ![Ultralytics YOLO11 Transportation Heatmap](https://github.com/RizwanMunawar/ultralytics/assets/62513924/288d7053-622b-4452-b4e4-1f41aeb764aa) | ![Ultralytics YOLO11 Retail Heatmap](https://github.com/RizwanMunawar/ultralytics/assets/62513924/edef75ad-50a7-4c0a-be4a-a66cdfc12802) |\n", + "| Ultralytics YOLO11 Transportation Heatmap | Ultralytics YOLO11 Retail Heatmap |\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Cx-u59HQdu2o" + }, + "outputs": [], + "source": [ + "import cv2\n", + "\n", + "from ultralytics import solutions\n", + "\n", + "# Open video file\n", + "cap = cv2.VideoCapture(\"path/to/video/file.mp4\")\n", + "assert cap.isOpened(), \"Error reading video file\"\n", + "\n", + "# Get video properties\n", + "w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))\n", + "\n", + "# Initialize video writer\n", + "video_writer = cv2.VideoWriter(\"heatmap_output.avi\", cv2.VideoWriter_fourcc(*\"mp4v\"), fps, (w, h))\n", + "\n", + "# Initialize heatmap object\n", + "heatmap_obj = solutions.Heatmap(\n", + " colormap=cv2.COLORMAP_PARULA, # Color of the heatmap\n", + " show=True, # Display the image during processing\n", + " model=\"yolo11n.pt\", # Ultralytics YOLO11 model file\n", + ")\n", + "\n", + "while cap.isOpened():\n", + " success, im0 = cap.read()\n", + " if not success:\n", + " print(\"Video frame is empty or video processing has been successfully completed.\")\n", + " break\n", + "\n", + " # Generate heatmap on the frame\n", + " im0 = heatmap_obj.generate_heatmap(im0)\n", + "\n", + " # Write the frame to the output video\n", + " video_writer.write(im0)\n", + "\n", + "# Release resources\n", + "cap.release()\n", + "video_writer.release()\n", + "cv2.destroyAllWindows()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QrlKg-y3fEyD" + }, + "source": [ + "# Additional Resources\n", + "\n", + "## Community Support\n", + "\n", + "For more information on using heatmaps with Ultralytics, you can explore the comprehensive [Ultralytics Heatmaps Docs](https://docs.ultralytics.com/guides/heatmaps/). This guide covers everything from basic concepts to advanced techniques, ensuring you get the most out of your heatmap visualizations.\n", + "\n", + "## Ultralytics ⚡ Resources\n", + "\n", + "At Ultralytics, we are committed to providing cutting-edge AI solutions. Here are some key resources to learn more about our company and get involved with our community:\n", + "\n", + "- [Ultralytics HUB](https://ultralytics.com/hub): Simplify your AI projects with Ultralytics HUB, our no-code tool for effortless YOLO training and deployment.\n", + "- [Ultralytics Licensing](https://ultralytics.com/license): Review our licensing terms to understand how you can use our software in your projects.\n", + "- [About Us](https://ultralytics.com/about): Discover our mission, vision, and the story behind Ultralytics.\n", + "- [Join Our Team](https://ultralytics.com/work): Explore career opportunities and join our team of talented professionals.\n", + "\n", + "## YOLO11 🚀 Resources\n", + "\n", + "YOLO11 is the latest evolution in the YOLO series, offering state-of-the-art performance in object detection and image segmentation. Here are some essential resources to help you get started with YOLO11:\n", + "\n", + "- [GitHub](https://github.com/ultralytics/ultralytics): Access the YOLO11 repository on GitHub, where you can find the source code, contribute to the project, and report issues.\n", + "- [Docs](https://docs.ultralytics.com/): Explore the official documentation for YOLO11, including installation guides, tutorials, and detailed API references.\n", + "- [Discord](https://ultralytics.com/discord): Join our Discord community to connect with other users, share your projects, and get help from the Ultralytics team.\n", + "\n", + "These resources are designed to help you leverage the full potential of Ultralytics' offerings and YOLO11. Whether you're a beginner or an experienced developer, you'll find the information and support you need to succeed." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/hub.ipynb b/examples/hub.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..05657155dfe900dc1457c49792cab4918936c6be --- /dev/null +++ b/examples/hub.ipynb @@ -0,0 +1,115 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "FIzICjaph_Wy" + }, + "source": [ + "\n", + "\n", + "\n", + "
\n", + "\n", + "[中文](https://docs.ultralytics.com/zh/hub/) | [한국어](https://docs.ultralytics.com/ko/hub/) | [日本語](https://docs.ultralytics.com/ja/hub/) | [Русский](https://docs.ultralytics.com/ru/hub/) | [Deutsch](https://docs.ultralytics.com/de/hub/) | [Français](https://docs.ultralytics.com/fr/hub/) | [Español](https://docs.ultralytics.com/es/hub/) | [Português](https://docs.ultralytics.com/pt/hub/) | [Türkçe](https://docs.ultralytics.com/tr/hub/) | [Tiếng Việt](https://docs.ultralytics.com/vi/hub/) | [العربية](https://docs.ultralytics.com/ar/hub/)\n", + "\n", + " \"CI\n", + " \"Open\n", + "\n", + " \"Discord\"\n", + " \"Ultralytics\n", + " \"Ultralytics\n", + "\n", + "Welcome to the [Ultralytics](https://ultralytics.com/) HUB notebook!\n", + "\n", + "This notebook allows you to train Ultralytics [YOLO](https://github.com/ultralytics/ultralytics) 🚀 models using [HUB](https://hub.ultralytics.com/). Please browse the HUB Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eRQ2ow94MiOv" + }, + "source": [ + "# Setup\n", + "\n", + "Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware.\n", + "\n", + "[![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "FyDnXd-n4c7Y", + "outputId": "e1d713ec-e8a6-4422-fe61-c76ec9f03df5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ultralytics 8.2.3 🚀 Python-3.10.12 torch-2.2.1+cu121 CUDA:0 (T4, 15102MiB)\n", + "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 28.8/78.2 GB disk)\n" + ] + } + ], + "source": [ + "%pip install ultralytics # install\n", + "from ultralytics import YOLO, checks, hub\n", + "\n", + "checks() # checks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cQ9BwaAqxAm4" + }, + "source": [ + "# Start\n", + "\n", + "⚡ Login with your API key, load your YOLO 🚀 model and start training in 3 lines of code!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XSlZaJ9Iw_iZ" + }, + "outputs": [], + "source": [ + "# Log in to HUB using your API key (https://hub.ultralytics.com/settings?tab=api+keys)\n", + "hub.login(\"YOUR_API_KEY\")\n", + "\n", + "# Load your model from HUB (replace 'YOUR_MODEL_ID' with your model ID)\n", + "model = YOLO(\"https://hub.ultralytics.com/models/YOUR_MODEL_ID\")\n", + "\n", + "# Train the model\n", + "results = model.train()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "Ultralytics HUB", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/object_counting.ipynb b/examples/object_counting.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b1f0c523f2953fa55f3e2b831a14206afb29316c --- /dev/null +++ b/examples/object_counting.ipynb @@ -0,0 +1,200 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PN1cAxdvd61e" + }, + "source": [ + "
\n", + "\n", + " \n", + " \n", + "\n", + " [中文](https://docs.ultralytics.com/zh/) | [한국어](https://docs.ultralytics.com/ko/) | [日本語](https://docs.ultralytics.com/ja/) | [Русский](https://docs.ultralytics.com/ru/) | [Deutsch](https://docs.ultralytics.com/de/) | [Français](https://docs.ultralytics.com/fr/) | [Español](https://docs.ultralytics.com/es/) | [Português](https://docs.ultralytics.com/pt/) | [Türkçe](https://docs.ultralytics.com/tr/) | [Tiếng Việt](https://docs.ultralytics.com/vi/) | [العربية](https://docs.ultralytics.com/ar/)\n", + "\n", + " \"Ultralytics\n", + " \"Run\n", + " \"Open\n", + " \"Open\n", + " \"Discord\"\n", + "\n", + "Welcome to the Ultralytics YOLO11 🚀 notebook! YOLO11 is the latest version of the YOLO (You Only Look Once) AI models developed by Ultralytics. This notebook serves as the starting point for exploring the various resources available to help you get started with YOLO11 and understand its features and capabilities.\n", + "\n", + "YOLO11 models are fast, accurate, and easy to use, making them ideal for various object detection and image segmentation tasks. They can be trained on large datasets and run on diverse hardware platforms, from CPUs to GPUs.\n", + "\n", + "We hope that the resources in this notebook will help you get the most out of YOLO11. Please browse the YOLO11 Object Counting Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o68Sg1oOeZm2" + }, + "source": [ + "# Setup\n", + "\n", + "Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware.\n", + "\n", + "[![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9dSwz_uOReMI", + "outputId": "fd3bab88-2f25-46c0-cae9-04d2beedc0c1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ultralytics 8.2.18 🚀 Python-3.10.12 torch-2.2.1+cu121 CUDA:0 (T4, 15102MiB)\n", + "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 29.8/78.2 GB disk)\n" + ] + } + ], + "source": [ + "%pip install ultralytics\n", + "import ultralytics\n", + "\n", + "ultralytics.checks()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m7VkxQ2aeg7k" + }, + "source": [ + "# Object Counting using Ultralytics YOLO11 🚀\n", + "\n", + "## What is Object Counting?\n", + "\n", + "Object counting with [Ultralytics YOLO11](https://github.com/ultralytics/ultralytics/) involves accurate identification and counting of specific objects in videos and camera streams. YOLO11 excels in real-time applications, providing efficient and precise object counting for various scenarios like crowd analysis and surveillance, thanks to its state-of-the-art algorithms and deep learning capabilities.\n", + "\n", + "## Advantages of Object Counting?\n", + "\n", + "- **Resource Optimization:** Object counting facilitates efficient resource management by providing accurate counts, and optimizing resource allocation in applications like inventory management.\n", + "- **Enhanced Security:** Object counting enhances security and surveillance by accurately tracking and counting entities, aiding in proactive threat detection.\n", + "- **Informed Decision-Making:** Object counting offers valuable insights for decision-making, optimizing processes in retail, traffic management, and various other domains.\n", + "\n", + "## Real World Applications\n", + "\n", + "| Logistics | Aquaculture |\n", + "|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------:|\n", + "| ![Conveyor Belt Packets Counting Using Ultralytics YOLO11](https://github.com/RizwanMunawar/ultralytics/assets/62513924/70e2d106-510c-4c6c-a57a-d34a765aa757) | ![Fish Counting in Sea using Ultralytics YOLO11](https://github.com/RizwanMunawar/ultralytics/assets/62513924/c60d047b-3837-435f-8d29-bb9fc95d2191) |\n", + "| Conveyor Belt Packets Counting Using Ultralytics YOLO11 | Fish Counting in Sea using Ultralytics YOLO11 |\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Cx-u59HQdu2o" + }, + "outputs": [], + "source": [ + "import cv2\n", + "\n", + "from ultralytics import solutions\n", + "\n", + "# Open the video file\n", + "cap = cv2.VideoCapture(\"path/to/video/file.mp4\")\n", + "assert cap.isOpened(), \"Error reading video file\"\n", + "\n", + "# Get video properties: width, height, and frames per second (fps)\n", + "w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))\n", + "\n", + "# Define points for a line or region of interest in the video frame\n", + "line_points = [(20, 400), (1080, 400)] # Line coordinates\n", + "\n", + "# Initialize the video writer to save the output video\n", + "video_writer = cv2.VideoWriter(\"object_counting_output.avi\", cv2.VideoWriter_fourcc(*\"mp4v\"), fps, (w, h))\n", + "\n", + "# Initialize the Object Counter with visualization options and other parameters\n", + "counter = solutions.ObjectCounter(\n", + " show=True, # Display the image during processing\n", + " region=line_points, # Region of interest points\n", + " model=\"yolo11n.pt\", # Ultralytics YOLO11 model file\n", + " line_width=2, # Thickness of the lines and bounding boxes\n", + ")\n", + "\n", + "# Process video frames in a loop\n", + "while cap.isOpened():\n", + " success, im0 = cap.read()\n", + " if not success:\n", + " print(\"Video frame is empty or video processing has been successfully completed.\")\n", + " break\n", + "\n", + " # Use the Object Counter to count objects in the frame and get the annotated image\n", + " im0 = counter.count(im0)\n", + "\n", + " # Write the annotated frame to the output video\n", + " video_writer.write(im0)\n", + "\n", + "# Release the video capture and writer objects\n", + "cap.release()\n", + "video_writer.release()\n", + "\n", + "# Close all OpenCV windows\n", + "cv2.destroyAllWindows()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QrlKg-y3fEyD" + }, + "source": [ + "# Additional Resources\n", + "\n", + "## Community Support\n", + "\n", + "For more information on counting objects with Ultralytics, you can explore the comprehensive [Ultralytics Object Counting Docs](https://docs.ultralytics.com/guides/object-counting/). This guide covers everything from basic concepts to advanced techniques, ensuring you get the most out of counting and visualization.\n", + "\n", + "## Ultralytics ⚡ Resources\n", + "\n", + "At Ultralytics, we are committed to providing cutting-edge AI solutions. Here are some key resources to learn more about our company and get involved with our community:\n", + "\n", + "- [Ultralytics HUB](https://ultralytics.com/hub): Simplify your AI projects with Ultralytics HUB, our no-code tool for effortless YOLO training and deployment.\n", + "- [Ultralytics Licensing](https://ultralytics.com/license): Review our licensing terms to understand how you can use our software in your projects.\n", + "- [About Us](https://ultralytics.com/about): Discover our mission, vision, and the story behind Ultralytics.\n", + "- [Join Our Team](https://ultralytics.com/work): Explore career opportunities and join our team of talented professionals.\n", + "\n", + "## YOLO11 🚀 Resources\n", + "\n", + "YOLO11 is the latest evolution in the YOLO series, offering state-of-the-art performance in object detection and image segmentation. Here are some essential resources to help you get started with YOLO11:\n", + "\n", + "- [GitHub](https://github.com/ultralytics/ultralytics): Access the YOLO11 repository on GitHub, where you can find the source code, contribute to the project, and report issues.\n", + "- [Docs](https://docs.ultralytics.com/): Explore the official documentation for YOLO11, including installation guides, tutorials, and detailed API references.\n", + "- [Discord](https://ultralytics.com/discord): Join our Discord community to connect with other users, share your projects, and get help from the Ultralytics team.\n", + "\n", + "These resources are designed to help you leverage the full potential of Ultralytics' offerings and YOLO11. Whether you're a beginner or an experienced developer, you'll find the information and support you need to succeed." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/object_tracking.ipynb b/examples/object_tracking.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f89c34ddeae2554d4121d77f39c8401793f2768f --- /dev/null +++ b/examples/object_tracking.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PN1cAxdvd61e" + }, + "source": [ + "
\n", + "\n", + " \n", + " \n", + "\n", + " [中文](https://docs.ultralytics.com/zh/) | [한국어](https://docs.ultralytics.com/ko/) | [日本語](https://docs.ultralytics.com/ja/) | [Русский](https://docs.ultralytics.com/ru/) | [Deutsch](https://docs.ultralytics.com/de/) | [Français](https://docs.ultralytics.com/fr/) | [Español](https://docs.ultralytics.com/es/) | [Português](https://docs.ultralytics.com/pt/) | [Türkçe](https://docs.ultralytics.com/tr/) | [Tiếng Việt](https://docs.ultralytics.com/vi/) | [العربية](https://docs.ultralytics.com/ar/)\n", + "\n", + " \"Ultralytics\n", + " \"Run\n", + " \"Open\n", + " \"Open\n", + " \"Discord\"\n", + "\n", + "Welcome to the Ultralytics YOLO11 🚀 notebook! YOLO11 is the latest version of the YOLO (You Only Look Once) AI models developed by Ultralytics. This notebook serves as the starting point for exploring the various resources available to help you get started with YOLO11 and understand its features and capabilities.\n", + "\n", + "YOLO11 models are fast, accurate, and easy to use, making them ideal for various object detection and image segmentation tasks. They can be trained on large datasets and run on diverse hardware platforms, from CPUs to GPUs.\n", + "\n", + "We hope that the resources in this notebook will help you get the most out of YOLO11. Please browse the YOLO11 Tracking Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o68Sg1oOeZm2" + }, + "source": [ + "# Setup\n", + "\n", + "Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware.\n", + "\n", + "[![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9dSwz_uOReMI", + "outputId": "ed8c2370-8fc7-4e4e-f669-d0bae4d944e9" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ultralytics 8.2.17 🚀 Python-3.10.12 torch-2.2.1+cu121 CUDA:0 (T4, 15102MiB)\n", + "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 29.8/78.2 GB disk)\n" + ] + } + ], + "source": [ + "%pip install ultralytics\n", + "import ultralytics\n", + "\n", + "ultralytics.checks()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m7VkxQ2aeg7k" + }, + "source": [ + "# Ultralytics Object Tracking\n", + "\n", + "[Ultralytics YOLO11](https://github.com/ultralytics/ultralytics/) instance segmentation involves identifying and outlining individual objects in an image, providing a detailed understanding of spatial distribution. Unlike semantic segmentation, it uniquely labels and precisely delineates each object, crucial for tasks like object detection and medical imaging.\n", + "\n", + "There are two types of instance segmentation tracking available in the Ultralytics package:\n", + "\n", + "- **Instance Segmentation with Class Objects:** Each class object is assigned a unique color for clear visual separation.\n", + "\n", + "- **Instance Segmentation with Object Tracks:** Every track is represented by a distinct color, facilitating easy identification and tracking.\n", + "\n", + "## Samples\n", + "\n", + "| Instance Segmentation | Instance Segmentation + Object Tracking |\n", + "|:---------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------------:|\n", + "| ![Ultralytics Instance Segmentation](https://github.com/RizwanMunawar/ultralytics/assets/62513924/d4ad3499-1f33-4871-8fbc-1be0b2643aa2) | ![Ultralytics Instance Segmentation with Object Tracking](https://github.com/RizwanMunawar/ultralytics/assets/62513924/2e5c38cc-fd5c-4145-9682-fa94ae2010a0) |\n", + "| Ultralytics Instance Segmentation 😍 | Ultralytics Instance Segmentation with Object Tracking 🔥 |" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-ZF9DM6e6gz0" + }, + "source": [ + "## CLI\n", + "\n", + "Command-Line Interface (CLI) example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-XJqhOwo6iqT" + }, + "outputs": [], + "source": [ + "!yolo track source=\"/path/to/video/file.mp4\" save=True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XRcw0vIE6oNb" + }, + "source": [ + "## Python\n", + "\n", + "Python Instance Segmentation and Object tracking example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Cx-u59HQdu2o" + }, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "\n", + "import cv2\n", + "\n", + "from ultralytics import YOLO\n", + "from ultralytics.utils.plotting import Annotator, colors\n", + "\n", + "# Dictionary to store tracking history with default empty lists\n", + "track_history = defaultdict(lambda: [])\n", + "\n", + "# Load the YOLO model with segmentation capabilities\n", + "model = YOLO(\"yolo11n-seg.pt\")\n", + "\n", + "# Open the video file\n", + "cap = cv2.VideoCapture(\"path/to/video/file.mp4\")\n", + "\n", + "# Retrieve video properties: width, height, and frames per second\n", + "w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))\n", + "\n", + "# Initialize video writer to save the output video with the specified properties\n", + "out = cv2.VideoWriter(\"instance-segmentation-object-tracking.avi\", cv2.VideoWriter_fourcc(*\"MJPG\"), fps, (w, h))\n", + "\n", + "while True:\n", + " # Read a frame from the video\n", + " ret, im0 = cap.read()\n", + " if not ret:\n", + " print(\"Video frame is empty or video processing has been successfully completed.\")\n", + " break\n", + "\n", + " # Create an annotator object to draw on the frame\n", + " annotator = Annotator(im0, line_width=2)\n", + "\n", + " # Perform object tracking on the current frame\n", + " results = model.track(im0, persist=True)\n", + "\n", + " # Check if tracking IDs and masks are present in the results\n", + " if results[0].boxes.id is not None and results[0].masks is not None:\n", + " # Extract masks and tracking IDs\n", + " masks = results[0].masks.xy\n", + " track_ids = results[0].boxes.id.int().cpu().tolist()\n", + "\n", + " # Annotate each mask with its corresponding tracking ID and color\n", + " for mask, track_id in zip(masks, track_ids):\n", + " annotator.seg_bbox(mask=mask, mask_color=colors(int(track_id), True), label=str(track_id))\n", + "\n", + " # Write the annotated frame to the output video\n", + " out.write(im0)\n", + " # Display the annotated frame\n", + " cv2.imshow(\"instance-segmentation-object-tracking\", im0)\n", + "\n", + " # Exit the loop if 'q' is pressed\n", + " if cv2.waitKey(1) & 0xFF == ord(\"q\"):\n", + " break\n", + "\n", + "# Release the video writer and capture objects, and close all OpenCV windows\n", + "out.release()\n", + "cap.release()\n", + "cv2.destroyAllWindows()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QrlKg-y3fEyD" + }, + "source": [ + "# Additional Resources\n", + "\n", + "## Community Support\n", + "\n", + "For more information on using tracking with Ultralytics, you can explore the comprehensive [Ultralytics Tracking Docs](https://docs.ultralytics.com/modes/track/). This guide covers everything from basic concepts to advanced techniques, ensuring you get the most out of tracking and visualization.\n", + "\n", + "## Ultralytics ⚡ Resources\n", + "\n", + "At Ultralytics, we are committed to providing cutting-edge AI solutions. Here are some key resources to learn more about our company and get involved with our community:\n", + "\n", + "- [Ultralytics HUB](https://ultralytics.com/hub): Simplify your AI projects with Ultralytics HUB, our no-code tool for effortless YOLO training and deployment.\n", + "- [Ultralytics Licensing](https://ultralytics.com/license): Review our licensing terms to understand how you can use our software in your projects.\n", + "- [About Us](https://ultralytics.com/about): Discover our mission, vision, and the story behind Ultralytics.\n", + "- [Join Our Team](https://ultralytics.com/work): Explore career opportunities and join our team of talented professionals.\n", + "\n", + "## YOLO11 🚀 Resources\n", + "\n", + "YOLO11 is the latest evolution in the YOLO series, offering state-of-the-art performance in object detection and image segmentation. Here are some essential resources to help you get started with YOLO11:\n", + "\n", + "- [GitHub](https://github.com/ultralytics/ultralytics): Access the YOLO11 repository on GitHub, where you can find the source code, contribute to the project, and report issues.\n", + "- [Docs](https://docs.ultralytics.com/): Explore the official documentation for YOLO11, including installation guides, tutorials, and detailed API references.\n", + "- [Discord](https://ultralytics.com/discord): Join our Discord community to connect with other users, share your projects, and get help from the Ultralytics team.\n", + "\n", + "These resources are designed to help you leverage the full potential of Ultralytics' offerings and YOLO11. Whether you're a beginner or an experienced developer, you'll find the information and support you need to succeed." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/examples/tutorial.ipynb b/examples/tutorial.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9ed5dc32b47b6d403bd54752881249743a2c4b46 --- /dev/null +++ b/examples/tutorial.ipynb @@ -0,0 +1,665 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "YOLO11 Tutorial", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "t6MPjfT5NrKQ" + }, + "source": [ + "
\n", + "\n", + " \n", + " \n", + "\n", + " [中文](https://docs.ultralytics.com/zh/) | [한국어](https://docs.ultralytics.com/ko/) | [日本語](https://docs.ultralytics.com/ja/) | [Русский](https://docs.ultralytics.com/ru/) | [Deutsch](https://docs.ultralytics.com/de/) | [Français](https://docs.ultralytics.com/fr/) | [Español](https://docs.ultralytics.com/es/) | [Português](https://docs.ultralytics.com/pt/) | [Türkçe](https://docs.ultralytics.com/tr/) | [Tiếng Việt](https://docs.ultralytics.com/vi/) | [العربية](https://docs.ultralytics.com/ar/)\n", + "\n", + " \"Ultralytics\n", + " \"Run\n", + " \"Open\n", + " \"Open\n", + "\n", + " \"Discord\"\n", + " \"Ultralytics\n", + " \"Ultralytics\n", + "\n", + "Welcome to the Ultralytics YOLO11 🚀 notebook! YOLO11 is the latest version of the YOLO (You Only Look Once) AI models developed by Ultralytics. This notebook serves as the starting point for exploring the various resources available to help you get started with YOLO11 and understand its features and capabilities.\n", + "\n", + "YOLO11 models are fast, accurate, and easy to use, making them ideal for various object detection and image segmentation tasks. They can be trained on large datasets and run on diverse hardware platforms, from CPUs to GPUs.\n", + "\n", + "We hope that the resources in this notebook will help you get the most out of YOLO11. Please browse the YOLO11 Docs for details, raise an issue on GitHub for support, and join our Discord community for questions and discussions!\n", + "\n", + " \n", + " \"Ultralytics\n", + " \n", + "

\n", + " Watch: How to Train\n", + " Ultralytics\n", + " YOLO11 Model on Custom Dataset using Google Colab Notebook 🚀

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7mGmQbAO5pQb" + }, + "source": [ + "# Setup\n", + "\n", + "Pip install `ultralytics` and [dependencies](https://github.com/ultralytics/ultralytics/blob/main/pyproject.toml) and check software and hardware.\n", + "\n", + "[![PyPI - Version](https://img.shields.io/pypi/v/ultralytics?logo=pypi&logoColor=white)](https://pypi.org/project/ultralytics/) [![Downloads](https://static.pepy.tech/badge/ultralytics)](https://www.pepy.tech/projects/ultralytics) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/ultralytics?logo=python&logoColor=gold)](https://pypi.org/project/ultralytics/)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "wbvMlHd_QwMG", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "2e992f9f-90bb-4668-de12-fed629975285" + }, + "source": [ + "%pip install ultralytics\n", + "import ultralytics\n", + "ultralytics.checks()" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Ultralytics 8.3.2 🚀 Python-3.10.12 torch-2.4.1+cu121 CUDA:0 (Tesla T4, 15102MiB)\n", + "Setup complete ✅ (2 CPUs, 12.7 GB RAM, 41.1/112.6 GB disk)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4JnkELT0cIJg" + }, + "source": [ + "# 1. Predict\n", + "\n", + "YOLO11 may be used directly in the Command Line Interface (CLI) with a `yolo` command for a variety of tasks and modes and accepts additional arguments, i.e. `imgsz=640`. See a full list of available `yolo` [arguments](https://docs.ultralytics.com/usage/cfg/) and other details in the [YOLO11 Predict Docs](https://docs.ultralytics.com/modes/train/).\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "zR9ZbuQCH7FX", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e3ebec6f-658a-4803-d80c-e07d12908767" + }, + "source": [ + "# Run inference on an image with YOLO11n\n", + "!yolo predict model=yolo11n.pt source='https://ultralytics.com/images/zidane.jpg'" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt to 'yolo11n.pt'...\n", + "100% 5.35M/5.35M [00:00<00:00, 72.7MB/s]\n", + "Ultralytics 8.3.2 🚀 Python-3.10.12 torch-2.4.1+cu121 CUDA:0 (Tesla T4, 15102MiB)\n", + "YOLO11n summary (fused): 238 layers, 2,616,248 parameters, 0 gradients, 6.5 GFLOPs\n", + "\n", + "Downloading https://ultralytics.com/images/zidane.jpg to 'zidane.jpg'...\n", + "100% 49.2k/49.2k [00:00<00:00, 5.37MB/s]\n", + "image 1/1 /content/zidane.jpg: 384x640 2 persons, 1 tie, 63.4ms\n", + "Speed: 14.5ms preprocess, 63.4ms inference, 820.9ms postprocess per image at shape (1, 3, 384, 640)\n", + "Results saved to \u001b[1mruns/detect/predict\u001b[0m\n", + "💡 Learn more at https://docs.ultralytics.com/modes/predict\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hkAzDWJ7cWTr" + }, + "source": [ + "        \n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0eq1SMWl6Sfn" + }, + "source": [ + "# 2. Val\n", + "Validate a model's accuracy on the [COCO](https://docs.ultralytics.com/datasets/detect/coco/) dataset's `val` or `test` splits. The latest YOLO11 [models](https://github.com/ultralytics/ultralytics#models) are downloaded automatically the first time they are used. See [YOLO11 Val Docs](https://docs.ultralytics.com/modes/val/) for more information." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "WQPtK1QYVaD_" + }, + "source": [ + "# Download COCO val\n", + "import torch\n", + "torch.hub.download_url_to_file('https://ultralytics.com/assets/coco2017val.zip', 'tmp.zip') # download (780M - 5000 images)\n", + "!unzip -q tmp.zip -d datasets && rm tmp.zip # unzip" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "X58w8JLpMnjH", + "outputId": "af2a5deb-029b-466d-96a4-bd3e406987fa", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "source": [ + "# Validate YOLO11n on COCO8 val\n", + "!yolo val model=yolo11n.pt data=coco8.yaml" + ], + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Ultralytics 8.3.2 🚀 Python-3.10.12 torch-2.4.1+cu121 CUDA:0 (Tesla T4, 15102MiB)\n", + "YOLO11n summary (fused): 238 layers, 2,616,248 parameters, 0 gradients, 6.5 GFLOPs\n", + "\n", + "Dataset 'coco8.yaml' images not found ⚠️, missing path '/content/datasets/coco8/images/val'\n", + "Downloading https://ultralytics.com/assets/coco8.zip to '/content/datasets/coco8.zip'...\n", + "100% 433k/433k [00:00<00:00, 15.8MB/s]\n", + "Unzipping /content/datasets/coco8.zip to /content/datasets/coco8...: 100% 25/25 [00:00<00:00, 1188.35file/s]\n", + "Dataset download success ✅ (1.4s), saved to \u001b[1m/content/datasets\u001b[0m\n", + "\n", + "Downloading https://ultralytics.com/assets/Arial.ttf to '/root/.config/Ultralytics/Arial.ttf'...\n", + "100% 755k/755k [00:00<00:00, 17.7MB/s]\n", + "\u001b[34m\u001b[1mval: \u001b[0mScanning /content/datasets/coco8/labels/val... 4 images, 0 backgrounds, 0 corrupt: 100% 4/4 [00:00<00:00, 142.04it/s]\n", + "\u001b[34m\u001b[1mval: \u001b[0mNew cache created: /content/datasets/coco8/labels/val.cache\n", + " Class Images Instances Box(P R mAP50 mAP50-95): 100% 1/1 [00:04<00:00, 4.75s/it]\n", + " all 4 17 0.57 0.85 0.847 0.632\n", + " person 3 10 0.557 0.6 0.585 0.272\n", + " dog 1 1 0.548 1 0.995 0.697\n", + " horse 1 2 0.531 1 0.995 0.674\n", + " elephant 1 2 0.371 0.5 0.516 0.256\n", + " umbrella 1 1 0.569 1 0.995 0.995\n", + " potted plant 1 1 0.847 1 0.995 0.895\n", + "Speed: 1.0ms preprocess, 73.8ms inference, 0.0ms loss, 561.4ms postprocess per image\n", + "Results saved to \u001b[1mruns/detect/val\u001b[0m\n", + "💡 Learn more at https://docs.ultralytics.com/modes/val\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZY2VXXXu74w5" + }, + "source": [ + "# 3. Train\n", + "\n", + "

\n", + "\n", + "Train YOLO11 on [Detect](https://docs.ultralytics.com/tasks/detect/), [Segment](https://docs.ultralytics.com/tasks/segment/), [Classify](https://docs.ultralytics.com/tasks/classify/) and [Pose](https://docs.ultralytics.com/tasks/pose/) datasets. See [YOLO11 Train Docs](https://docs.ultralytics.com/modes/train/) for more information." + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Select YOLO11 🚀 logger {run: 'auto'}\n", + "logger = 'Comet' #@param ['Comet', 'TensorBoard']\n", + "\n", + "if logger == 'Comet':\n", + " %pip install -q comet_ml\n", + " import comet_ml; comet_ml.init()\n", + "elif logger == 'TensorBoard':\n", + " %load_ext tensorboard\n", + " %tensorboard --logdir ." + ], + "metadata": { + "id": "ktegpM42AooT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "1NcFxRcFdJ_O", + "outputId": "952f35f7-666f-4121-fbdf-2b3a33b28081", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "source": [ + "# Train YOLO11n on COCO8 for 3 epochs\n", + "!yolo train model=yolo11n.pt data=coco8.yaml epochs=3 imgsz=640" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Ultralytics 8.3.2 🚀 Python-3.10.12 torch-2.4.1+cu121 CUDA:0 (Tesla T4, 15102MiB)\n", + "\u001b[34m\u001b[1mengine/trainer: \u001b[0mtask=detect, mode=train, model=yolo11n.pt, data=coco8.yaml, epochs=3, time=None, patience=100, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=None, workers=8, project=None, name=train3, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False, classes=None, retina_masks=False, embed=None, show=False, save_frames=False, save_txt=False, save_conf=False, save_crop=False, show_labels=True, show_conf=True, show_boxes=True, line_width=None, format=torchscript, keras=False, optimize=False, int8=False, dynamic=False, simplify=True, opset=None, workspace=4, nms=False, lr0=0.01, lrf=0.01, momentum=0.937, weight_decay=0.0005, warmup_epochs=3.0, warmup_momentum=0.8, warmup_bias_lr=0.1, box=7.5, cls=0.5, dfl=1.5, pose=12.0, kobj=1.0, label_smoothing=0.0, nbs=64, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, flipud=0.0, fliplr=0.5, bgr=0.0, mosaic=1.0, mixup=0.0, copy_paste=0.0, copy_paste_mode=flip, auto_augment=randaugment, erasing=0.4, crop_fraction=1.0, cfg=None, tracker=botsort.yaml, save_dir=runs/detect/train3\n", + "\n", + " from n params module arguments \n", + " 0 -1 1 464 ultralytics.nn.modules.conv.Conv [3, 16, 3, 2] \n", + " 1 -1 1 4672 ultralytics.nn.modules.conv.Conv [16, 32, 3, 2] \n", + " 2 -1 1 6640 ultralytics.nn.modules.block.C3k2 [32, 64, 1, False, 0.25] \n", + " 3 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2] \n", + " 4 -1 1 26080 ultralytics.nn.modules.block.C3k2 [64, 128, 1, False, 0.25] \n", + " 5 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2] \n", + " 6 -1 1 87040 ultralytics.nn.modules.block.C3k2 [128, 128, 1, True] \n", + " 7 -1 1 295424 ultralytics.nn.modules.conv.Conv [128, 256, 3, 2] \n", + " 8 -1 1 346112 ultralytics.nn.modules.block.C3k2 [256, 256, 1, True] \n", + " 9 -1 1 164608 ultralytics.nn.modules.block.SPPF [256, 256, 5] \n", + " 10 -1 1 249728 ultralytics.nn.modules.block.C2PSA [256, 256, 1] \n", + " 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 12 [-1, 6] 1 0 ultralytics.nn.modules.conv.Concat [1] \n", + " 13 -1 1 111296 ultralytics.nn.modules.block.C3k2 [384, 128, 1, False] \n", + " 14 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] \n", + " 15 [-1, 4] 1 0 ultralytics.nn.modules.conv.Concat [1] \n", + " 16 -1 1 32096 ultralytics.nn.modules.block.C3k2 [256, 64, 1, False] \n", + " 17 -1 1 36992 ultralytics.nn.modules.conv.Conv [64, 64, 3, 2] \n", + " 18 [-1, 13] 1 0 ultralytics.nn.modules.conv.Concat [1] \n", + " 19 -1 1 86720 ultralytics.nn.modules.block.C3k2 [192, 128, 1, False] \n", + " 20 -1 1 147712 ultralytics.nn.modules.conv.Conv [128, 128, 3, 2] \n", + " 21 [-1, 10] 1 0 ultralytics.nn.modules.conv.Concat [1] \n", + " 22 -1 1 378880 ultralytics.nn.modules.block.C3k2 [384, 256, 1, True] \n", + " 23 [16, 19, 22] 1 464912 ultralytics.nn.modules.head.Detect [80, [64, 128, 256]] \n", + "YOLO11n summary: 319 layers, 2,624,080 parameters, 2,624,064 gradients, 6.6 GFLOPs\n", + "\n", + "Transferred 499/499 items from pretrained weights\n", + "\u001b[34m\u001b[1mTensorBoard: \u001b[0mStart with 'tensorboard --logdir runs/detect/train', view at http://localhost:6006/\n", + "Freezing layer 'model.23.dfl.conv.weight'\n", + "\u001b[34m\u001b[1mAMP: \u001b[0mrunning Automatic Mixed Precision (AMP) checks with YOLO11n...\n", + "\u001b[34m\u001b[1mAMP: \u001b[0mchecks passed ✅\n", + "\u001b[34m\u001b[1mtrain: \u001b[0mScanning /content/datasets/coco8/labels/train.cache... 4 images, 0 backgrounds, 0 corrupt: 100% 4/4 [00:00\n" + ], + "metadata": { + "id": "Phm9ccmOKye5" + } + }, + { + "cell_type": "markdown", + "source": [ + "## 1. Detection\n", + "\n", + "YOLO11 _detection_ models have no suffix and are the default YOLO11 models, i.e. `yolo11n.pt` and are pretrained on COCO. See [Detection Docs](https://docs.ultralytics.com/tasks/detect/) for full details.\n" + ], + "metadata": { + "id": "yq26lwpYK1lq" + } + }, + { + "cell_type": "code", + "source": [ + "# Load YOLO11n, train it on COCO128 for 3 epochs and predict an image with it\n", + "from ultralytics import YOLO\n", + "\n", + "model = YOLO('yolo11n.pt') # load a pretrained YOLO detection model\n", + "model.train(data='coco8.yaml', epochs=3) # train the model\n", + "model('https://ultralytics.com/images/bus.jpg') # predict on an image" + ], + "metadata": { + "id": "8Go5qqS9LbC5" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 2. Segmentation\n", + "\n", + "YOLO11 _segmentation_ models use the `-seg` suffix, i.e. `yolo11n-seg.pt` and are pretrained on COCO. See [Segmentation Docs](https://docs.ultralytics.com/tasks/segment/) for full details.\n" + ], + "metadata": { + "id": "7ZW58jUzK66B" + } + }, + { + "cell_type": "code", + "source": [ + "# Load YOLO11n-seg, train it on COCO128-seg for 3 epochs and predict an image with it\n", + "from ultralytics import YOLO\n", + "\n", + "model = YOLO('yolo11n-seg.pt') # load a pretrained YOLO segmentation model\n", + "model.train(data='coco8-seg.yaml', epochs=3) # train the model\n", + "model('https://ultralytics.com/images/bus.jpg') # predict on an image" + ], + "metadata": { + "id": "WFPJIQl_L5HT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 3. Classification\n", + "\n", + "YOLO11 _classification_ models use the `-cls` suffix, i.e. `yolo11n-cls.pt` and are pretrained on ImageNet. See [Classification Docs](https://docs.ultralytics.com/tasks/classify/) for full details.\n" + ], + "metadata": { + "id": "ax3p94VNK9zR" + } + }, + { + "cell_type": "code", + "source": [ + "# Load YOLO11n-cls, train it on mnist160 for 3 epochs and predict an image with it\n", + "from ultralytics import YOLO\n", + "\n", + "model = YOLO('yolo11n-cls.pt') # load a pretrained YOLO classification model\n", + "model.train(data='mnist160', epochs=3) # train the model\n", + "model('https://ultralytics.com/images/bus.jpg') # predict on an image" + ], + "metadata": { + "id": "5q9Zu6zlL5rS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 4. Pose\n", + "\n", + "YOLO11 _pose_ models use the `-pose` suffix, i.e. `yolo11n-pose.pt` and are pretrained on COCO Keypoints. See [Pose Docs](https://docs.ultralytics.com/tasks/pose/) for full details." + ], + "metadata": { + "id": "SpIaFLiO11TG" + } + }, + { + "cell_type": "code", + "source": [ + "# Load YOLO11n-pose, train it on COCO8-pose for 3 epochs and predict an image with it\n", + "from ultralytics import YOLO\n", + "\n", + "model = YOLO('yolo11n-pose.pt') # load a pretrained YOLO pose model\n", + "model.train(data='coco8-pose.yaml', epochs=3) # train the model\n", + "model('https://ultralytics.com/images/bus.jpg') # predict on an image" + ], + "metadata": { + "id": "si4aKFNg19vX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## 4. Oriented Bounding Boxes (OBB)\n", + "\n", + "YOLO11 _OBB_ models use the `-obb` suffix, i.e. `yolo11n-obb.pt` and are pretrained on the DOTA dataset. See [OBB Docs](https://docs.ultralytics.com/tasks/obb/) for full details." + ], + "metadata": { + "id": "cf5j_T9-B5F0" + } + }, + { + "cell_type": "code", + "source": [ + "# Load YOLO11n-obb, train it on DOTA8 for 3 epochs and predict an image with it\n", + "from ultralytics import YOLO\n", + "\n", + "model = YOLO('yolo11n-obb.pt') # load a pretrained YOLO OBB model\n", + "model.train(data='dota8.yaml', epochs=3) # train the model\n", + "model('https://ultralytics.com/images/boats.jpg') # predict on an image" + ], + "metadata": { + "id": "IJNKClOOB5YS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IEijrePND_2I" + }, + "source": [ + "# Appendix\n", + "\n", + "Additional content below." + ] + }, + { + "cell_type": "code", + "source": [ + "# Pip install from source\n", + "!pip install git+https://github.com/ultralytics/ultralytics@main" + ], + "metadata": { + "id": "pIdE6i8C3LYp" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Git clone and run tests on updates branch\n", + "!git clone https://github.com/ultralytics/ultralytics -b main\n", + "%pip install -qe ultralytics" + ], + "metadata": { + "id": "uRKlwxSJdhd1" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Run tests (Git clone only)\n", + "!pytest ultralytics/tests" + ], + "metadata": { + "id": "GtPlh7mcCGZX" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Validate multiple models\n", + "for x in 'nsmlx':\n", + " !yolo val model=yolo11{x}.pt data=coco.yaml" + ], + "metadata": { + "id": "Wdc6t_bfzDDk" + }, + "execution_count": null, + "outputs": [] + } + ] +} diff --git a/logs/yolov12l.csv b/logs/yolov12l.csv new file mode 100644 index 0000000000000000000000000000000000000000..077042f5378a1d96d53e5dec8322cc4b3c15eeb1 --- /dev/null +++ b/logs/yolov12l.csv @@ -0,0 +1,601 @@ +epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2 +1,339.515,3.50537,5.39731,3.93491,0.00124,0.03039,0.00102,0.00036,2.90862,5.36478,3.38559,0.00332613,0.00332613,0.00332613 +2,674.31,2.29052,3.87526,2.41236,0.29449,0.04716,0.02494,0.01215,1.98648,3.40587,2.2623,0.00664848,0.00664848,0.00664848 +3,1009.33,1.82444,3.17865,1.90075,0.18715,0.11045,0.06667,0.03792,1.68442,2.86791,1.92568,0.00995982,0.00995982,0.00995982 +4,1343.43,1.60355,2.69529,1.68308,0.27454,0.18943,0.14434,0.08914,1.48768,2.37295,1.71398,0.0099505,0.0099505,0.0099505 +5,1676.4,1.47571,2.37954,1.56957,0.34372,0.23998,0.21,0.13412,1.38809,2.11763,1.62507,0.009934,0.009934,0.009934 +6,2010.38,1.40929,2.18015,1.50784,0.37788,0.27837,0.25476,0.16657,1.32322,1.92451,1.55641,0.0099175,0.0099175,0.0099175 +7,2343.96,1.35212,2.03533,1.46223,0.44001,0.30475,0.29696,0.19705,1.27377,1.78449,1.5056,0.009901,0.009901,0.009901 +8,2676.06,1.3158,1.9333,1.43198,0.44891,0.32622,0.32489,0.21996,1.23968,1.69481,1.47202,0.0098845,0.0098845,0.0098845 +9,3009.05,1.27741,1.82506,1.3979,0.49566,0.34958,0.35489,0.24383,1.20284,1.60799,1.44046,0.009868,0.009868,0.009868 +10,3341.57,1.26078,1.78262,1.38231,0.50673,0.36221,0.37492,0.25732,1.18591,1.5561,1.42331,0.0098515,0.0098515,0.0098515 +11,3673.47,1.24672,1.74267,1.36511,0.53834,0.37242,0.39763,0.27636,1.15624,1.48566,1.38694,0.009835,0.009835,0.009835 +12,4006.24,1.2313,1.69314,1.34667,0.53183,0.38781,0.41007,0.28451,1.1439,1.43917,1.37315,0.0098185,0.0098185,0.0098185 +13,4338.94,1.21715,1.6509,1.33201,0.55692,0.40192,0.4247,0.29919,1.12385,1.39582,1.35435,0.009802,0.009802,0.009802 +14,4672.12,1.20509,1.62334,1.32076,0.54257,0.41828,0.43731,0.30779,1.11269,1.3637,1.34427,0.0097855,0.0097855,0.0097855 +15,5006.71,1.19163,1.59336,1.31065,0.58296,0.41486,0.44933,0.31825,1.1003,1.32987,1.33075,0.009769,0.009769,0.009769 +16,5340.58,1.18282,1.57683,1.30227,0.55761,0.44287,0.46322,0.32927,1.08957,1.29571,1.31662,0.0097525,0.0097525,0.0097525 +17,5673.79,1.16887,1.56098,1.29165,0.5757,0.44285,0.47005,0.33476,1.07934,1.27309,1.30628,0.009736,0.009736,0.009736 +18,6006.44,1.17089,1.53612,1.29179,0.57017,0.45739,0.48112,0.34346,1.07282,1.25585,1.29976,0.0097195,0.0097195,0.0097195 +19,6339.58,1.15918,1.52115,1.2778,0.57292,0.46464,0.49016,0.35025,1.06064,1.23206,1.28982,0.009703,0.009703,0.009703 +20,6672.2,1.1474,1.50456,1.27399,0.57782,0.46417,0.49455,0.35502,1.05556,1.21453,1.28223,0.0096865,0.0096865,0.0096865 +21,7005.81,1.14619,1.48513,1.26837,0.60426,0.46899,0.50412,0.3633,1.04847,1.19666,1.27322,0.00967,0.00967,0.00967 +22,7338.86,1.13456,1.47845,1.25841,0.60664,0.47729,0.5116,0.36757,1.04232,1.18328,1.2668,0.0096535,0.0096535,0.0096535 +23,7671.45,1.13297,1.46623,1.25447,0.59445,0.49168,0.51799,0.3737,1.03811,1.17036,1.26364,0.009637,0.009637,0.009637 +24,8004.47,1.12597,1.44845,1.25281,0.60622,0.48754,0.52069,0.37656,1.03291,1.16242,1.2585,0.0096205,0.0096205,0.0096205 +25,8337.24,1.12306,1.43984,1.24951,0.61931,0.48193,0.52364,0.37965,1.02809,1.15354,1.25248,0.009604,0.009604,0.009604 +26,8670.01,1.12349,1.43501,1.24851,0.63143,0.47851,0.52814,0.38322,1.02405,1.14283,1.25152,0.0095875,0.0095875,0.0095875 +27,9002.2,1.12864,1.44485,1.2523,0.62479,0.48835,0.53199,0.38683,1.02104,1.13615,1.24809,0.009571,0.009571,0.009571 +28,9336.02,1.11923,1.42122,1.24311,0.62711,0.48906,0.53504,0.38907,1.01765,1.12833,1.24453,0.0095545,0.0095545,0.0095545 +29,9668.81,1.11019,1.42381,1.23884,0.63298,0.49252,0.53897,0.39257,1.01426,1.12196,1.24169,0.009538,0.009538,0.009538 +30,10001.6,1.11902,1.40856,1.23924,0.63402,0.49027,0.54058,0.39316,1.01318,1.11689,1.23994,0.0095215,0.0095215,0.0095215 +31,10333.9,1.1096,1.40098,1.23409,0.62188,0.50063,0.54372,0.39555,1.01067,1.11139,1.23633,0.009505,0.009505,0.009505 +32,10667.1,1.09967,1.40037,1.23033,0.63223,0.50015,0.54639,0.39801,1.00796,1.10738,1.23361,0.0094885,0.0094885,0.0094885 +33,11000,1.10104,1.38723,1.22875,0.62455,0.50649,0.54795,0.39939,1.0061,1.10285,1.23128,0.009472,0.009472,0.009472 +34,11332.2,1.09674,1.3802,1.22636,0.63335,0.50689,0.55029,0.4013,1.00465,1.09914,1.22987,0.0094555,0.0094555,0.0094555 +35,11664.9,1.1023,1.37526,1.22542,0.63352,0.50719,0.5515,0.40233,1.00366,1.09616,1.22844,0.009439,0.009439,0.009439 +36,11997.5,1.09445,1.3663,1.22177,0.63825,0.50439,0.55231,0.40362,1.00171,1.09305,1.22686,0.0094225,0.0094225,0.0094225 +37,12329,1.09544,1.36887,1.22457,0.6382,0.50533,0.55325,0.4047,1.00063,1.09024,1.22566,0.009406,0.009406,0.009406 +38,12662.6,1.09019,1.35392,1.21883,0.64009,0.50625,0.55396,0.4055,0.99988,1.08835,1.22453,0.0093895,0.0093895,0.0093895 +39,12995.6,1.08457,1.35446,1.21445,0.63798,0.50768,0.55442,0.4061,0.99863,1.08597,1.22316,0.009373,0.009373,0.009373 +40,13327.9,1.08957,1.34708,1.21577,0.63801,0.50866,0.55509,0.40676,0.99762,1.08365,1.22187,0.0093565,0.0093565,0.0093565 +41,13661.9,1.0885,1.34489,1.21353,0.64322,0.50727,0.55571,0.40736,0.99719,1.08139,1.22115,0.00934,0.00934,0.00934 +42,13995.2,1.0888,1.34748,1.21251,0.64672,0.50632,0.55687,0.40813,0.99673,1.08031,1.22035,0.0093235,0.0093235,0.0093235 +43,14326.5,1.08285,1.33805,1.20708,0.64384,0.50894,0.55718,0.40857,0.99622,1.07891,1.21959,0.009307,0.009307,0.009307 +44,14659.6,1.07912,1.33274,1.2079,0.64195,0.50975,0.5578,0.40925,0.9956,1.07775,1.21903,0.0092905,0.0092905,0.0092905 +45,14992,1.08148,1.34054,1.20964,0.64139,0.51007,0.55847,0.40974,0.99512,1.07684,1.21837,0.009274,0.009274,0.009274 +46,15325.3,1.08338,1.33678,1.21249,0.64331,0.50963,0.55909,0.41035,0.99451,1.07587,1.21755,0.0092575,0.0092575,0.0092575 +47,15660,1.08277,1.32516,1.20613,0.64878,0.50784,0.55951,0.41083,0.99396,1.07512,1.21694,0.009241,0.009241,0.009241 +48,15992.5,1.08089,1.31833,1.2057,0.65144,0.5072,0.55973,0.41112,0.99345,1.07418,1.21654,0.0092245,0.0092245,0.0092245 +49,16324.7,1.07137,1.31479,1.20162,0.65194,0.50817,0.55993,0.41131,0.99304,1.07381,1.21606,0.009208,0.009208,0.009208 +50,16656.5,1.07346,1.32107,1.20351,0.64935,0.50929,0.56029,0.4116,0.99263,1.07367,1.21538,0.0091915,0.0091915,0.0091915 +51,16987.9,1.07494,1.32269,1.20471,0.64823,0.51034,0.56067,0.4117,0.99272,1.07342,1.21516,0.009175,0.009175,0.009175 +52,17319,1.07716,1.31053,1.20348,0.64881,0.50998,0.56113,0.41209,0.9926,1.07329,1.21508,0.0091585,0.0091585,0.0091585 +53,17650.7,1.06792,1.30552,1.2008,0.64766,0.51152,0.56117,0.41206,0.99252,1.07375,1.21475,0.009142,0.009142,0.009142 +54,17979.9,1.07475,1.29928,1.1975,0.65118,0.50967,0.56129,0.41201,0.99198,1.07462,1.21412,0.0091255,0.0091255,0.0091255 +55,18312.5,1.06793,1.30218,1.20036,0.65463,0.5078,0.56129,0.4121,0.99178,1.07525,1.21411,0.009109,0.009109,0.009109 +56,18644.9,1.06691,1.29861,1.19885,0.65413,0.50865,0.56101,0.41234,0.99149,1.07613,1.21378,0.0090925,0.0090925,0.0090925 +57,18977.3,1.06584,1.29235,1.19663,0.65457,0.50869,0.56105,0.41247,0.9916,1.07713,1.21363,0.009076,0.009076,0.009076 +58,19309.6,1.06618,1.29262,1.19672,0.6558,0.50807,0.56106,0.41255,0.99179,1.07829,1.21368,0.0090595,0.0090595,0.0090595 +59,19641.8,1.06454,1.29178,1.19326,0.65586,0.5083,0.56125,0.41264,0.99172,1.07972,1.21357,0.009043,0.009043,0.009043 +60,19973,1.06428,1.28334,1.19269,0.65237,0.50998,0.56102,0.41269,0.9918,1.0812,1.2138,0.0090265,0.0090265,0.0090265 +61,20304.9,1.06711,1.29269,1.19731,0.64696,0.5125,0.56101,0.41283,0.99175,1.0827,1.21393,0.00901,0.00901,0.00901 +62,20636,1.05883,1.28137,1.19082,0.6478,0.51231,0.56062,0.41275,0.9917,1.08469,1.21375,0.0089935,0.0089935,0.0089935 +63,20967.4,1.06154,1.28246,1.19157,0.64734,0.51363,0.56061,0.41267,0.99175,1.08654,1.21378,0.008977,0.008977,0.008977 +64,21300.1,1.06152,1.28286,1.19348,0.64783,0.51248,0.56024,0.41253,0.99141,1.0888,1.21373,0.0089605,0.0089605,0.0089605 +65,21632.3,1.06364,1.28306,1.19372,0.64834,0.51178,0.56005,0.4125,0.9915,1.09067,1.21376,0.008944,0.008944,0.008944 +66,21964.9,1.0574,1.26505,1.18739,0.64676,0.51227,0.55972,0.41226,0.9914,1.09253,1.21372,0.0089275,0.0089275,0.0089275 +67,22295.7,1.06142,1.27786,1.18941,0.64965,0.51095,0.55948,0.41229,0.99156,1.09447,1.21382,0.008911,0.008911,0.008911 +68,22627.8,1.05332,1.27129,1.18654,0.64911,0.51044,0.55953,0.41254,0.99168,1.09652,1.21379,0.0088945,0.0088945,0.0088945 +69,22960,1.05257,1.27969,1.19052,0.64956,0.50959,0.55934,0.41265,0.99166,1.09797,1.2138,0.008878,0.008878,0.008878 +70,23292.7,1.05492,1.26569,1.18735,0.64933,0.51003,0.55933,0.41248,0.99162,1.09991,1.21365,0.0088615,0.0088615,0.0088615 +71,23625.8,1.05426,1.27105,1.18699,0.65276,0.50799,0.559,0.41251,0.99176,1.10163,1.21362,0.008845,0.008845,0.008845 +72,23957.4,1.05187,1.25969,1.18452,0.65176,0.50945,0.55935,0.41271,0.99148,1.10329,1.21347,0.0088285,0.0088285,0.0088285 +73,24288,1.04796,1.26084,1.1815,0.65014,0.51026,0.55927,0.41264,0.99132,1.1053,1.2135,0.008812,0.008812,0.008812 +74,24619.3,1.05406,1.26196,1.18655,0.65103,0.50888,0.55933,0.41258,0.99104,1.10679,1.21319,0.0087955,0.0087955,0.0087955 +75,24950.7,1.05526,1.2617,1.18467,0.65058,0.50964,0.55926,0.41268,0.99079,1.10818,1.213,0.008779,0.008779,0.008779 +76,25282.5,1.04908,1.25372,1.18254,0.64694,0.51088,0.5596,0.41295,0.99068,1.10918,1.21265,0.0087625,0.0087625,0.0087625 +77,25615.6,1.0449,1.25239,1.17886,0.64991,0.50942,0.55984,0.41325,0.99066,1.11031,1.21251,0.008746,0.008746,0.008746 +78,25945.7,1.04572,1.25462,1.18269,0.64998,0.50996,0.56013,0.41337,0.99047,1.11131,1.21235,0.0087295,0.0087295,0.0087295 +79,26277.4,1.04843,1.25849,1.18343,0.65048,0.5095,0.56015,0.41355,0.99007,1.11211,1.21203,0.008713,0.008713,0.008713 +80,26608.9,1.05032,1.24024,1.17944,0.65352,0.50838,0.56011,0.41382,0.98991,1.11276,1.212,0.0086965,0.0086965,0.0086965 +81,26941.1,1.05005,1.24802,1.18048,0.65389,0.50836,0.56073,0.41406,0.9897,1.11289,1.21173,0.00868,0.00868,0.00868 +82,27273.3,1.04755,1.25242,1.18079,0.65771,0.50823,0.56109,0.41444,0.98937,1.11353,1.2114,0.0086635,0.0086635,0.0086635 +83,27605.4,1.04907,1.24937,1.17926,0.65742,0.50791,0.56128,0.41467,0.98885,1.1137,1.21097,0.008647,0.008647,0.008647 +84,27935.1,1.04722,1.24181,1.17834,0.65887,0.50887,0.56197,0.41523,0.98843,1.11371,1.2105,0.0086305,0.0086305,0.0086305 +85,28266.3,1.04714,1.23913,1.17717,0.6573,0.51017,0.56224,0.41578,0.98782,1.11321,1.21001,0.008614,0.008614,0.008614 +86,28597.5,1.04164,1.24271,1.17585,0.65546,0.51115,0.56262,0.41589,0.98732,1.11302,1.20947,0.0085975,0.0085975,0.0085975 +87,28930.7,1.03911,1.23713,1.17393,0.65524,0.51128,0.56297,0.41628,0.98675,1.11258,1.20875,0.008581,0.008581,0.008581 +88,29263.3,1.03681,1.23946,1.17252,0.65406,0.51263,0.56333,0.41703,0.98636,1.11169,1.20812,0.0085645,0.0085645,0.0085645 +89,29596.2,1.03391,1.23469,1.17197,0.6569,0.51275,0.56401,0.41764,0.98575,1.11102,1.20743,0.008548,0.008548,0.008548 +90,29927.5,1.03613,1.23494,1.17411,0.65521,0.51395,0.56438,0.41809,0.9852,1.11006,1.20689,0.0085315,0.0085315,0.0085315 +91,30260,1.03787,1.23002,1.17386,0.65359,0.51522,0.56508,0.41886,0.98479,1.10887,1.20629,0.008515,0.008515,0.008515 +92,30591.6,1.04488,1.23815,1.17763,0.65801,0.51355,0.56555,0.41921,0.98415,1.10759,1.20562,0.0084985,0.0084985,0.0084985 +93,30923.9,1.04192,1.22817,1.17698,0.65804,0.5137,0.56635,0.41996,0.98344,1.10586,1.20474,0.008482,0.008482,0.008482 +94,31256.8,1.04143,1.24721,1.17767,0.65925,0.51328,0.5669,0.4207,0.98285,1.1043,1.20406,0.0084655,0.0084655,0.0084655 +95,31589,1.03604,1.23238,1.17342,0.6608,0.51273,0.5673,0.42127,0.98231,1.10325,1.20343,0.008449,0.008449,0.008449 +96,31917.6,1.04009,1.24287,1.18035,0.66174,0.51318,0.56815,0.4219,0.98166,1.10198,1.20291,0.0084325,0.0084325,0.0084325 +97,32249.4,1.03721,1.2269,1.17355,0.6633,0.51322,0.56892,0.42264,0.98104,1.10033,1.20219,0.008416,0.008416,0.008416 +98,32580.3,1.03622,1.22429,1.17256,0.66389,0.5139,0.56959,0.42309,0.9804,1.09867,1.20133,0.0083995,0.0083995,0.0083995 +99,32913.2,1.03554,1.22783,1.17379,0.66518,0.51334,0.57043,0.42357,0.97956,1.09636,1.20041,0.008383,0.008383,0.008383 +100,33244.9,1.04189,1.2302,1.17632,0.66622,0.51358,0.57091,0.42399,0.97891,1.09456,1.19966,0.0083665,0.0083665,0.0083665 +101,33576.8,1.03999,1.22729,1.17423,0.66728,0.51368,0.57152,0.42449,0.9784,1.09294,1.19888,0.00835,0.00835,0.00835 +102,33910.4,1.03494,1.22791,1.17311,0.6675,0.51493,0.57227,0.42513,0.97778,1.09133,1.19812,0.0083335,0.0083335,0.0083335 +103,34243.2,1.03114,1.2192,1.16914,0.66757,0.51572,0.57313,0.42581,0.97721,1.08962,1.19741,0.008317,0.008317,0.008317 +104,34574.2,1.02431,1.21946,1.16771,0.66936,0.51655,0.57407,0.42644,0.97643,1.08764,1.19667,0.0083005,0.0083005,0.0083005 +105,34905.4,1.03059,1.21937,1.17308,0.67044,0.51685,0.57461,0.42709,0.97602,1.08575,1.1961,0.008284,0.008284,0.008284 +106,35238.1,1.03573,1.22035,1.17194,0.6699,0.51913,0.5753,0.42787,0.97563,1.08362,1.19558,0.0082675,0.0082675,0.0082675 +107,35569.8,1.03274,1.22663,1.17117,0.6686,0.51912,0.5761,0.42841,0.97513,1.08091,1.19482,0.008251,0.008251,0.008251 +108,35902.1,1.02934,1.20727,1.16816,0.67075,0.51908,0.57678,0.42925,0.97474,1.07885,1.19425,0.0082345,0.0082345,0.0082345 +109,36233.6,1.03131,1.21002,1.16636,0.66849,0.52106,0.57734,0.42971,0.9739,1.07712,1.19349,0.008218,0.008218,0.008218 +110,36566.5,1.02569,1.21139,1.16932,0.66938,0.52165,0.57834,0.43036,0.97318,1.07568,1.19275,0.0082015,0.0082015,0.0082015 +111,36898.5,1.02756,1.2046,1.16631,0.67062,0.52078,0.57906,0.43062,0.97276,1.07384,1.19212,0.008185,0.008185,0.008185 +112,37230.1,1.03071,1.21262,1.1699,0.67382,0.52075,0.57963,0.43106,0.97216,1.07211,1.19146,0.0081685,0.0081685,0.0081685 +113,37562.8,1.02811,1.21612,1.16828,0.67359,0.52171,0.58028,0.43156,0.97152,1.0698,1.1907,0.008152,0.008152,0.008152 +114,37895.7,1.03314,1.20537,1.16629,0.67313,0.52429,0.581,0.43221,0.97088,1.06797,1.18987,0.0081355,0.0081355,0.0081355 +115,38228.3,1.0322,1.21545,1.16559,0.67453,0.52446,0.58148,0.43271,0.97024,1.06575,1.18891,0.008119,0.008119,0.008119 +116,38560.6,1.02464,1.20214,1.16621,0.6761,0.52392,0.58222,0.43336,0.97005,1.06361,1.18825,0.0081025,0.0081025,0.0081025 +117,38892.7,1.03085,1.21251,1.16848,0.67857,0.52489,0.58295,0.43379,0.96952,1.06131,1.18752,0.008086,0.008086,0.008086 +118,39224.4,1.02692,1.19895,1.16275,0.67726,0.52637,0.58338,0.43406,0.96898,1.05927,1.18677,0.0080695,0.0080695,0.0080695 +119,39556,1.0277,1.2049,1.16324,0.68007,0.52474,0.58408,0.43477,0.96834,1.05752,1.18587,0.008053,0.008053,0.008053 +120,39887.6,1.02408,1.19742,1.16284,0.67586,0.52806,0.58467,0.43536,0.96752,1.05509,1.18498,0.0080365,0.0080365,0.0080365 +121,40219.6,1.01912,1.20824,1.16249,0.67684,0.52784,0.58554,0.43646,0.96727,1.05314,1.18443,0.00802,0.00802,0.00802 +122,40550.6,1.02399,1.20315,1.16299,0.6782,0.52778,0.5862,0.43707,0.96666,1.05117,1.18375,0.0080035,0.0080035,0.0080035 +123,40882,1.02844,1.20221,1.16713,0.67807,0.52825,0.58691,0.43783,0.96606,1.04899,1.18318,0.007987,0.007987,0.007987 +124,41214.2,1.02405,1.19628,1.1644,0.68169,0.52792,0.58768,0.43858,0.96542,1.04662,1.18249,0.0079705,0.0079705,0.0079705 +125,41545.8,1.02372,1.19442,1.1658,0.67928,0.53115,0.58865,0.43886,0.96482,1.04425,1.18188,0.007954,0.007954,0.007954 +126,41876.6,1.01784,1.19016,1.16046,0.67812,0.53258,0.58938,0.43957,0.96433,1.04179,1.18145,0.0079375,0.0079375,0.0079375 +127,42208.6,1.02372,1.20012,1.16134,0.68416,0.53004,0.59002,0.44003,0.9639,1.0393,1.1807,0.007921,0.007921,0.007921 +128,42541.2,1.02044,1.20219,1.16031,0.68511,0.52965,0.59061,0.44056,0.9636,1.03748,1.18018,0.0079045,0.0079045,0.0079045 +129,42873,1.02164,1.20321,1.16344,0.68813,0.52955,0.59126,0.44125,0.96307,1.03542,1.17954,0.007888,0.007888,0.007888 +130,43205,1.02425,1.19416,1.16122,0.69475,0.52845,0.59171,0.44168,0.96302,1.03315,1.17934,0.0078715,0.0078715,0.0078715 +131,43537.4,1.02413,1.20253,1.16626,0.69492,0.52909,0.59247,0.44205,0.96261,1.03092,1.17863,0.007855,0.007855,0.007855 +132,43870.4,1.01932,1.19225,1.16115,0.69238,0.53082,0.59306,0.44256,0.96216,1.029,1.178,0.0078385,0.0078385,0.0078385 +133,44203.5,1.01409,1.17868,1.1559,0.69549,0.53027,0.59388,0.44337,0.96199,1.02693,1.17769,0.007822,0.007822,0.007822 +134,44535.3,1.02194,1.19034,1.15897,0.69376,0.53197,0.59438,0.44389,0.96162,1.02522,1.17726,0.0078055,0.0078055,0.0078055 +135,44867.8,1.02764,1.19649,1.16404,0.69585,0.53198,0.59525,0.44439,0.96104,1.02352,1.17662,0.007789,0.007789,0.007789 +136,45199,1.01793,1.18346,1.16024,0.6954,0.53293,0.5962,0.44508,0.9606,1.0217,1.17616,0.0077725,0.0077725,0.0077725 +137,45529.7,1.02068,1.18795,1.1588,0.68945,0.53547,0.5966,0.44523,0.95989,1.02019,1.17551,0.007756,0.007756,0.007756 +138,45861.8,1.0242,1.1887,1.1595,0.69188,0.53447,0.59732,0.44574,0.95951,1.0183,1.17492,0.0077395,0.0077395,0.0077395 +139,46193,1.0172,1.1765,1.16014,0.69241,0.53601,0.59793,0.44656,0.95913,1.01669,1.17452,0.007723,0.007723,0.007723 +140,46524.7,1.02077,1.17968,1.15942,0.69292,0.53607,0.59859,0.44735,0.95877,1.01481,1.17401,0.0077065,0.0077065,0.0077065 +141,46855.2,1.01601,1.17363,1.15784,0.69388,0.53731,0.59942,0.44765,0.95838,1.01299,1.17347,0.00769,0.00769,0.00769 +142,47183.1,1.01598,1.17476,1.15867,0.69196,0.53867,0.60007,0.4481,0.95792,1.01122,1.17304,0.0076735,0.0076735,0.0076735 +143,47514.9,1.01709,1.18041,1.161,0.69334,0.53827,0.60093,0.4489,0.95748,1.0092,1.17267,0.007657,0.007657,0.007657 +144,47846.7,1.01468,1.17345,1.15704,0.69382,0.53855,0.60154,0.44955,0.95709,1.00733,1.1724,0.0076405,0.0076405,0.0076405 +145,48178,1.01149,1.17323,1.15522,0.6889,0.54208,0.60225,0.45014,0.95696,1.0059,1.17207,0.007624,0.007624,0.007624 +146,48510.1,1.01281,1.17533,1.15784,0.69799,0.53794,0.60262,0.45051,0.95649,1.00375,1.17164,0.0076075,0.0076075,0.0076075 +147,48842.3,1.01843,1.17534,1.1574,0.70025,0.53672,0.60334,0.45102,0.95646,1.00195,1.17159,0.007591,0.007591,0.007591 +148,49174.6,1.01971,1.17156,1.1576,0.69927,0.53841,0.60411,0.45155,0.95592,1.00047,1.17095,0.0075745,0.0075745,0.0075745 +149,49506.4,1.0122,1.16836,1.15144,0.69078,0.5436,0.60464,0.45199,0.9555,0.99856,1.17051,0.007558,0.007558,0.007558 +150,49838.5,1.01034,1.1752,1.15717,0.69172,0.54367,0.60495,0.45233,0.95518,0.99679,1.17011,0.0075415,0.0075415,0.0075415 +151,50170.5,1.0131,1.17401,1.15649,0.69181,0.545,0.60527,0.45264,0.95476,0.99516,1.16963,0.007525,0.007525,0.007525 +152,50502.4,1.00924,1.1712,1.15429,0.69071,0.54578,0.60611,0.45317,0.9541,0.9932,1.16879,0.0075085,0.0075085,0.0075085 +153,50834,1.02241,1.17643,1.15993,0.69716,0.54542,0.60684,0.45352,0.95394,0.99174,1.16857,0.007492,0.007492,0.007492 +154,51164.1,1.01286,1.16461,1.15554,0.69528,0.54734,0.60745,0.45384,0.95354,0.98991,1.16826,0.0074755,0.0074755,0.0074755 +155,51495.3,1.015,1.17106,1.1561,0.69469,0.54841,0.60791,0.45419,0.95293,0.98845,1.16761,0.007459,0.007459,0.007459 +156,51827.3,1.01139,1.16675,1.1555,0.6941,0.54901,0.60849,0.45472,0.95247,0.98683,1.16714,0.0074425,0.0074425,0.0074425 +157,52158.3,1.00452,1.16045,1.15067,0.69325,0.5492,0.60876,0.45488,0.95202,0.98539,1.1667,0.007426,0.007426,0.007426 +158,52489.3,1.00514,1.15611,1.14684,0.69425,0.54972,0.60939,0.45552,0.95179,0.984,1.16644,0.0074095,0.0074095,0.0074095 +159,52821.4,1.00963,1.16231,1.15277,0.69254,0.55009,0.60959,0.45573,0.95158,0.98262,1.16615,0.007393,0.007393,0.007393 +160,53153.9,1.00883,1.16237,1.15375,0.68942,0.55174,0.60957,0.4559,0.9511,0.98125,1.16565,0.0073765,0.0073765,0.0073765 +161,53485.6,1.00681,1.16463,1.15199,0.6918,0.55179,0.60995,0.45628,0.95052,0.98022,1.16534,0.00736,0.00736,0.00736 +162,53817.2,1.01027,1.16736,1.15379,0.6938,0.55144,0.61042,0.45656,0.95022,0.97862,1.16503,0.0073435,0.0073435,0.0073435 +163,54150,1.00681,1.16208,1.15194,0.69366,0.55129,0.61109,0.45708,0.94998,0.97735,1.16474,0.007327,0.007327,0.007327 +164,54482.5,1.00381,1.15699,1.15102,0.69109,0.55293,0.61184,0.45758,0.94957,0.97591,1.16432,0.0073105,0.0073105,0.0073105 +165,54814.1,1.00684,1.16158,1.151,0.68942,0.55516,0.61238,0.45799,0.94885,0.97463,1.16385,0.007294,0.007294,0.007294 +166,55144.8,1.01448,1.16525,1.15282,0.68867,0.55581,0.61296,0.45853,0.94855,0.97333,1.16365,0.0072775,0.0072775,0.0072775 +167,55476.7,1.00514,1.16054,1.15075,0.69024,0.55554,0.61321,0.45853,0.94811,0.97239,1.16323,0.007261,0.007261,0.007261 +168,55810,1.0097,1.16344,1.15527,0.69371,0.55389,0.61337,0.45885,0.94786,0.97113,1.16294,0.0072445,0.0072445,0.0072445 +169,56142.6,1.00385,1.16397,1.15306,0.69464,0.55437,0.61386,0.45932,0.94739,0.96989,1.16264,0.007228,0.007228,0.007228 +170,56475.1,1.00343,1.15405,1.14922,0.69463,0.55445,0.61441,0.45971,0.94717,0.96873,1.16231,0.0072115,0.0072115,0.0072115 +171,56806.3,1.00623,1.1547,1.15115,0.69707,0.5549,0.61499,0.46013,0.94677,0.96748,1.16168,0.007195,0.007195,0.007195 +172,57138.1,1.00491,1.14682,1.14579,0.69654,0.55599,0.61526,0.46041,0.94652,0.96649,1.16151,0.0071785,0.0071785,0.0071785 +173,57470.7,1.00281,1.1559,1.14737,0.69424,0.55788,0.6155,0.46082,0.94636,0.96554,1.16115,0.007162,0.007162,0.007162 +174,57802.4,1.00499,1.15646,1.15194,0.69667,0.55738,0.6157,0.46103,0.94605,0.96451,1.16099,0.0071455,0.0071455,0.0071455 +175,58133.9,0.99922,1.15259,1.14941,0.69772,0.55756,0.61705,0.4622,0.94546,0.96347,1.16056,0.007129,0.007129,0.007129 +176,58466.3,1.0074,1.15394,1.14946,0.69474,0.5606,0.61759,0.46261,0.94498,0.96265,1.16011,0.0071125,0.0071125,0.0071125 +177,58798.9,1.00129,1.15292,1.14785,0.69402,0.56142,0.61793,0.46287,0.94468,0.96196,1.15971,0.007096,0.007096,0.007096 +178,59130.3,1.00092,1.14774,1.14794,0.6925,0.56159,0.61823,0.46314,0.9443,0.96114,1.15935,0.0070795,0.0070795,0.0070795 +179,59461.6,0.99958,1.14571,1.14603,0.69554,0.56029,0.61882,0.46344,0.94421,0.96004,1.15907,0.007063,0.007063,0.007063 +180,59792.1,0.99563,1.14436,1.14498,0.69679,0.56076,0.619,0.46352,0.94426,0.95948,1.15881,0.0070465,0.0070465,0.0070465 +181,60124.6,1.00111,1.14688,1.14873,0.69694,0.56017,0.61938,0.46377,0.94406,0.95852,1.15849,0.00703,0.00703,0.00703 +182,60458.5,0.99837,1.14473,1.14812,0.69791,0.5593,0.61955,0.46416,0.94385,0.95742,1.15821,0.0070135,0.0070135,0.0070135 +183,60790.1,1.00404,1.15122,1.14727,0.69569,0.56148,0.61992,0.46426,0.94367,0.95654,1.1579,0.006997,0.006997,0.006997 +184,61121.7,0.99982,1.14791,1.14772,0.69601,0.56208,0.62006,0.46456,0.94364,0.95557,1.15778,0.0069805,0.0069805,0.0069805 +185,61453.5,0.9987,1.14864,1.15152,0.69798,0.56138,0.62045,0.46475,0.94332,0.95466,1.15755,0.006964,0.006964,0.006964 +186,61785.7,0.9965,1.14521,1.14412,0.70071,0.5596,0.62082,0.46524,0.943,0.95365,1.15731,0.0069475,0.0069475,0.0069475 +187,62118.1,1.00329,1.1462,1.148,0.70281,0.55886,0.62114,0.46545,0.94264,0.95265,1.15707,0.006931,0.006931,0.006931 +188,62450,0.99894,1.14716,1.14647,0.70076,0.56049,0.62183,0.46586,0.9426,0.95142,1.1568,0.0069145,0.0069145,0.0069145 +189,62782.2,0.99975,1.14195,1.1473,0.70255,0.55998,0.62224,0.46611,0.9423,0.95077,1.15647,0.006898,0.006898,0.006898 +190,63114.5,0.99278,1.13591,1.14368,0.70243,0.56124,0.6228,0.46645,0.94175,0.95023,1.15594,0.0068815,0.0068815,0.0068815 +191,63446.8,1.00073,1.14239,1.14771,0.70153,0.56163,0.62299,0.46662,0.94145,0.94918,1.15548,0.006865,0.006865,0.006865 +192,63779.7,1.00063,1.14545,1.14491,0.70194,0.56222,0.62301,0.46677,0.94127,0.94848,1.15526,0.0068485,0.0068485,0.0068485 +193,64111.6,0.99471,1.13433,1.1426,0.70198,0.56317,0.62329,0.46714,0.941,0.9476,1.15492,0.006832,0.006832,0.006832 +194,64443.1,0.99358,1.14241,1.14137,0.7048,0.56156,0.6237,0.46733,0.94061,0.94658,1.15455,0.0068155,0.0068155,0.0068155 +195,64774.9,0.99578,1.13352,1.143,0.69989,0.56559,0.62392,0.46739,0.9405,0.9458,1.1545,0.006799,0.006799,0.006799 +196,65103.9,0.99395,1.14008,1.14349,0.69911,0.56547,0.62408,0.46781,0.94021,0.94491,1.15441,0.0067825,0.0067825,0.0067825 +197,65436.3,0.99862,1.13893,1.14377,0.69954,0.5651,0.62408,0.46794,0.94016,0.94426,1.1543,0.006766,0.006766,0.006766 +198,65768.5,0.99767,1.14264,1.14687,0.69909,0.56562,0.6247,0.4685,0.93984,0.94372,1.15403,0.0067495,0.0067495,0.0067495 +199,66100.9,0.99849,1.1408,1.14634,0.69998,0.56491,0.62503,0.46884,0.93955,0.94285,1.15376,0.006733,0.006733,0.006733 +200,66433.2,0.99358,1.1375,1.14504,0.70083,0.56478,0.62514,0.46883,0.93945,0.94223,1.1536,0.0067165,0.0067165,0.0067165 +201,66764.1,0.99746,1.13691,1.14599,0.69954,0.56611,0.62507,0.46878,0.93898,0.94126,1.15318,0.0067,0.0067,0.0067 +202,67096.3,0.99491,1.13868,1.14527,0.70064,0.56593,0.62518,0.46915,0.93874,0.94048,1.15278,0.0066835,0.0066835,0.0066835 +203,67428.6,0.99752,1.12979,1.14342,0.7024,0.56516,0.62528,0.46916,0.93837,0.93943,1.15251,0.006667,0.006667,0.006667 +204,67760,0.99649,1.13281,1.14746,0.70222,0.56535,0.62556,0.46927,0.93828,0.93885,1.15247,0.0066505,0.0066505,0.0066505 +205,68092,0.99538,1.13613,1.14482,0.70175,0.56604,0.62586,0.46932,0.9382,0.9381,1.15233,0.006634,0.006634,0.006634 +206,68424.5,0.99685,1.13407,1.14302,0.7014,0.5662,0.62605,0.46944,0.93818,0.93739,1.15206,0.0066175,0.0066175,0.0066175 +207,68756.7,0.99232,1.12839,1.14065,0.70263,0.56593,0.62646,0.46976,0.93795,0.93664,1.15176,0.006601,0.006601,0.006601 +208,69088.3,0.99833,1.13061,1.14214,0.70321,0.56715,0.62682,0.47001,0.93791,0.93603,1.15163,0.0065845,0.0065845,0.0065845 +209,69418.9,0.9975,1.13767,1.14382,0.70323,0.56754,0.62716,0.47047,0.93799,0.93523,1.15143,0.006568,0.006568,0.006568 +210,69751.1,0.99648,1.13623,1.14485,0.70213,0.56858,0.62749,0.47041,0.93792,0.935,1.15119,0.0065515,0.0065515,0.0065515 +211,70082.5,0.98997,1.12271,1.14039,0.69997,0.57063,0.62762,0.47076,0.93781,0.93415,1.15094,0.006535,0.006535,0.006535 +212,70414.4,0.98831,1.11634,1.14078,0.70119,0.56935,0.6279,0.4709,0.9375,0.9335,1.15066,0.0065185,0.0065185,0.0065185 +213,70745.2,0.99299,1.12402,1.14062,0.7025,0.56914,0.62826,0.47127,0.93715,0.93271,1.15024,0.006502,0.006502,0.006502 +214,71077.8,0.99175,1.12877,1.14175,0.70147,0.57024,0.62853,0.47143,0.93672,0.93209,1.14996,0.0064855,0.0064855,0.0064855 +215,71409,0.98827,1.1201,1.13646,0.69901,0.57247,0.62884,0.47171,0.93655,0.93138,1.14949,0.006469,0.006469,0.006469 +216,71740.9,0.98805,1.12468,1.1388,0.69741,0.57351,0.6292,0.4722,0.93628,0.93067,1.14926,0.0064525,0.0064525,0.0064525 +217,72073.8,0.98255,1.1106,1.13402,0.70096,0.57172,0.62949,0.472,0.93608,0.93003,1.14912,0.006436,0.006436,0.006436 +218,72406,0.98974,1.11669,1.13509,0.7,0.57144,0.62981,0.47199,0.93623,0.92977,1.14896,0.0064195,0.0064195,0.0064195 +219,72737.6,0.99255,1.12758,1.14262,0.69904,0.57308,0.63024,0.47243,0.93605,0.92931,1.14875,0.006403,0.006403,0.006403 +220,73069.7,0.9869,1.1186,1.13882,0.69871,0.57442,0.63069,0.47302,0.93595,0.92874,1.14863,0.0063865,0.0063865,0.0063865 +221,73402.1,0.98886,1.11543,1.1383,0.69882,0.57395,0.63109,0.47327,0.93603,0.9281,1.14851,0.00637,0.00637,0.00637 +222,73733.7,0.9855,1.11921,1.13598,0.70147,0.57269,0.63091,0.4733,0.93589,0.92762,1.14837,0.0063535,0.0063535,0.0063535 +223,74066.2,0.98776,1.11857,1.14022,0.70002,0.575,0.63164,0.47351,0.93591,0.92735,1.14826,0.006337,0.006337,0.006337 +224,74399,0.9867,1.1229,1.13977,0.70208,0.57364,0.63215,0.47374,0.9357,0.92649,1.14805,0.0063205,0.0063205,0.0063205 +225,74730.8,0.98904,1.11589,1.13908,0.70113,0.57423,0.63221,0.47386,0.93552,0.9261,1.14791,0.006304,0.006304,0.006304 +226,75063.2,0.98443,1.11857,1.13626,0.69951,0.57526,0.63235,0.47412,0.93555,0.92559,1.14773,0.0062875,0.0062875,0.0062875 +227,75393.7,0.97834,1.11387,1.13596,0.70167,0.57461,0.63284,0.47493,0.93517,0.92494,1.14719,0.006271,0.006271,0.006271 +228,75725.7,0.99048,1.11415,1.13712,0.70248,0.57378,0.63304,0.47487,0.93507,0.92416,1.14699,0.0062545,0.0062545,0.0062545 +229,76057.6,0.98318,1.1077,1.13301,0.70148,0.57522,0.63338,0.47536,0.93496,0.92364,1.14682,0.006238,0.006238,0.006238 +230,76387,0.98176,1.10545,1.13218,0.70365,0.57419,0.63347,0.47534,0.93478,0.92285,1.14663,0.0062215,0.0062215,0.0062215 +231,76718.6,0.98749,1.10998,1.13725,0.70455,0.57426,0.63357,0.47537,0.9345,0.92276,1.14618,0.006205,0.006205,0.006205 +232,77050.8,0.98721,1.1171,1.13857,0.70447,0.57433,0.63385,0.47545,0.93446,0.92199,1.14606,0.0061885,0.0061885,0.0061885 +233,77381.5,0.98712,1.10582,1.13628,0.70317,0.57555,0.63429,0.47571,0.93421,0.92145,1.14581,0.006172,0.006172,0.006172 +234,77714.2,0.98826,1.10787,1.13849,0.7025,0.577,0.63444,0.47577,0.93407,0.9208,1.14571,0.0061555,0.0061555,0.0061555 +235,78045.7,0.98065,1.10825,1.13363,0.69963,0.57926,0.63466,0.47602,0.93379,0.92024,1.14533,0.006139,0.006139,0.006139 +236,78376.9,0.98468,1.10695,1.13603,0.70263,0.57771,0.63494,0.47621,0.93358,0.9196,1.14525,0.0061225,0.0061225,0.0061225 +237,78708.3,0.98328,1.11468,1.13741,0.70151,0.57915,0.6351,0.47657,0.9333,0.919,1.14505,0.006106,0.006106,0.006106 +238,79040.4,0.98551,1.10789,1.13701,0.70191,0.57899,0.63506,0.47715,0.93315,0.91885,1.14498,0.0060895,0.0060895,0.0060895 +239,79372.7,0.98191,1.10683,1.13343,0.69992,0.58041,0.6357,0.47742,0.93326,0.91808,1.1449,0.006073,0.006073,0.006073 +240,79704.9,0.98415,1.10036,1.13156,0.70071,0.57969,0.63593,0.478,0.93311,0.9177,1.1447,0.0060565,0.0060565,0.0060565 +241,80037.4,0.9761,1.10215,1.13331,0.70357,0.57868,0.63623,0.47774,0.93307,0.91735,1.14456,0.00604,0.00604,0.00604 +242,80369.1,0.98457,1.10968,1.13617,0.70669,0.57633,0.63635,0.47821,0.93302,0.91667,1.14449,0.0060235,0.0060235,0.0060235 +243,80700.7,0.98915,1.11089,1.13843,0.70879,0.57565,0.63663,0.47847,0.93276,0.91622,1.14412,0.006007,0.006007,0.006007 +244,81033.5,0.97676,1.09881,1.13282,0.70947,0.5759,0.6367,0.47859,0.93276,0.91601,1.14403,0.0059905,0.0059905,0.0059905 +245,81365.9,0.98921,1.11021,1.1357,0.7093,0.57575,0.63697,0.47875,0.93271,0.91548,1.14384,0.005974,0.005974,0.005974 +246,81698.4,0.98636,1.1087,1.13682,0.70788,0.57722,0.63731,0.47901,0.93243,0.91507,1.14358,0.0059575,0.0059575,0.0059575 +247,82029.3,0.97936,1.10077,1.13142,0.70662,0.57813,0.63782,0.47918,0.93243,0.91453,1.14345,0.005941,0.005941,0.005941 +248,82361.1,0.98569,1.10676,1.13626,0.70571,0.57903,0.63788,0.47953,0.93217,0.91399,1.14318,0.0059245,0.0059245,0.0059245 +249,82692.8,0.98121,1.10406,1.13043,0.70522,0.57937,0.63784,0.47958,0.93219,0.91376,1.14296,0.005908,0.005908,0.005908 +250,83025.4,0.97922,1.08694,1.13118,0.70401,0.58026,0.63805,0.47999,0.93218,0.91322,1.14282,0.0058915,0.0058915,0.0058915 +251,83358.4,0.97617,1.10047,1.13112,0.70663,0.57866,0.63854,0.48025,0.93188,0.91287,1.14255,0.005875,0.005875,0.005875 +252,83688.4,0.97915,1.10346,1.13009,0.70641,0.57938,0.63865,0.48053,0.93169,0.91243,1.14227,0.0058585,0.0058585,0.0058585 +253,84020.6,0.97498,1.09456,1.13137,0.7041,0.58123,0.63924,0.48082,0.93142,0.91203,1.142,0.005842,0.005842,0.005842 +254,84354.3,0.97966,1.08814,1.13061,0.70402,0.58125,0.63925,0.48095,0.93152,0.91187,1.14216,0.0058255,0.0058255,0.0058255 +255,84687.6,0.97595,1.08638,1.12894,0.70591,0.58022,0.63963,0.48139,0.93155,0.91146,1.14201,0.005809,0.005809,0.005809 +256,85021.2,0.97798,1.09875,1.13173,0.70918,0.57921,0.64011,0.48169,0.93136,0.91105,1.14166,0.0057925,0.0057925,0.0057925 +257,85352.9,0.97776,1.08133,1.12773,0.71105,0.57966,0.64069,0.48197,0.93148,0.91048,1.14171,0.005776,0.005776,0.005776 +258,85686.1,0.97898,1.08644,1.12736,0.71153,0.57898,0.64108,0.48221,0.93087,0.91008,1.14122,0.0057595,0.0057595,0.0057595 +259,86018.4,0.97515,1.08808,1.12981,0.71397,0.5789,0.64111,0.4826,0.93092,0.9096,1.14128,0.005743,0.005743,0.005743 +260,86350.9,0.96946,1.08812,1.12727,0.71247,0.57976,0.64167,0.48296,0.93084,0.90918,1.14112,0.0057265,0.0057265,0.0057265 +261,86683.3,0.97739,1.08769,1.12702,0.7141,0.57869,0.64202,0.48321,0.93069,0.90861,1.14111,0.00571,0.00571,0.00571 +262,87015.1,0.97318,1.08465,1.12498,0.71055,0.58217,0.64236,0.48325,0.93022,0.90822,1.14065,0.0056935,0.0056935,0.0056935 +263,87346.2,0.96948,1.08435,1.12506,0.71422,0.58001,0.6425,0.48331,0.92986,0.90765,1.14021,0.005677,0.005677,0.005677 +264,87678.7,0.96952,1.08208,1.12764,0.70341,0.58607,0.64233,0.48359,0.93,0.90735,1.1404,0.0056605,0.0056605,0.0056605 +265,88010.6,0.97576,1.08821,1.12889,0.71342,0.58017,0.64227,0.4835,0.92975,0.90673,1.14019,0.005644,0.005644,0.005644 +266,88343,0.97035,1.08221,1.1294,0.71482,0.58089,0.64253,0.48361,0.92956,0.90636,1.14008,0.0056275,0.0056275,0.0056275 +267,88675.5,0.97294,1.08127,1.12723,0.71504,0.58079,0.64291,0.48411,0.92936,0.90603,1.13995,0.005611,0.005611,0.005611 +268,89008.3,0.97352,1.08809,1.13041,0.71406,0.58145,0.64383,0.48475,0.92901,0.9056,1.13962,0.0055945,0.0055945,0.0055945 +269,89340.9,0.97748,1.08282,1.12814,0.70729,0.58438,0.64412,0.48503,0.92866,0.90504,1.13934,0.005578,0.005578,0.005578 +270,89672.4,0.96968,1.08889,1.12994,0.70465,0.58809,0.64463,0.48559,0.92869,0.90462,1.13938,0.0055615,0.0055615,0.0055615 +271,90004.5,0.97265,1.08626,1.1317,0.70431,0.58802,0.64467,0.48548,0.92852,0.90417,1.13912,0.005545,0.005545,0.005545 +272,90335.4,0.97203,1.08041,1.12957,0.70616,0.58805,0.64507,0.48606,0.92845,0.90416,1.13895,0.0055285,0.0055285,0.0055285 +273,90667.1,0.96934,1.07939,1.12818,0.7036,0.58833,0.64479,0.48576,0.92831,0.90363,1.13888,0.005512,0.005512,0.005512 +274,90999,0.96635,1.07909,1.12484,0.70837,0.58701,0.64473,0.48601,0.92804,0.90299,1.13864,0.0054955,0.0054955,0.0054955 +275,91330.5,0.96979,1.07275,1.12648,0.70719,0.58882,0.64517,0.48614,0.92792,0.90264,1.13854,0.005479,0.005479,0.005479 +276,91662.9,0.97234,1.08175,1.12351,0.70252,0.59011,0.64541,0.48611,0.9279,0.90244,1.13857,0.0054625,0.0054625,0.0054625 +277,91995.6,0.97015,1.08445,1.12681,0.70803,0.58577,0.64561,0.48622,0.9279,0.90211,1.13851,0.005446,0.005446,0.005446 +278,92328.6,0.96802,1.08307,1.12623,0.70634,0.58696,0.64596,0.48632,0.92771,0.90173,1.13822,0.0054295,0.0054295,0.0054295 +279,92661.1,0.97697,1.08003,1.12917,0.71035,0.5853,0.6462,0.48679,0.92758,0.9013,1.13804,0.005413,0.005413,0.005413 +280,92993.4,0.97119,1.07471,1.12857,0.70868,0.58633,0.64604,0.48689,0.92742,0.90063,1.13793,0.0053965,0.0053965,0.0053965 +281,93326.1,0.97016,1.07739,1.12728,0.70848,0.58663,0.6462,0.48693,0.92722,0.90017,1.13779,0.00538,0.00538,0.00538 +282,93658.8,0.97098,1.07821,1.12733,0.7078,0.58683,0.6464,0.48701,0.92699,0.89952,1.1377,0.0053635,0.0053635,0.0053635 +283,93991.2,0.96886,1.07741,1.12773,0.70578,0.58925,0.64688,0.48742,0.92685,0.89921,1.13759,0.005347,0.005347,0.005347 +284,94323.2,0.96376,1.07439,1.1201,0.70471,0.58986,0.64693,0.48755,0.92667,0.89875,1.13739,0.0053305,0.0053305,0.0053305 +285,94654.2,0.97159,1.07291,1.12224,0.70719,0.5888,0.64733,0.48788,0.92637,0.89842,1.13701,0.005314,0.005314,0.005314 +286,94986,0.96312,1.07078,1.12425,0.70763,0.58879,0.64772,0.48805,0.9261,0.89795,1.13679,0.0052975,0.0052975,0.0052975 +287,95317.9,0.96367,1.06811,1.12403,0.70741,0.58939,0.64788,0.48808,0.92611,0.8976,1.13669,0.005281,0.005281,0.005281 +288,95648.9,0.96669,1.0798,1.12677,0.70825,0.58977,0.64813,0.48829,0.92612,0.89713,1.13662,0.0052645,0.0052645,0.0052645 +289,95980.9,0.96566,1.07385,1.12504,0.70655,0.59106,0.6481,0.48828,0.92598,0.89644,1.13653,0.005248,0.005248,0.005248 +290,96312.9,0.96518,1.07312,1.12453,0.70484,0.59249,0.64794,0.48847,0.9258,0.89617,1.13635,0.0052315,0.0052315,0.0052315 +291,96644.7,0.96397,1.06353,1.12188,0.70883,0.59198,0.64813,0.48874,0.92553,0.89581,1.1363,0.005215,0.005215,0.005215 +292,96976.1,0.96598,1.07191,1.12409,0.70806,0.59228,0.64823,0.48877,0.92541,0.8955,1.13623,0.0051985,0.0051985,0.0051985 +293,97308,0.96731,1.06902,1.12619,0.71333,0.58906,0.64824,0.48889,0.92542,0.89524,1.13632,0.005182,0.005182,0.005182 +294,97639.1,0.96093,1.05981,1.12039,0.70745,0.59255,0.64825,0.48923,0.92522,0.89511,1.13622,0.0051655,0.0051655,0.0051655 +295,97970.9,0.96466,1.05879,1.12281,0.70604,0.59324,0.64824,0.48924,0.92508,0.89467,1.13617,0.005149,0.005149,0.005149 +296,98302.2,0.96668,1.06762,1.12605,0.70621,0.59379,0.64881,0.4894,0.92479,0.89442,1.13596,0.0051325,0.0051325,0.0051325 +297,98634.4,0.96633,1.07734,1.12375,0.70638,0.59393,0.64898,0.48938,0.92436,0.8943,1.1356,0.005116,0.005116,0.005116 +298,98966.2,0.96073,1.06532,1.12141,0.70644,0.5949,0.64886,0.48933,0.924,0.89388,1.13528,0.0050995,0.0050995,0.0050995 +299,99300,0.9648,1.06107,1.12074,0.71,0.59258,0.64904,0.48953,0.92405,0.8934,1.13524,0.005083,0.005083,0.005083 +300,99631.8,0.9645,1.06915,1.12334,0.70993,0.59262,0.64902,0.48944,0.924,0.89307,1.13536,0.0050665,0.0050665,0.0050665 +301,99963.6,0.96365,1.05752,1.12331,0.71287,0.59033,0.64941,0.48966,0.92367,0.89256,1.1352,0.00505,0.00505,0.00505 +302,100296,0.96308,1.06752,1.12291,0.70397,0.59661,0.64949,0.48949,0.9236,0.89206,1.13512,0.0050335,0.0050335,0.0050335 +303,100630,0.96171,1.06126,1.1223,0.70644,0.59627,0.64967,0.48978,0.92339,0.89172,1.13482,0.005017,0.005017,0.005017 +304,100962,0.9603,1.06242,1.12086,0.70489,0.59745,0.6497,0.49016,0.92341,0.89122,1.13462,0.0050005,0.0050005,0.0050005 +305,101295,0.96383,1.06799,1.12233,0.71023,0.59355,0.64981,0.49036,0.92347,0.89115,1.1346,0.004984,0.004984,0.004984 +306,101628,0.95992,1.05794,1.12102,0.71026,0.59362,0.65008,0.49044,0.92315,0.89061,1.13423,0.0049675,0.0049675,0.0049675 +307,101961,0.95955,1.05639,1.11846,0.71134,0.59379,0.65019,0.49064,0.92341,0.89051,1.13428,0.004951,0.004951,0.004951 +308,102293,0.95579,1.05146,1.11709,0.71448,0.59239,0.6506,0.49082,0.9233,0.89043,1.13415,0.0049345,0.0049345,0.0049345 +309,102626,0.96546,1.0623,1.12108,0.71624,0.59199,0.65072,0.49075,0.92352,0.89022,1.13429,0.004918,0.004918,0.004918 +310,102958,0.96293,1.05846,1.11882,0.71784,0.5913,0.65128,0.49114,0.92331,0.89,1.13391,0.0049015,0.0049015,0.0049015 +311,103291,0.9583,1.0551,1.11937,0.72178,0.58923,0.65117,0.49104,0.92285,0.88955,1.13348,0.004885,0.004885,0.004885 +312,103622,0.95972,1.04551,1.1135,0.71825,0.59044,0.65132,0.49114,0.92279,0.88928,1.13332,0.0048685,0.0048685,0.0048685 +313,103955,0.96145,1.05969,1.12127,0.71427,0.59292,0.65155,0.49133,0.92269,0.88886,1.13318,0.004852,0.004852,0.004852 +314,104287,0.96137,1.05501,1.12014,0.7118,0.59467,0.65192,0.49166,0.92276,0.88842,1.13317,0.0048355,0.0048355,0.0048355 +315,104619,0.95692,1.05532,1.11659,0.70814,0.59668,0.65236,0.492,0.92258,0.88796,1.13293,0.004819,0.004819,0.004819 +316,104951,0.95965,1.05138,1.11949,0.70846,0.59688,0.65229,0.49209,0.92251,0.88777,1.13288,0.0048025,0.0048025,0.0048025 +317,105284,0.95689,1.05359,1.11819,0.70749,0.5985,0.65233,0.49201,0.92234,0.88752,1.13275,0.004786,0.004786,0.004786 +318,105616,0.94998,1.04013,1.11525,0.71333,0.59555,0.65268,0.49223,0.92258,0.88688,1.1327,0.0047695,0.0047695,0.0047695 +319,105949,0.95403,1.04785,1.11798,0.71266,0.59584,0.65272,0.49243,0.92241,0.88648,1.13249,0.004753,0.004753,0.004753 +320,106281,0.95751,1.0481,1.11804,0.71945,0.59102,0.65287,0.49276,0.92235,0.88636,1.13238,0.0047365,0.0047365,0.0047365 +321,106613,0.95834,1.05303,1.11758,0.72146,0.58863,0.65276,0.4926,0.92234,0.88643,1.13242,0.00472,0.00472,0.00472 +322,106944,0.95959,1.04324,1.1152,0.72159,0.59012,0.653,0.49257,0.92213,0.88609,1.13211,0.0047035,0.0047035,0.0047035 +323,107275,0.95608,1.05218,1.11951,0.72152,0.58918,0.65285,0.49285,0.92193,0.88609,1.13193,0.004687,0.004687,0.004687 +324,107605,0.9536,1.04273,1.11334,0.7202,0.58892,0.65313,0.49279,0.92179,0.88585,1.1316,0.0046705,0.0046705,0.0046705 +325,107936,0.95497,1.042,1.11864,0.72393,0.58796,0.65293,0.49249,0.92178,0.88573,1.13172,0.004654,0.004654,0.004654 +326,108267,0.95586,1.05537,1.1187,0.72082,0.58995,0.65365,0.49335,0.92165,0.88528,1.13165,0.0046375,0.0046375,0.0046375 +327,108599,0.9536,1.047,1.11837,0.72182,0.5895,0.65382,0.4932,0.92167,0.88478,1.13155,0.004621,0.004621,0.004621 +328,108931,0.95174,1.03884,1.11328,0.71733,0.59225,0.65401,0.49333,0.92172,0.88472,1.13137,0.0046045,0.0046045,0.0046045 +329,109260,0.9511,1.03468,1.11315,0.71478,0.59434,0.65438,0.49392,0.9214,0.88458,1.13105,0.004588,0.004588,0.004588 +330,109591,0.95061,1.02914,1.11379,0.71387,0.59525,0.65452,0.49401,0.92135,0.88437,1.13101,0.0045715,0.0045715,0.0045715 +331,109923,0.95333,1.04162,1.1153,0.71458,0.59483,0.65445,0.49388,0.92113,0.88416,1.13099,0.004555,0.004555,0.004555 +332,110255,0.94929,1.04706,1.11679,0.71413,0.59484,0.65489,0.49435,0.92094,0.8838,1.13079,0.0045385,0.0045385,0.0045385 +333,110587,0.95661,1.04062,1.11606,0.7241,0.58956,0.65509,0.49433,0.92068,0.88328,1.13054,0.004522,0.004522,0.004522 +334,110919,0.94803,1.03266,1.11267,0.72233,0.59005,0.65503,0.49391,0.92041,0.88272,1.1304,0.0045055,0.0045055,0.0045055 +335,111252,0.95192,1.0327,1.11395,0.71659,0.59331,0.65513,0.49424,0.92001,0.88215,1.13005,0.004489,0.004489,0.004489 +336,111583,0.9519,1.03863,1.11434,0.71347,0.5947,0.65517,0.49443,0.91981,0.88184,1.12985,0.0044725,0.0044725,0.0044725 +337,111916,0.94799,1.03286,1.11094,0.71493,0.5932,0.65531,0.49454,0.91965,0.88179,1.12965,0.004456,0.004456,0.004456 +338,112248,0.95249,1.02903,1.11453,0.71264,0.59413,0.65536,0.49486,0.91941,0.88133,1.12945,0.0044395,0.0044395,0.0044395 +339,112581,0.94928,1.02748,1.1126,0.71061,0.59607,0.65566,0.49503,0.91945,0.88115,1.1293,0.004423,0.004423,0.004423 +340,112913,0.94842,1.03761,1.11263,0.71117,0.59624,0.65564,0.49506,0.91937,0.88096,1.12905,0.0044065,0.0044065,0.0044065 +341,113246,0.95406,1.04007,1.11749,0.70972,0.59674,0.6558,0.49527,0.91931,0.88056,1.12907,0.00439,0.00439,0.00439 +342,113578,0.94852,1.03662,1.11182,0.70883,0.59799,0.65605,0.4955,0.91913,0.8801,1.1289,0.0043735,0.0043735,0.0043735 +343,113910,0.95219,1.02672,1.11244,0.70716,0.59919,0.65614,0.49566,0.9189,0.87977,1.12863,0.004357,0.004357,0.004357 +344,114242,0.95125,1.02972,1.11047,0.70812,0.60021,0.6562,0.49581,0.91874,0.87929,1.12843,0.0043405,0.0043405,0.0043405 +345,114575,0.95176,1.03965,1.11736,0.70795,0.60048,0.65635,0.49591,0.91851,0.87882,1.12831,0.004324,0.004324,0.004324 +346,114907,0.94494,1.035,1.11069,0.70832,0.60073,0.65683,0.49599,0.91814,0.87858,1.128,0.0043075,0.0043075,0.0043075 +347,115237,0.95341,1.03369,1.11339,0.70847,0.60045,0.65688,0.49615,0.91779,0.8782,1.12765,0.004291,0.004291,0.004291 +348,115569,0.94462,1.01976,1.1105,0.70939,0.59976,0.65713,0.49626,0.91756,0.87765,1.12747,0.0042745,0.0042745,0.0042745 +349,115901,0.95016,1.02735,1.11262,0.70791,0.60187,0.65744,0.49642,0.91774,0.87745,1.12764,0.004258,0.004258,0.004258 +350,116233,0.94322,1.01951,1.10847,0.71179,0.59931,0.65765,0.49661,0.91784,0.87712,1.12751,0.0042415,0.0042415,0.0042415 +351,116563,0.9484,1.01888,1.10919,0.71276,0.59945,0.65763,0.4968,0.91749,0.87666,1.12738,0.004225,0.004225,0.004225 +352,116894,0.9536,1.03194,1.11483,0.71288,0.5998,0.65819,0.49712,0.91717,0.87631,1.12706,0.0042085,0.0042085,0.0042085 +353,117227,0.94697,1.026,1.11432,0.71056,0.60213,0.65838,0.49722,0.91703,0.87587,1.12706,0.004192,0.004192,0.004192 +354,117558,0.9472,1.01587,1.10952,0.71318,0.60074,0.65867,0.49724,0.91694,0.87541,1.12705,0.0041755,0.0041755,0.0041755 +355,117891,0.94764,1.02434,1.11134,0.70909,0.60291,0.65861,0.4974,0.91637,0.8754,1.12662,0.004159,0.004159,0.004159 +356,118222,0.94343,1.0236,1.10952,0.71143,0.60064,0.65858,0.49739,0.91639,0.87534,1.12657,0.0041425,0.0041425,0.0041425 +357,118554,0.94524,1.01986,1.11204,0.70868,0.60232,0.65879,0.4975,0.91622,0.87524,1.1264,0.004126,0.004126,0.004126 +358,118887,0.94754,1.02291,1.10964,0.70752,0.60334,0.65907,0.49774,0.91605,0.87485,1.1262,0.0041095,0.0041095,0.0041095 +359,119219,0.94811,1.02358,1.11241,0.71047,0.60116,0.65942,0.49775,0.9161,0.87469,1.12618,0.004093,0.004093,0.004093 +360,119551,0.94923,1.01962,1.11114,0.70676,0.60389,0.65954,0.49806,0.91603,0.87465,1.12599,0.0040765,0.0040765,0.0040765 +361,119883,0.94457,1.01765,1.10786,0.70641,0.60481,0.65987,0.49811,0.91592,0.87443,1.1258,0.00406,0.00406,0.00406 +362,120215,0.94517,1.01737,1.1086,0.70735,0.60478,0.66022,0.49829,0.91569,0.87451,1.12554,0.0040435,0.0040435,0.0040435 +363,120542,0.94495,1.01133,1.10722,0.70794,0.60523,0.66021,0.49813,0.9157,0.87403,1.12567,0.004027,0.004027,0.004027 +364,120874,0.94489,1.01507,1.10952,0.71238,0.60347,0.6605,0.49831,0.91565,0.87369,1.12558,0.0040105,0.0040105,0.0040105 +365,121206,0.94273,1.01317,1.10741,0.71376,0.6029,0.66001,0.49793,0.91559,0.8735,1.1255,0.003994,0.003994,0.003994 +366,121538,0.94169,1.01309,1.1059,0.71306,0.60404,0.66057,0.49802,0.91557,0.87293,1.12535,0.0039775,0.0039775,0.0039775 +367,121868,0.93468,1.00931,1.10733,0.71191,0.60453,0.66072,0.49843,0.91543,0.87251,1.12535,0.003961,0.003961,0.003961 +368,122200,0.93707,1.01013,1.10693,0.71213,0.60487,0.66115,0.49873,0.91517,0.87256,1.12528,0.0039445,0.0039445,0.0039445 +369,122533,0.93723,1.0067,1.10495,0.71294,0.60548,0.66134,0.49887,0.91499,0.87246,1.12518,0.003928,0.003928,0.003928 +370,122864,0.94183,1.01072,1.1077,0.71565,0.60371,0.66152,0.49886,0.91487,0.87209,1.12496,0.0039115,0.0039115,0.0039115 +371,123197,0.93659,1.00251,1.10636,0.71512,0.60392,0.66151,0.499,0.91467,0.872,1.12468,0.003895,0.003895,0.003895 +372,123529,0.94405,1.01596,1.10803,0.71638,0.60391,0.6618,0.49926,0.91452,0.87172,1.12445,0.0038785,0.0038785,0.0038785 +373,123861,0.93326,1.00492,1.10398,0.71748,0.60251,0.6626,0.49967,0.91467,0.87162,1.12441,0.003862,0.003862,0.003862 +374,124191,0.94075,1.00339,1.1057,0.71886,0.60287,0.66277,0.49976,0.91457,0.87146,1.12424,0.0038455,0.0038455,0.0038455 +375,124523,0.94105,1.00372,1.10539,0.71979,0.60213,0.66285,0.49983,0.91423,0.87122,1.12383,0.003829,0.003829,0.003829 +376,124856,0.94229,1.01197,1.10612,0.71717,0.60426,0.66297,0.50017,0.91409,0.87112,1.12371,0.0038125,0.0038125,0.0038125 +377,125189,0.93952,1.01259,1.10467,0.71654,0.6052,0.6633,0.50041,0.91434,0.87065,1.12359,0.003796,0.003796,0.003796 +378,125520,0.94312,1.00988,1.10774,0.71501,0.6061,0.66366,0.50083,0.91431,0.87051,1.12347,0.0037795,0.0037795,0.0037795 +379,125850,0.93337,1.00396,1.1043,0.71623,0.60488,0.66352,0.50095,0.91394,0.87052,1.1231,0.003763,0.003763,0.003763 +380,126183,0.94814,1.01943,1.10961,0.71554,0.60581,0.66377,0.50107,0.91366,0.86992,1.12298,0.0037465,0.0037465,0.0037465 +381,126514,0.92873,1.00413,1.10436,0.71622,0.60646,0.66388,0.50136,0.91339,0.86976,1.12278,0.00373,0.00373,0.00373 +382,126847,0.93813,1.00021,1.10527,0.7197,0.60478,0.6643,0.50174,0.91345,0.86946,1.12257,0.0037135,0.0037135,0.0037135 +383,127179,0.94148,1.00477,1.10649,0.71673,0.60649,0.66428,0.50177,0.91315,0.86899,1.12208,0.003697,0.003697,0.003697 +384,127509,0.94013,1.00273,1.10505,0.71381,0.60805,0.66417,0.50195,0.91322,0.8688,1.12219,0.0036805,0.0036805,0.0036805 +385,127841,0.94168,1.00691,1.10594,0.71173,0.60886,0.66445,0.5024,0.91307,0.86827,1.12207,0.003664,0.003664,0.003664 +386,128173,0.935,0.99281,1.10472,0.71293,0.60888,0.66477,0.50226,0.91297,0.86796,1.12201,0.0036475,0.0036475,0.0036475 +387,128505,0.93511,0.99511,1.10341,0.71501,0.60888,0.66571,0.50263,0.91256,0.86753,1.12176,0.003631,0.003631,0.003631 +388,128835,0.92926,0.99587,1.10087,0.71563,0.60911,0.66552,0.50293,0.91239,0.86732,1.1214,0.0036145,0.0036145,0.0036145 +389,129167,0.9325,0.9897,1.10108,0.7158,0.60883,0.66617,0.50305,0.91182,0.86668,1.12109,0.003598,0.003598,0.003598 +390,129497,0.92945,0.99101,1.10379,0.71449,0.60814,0.66578,0.50299,0.91169,0.86666,1.12096,0.0035815,0.0035815,0.0035815 +391,129829,0.93253,0.99591,1.10301,0.71706,0.60708,0.66668,0.5036,0.91125,0.86654,1.12074,0.003565,0.003565,0.003565 +392,130160,0.93115,0.99911,1.1044,0.71703,0.60699,0.66687,0.50385,0.91175,0.8664,1.12102,0.0035485,0.0035485,0.0035485 +393,130492,0.92821,0.99586,1.10087,0.71696,0.60756,0.66663,0.5037,0.91164,0.86606,1.12103,0.003532,0.003532,0.003532 +394,130824,0.93392,0.98112,1.09968,0.71768,0.60778,0.66702,0.50383,0.91133,0.86611,1.12072,0.0035155,0.0035155,0.0035155 +395,131156,0.92992,0.98514,1.09832,0.7162,0.60815,0.66727,0.504,0.91156,0.86566,1.12082,0.003499,0.003499,0.003499 +396,131488,0.9326,0.99497,1.10113,0.71681,0.60809,0.6675,0.50429,0.91148,0.86546,1.12082,0.0034825,0.0034825,0.0034825 +397,131821,0.93449,0.98915,1.10055,0.71633,0.60748,0.66761,0.50467,0.91166,0.86474,1.12087,0.003466,0.003466,0.003466 +398,132154,0.93755,0.99322,1.10341,0.71493,0.61006,0.66825,0.50484,0.91123,0.86442,1.12067,0.0034495,0.0034495,0.0034495 +399,132487,0.92322,0.98001,1.09685,0.71591,0.61054,0.66872,0.505,0.91131,0.86395,1.12068,0.003433,0.003433,0.003433 +400,132818,0.93056,0.98979,1.0989,0.71618,0.61002,0.66893,0.50533,0.91111,0.86354,1.12026,0.0034165,0.0034165,0.0034165 +401,133151,0.92831,0.97855,1.09596,0.72026,0.60737,0.66911,0.50545,0.9108,0.86304,1.11995,0.0034,0.0034,0.0034 +402,133483,0.92894,0.98666,1.10004,0.71725,0.61057,0.66906,0.50551,0.91065,0.86283,1.11972,0.0033835,0.0033835,0.0033835 +403,133815,0.92476,0.98162,1.09736,0.71721,0.61001,0.6691,0.50583,0.91034,0.86249,1.11958,0.003367,0.003367,0.003367 +404,134147,0.92231,0.98096,1.09629,0.71906,0.6078,0.66896,0.50584,0.91,0.8621,1.11932,0.0033505,0.0033505,0.0033505 +405,134480,0.93134,0.9823,1.10091,0.72061,0.60719,0.66918,0.50594,0.91001,0.86167,1.11936,0.003334,0.003334,0.003334 +406,134813,0.92031,0.9761,1.09724,0.724,0.60599,0.66914,0.50605,0.90955,0.86127,1.11903,0.0033175,0.0033175,0.0033175 +407,135145,0.92198,0.98287,1.09887,0.72141,0.60798,0.66929,0.50596,0.90921,0.861,1.11876,0.003301,0.003301,0.003301 +408,135476,0.92127,0.97242,1.09194,0.72199,0.60793,0.66928,0.5059,0.90892,0.86066,1.1184,0.0032845,0.0032845,0.0032845 +409,135808,0.92443,0.9712,1.09306,0.72374,0.6079,0.66937,0.50617,0.90844,0.86031,1.1181,0.003268,0.003268,0.003268 +410,136140,0.92185,0.97566,1.09558,0.72258,0.60815,0.66952,0.50641,0.90836,0.86018,1.11795,0.0032515,0.0032515,0.0032515 +411,136470,0.92827,0.98073,1.09648,0.72331,0.6075,0.66989,0.50651,0.9082,0.85982,1.11759,0.003235,0.003235,0.003235 +412,136801,0.92708,0.97596,1.09804,0.71673,0.61135,0.67046,0.50709,0.90809,0.85959,1.11744,0.0032185,0.0032185,0.0032185 +413,137134,0.92872,0.9775,1.09619,0.72352,0.60858,0.67033,0.5072,0.90767,0.85918,1.11716,0.003202,0.003202,0.003202 +414,137466,0.92256,0.97384,1.09566,0.71915,0.61111,0.67069,0.50724,0.90748,0.85893,1.11697,0.0031855,0.0031855,0.0031855 +415,137798,0.91894,0.97364,1.09509,0.71955,0.61009,0.67094,0.50756,0.90745,0.85864,1.11687,0.003169,0.003169,0.003169 +416,138132,0.92007,0.96908,1.09233,0.71983,0.60983,0.6711,0.50749,0.9072,0.85825,1.11673,0.0031525,0.0031525,0.0031525 +417,138465,0.92556,0.97906,1.09792,0.72366,0.60847,0.67137,0.50765,0.90706,0.85805,1.11672,0.003136,0.003136,0.003136 +418,138797,0.92505,0.97253,1.09679,0.72352,0.6074,0.67156,0.50783,0.90709,0.85758,1.11666,0.0031195,0.0031195,0.0031195 +419,139129,0.92012,0.97591,1.09394,0.72261,0.60917,0.6721,0.5082,0.90674,0.85725,1.11631,0.003103,0.003103,0.003103 +420,139460,0.92218,0.96957,1.09647,0.72345,0.60914,0.67289,0.50853,0.90656,0.85693,1.11623,0.0030865,0.0030865,0.0030865 +421,139792,0.92265,0.97346,1.09682,0.7233,0.60939,0.67316,0.50887,0.90669,0.85638,1.11625,0.00307,0.00307,0.00307 +422,140125,0.91995,0.96923,1.09495,0.72595,0.60816,0.67316,0.50894,0.9066,0.85612,1.1161,0.0030535,0.0030535,0.0030535 +423,140457,0.92162,0.97235,1.09749,0.72273,0.61047,0.67329,0.50905,0.90644,0.85569,1.11602,0.003037,0.003037,0.003037 +424,140789,0.91592,0.97237,1.09279,0.72776,0.60637,0.67355,0.50932,0.90608,0.85524,1.11572,0.0030205,0.0030205,0.0030205 +425,141121,0.92587,0.96489,1.09336,0.73129,0.60469,0.67347,0.50927,0.90587,0.85509,1.11559,0.003004,0.003004,0.003004 +426,141453,0.92057,0.96039,1.09218,0.73264,0.60398,0.67362,0.50976,0.90563,0.85482,1.11525,0.0029875,0.0029875,0.0029875 +427,141786,0.92806,0.97185,1.09581,0.72984,0.60597,0.67452,0.51006,0.90566,0.85439,1.11527,0.002971,0.002971,0.002971 +428,142118,0.91536,0.95351,1.0924,0.72335,0.61068,0.67419,0.51024,0.90578,0.85412,1.1154,0.0029545,0.0029545,0.0029545 +429,142450,0.91391,0.96159,1.09255,0.7174,0.61451,0.674,0.51031,0.90589,0.85378,1.11543,0.002938,0.002938,0.002938 +430,142781,0.91496,0.95762,1.09175,0.72196,0.61151,0.67402,0.51011,0.90584,0.85337,1.11537,0.0029215,0.0029215,0.0029215 +431,143114,0.91617,0.96464,1.09425,0.72098,0.61217,0.67414,0.51034,0.90551,0.85317,1.11493,0.002905,0.002905,0.002905 +432,143446,0.91688,0.9587,1.09056,0.72325,0.61131,0.67432,0.51043,0.90527,0.85284,1.11487,0.0028885,0.0028885,0.0028885 +433,143780,0.91547,0.95609,1.08885,0.72214,0.61228,0.6745,0.51066,0.90522,0.85249,1.11478,0.002872,0.002872,0.002872 +434,144112,0.91634,0.95745,1.08839,0.7227,0.61202,0.67468,0.51093,0.90516,0.85206,1.11462,0.0028555,0.0028555,0.0028555 +435,144444,0.92097,0.96302,1.09332,0.72286,0.61245,0.675,0.51088,0.90507,0.85157,1.11444,0.002839,0.002839,0.002839 +436,144776,0.91436,0.95918,1.09264,0.72405,0.61174,0.67492,0.51103,0.90472,0.85125,1.11422,0.0028225,0.0028225,0.0028225 +437,145109,0.91529,0.95087,1.08936,0.7237,0.61244,0.67504,0.51124,0.90454,0.85089,1.1143,0.002806,0.002806,0.002806 +438,145440,0.91486,0.95506,1.0888,0.72428,0.61263,0.67512,0.51118,0.90461,0.85033,1.11432,0.0027895,0.0027895,0.0027895 +439,145769,0.90793,0.95001,1.08795,0.72446,0.61232,0.67526,0.51131,0.90451,0.85031,1.11416,0.002773,0.002773,0.002773 +440,146099,0.91425,0.9489,1.08886,0.72425,0.61362,0.67553,0.51149,0.90392,0.85,1.11367,0.0027565,0.0027565,0.0027565 +441,146430,0.90918,0.94897,1.08952,0.7255,0.61295,0.67573,0.51155,0.90383,0.84983,1.11347,0.00274,0.00274,0.00274 +442,146762,0.91021,0.94254,1.08509,0.72612,0.61202,0.67584,0.51171,0.90332,0.84969,1.11289,0.0027235,0.0027235,0.0027235 +443,147093,0.90975,0.94456,1.08854,0.72366,0.61402,0.67594,0.51178,0.90342,0.84969,1.11292,0.002707,0.002707,0.002707 +444,147425,0.9178,0.94398,1.08731,0.72535,0.61428,0.67537,0.5115,0.90336,0.84947,1.11288,0.0026905,0.0026905,0.0026905 +445,147756,0.91467,0.94991,1.09117,0.72563,0.61354,0.67624,0.51179,0.90319,0.84923,1.11251,0.002674,0.002674,0.002674 +446,148088,0.9053,0.94082,1.0866,0.72345,0.61588,0.67674,0.51204,0.90285,0.84869,1.11213,0.0026575,0.0026575,0.0026575 +447,148421,0.90596,0.94645,1.08559,0.72439,0.61489,0.67667,0.51214,0.90246,0.84783,1.1119,0.002641,0.002641,0.002641 +448,148753,0.91047,0.94358,1.0861,0.72607,0.61451,0.67666,0.51216,0.902,0.84746,1.11148,0.0026245,0.0026245,0.0026245 +449,149085,0.90739,0.94255,1.08261,0.7273,0.61434,0.67674,0.51211,0.90167,0.84724,1.1112,0.002608,0.002608,0.002608 +450,149416,0.90454,0.94142,1.08618,0.72647,0.61453,0.67658,0.51223,0.90129,0.84705,1.11091,0.0025915,0.0025915,0.0025915 +451,149749,0.9058,0.93866,1.08459,0.72826,0.61414,0.67684,0.51229,0.90116,0.84658,1.11069,0.002575,0.002575,0.002575 +452,150080,0.91441,0.95119,1.09328,0.72788,0.61506,0.67689,0.51243,0.90106,0.84639,1.1106,0.0025585,0.0025585,0.0025585 +453,150411,0.90959,0.93948,1.08723,0.71955,0.61733,0.67696,0.5123,0.90091,0.84615,1.11046,0.002542,0.002542,0.002542 +454,150743,0.90638,0.93804,1.08746,0.72047,0.61829,0.67742,0.51263,0.90121,0.84559,1.11068,0.0025255,0.0025255,0.0025255 +455,151076,0.90671,0.9324,1.0854,0.72054,0.61824,0.67721,0.51285,0.9009,0.8453,1.11041,0.002509,0.002509,0.002509 +456,151408,0.90419,0.93755,1.08425,0.72022,0.61923,0.67748,0.51299,0.90074,0.84502,1.11036,0.0024925,0.0024925,0.0024925 +457,151740,0.90251,0.93335,1.08257,0.72072,0.61894,0.67757,0.51323,0.90051,0.84486,1.11013,0.002476,0.002476,0.002476 +458,152072,0.90452,0.93901,1.08587,0.7176,0.62218,0.67786,0.51347,0.90017,0.84407,1.10976,0.0024595,0.0024595,0.0024595 +459,152404,0.90384,0.93424,1.08421,0.71841,0.62134,0.67859,0.51393,0.89977,0.84363,1.10953,0.002443,0.002443,0.002443 +460,152735,0.90338,0.9257,1.08487,0.71779,0.62166,0.67862,0.51417,0.89954,0.84323,1.10927,0.0024265,0.0024265,0.0024265 +461,153068,0.90031,0.93311,1.08732,0.71523,0.62342,0.67879,0.51432,0.89948,0.84304,1.10932,0.00241,0.00241,0.00241 +462,153401,0.90364,0.93364,1.08495,0.71473,0.62445,0.6791,0.51456,0.8991,0.84266,1.10889,0.0023935,0.0023935,0.0023935 +463,153734,0.90575,0.93124,1.0829,0.71639,0.62321,0.67933,0.5148,0.89896,0.842,1.10863,0.002377,0.002377,0.002377 +464,154067,0.89896,0.92637,1.08083,0.71893,0.62166,0.67956,0.51495,0.89896,0.84152,1.10851,0.0023605,0.0023605,0.0023605 +465,154399,0.90276,0.92528,1.0826,0.72146,0.62056,0.67969,0.51491,0.8988,0.84143,1.10811,0.002344,0.002344,0.002344 +466,154731,0.90129,0.92528,1.08317,0.72483,0.61868,0.68003,0.51508,0.8986,0.84111,1.10779,0.0023275,0.0023275,0.0023275 +467,155062,0.90196,0.93022,1.08522,0.72403,0.61941,0.67991,0.51509,0.89844,0.84054,1.10771,0.002311,0.002311,0.002311 +468,155395,0.89708,0.92637,1.08146,0.72379,0.62034,0.68018,0.51536,0.89834,0.83991,1.10762,0.0022945,0.0022945,0.0022945 +469,155727,0.90427,0.93181,1.0863,0.72537,0.61962,0.6805,0.51534,0.8984,0.83971,1.1076,0.002278,0.002278,0.002278 +470,156059,0.90161,0.92575,1.08532,0.72398,0.62095,0.68075,0.51558,0.89811,0.83912,1.10733,0.0022615,0.0022615,0.0022615 +471,156390,0.90265,0.92668,1.08402,0.72546,0.61907,0.68081,0.51588,0.89785,0.83871,1.10731,0.002245,0.002245,0.002245 +472,156721,0.90291,0.92277,1.08051,0.72495,0.62012,0.68118,0.51599,0.89751,0.83884,1.10697,0.0022285,0.0022285,0.0022285 +473,157053,0.89804,0.92393,1.08239,0.72584,0.61998,0.68126,0.51627,0.89749,0.83843,1.10702,0.002212,0.002212,0.002212 +474,157385,0.89518,0.92115,1.08199,0.72589,0.62102,0.68142,0.51655,0.89714,0.83833,1.10677,0.0021955,0.0021955,0.0021955 +475,157716,0.89605,0.91607,1.07994,0.72854,0.6196,0.68162,0.51676,0.89719,0.83782,1.10673,0.002179,0.002179,0.002179 +476,158048,0.89977,0.91703,1.07891,0.72683,0.62117,0.68187,0.51689,0.89711,0.83717,1.10645,0.0021625,0.0021625,0.0021625 +477,158379,0.89578,0.91209,1.08096,0.73151,0.61838,0.68216,0.51711,0.897,0.83688,1.10638,0.002146,0.002146,0.002146 +478,158711,0.8972,0.91721,1.08008,0.73159,0.61943,0.68232,0.51723,0.89663,0.83635,1.10601,0.0021295,0.0021295,0.0021295 +479,159043,0.89752,0.91273,1.0788,0.73197,0.61845,0.68272,0.51726,0.89641,0.83608,1.10586,0.002113,0.002113,0.002113 +480,159375,0.8993,0.91644,1.07966,0.73038,0.62004,0.68279,0.5174,0.89617,0.83587,1.10566,0.0020965,0.0020965,0.0020965 +481,159706,0.90078,0.91404,1.07922,0.73074,0.61912,0.68275,0.51755,0.89611,0.83539,1.10565,0.00208,0.00208,0.00208 +482,160038,0.90082,0.91718,1.08017,0.73079,0.61973,0.68284,0.51768,0.89617,0.83493,1.10564,0.0020635,0.0020635,0.0020635 +483,160370,0.89275,0.90979,1.07349,0.73013,0.62036,0.68298,0.51775,0.8962,0.83494,1.1055,0.002047,0.002047,0.002047 +484,160702,0.89373,0.9115,1.07741,0.72995,0.62138,0.68316,0.51788,0.89619,0.83432,1.10537,0.0020305,0.0020305,0.0020305 +485,161034,0.89432,0.90639,1.07493,0.72956,0.62119,0.68322,0.51834,0.89614,0.83398,1.10522,0.002014,0.002014,0.002014 +486,161365,0.88924,0.90117,1.07478,0.72986,0.62136,0.68315,0.51835,0.89592,0.83354,1.10498,0.0019975,0.0019975,0.0019975 +487,161698,0.88923,0.90078,1.07469,0.7306,0.62113,0.6832,0.51845,0.8958,0.83259,1.10494,0.001981,0.001981,0.001981 +488,162029,0.88907,0.90022,1.07492,0.72764,0.62256,0.68343,0.51841,0.89568,0.83224,1.10468,0.0019645,0.0019645,0.0019645 +489,162361,0.89052,0.89972,1.07595,0.72945,0.62159,0.68382,0.51866,0.8953,0.83181,1.10438,0.001948,0.001948,0.001948 +490,162693,0.89223,0.90222,1.07533,0.73053,0.6208,0.68408,0.51913,0.89506,0.8313,1.10419,0.0019315,0.0019315,0.0019315 +491,163025,0.88688,0.89167,1.07454,0.7268,0.62282,0.68451,0.51941,0.89529,0.83104,1.10436,0.001915,0.001915,0.001915 +492,163357,0.88582,0.89938,1.07506,0.72714,0.62189,0.68449,0.51946,0.89514,0.83037,1.10402,0.0018985,0.0018985,0.0018985 +493,163688,0.8904,0.8908,1.07195,0.72493,0.62451,0.68487,0.51969,0.8949,0.82989,1.10376,0.001882,0.001882,0.001882 +494,164018,0.88559,0.89476,1.07151,0.72599,0.62431,0.68529,0.51984,0.89456,0.82965,1.10344,0.0018655,0.0018655,0.0018655 +495,164351,0.8818,0.89587,1.07129,0.72442,0.62644,0.68529,0.52033,0.89441,0.82947,1.1034,0.001849,0.001849,0.001849 +496,164682,0.88708,0.89553,1.07176,0.72336,0.62632,0.6854,0.52043,0.89434,0.82904,1.1031,0.0018325,0.0018325,0.0018325 +497,165014,0.88193,0.89196,1.07038,0.72699,0.62446,0.68554,0.52028,0.89436,0.82853,1.103,0.001816,0.001816,0.001816 +498,165346,0.88127,0.88652,1.06809,0.72571,0.62578,0.68583,0.52073,0.89423,0.82842,1.1027,0.0017995,0.0017995,0.0017995 +499,165678,0.89057,0.89712,1.07249,0.72889,0.62404,0.68587,0.52084,0.89388,0.82816,1.10232,0.001783,0.001783,0.001783 +500,166010,0.88783,0.88846,1.07289,0.72909,0.62437,0.6864,0.52094,0.89354,0.82783,1.10209,0.0017665,0.0017665,0.0017665 +501,166341,0.88886,0.89459,1.07261,0.72913,0.62407,0.68671,0.52113,0.89324,0.82739,1.1017,0.00175,0.00175,0.00175 +502,166673,0.87978,0.88657,1.06958,0.72982,0.62429,0.68714,0.52147,0.89294,0.82723,1.10141,0.0017335,0.0017335,0.0017335 +503,167004,0.87697,0.87633,1.06596,0.72862,0.62435,0.6872,0.52176,0.89259,0.82683,1.10091,0.001717,0.001717,0.001717 +504,167337,0.88109,0.88159,1.07002,0.73176,0.62339,0.68726,0.52216,0.8923,0.82642,1.10056,0.0017005,0.0017005,0.0017005 +505,167669,0.87582,0.87844,1.06702,0.73197,0.62435,0.68739,0.52234,0.89186,0.82586,1.10019,0.001684,0.001684,0.001684 +506,168002,0.87648,0.87132,1.0654,0.72861,0.62611,0.68776,0.52253,0.89176,0.82552,1.09976,0.0016675,0.0016675,0.0016675 +507,168333,0.88163,0.88384,1.07034,0.73036,0.62571,0.68821,0.52289,0.8916,0.82522,1.09983,0.001651,0.001651,0.001651 +508,168666,0.87462,0.88188,1.06505,0.73987,0.61895,0.68844,0.52316,0.8914,0.82491,1.09962,0.0016345,0.0016345,0.0016345 +509,168997,0.87874,0.87846,1.06761,0.74062,0.61856,0.68858,0.52321,0.89127,0.82497,1.09951,0.001618,0.001618,0.001618 +510,169330,0.87667,0.86806,1.06402,0.73948,0.61874,0.6885,0.52342,0.89109,0.82459,1.09922,0.0016015,0.0016015,0.0016015 +511,169662,0.88033,0.87687,1.06688,0.73824,0.61984,0.68848,0.52358,0.89076,0.82442,1.09902,0.001585,0.001585,0.001585 +512,169994,0.88053,0.8747,1.07055,0.73632,0.62026,0.68818,0.52362,0.89079,0.82416,1.09896,0.0015685,0.0015685,0.0015685 +513,170327,0.87991,0.86988,1.0657,0.73814,0.61929,0.68842,0.52369,0.89072,0.82415,1.09889,0.001552,0.001552,0.001552 +514,170659,0.86754,0.86493,1.06212,0.73698,0.61991,0.68853,0.52389,0.89044,0.82374,1.09858,0.0015355,0.0015355,0.0015355 +515,170991,0.87422,0.8683,1.06199,0.73294,0.62254,0.68866,0.5242,0.89052,0.8234,1.09854,0.001519,0.001519,0.001519 +516,171323,0.86893,0.86801,1.06604,0.72885,0.6263,0.68924,0.52459,0.89039,0.82315,1.09846,0.0015025,0.0015025,0.0015025 +517,171655,0.87627,0.86973,1.06474,0.73039,0.62564,0.68918,0.52473,0.89035,0.82285,1.09826,0.001486,0.001486,0.001486 +518,171984,0.86891,0.86019,1.06322,0.72443,0.62884,0.68933,0.52491,0.89027,0.82223,1.09802,0.0014695,0.0014695,0.0014695 +519,172316,0.86914,0.85777,1.06228,0.72754,0.6274,0.68965,0.52485,0.88987,0.82184,1.09768,0.001453,0.001453,0.001453 +520,172648,0.86772,0.8651,1.0647,0.7291,0.62595,0.68967,0.52486,0.88971,0.82154,1.0977,0.0014365,0.0014365,0.0014365 +521,172980,0.86885,0.8585,1.06281,0.72473,0.62864,0.69002,0.52519,0.88942,0.82091,1.09745,0.00142,0.00142,0.00142 +522,173312,0.867,0.85981,1.06229,0.72435,0.6285,0.6904,0.52514,0.88933,0.82073,1.09735,0.0014035,0.0014035,0.0014035 +523,173644,0.86703,0.85543,1.06182,0.72507,0.62826,0.69011,0.52532,0.88909,0.82023,1.09718,0.001387,0.001387,0.001387 +524,173975,0.86687,0.85552,1.06129,0.72991,0.62638,0.69019,0.52542,0.88882,0.81995,1.097,0.0013705,0.0013705,0.0013705 +525,174306,0.86567,0.8511,1.06004,0.73226,0.62477,0.69038,0.52537,0.88864,0.81974,1.09682,0.001354,0.001354,0.001354 +526,174640,0.86624,0.85603,1.06139,0.73049,0.62636,0.69043,0.52562,0.88838,0.81931,1.09638,0.0013375,0.0013375,0.0013375 +527,174972,0.86555,0.85119,1.06117,0.72857,0.62823,0.6906,0.52579,0.88789,0.81889,1.09618,0.001321,0.001321,0.001321 +528,175303,0.86107,0.84583,1.05739,0.72847,0.6288,0.69073,0.52593,0.88766,0.81843,1.09586,0.0013045,0.0013045,0.0013045 +529,175636,0.86956,0.84722,1.06157,0.72629,0.63041,0.69092,0.52607,0.88754,0.8182,1.09567,0.001288,0.001288,0.001288 +530,175968,0.86481,0.84565,1.05916,0.72247,0.63258,0.69097,0.52637,0.8873,0.81809,1.09543,0.0012715,0.0012715,0.0012715 +531,176299,0.86343,0.84728,1.06205,0.71985,0.6353,0.69108,0.52668,0.8871,0.81758,1.0954,0.001255,0.001255,0.001255 +532,176632,0.86028,0.84462,1.05862,0.71832,0.63679,0.69124,0.52696,0.88696,0.81735,1.09537,0.0012385,0.0012385,0.0012385 +533,176965,0.86536,0.83956,1.05799,0.71885,0.63625,0.69135,0.52728,0.88685,0.81723,1.09525,0.001222,0.001222,0.001222 +534,177297,0.8587,0.83853,1.05967,0.71893,0.63509,0.69159,0.52745,0.88638,0.81679,1.09492,0.0012055,0.0012055,0.0012055 +535,177629,0.85531,0.83538,1.05516,0.72035,0.6352,0.69197,0.5276,0.88648,0.81636,1.09492,0.001189,0.001189,0.001189 +536,177961,0.85829,0.83463,1.05746,0.72221,0.63436,0.69214,0.52769,0.88633,0.81597,1.09494,0.0011725,0.0011725,0.0011725 +537,178292,0.85615,0.8328,1.05598,0.72168,0.63418,0.69238,0.52794,0.88607,0.8153,1.09461,0.001156,0.001156,0.001156 +538,178625,0.8592,0.83272,1.05556,0.72229,0.63351,0.69215,0.52787,0.88602,0.81483,1.0944,0.0011395,0.0011395,0.0011395 +539,178957,0.85842,0.83433,1.05574,0.72062,0.63484,0.69221,0.52807,0.88574,0.81425,1.09427,0.001123,0.001123,0.001123 +540,179289,0.85738,0.82664,1.05329,0.72313,0.63399,0.69242,0.5283,0.88519,0.81401,1.09366,0.0011065,0.0011065,0.0011065 +541,179622,0.85786,0.8316,1.05698,0.72163,0.63506,0.69271,0.52833,0.88498,0.81395,1.09353,0.00109,0.00109,0.00109 +542,179954,0.85155,0.82884,1.05469,0.72342,0.63415,0.69307,0.52868,0.8845,0.81366,1.09331,0.0010735,0.0010735,0.0010735 +543,180285,0.85648,0.8271,1.0566,0.72285,0.63451,0.69323,0.52868,0.88454,0.81336,1.09354,0.001057,0.001057,0.001057 +544,180618,0.85417,0.82608,1.05366,0.72509,0.63406,0.69366,0.52881,0.88449,0.81317,1.0935,0.0010405,0.0010405,0.0010405 +545,180948,0.84453,0.8146,1.04954,0.72493,0.6341,0.69394,0.52899,0.88446,0.8128,1.09337,0.001024,0.001024,0.001024 +546,181280,0.84759,0.81909,1.05258,0.72631,0.63474,0.69446,0.52941,0.88402,0.8123,1.09308,0.0010075,0.0010075,0.0010075 +547,181613,0.84793,0.81284,1.05273,0.72721,0.635,0.69443,0.52944,0.88388,0.81232,1.093,0.000991,0.000991,0.000991 +548,181945,0.85236,0.81741,1.05181,0.72965,0.63444,0.69476,0.52985,0.8833,0.81204,1.09259,0.0009745,0.0009745,0.0009745 +549,182277,0.84244,0.81463,1.05012,0.72608,0.63696,0.69482,0.52985,0.8831,0.81176,1.09245,0.000958,0.000958,0.000958 +550,182608,0.85322,0.81916,1.05213,0.72163,0.63963,0.69505,0.53032,0.88302,0.81124,1.09232,0.0009415,0.0009415,0.0009415 +551,182940,0.84463,0.8116,1.04803,0.72167,0.63901,0.69493,0.53014,0.88282,0.81108,1.0921,0.000925,0.000925,0.000925 +552,183273,0.85234,0.81232,1.05265,0.72175,0.63977,0.6952,0.53045,0.8824,0.81073,1.09178,0.0009085,0.0009085,0.0009085 +553,183606,0.84906,0.80359,1.04675,0.72183,0.63981,0.69523,0.53054,0.88221,0.81042,1.09151,0.000892,0.000892,0.000892 +554,183937,0.8484,0.80795,1.0504,0.72317,0.63859,0.69555,0.53075,0.88203,0.80987,1.09142,0.0008755,0.0008755,0.0008755 +555,184268,0.84188,0.80778,1.04787,0.72361,0.63865,0.69555,0.53078,0.88165,0.80939,1.09089,0.000859,0.000859,0.000859 +556,184600,0.84643,0.80321,1.0478,0.72338,0.63941,0.69585,0.53105,0.8814,0.80906,1.09063,0.0008425,0.0008425,0.0008425 +557,184932,0.84164,0.7989,1.04564,0.72345,0.63847,0.69621,0.53104,0.88148,0.8089,1.09067,0.000826,0.000826,0.000826 +558,185265,0.84193,0.7997,1.04496,0.72513,0.6381,0.69648,0.53129,0.88111,0.80879,1.0904,0.0008095,0.0008095,0.0008095 +559,185597,0.83801,0.79299,1.04393,0.72387,0.63909,0.6965,0.53139,0.8811,0.80857,1.09033,0.000793,0.000793,0.000793 +560,185928,0.83821,0.7933,1.04693,0.72298,0.64029,0.6968,0.53149,0.88097,0.80838,1.09037,0.0007765,0.0007765,0.0007765 +561,186260,0.83586,0.79311,1.04594,0.72321,0.64119,0.69714,0.53179,0.88074,0.8082,1.09038,0.00076,0.00076,0.00076 +562,186590,0.83862,0.79598,1.04338,0.72418,0.6412,0.69728,0.53213,0.88042,0.80771,1.09013,0.0007435,0.0007435,0.0007435 +563,186923,0.83487,0.7848,1.04199,0.72635,0.63941,0.69728,0.53218,0.8803,0.80719,1.08985,0.000727,0.000727,0.000727 +564,187256,0.83152,0.78205,1.03798,0.72964,0.63769,0.69725,0.53239,0.88004,0.80702,1.08938,0.0007105,0.0007105,0.0007105 +565,187587,0.83573,0.78913,1.04551,0.73124,0.63692,0.69734,0.53256,0.87984,0.80694,1.08916,0.000694,0.000694,0.000694 +566,187921,0.83851,0.78554,1.04462,0.73344,0.63487,0.69719,0.53253,0.87966,0.80689,1.08908,0.0006775,0.0006775,0.0006775 +567,188252,0.83187,0.77485,1.04176,0.73259,0.63571,0.69739,0.5325,0.87955,0.80683,1.08913,0.000661,0.000661,0.000661 +568,188584,0.82939,0.77683,1.03814,0.7338,0.63463,0.69765,0.53285,0.87947,0.80681,1.08893,0.0006445,0.0006445,0.0006445 +569,188914,0.83004,0.77673,1.0407,0.72966,0.63829,0.69782,0.53319,0.87914,0.80672,1.08886,0.000628,0.000628,0.000628 +570,189247,0.82701,0.76952,1.03531,0.73098,0.63751,0.69802,0.53337,0.87928,0.80629,1.08885,0.0006115,0.0006115,0.0006115 +571,189579,0.82555,0.77242,1.03948,0.72816,0.63844,0.69795,0.53352,0.87917,0.80622,1.08876,0.000595,0.000595,0.000595 +572,189910,0.82398,0.77032,1.04025,0.72852,0.63917,0.69815,0.5336,0.87881,0.80614,1.08861,0.0005785,0.0005785,0.0005785 +573,190242,0.82279,0.76355,1.03723,0.72881,0.63948,0.69842,0.53373,0.87881,0.80562,1.08871,0.000562,0.000562,0.000562 +574,190573,0.8283,0.76636,1.03664,0.7274,0.63891,0.69855,0.53374,0.87856,0.80539,1.08854,0.0005455,0.0005455,0.0005455 +575,190904,0.83311,0.77347,1.03916,0.73302,0.63647,0.69868,0.53381,0.8781,0.8054,1.08817,0.000529,0.000529,0.000529 +576,191235,0.82704,0.76125,1.03562,0.73438,0.63484,0.6989,0.53408,0.8778,0.80487,1.08785,0.0005125,0.0005125,0.0005125 +577,191567,0.82038,0.75326,1.03467,0.73641,0.63423,0.69915,0.53453,0.87765,0.80456,1.08767,0.000496,0.000496,0.000496 +578,191899,0.81731,0.75185,1.03468,0.7372,0.63398,0.69926,0.53476,0.87744,0.80433,1.08753,0.0004795,0.0004795,0.0004795 +579,192231,0.81436,0.75094,1.03333,0.73696,0.63461,0.69982,0.53489,0.87716,0.80386,1.08739,0.000463,0.000463,0.000463 +580,192563,0.81445,0.74475,1.03182,0.73532,0.63604,0.69986,0.53484,0.87717,0.80346,1.08736,0.0004465,0.0004465,0.0004465 +581,192895,0.82257,0.75303,1.03451,0.73588,0.63562,0.70003,0.53508,0.87722,0.80332,1.08738,0.00043,0.00043,0.00043 +582,193226,0.81489,0.74116,1.02948,0.73707,0.63504,0.69998,0.5353,0.87718,0.80325,1.08725,0.0004135,0.0004135,0.0004135 +583,193559,0.8188,0.74574,1.03059,0.73772,0.63483,0.69972,0.53527,0.87696,0.80284,1.08692,0.000397,0.000397,0.000397 +584,193889,0.81105,0.73046,1.02913,0.74026,0.63371,0.7,0.53539,0.87675,0.80268,1.08678,0.0003805,0.0003805,0.0003805 +585,194222,0.81248,0.73256,1.02922,0.73831,0.63592,0.70037,0.53554,0.87662,0.80227,1.08662,0.000364,0.000364,0.000364 +586,194553,0.81324,0.73291,1.02706,0.73698,0.63639,0.70049,0.53568,0.87634,0.80216,1.08653,0.0003475,0.0003475,0.0003475 +587,194885,0.80739,0.73206,1.02741,0.73604,0.64032,0.70086,0.53603,0.87611,0.80196,1.0862,0.000331,0.000331,0.000331 +588,195215,0.80772,0.73086,1.02755,0.73997,0.63769,0.70121,0.5365,0.87585,0.80192,1.08585,0.0003145,0.0003145,0.0003145 +589,195547,0.80986,0.72327,1.02418,0.74063,0.6384,0.70162,0.53653,0.87578,0.8017,1.08553,0.000298,0.000298,0.000298 +590,195878,0.80981,0.73307,1.02791,0.74063,0.63871,0.70188,0.53669,0.87568,0.80163,1.08527,0.0002815,0.0002815,0.0002815 +591,196210,0.85815,0.75222,1.08431,0.7416,0.63818,0.70269,0.53726,0.87514,0.80156,1.08465,0.000265,0.000265,0.000265 +592,196540,0.85535,0.74035,1.07803,0.74375,0.63734,0.70303,0.53768,0.87489,0.80139,1.08442,0.0002485,0.0002485,0.0002485 +593,196871,0.85514,0.73019,1.08344,0.74089,0.63787,0.70304,0.5375,0.8747,0.8012,1.08429,0.000232,0.000232,0.000232 +594,197202,0.84685,0.72525,1.0778,0.73911,0.63864,0.70348,0.53793,0.8743,0.80088,1.08385,0.0002155,0.0002155,0.0002155 +595,197533,0.84805,0.71458,1.06866,0.73979,0.63871,0.70349,0.53764,0.87411,0.80058,1.0833,0.000199,0.000199,0.000199 +596,197863,0.84696,0.71442,1.07341,0.74203,0.6376,0.70358,0.53792,0.87426,0.80051,1.08322,0.0001825,0.0001825,0.0001825 +597,198194,0.83636,0.70403,1.06855,0.74047,0.63941,0.70364,0.53808,0.87418,0.8004,1.08298,0.000166,0.000166,0.000166 +598,198524,0.83957,0.69976,1.06273,0.73454,0.64506,0.704,0.53816,0.87395,0.80049,1.0823,0.0001495,0.0001495,0.0001495 +599,198855,0.82956,0.69088,1.06695,0.73572,0.64508,0.70427,0.53856,0.87369,0.80078,1.08199,0.000133,0.000133,0.000133 +600,199185,0.83061,0.68498,1.06622,0.74023,0.6418,0.70437,0.53866,0.87363,0.80081,1.08185,0.0001165,0.0001165,0.0001165 diff --git a/logs/yolov12m.csv b/logs/yolov12m.csv new file mode 100644 index 0000000000000000000000000000000000000000..c2756ad8410fa76014e25db87d4bc9f14afbedd6 --- /dev/null +++ b/logs/yolov12m.csv @@ -0,0 +1,601 @@ +epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2 +1,282.07,3.75283,5.79522,4.25518,0.00131,0.00824,0.00074,0.00025,3.58229,inf,4.43423,0.00332613,0.00332613,0.00332613 +2,546.658,2.76363,4.47115,2.97402,0.36227,0.03675,0.01852,0.00853,2.1628,3.54736,2.37751,0.00664848,0.00664848,0.00664848 +3,809.195,1.89134,3.26712,1.95665,0.22569,0.12162,0.07663,0.04221,1.7275,2.83615,1.95167,0.00995982,0.00995982,0.00995982 +4,1069.42,1.64624,2.72654,1.71241,0.31588,0.20148,0.16101,0.09756,1.52998,2.33678,1.7313,0.0099505,0.0099505,0.0099505 +5,1328.84,1.51622,2.41359,1.59059,0.34484,0.26198,0.22473,0.14136,1.42241,2.09368,1.63703,0.009934,0.009934,0.009934 +6,1588.21,1.44652,2.21834,1.52407,0.40269,0.28618,0.27009,0.17244,1.3547,1.9055,1.56417,0.0099175,0.0099175,0.0099175 +7,1847.61,1.39441,2.08004,1.48047,0.45128,0.30861,0.30367,0.19881,1.31356,1.78624,1.50911,0.009901,0.009901,0.009901 +8,2107.06,1.35919,1.9829,1.44775,0.4497,0.33572,0.33363,0.22218,1.27227,1.67945,1.47081,0.0098845,0.0098845,0.0098845 +9,2366.42,1.32,1.88119,1.411,0.49232,0.34797,0.35871,0.24074,1.24865,1.61728,1.44455,0.009868,0.009868,0.009868 +10,2625.48,1.30165,1.83752,1.39449,0.49936,0.37117,0.38435,0.26012,1.21333,1.53287,1.41207,0.0098515,0.0098515,0.0098515 +11,2884.42,1.28452,1.79186,1.37486,0.52415,0.38388,0.40506,0.27733,1.18862,1.48524,1.37946,0.009835,0.009835,0.009835 +12,3143.5,1.27096,1.74417,1.35587,0.52308,0.39734,0.42104,0.29019,1.1701,1.4428,1.3609,0.0098185,0.0098185,0.0098185 +13,3403,1.25879,1.69794,1.3408,0.5399,0.41965,0.43775,0.30206,1.15838,1.39777,1.34353,0.009802,0.009802,0.009802 +14,3661.93,1.24108,1.67093,1.32527,0.5588,0.4251,0.44999,0.31418,1.13714,1.35889,1.32395,0.0097855,0.0097855,0.0097855 +15,3921.43,1.2291,1.64034,1.31623,0.57273,0.43413,0.46379,0.32416,1.12205,1.3307,1.31113,0.009769,0.009769,0.009769 +16,4180.67,1.22379,1.62595,1.30848,0.5714,0.44561,0.47316,0.3321,1.11226,1.30609,1.29839,0.0097525,0.0097525,0.0097525 +17,4438.85,1.20879,1.61092,1.29555,0.59072,0.44868,0.4847,0.34295,1.10048,1.27229,1.28541,0.009736,0.009736,0.009736 +18,4697.25,1.20955,1.58447,1.2927,0.57516,0.46021,0.49169,0.34775,1.09344,1.25771,1.27427,0.0097195,0.0097195,0.0097195 +19,4955.64,1.19818,1.57208,1.2812,0.59967,0.46517,0.50078,0.35423,1.08414,1.23796,1.26632,0.009703,0.009703,0.009703 +20,5214.39,1.18363,1.54943,1.27492,0.59679,0.47048,0.50639,0.36002,1.07502,1.22023,1.25471,0.0096865,0.0096865,0.0096865 +21,5473.36,1.18339,1.53664,1.27088,0.59174,0.4792,0.51326,0.3661,1.06802,1.20481,1.24857,0.00967,0.00967,0.00967 +22,5732.42,1.17428,1.52667,1.26173,0.60613,0.48301,0.51943,0.37126,1.06088,1.18962,1.24237,0.0096535,0.0096535,0.0096535 +23,5991.68,1.17094,1.51093,1.25645,0.6045,0.49092,0.52408,0.37519,1.0558,1.17915,1.23648,0.009637,0.009637,0.009637 +24,6250.32,1.16438,1.49642,1.25454,0.62215,0.48331,0.52697,0.37781,1.05093,1.16928,1.23108,0.0096205,0.0096205,0.0096205 +25,6508.84,1.15953,1.48934,1.2509,0.61844,0.49214,0.53243,0.38194,1.04663,1.15802,1.2269,0.009604,0.009604,0.009604 +26,6767.87,1.16433,1.48516,1.25017,0.61908,0.49154,0.53443,0.38448,1.04293,1.15069,1.22306,0.0095875,0.0095875,0.0095875 +27,7026.38,1.16818,1.49106,1.25387,0.62517,0.49294,0.53748,0.38713,1.0386,1.14477,1.21929,0.009571,0.009571,0.009571 +28,7285.55,1.15589,1.47312,1.24485,0.62162,0.49712,0.54139,0.39025,1.03506,1.13692,1.21565,0.0095545,0.0095545,0.0095545 +29,7544,1.15017,1.47372,1.24108,0.61795,0.50445,0.54498,0.39309,1.03272,1.12992,1.21335,0.009538,0.009538,0.009538 +30,7802.86,1.15701,1.45827,1.24049,0.62229,0.50811,0.54748,0.39588,1.03034,1.12437,1.20931,0.0095215,0.0095215,0.0095215 +31,8062.13,1.14553,1.4486,1.23463,0.63551,0.50385,0.54976,0.3978,1.02851,1.11812,1.20753,0.009505,0.009505,0.009505 +32,8319.63,1.13736,1.44685,1.23141,0.63378,0.50792,0.55222,0.39997,1.02655,1.11319,1.20528,0.0094885,0.0094885,0.0094885 +33,8578.31,1.13822,1.43594,1.22867,0.64331,0.5036,0.55411,0.40123,1.02478,1.10989,1.20301,0.009472,0.009472,0.009472 +34,8837.53,1.13694,1.43267,1.2258,0.64255,0.50468,0.5551,0.40271,1.02161,1.10611,1.20051,0.0094555,0.0094555,0.0094555 +35,9096.13,1.14286,1.42713,1.22761,0.64478,0.50495,0.55692,0.40454,1.01978,1.10241,1.19856,0.009439,0.009439,0.009439 +36,9354.65,1.13375,1.41563,1.21887,0.63457,0.51118,0.55788,0.40543,1.01838,1.09875,1.19693,0.0094225,0.0094225,0.0094225 +37,9612.91,1.13692,1.42218,1.2254,0.63303,0.51419,0.55899,0.40624,1.01725,1.0967,1.1955,0.009406,0.009406,0.009406 +38,9870.73,1.12785,1.40739,1.21831,0.63299,0.51556,0.55992,0.4069,1.01612,1.09422,1.19419,0.0093895,0.0093895,0.0093895 +39,10128.9,1.12435,1.40651,1.21517,0.6317,0.51784,0.56075,0.40775,1.01571,1.0918,1.19351,0.009373,0.009373,0.009373 +40,10387.5,1.13051,1.39791,1.21581,0.63143,0.52027,0.56193,0.40881,1.0147,1.08977,1.19253,0.0093565,0.0093565,0.0093565 +41,10646.3,1.12702,1.39813,1.21324,0.63338,0.51912,0.56255,0.40918,1.0139,1.0882,1.19158,0.00934,0.00934,0.00934 +42,10903.8,1.12809,1.40031,1.21336,0.63383,0.51776,0.56313,0.40988,1.01349,1.087,1.19099,0.0093235,0.0093235,0.0093235 +43,11162.1,1.12206,1.39154,1.20845,0.63912,0.51641,0.56357,0.41038,1.01259,1.08567,1.19008,0.009307,0.009307,0.009307 +44,11420.5,1.1198,1.38535,1.20798,0.64401,0.5146,0.56433,0.41089,1.01182,1.08468,1.18932,0.0092905,0.0092905,0.0092905 +45,11679.3,1.12135,1.39247,1.21009,0.64495,0.51561,0.56521,0.41127,1.01161,1.08393,1.18894,0.009274,0.009274,0.009274 +46,11937.4,1.12593,1.39032,1.21089,0.64448,0.51535,0.56538,0.41154,1.01133,1.08275,1.18836,0.0092575,0.0092575,0.0092575 +47,12196,1.12409,1.37634,1.20577,0.64526,0.51627,0.5657,0.41205,1.01119,1.08215,1.18808,0.009241,0.009241,0.009241 +48,12454.2,1.12216,1.37156,1.2049,0.65279,0.51577,0.56626,0.41236,1.01088,1.08182,1.18745,0.0092245,0.0092245,0.0092245 +49,268.177,1.1117,1.37141,1.20643,0.65188,0.51616,0.56676,0.41299,1.01015,1.07991,1.18697,0.009208,0.009208,0.009208 +50,525.819,1.12521,1.39964,1.21514,0.65313,0.51566,0.56722,0.41342,1.00967,1.07844,1.18646,0.0091915,0.0091915,0.0091915 +51,784.193,1.12172,1.39792,1.21142,0.65369,0.51623,0.56737,0.4137,1.00934,1.07772,1.18605,0.009175,0.009175,0.009175 +52,1043.41,1.12493,1.39825,1.21107,0.65421,0.51623,0.56771,0.41399,1.00922,1.07725,1.18575,0.0091585,0.0091585,0.0091585 +53,1301.89,1.12067,1.38548,1.21237,0.65599,0.51552,0.56785,0.41427,1.00871,1.07636,1.18532,0.009142,0.009142,0.009142 +54,1559.88,1.12982,1.37806,1.20908,0.65481,0.51613,0.56829,0.41447,1.00851,1.0756,1.18498,0.0091255,0.0091255,0.0091255 +55,1818.02,1.12423,1.38742,1.20871,0.6548,0.5168,0.56906,0.41498,1.00843,1.07481,1.18474,0.009109,0.009109,0.009109 +56,2076.85,1.11766,1.37874,1.20921,0.65411,0.51844,0.56924,0.41513,1.00831,1.07427,1.18439,0.0090925,0.0090925,0.0090925 +57,2335.07,1.11058,1.36767,1.2009,0.6553,0.51844,0.56956,0.41529,1.00823,1.07382,1.18407,0.009076,0.009076,0.009076 +58,2593.78,1.11393,1.37487,1.20509,0.65587,0.51865,0.56995,0.41571,1.00785,1.07298,1.18368,0.0090595,0.0090595,0.0090595 +59,2851.68,1.11031,1.37275,1.2011,0.6562,0.5185,0.5699,0.41575,1.00769,1.07296,1.18321,0.009043,0.009043,0.009043 +60,3109.46,1.11734,1.35971,1.19848,0.65577,0.51904,0.57031,0.41591,1.00725,1.07268,1.18284,0.0090265,0.0090265,0.0090265 +61,3368.03,1.11451,1.36893,1.20079,0.65556,0.52011,0.57057,0.4162,1.00704,1.07247,1.18254,0.00901,0.00901,0.00901 +62,3626.29,1.1087,1.35567,1.19646,0.65213,0.52044,0.57059,0.41631,1.00697,1.07255,1.18232,0.0089935,0.0089935,0.0089935 +63,3884.51,1.10705,1.35364,1.19556,0.65177,0.52046,0.57096,0.41664,1.00688,1.07255,1.18214,0.008977,0.008977,0.008977 +64,4142.17,1.10396,1.35028,1.19393,0.65184,0.52078,0.5714,0.41696,1.00718,1.07262,1.18235,0.0089605,0.0089605,0.0089605 +65,4399.99,1.10696,1.35446,1.19641,0.65483,0.52026,0.57151,0.41714,1.00701,1.07287,1.18191,0.008944,0.008944,0.008944 +66,4657.58,1.10443,1.34158,1.19192,0.65203,0.52251,0.57158,0.4172,1.00688,1.0728,1.18156,0.0089275,0.0089275,0.0089275 +67,4915.42,1.10213,1.3474,1.18971,0.65234,0.52264,0.57179,0.41738,1.00682,1.07308,1.18136,0.008911,0.008911,0.008911 +68,5173.24,1.10485,1.35536,1.19312,0.65138,0.5232,0.57197,0.41734,1.00677,1.07357,1.18099,0.0088945,0.0088945,0.0088945 +69,5431.33,1.10112,1.3426,1.19166,0.65236,0.52338,0.57206,0.4175,1.00661,1.07422,1.18075,0.008878,0.008878,0.008878 +70,5689.16,1.10148,1.33704,1.18812,0.65348,0.5223,0.57222,0.41764,1.00661,1.07495,1.18051,0.0088615,0.0088615,0.0088615 +71,5947.03,1.10113,1.33588,1.18827,0.65601,0.51945,0.57226,0.41759,1.00659,1.0757,1.18033,0.008845,0.008845,0.008845 +72,6203.88,1.09922,1.33364,1.18844,0.65465,0.52127,0.57229,0.41772,1.00638,1.07632,1.18011,0.0088285,0.0088285,0.0088285 +73,6461.89,1.10629,1.33796,1.18885,0.65402,0.52157,0.5721,0.41766,1.00623,1.07746,1.17994,0.008812,0.008812,0.008812 +74,6719.26,1.09453,1.32526,1.1862,0.65463,0.52168,0.57209,0.41772,1.00611,1.0783,1.1798,0.0087955,0.0087955,0.0087955 +75,6977.38,1.10291,1.33086,1.18607,0.65358,0.52338,0.57223,0.41792,1.0061,1.07916,1.17965,0.008779,0.008779,0.008779 +76,7235.51,1.09846,1.33015,1.18824,0.65057,0.5242,0.57211,0.41794,1.00594,1.08027,1.17943,0.0087625,0.0087625,0.0087625 +77,7492.74,1.09138,1.31549,1.1808,0.65184,0.52371,0.57189,0.41823,1.00568,1.08134,1.17913,0.008746,0.008746,0.008746 +78,7750.1,1.0951,1.31539,1.18482,0.65358,0.5228,0.57189,0.41839,1.00582,1.08231,1.17904,0.0087295,0.0087295,0.0087295 +79,8008.18,1.09424,1.31789,1.182,0.65126,0.52347,0.57205,0.41851,1.00591,1.0831,1.17888,0.008713,0.008713,0.008713 +80,8264.9,1.09516,1.32897,1.18494,0.65246,0.52325,0.57209,0.41894,1.00591,1.08373,1.17882,0.0086965,0.0086965,0.0086965 +81,8522.72,1.09209,1.31137,1.18204,0.65305,0.52294,0.57208,0.41914,1.00589,1.08433,1.17866,0.00868,0.00868,0.00868 +82,8779.54,1.09326,1.32525,1.18114,0.65171,0.52329,0.57233,0.41934,1.00588,1.08509,1.17851,0.0086635,0.0086635,0.0086635 +83,9036.89,1.0963,1.31838,1.18026,0.65671,0.5208,0.57241,0.41954,1.00553,1.08593,1.17819,0.008647,0.008647,0.008647 +84,9294.37,1.08336,1.30036,1.17471,0.65636,0.52246,0.57311,0.42001,1.0053,1.08677,1.17783,0.0086305,0.0086305,0.0086305 +85,9551.5,1.08889,1.31155,1.17944,0.65777,0.52192,0.57347,0.42036,1.00495,1.08722,1.17745,0.008614,0.008614,0.008614 +86,9810,1.08609,1.30921,1.17927,0.65788,0.52224,0.57353,0.4207,1.00471,1.08772,1.17731,0.0085975,0.0085975,0.0085975 +87,10067.1,1.08835,1.30317,1.17458,0.66093,0.52068,0.57393,0.42071,1.0044,1.08829,1.17709,0.008581,0.008581,0.008581 +88,10325.1,1.09005,1.31326,1.17807,0.66147,0.52045,0.5739,0.42085,1.00412,1.08876,1.17692,0.0085645,0.0085645,0.0085645 +89,10582.6,1.08356,1.30289,1.17504,0.65887,0.52165,0.57401,0.42122,1.00378,1.08936,1.17648,0.008548,0.008548,0.008548 +90,10840.1,1.08647,1.30578,1.17581,0.66352,0.52121,0.57415,0.42146,1.00358,1.08926,1.17613,0.0085315,0.0085315,0.0085315 +91,11097.4,1.08504,1.30617,1.17717,0.66526,0.52179,0.5744,0.42204,1.00326,1.08936,1.17566,0.008515,0.008515,0.008515 +92,11354.9,1.09022,1.30435,1.17735,0.66594,0.52237,0.57492,0.42226,1.00308,1.0896,1.17535,0.0084985,0.0084985,0.0084985 +93,11611.5,1.08295,1.29655,1.17642,0.66495,0.52322,0.57524,0.42251,1.00293,1.08971,1.17508,0.008482,0.008482,0.008482 +94,11868.9,1.09021,1.31513,1.17984,0.66754,0.52269,0.57575,0.42312,1.00289,1.08986,1.17479,0.0084655,0.0084655,0.0084655 +95,12127.1,1.08238,1.29813,1.17284,0.66829,0.52327,0.57634,0.42324,1.00267,1.08987,1.17437,0.008449,0.008449,0.008449 +96,12384.9,1.08474,1.29475,1.17402,0.67108,0.52255,0.57674,0.42364,1.00283,1.08932,1.17411,0.0084325,0.0084325,0.0084325 +97,12642.4,1.08324,1.28618,1.17227,0.67091,0.52309,0.57717,0.42412,1.00224,1.08857,1.17352,0.008416,0.008416,0.008416 +98,12899.9,1.07994,1.29487,1.17226,0.67061,0.5235,0.57772,0.4247,1.00165,1.08801,1.17294,0.0083995,0.0083995,0.0083995 +99,13157.2,1.08245,1.293,1.17564,0.67254,0.52424,0.57837,0.425,1.00137,1.08718,1.17245,0.008383,0.008383,0.008383 +100,13414,1.0797,1.28205,1.17013,0.67214,0.52617,0.57914,0.42538,1.00084,1.08603,1.17184,0.0083665,0.0083665,0.0083665 +101,13671.4,1.08075,1.28401,1.17177,0.67328,0.52624,0.57967,0.42613,1.00032,1.08521,1.17124,0.00835,0.00835,0.00835 +102,13928.8,1.08153,1.2923,1.17138,0.67369,0.52685,0.57989,0.42642,0.99975,1.08437,1.1705,0.0083335,0.0083335,0.0083335 +103,14185.9,1.07238,1.28303,1.16851,0.67327,0.52729,0.58012,0.42672,0.99904,1.08354,1.16971,0.008317,0.008317,0.008317 +104,14443.4,1.07408,1.29844,1.17305,0.67435,0.52671,0.5805,0.42708,0.99842,1.08245,1.16889,0.0083005,0.0083005,0.0083005 +105,14701.2,1.07576,1.28171,1.17031,0.67401,0.52701,0.58083,0.42781,0.99784,1.0812,1.16836,0.008284,0.008284,0.008284 +106,14958.9,1.07557,1.27653,1.16576,0.67585,0.52699,0.58147,0.42819,0.99731,1.08027,1.16775,0.0082675,0.0082675,0.0082675 +107,15216.1,1.08205,1.28915,1.17299,0.67638,0.52799,0.58223,0.42877,0.9969,1.07924,1.16732,0.008251,0.008251,0.008251 +108,15474.2,1.08217,1.2864,1.171,0.67574,0.52867,0.58292,0.4293,0.99631,1.07787,1.16664,0.0082345,0.0082345,0.0082345 +109,15731.6,1.07964,1.28354,1.16955,0.67618,0.52936,0.58351,0.4297,0.9957,1.07664,1.16594,0.008218,0.008218,0.008218 +110,15989.5,1.06929,1.27399,1.16392,0.6774,0.52956,0.58461,0.43053,0.99527,1.07515,1.16542,0.0082015,0.0082015,0.0082015 +111,16246.9,1.0755,1.27514,1.16734,0.67891,0.52965,0.58543,0.43086,0.99486,1.0736,1.16483,0.008185,0.008185,0.008185 +112,16504.1,1.08075,1.28317,1.17102,0.67961,0.52991,0.58589,0.43142,0.99436,1.07196,1.16407,0.0081685,0.0081685,0.0081685 +113,16761.8,1.07459,1.27604,1.16862,0.68093,0.53032,0.58627,0.43198,0.99358,1.07015,1.1633,0.008152,0.008152,0.008152 +114,17018.8,1.07179,1.27669,1.16752,0.6798,0.53234,0.587,0.43272,0.99287,1.06903,1.16244,0.0081355,0.0081355,0.0081355 +115,17276.2,1.07777,1.27791,1.16539,0.67953,0.53322,0.5876,0.43339,0.99229,1.06715,1.16171,0.008119,0.008119,0.008119 +116,17534,1.06893,1.25952,1.16391,0.68086,0.53351,0.58812,0.43387,0.99164,1.06559,1.16101,0.0081025,0.0081025,0.0081025 +117,17791.1,1.07366,1.2702,1.16434,0.67956,0.53538,0.58861,0.43435,0.99073,1.06397,1.16017,0.008086,0.008086,0.008086 +118,18048.8,1.06814,1.26573,1.16332,0.67896,0.53597,0.589,0.4346,0.99008,1.06252,1.15937,0.0080695,0.0080695,0.0080695 +119,18306.9,1.07579,1.2711,1.16407,0.67806,0.53654,0.58945,0.43507,0.98953,1.06103,1.15876,0.008053,0.008053,0.008053 +120,18563.7,1.07191,1.26903,1.16573,0.67753,0.53763,0.59016,0.43557,0.98881,1.05933,1.15796,0.0080365,0.0080365,0.0080365 +121,18821.3,1.06211,1.26206,1.16057,0.67855,0.53812,0.59081,0.43637,0.98824,1.05725,1.15726,0.00802,0.00802,0.00802 +122,19078.4,1.0698,1.26257,1.16459,0.68208,0.53662,0.59143,0.43659,0.98744,1.05542,1.15645,0.0080035,0.0080035,0.0080035 +123,19336,1.071,1.26488,1.16327,0.6854,0.53568,0.59198,0.43737,0.98671,1.05338,1.15566,0.007987,0.007987,0.007987 +124,19594.1,1.07122,1.26911,1.16253,0.68517,0.5364,0.59251,0.43804,0.9861,1.05146,1.15503,0.0079705,0.0079705,0.0079705 +125,19851.4,1.06833,1.26577,1.16375,0.68571,0.53686,0.59323,0.43868,0.9854,1.04942,1.15426,0.007954,0.007954,0.007954 +126,20108.5,1.06525,1.25966,1.15873,0.6882,0.5358,0.59407,0.43923,0.98494,1.04782,1.1537,0.0079375,0.0079375,0.0079375 +127,20365.8,1.06266,1.25765,1.16098,0.687,0.53739,0.5946,0.43992,0.98418,1.04614,1.15299,0.007921,0.007921,0.007921 +128,20623.2,1.06448,1.25298,1.15915,0.68762,0.5388,0.59534,0.4407,0.9835,1.04435,1.15231,0.0079045,0.0079045,0.0079045 +129,20880.9,1.06881,1.26641,1.16029,0.68382,0.53899,0.59578,0.44099,0.98285,1.04235,1.15155,0.007888,0.007888,0.007888 +130,21138.2,1.06843,1.26823,1.16051,0.68475,0.53858,0.5966,0.44159,0.98245,1.04042,1.15099,0.0078715,0.0078715,0.0078715 +131,21396,1.06898,1.27146,1.16384,0.68404,0.53934,0.59694,0.44199,0.98184,1.03855,1.15038,0.007855,0.007855,0.007855 +132,21653.8,1.0691,1.25067,1.15937,0.68427,0.53954,0.59741,0.4423,0.98141,1.03697,1.14974,0.0078385,0.0078385,0.0078385 +133,21911.3,1.06012,1.24439,1.15612,0.68675,0.53902,0.59771,0.44308,0.98082,1.03524,1.14907,0.007822,0.007822,0.007822 +134,22168.5,1.06733,1.24955,1.15828,0.68814,0.53929,0.59819,0.44367,0.98031,1.03351,1.14859,0.0078055,0.0078055,0.0078055 +135,22426.1,1.06315,1.25262,1.15886,0.69092,0.53892,0.59902,0.444,0.97979,1.03167,1.14805,0.007789,0.007789,0.007789 +136,22683.5,1.06448,1.24688,1.15936,0.68818,0.53963,0.5994,0.44465,0.97929,1.03001,1.14753,0.0077725,0.0077725,0.0077725 +137,22940.9,1.06684,1.26155,1.16182,0.69091,0.53905,0.60009,0.44502,0.97854,1.02843,1.1466,0.007756,0.007756,0.007756 +138,23198.9,1.05964,1.24761,1.15422,0.69101,0.54005,0.60029,0.44547,0.97799,1.02651,1.14595,0.0077395,0.0077395,0.0077395 +139,23456.8,1.06071,1.24256,1.15576,0.6907,0.54019,0.60074,0.44578,0.97757,1.02502,1.14529,0.007723,0.007723,0.007723 +140,23714.2,1.05785,1.24455,1.15805,0.68762,0.54228,0.60133,0.44624,0.97757,1.02376,1.14496,0.0077065,0.0077065,0.0077065 +141,23971.8,1.06184,1.24634,1.15628,0.69051,0.5414,0.60171,0.4466,0.97708,1.02218,1.14447,0.00769,0.00769,0.00769 +142,24228.7,1.06263,1.25198,1.16227,0.68977,0.54237,0.60206,0.44685,0.97657,1.02051,1.14394,0.0076735,0.0076735,0.0076735 +143,24486.2,1.06398,1.24757,1.1602,0.6897,0.54277,0.60272,0.44732,0.97613,1.0189,1.14348,0.007657,0.007657,0.007657 +144,24743.9,1.05954,1.24789,1.15925,0.6874,0.54486,0.60285,0.44747,0.97568,1.0174,1.14306,0.0076405,0.0076405,0.0076405 +145,25000.8,1.05472,1.23943,1.15417,0.68798,0.54506,0.60345,0.44781,0.9753,1.01568,1.14265,0.007624,0.007624,0.007624 +146,25258.5,1.0588,1.24479,1.15757,0.68791,0.54583,0.60371,0.4482,0.97506,1.01445,1.1424,0.0076075,0.0076075,0.0076075 +147,25516.4,1.06192,1.24423,1.1592,0.68686,0.54648,0.60408,0.44841,0.97484,1.01336,1.142,0.007591,0.007591,0.007591 +148,25774.1,1.06904,1.24827,1.16147,0.68829,0.54641,0.60458,0.44881,0.97461,1.01197,1.1417,0.0075745,0.0075745,0.0075745 +149,26031.9,1.05378,1.23998,1.15481,0.68913,0.54668,0.60507,0.44918,0.97431,1.01059,1.14116,0.007558,0.007558,0.007558 +150,26289.8,1.06122,1.24859,1.15825,0.6904,0.54635,0.60554,0.44976,0.97398,1.00881,1.14081,0.0075415,0.0075415,0.0075415 +151,26547.7,1.05552,1.23769,1.15305,0.69795,0.54519,0.60641,0.45042,0.97374,1.00756,1.14041,0.007525,0.007525,0.007525 +152,26804.6,1.05979,1.23806,1.15501,0.69833,0.54625,0.6068,0.45078,0.97338,1.00631,1.14004,0.0075085,0.0075085,0.0075085 +153,27061.8,1.0634,1.24396,1.15622,0.69589,0.54623,0.60748,0.45119,0.97281,1.00508,1.13948,0.007492,0.007492,0.007492 +154,27319.7,1.06347,1.25581,1.15851,0.69602,0.54585,0.60775,0.45152,0.97246,1.00389,1.13913,0.0074755,0.0074755,0.0074755 +155,27577.3,1.05566,1.24407,1.15383,0.69377,0.54793,0.60795,0.4516,0.97214,1.00235,1.13858,0.007459,0.007459,0.007459 +156,27834.6,1.05596,1.24086,1.15565,0.69419,0.54778,0.60824,0.45202,0.972,1.00098,1.13841,0.0074425,0.0074425,0.0074425 +157,28092,1.04947,1.22883,1.15073,0.69745,0.54737,0.60892,0.45254,0.97156,1.00007,1.13808,0.007426,0.007426,0.007426 +158,28349,1.04834,1.2236,1.14955,0.69276,0.54955,0.6092,0.45286,0.97099,0.99893,1.13751,0.0074095,0.0074095,0.0074095 +159,28606.2,1.05686,1.23071,1.15678,0.69139,0.55029,0.60962,0.45303,0.971,0.9977,1.1374,0.007393,0.007393,0.007393 +160,28863.3,1.05336,1.22014,1.15076,0.68884,0.55137,0.61001,0.4535,0.97066,0.99646,1.13705,0.0073765,0.0073765,0.0073765 +161,29120.6,1.05624,1.24144,1.15542,0.69349,0.54994,0.61082,0.45407,0.97053,0.99506,1.13685,0.00736,0.00736,0.00736 +162,29378.7,1.05181,1.22072,1.15033,0.69502,0.55011,0.61161,0.45459,0.97018,0.99361,1.13639,0.0073435,0.0073435,0.0073435 +163,29636.1,1.05286,1.22964,1.15247,0.6963,0.55049,0.61215,0.45506,0.96978,0.99248,1.13594,0.007327,0.007327,0.007327 +164,29893,1.04719,1.22121,1.14543,0.69373,0.55221,0.6125,0.45558,0.96959,0.99157,1.13566,0.0073105,0.0073105,0.0073105 +165,30150.5,1.05529,1.23594,1.15291,0.69536,0.5513,0.61311,0.45586,0.96925,0.99068,1.13517,0.007294,0.007294,0.007294 +166,30408.2,1.05273,1.22627,1.15012,0.69658,0.55079,0.61344,0.4561,0.96855,0.98954,1.1346,0.0072775,0.0072775,0.0072775 +167,30666.1,1.04913,1.22697,1.15128,0.69606,0.55159,0.61377,0.45626,0.96807,0.98844,1.13413,0.007261,0.007261,0.007261 +168,30923.8,1.04869,1.22603,1.14944,0.69506,0.5525,0.61408,0.45652,0.96763,0.98761,1.1337,0.0072445,0.0072445,0.0072445 +169,31181.6,1.04515,1.23321,1.15026,0.69832,0.55158,0.61451,0.45681,0.96707,0.98624,1.13308,0.007228,0.007228,0.007228 +170,31439.4,1.04959,1.2305,1.15128,0.69874,0.55207,0.61488,0.45732,0.96676,0.98488,1.13275,0.0072115,0.0072115,0.0072115 +171,31697.2,1.04918,1.22759,1.15079,0.69882,0.55277,0.6146,0.45751,0.96618,0.98411,1.13235,0.007195,0.007195,0.007195 +172,31954.8,1.04519,1.20936,1.14632,0.70009,0.5526,0.61507,0.45763,0.96614,0.98326,1.13225,0.0071785,0.0071785,0.0071785 +173,32211.7,1.04828,1.22278,1.14731,0.70325,0.55126,0.61618,0.45816,0.96569,0.98228,1.13182,0.007162,0.007162,0.007162 +174,32470,1.04159,1.21224,1.14828,0.70131,0.55263,0.61667,0.45844,0.96529,0.98152,1.1316,0.0071455,0.0071455,0.0071455 +175,32727.8,1.0486,1.226,1.1495,0.69872,0.55408,0.617,0.45877,0.96492,0.98044,1.13133,0.007129,0.007129,0.007129 +176,32985.3,1.04868,1.21071,1.14502,0.69773,0.55441,0.61706,0.45902,0.96434,0.9794,1.13083,0.0071125,0.0071125,0.0071125 +177,33242.1,1.05248,1.22942,1.15117,0.69865,0.55493,0.61749,0.45929,0.96406,0.97851,1.13067,0.007096,0.007096,0.007096 +178,33499.2,1.05106,1.22124,1.14938,0.69624,0.55717,0.61799,0.45989,0.96375,0.97745,1.13037,0.0070795,0.0070795,0.0070795 +179,33756.2,1.04597,1.22183,1.14899,0.69772,0.55719,0.61827,0.46006,0.96353,0.97645,1.13006,0.007063,0.007063,0.007063 +180,34014.5,1.04141,1.20713,1.1483,0.69987,0.55621,0.61877,0.46036,0.96333,0.97606,1.12966,0.0070465,0.0070465,0.0070465 +181,34272.4,1.04681,1.22035,1.1478,0.69828,0.55721,0.61917,0.46066,0.96301,0.97516,1.12936,0.00703,0.00703,0.00703 +182,34530.8,1.04395,1.20374,1.14604,0.69904,0.55686,0.61937,0.4609,0.96255,0.97445,1.12889,0.0070135,0.0070135,0.0070135 +183,34789.2,1.05321,1.22338,1.15126,0.69885,0.55802,0.61998,0.4613,0.96188,0.97355,1.12843,0.006997,0.006997,0.006997 +184,35046.9,1.04926,1.22258,1.14986,0.70203,0.5576,0.62017,0.46149,0.96161,0.97254,1.12808,0.0069805,0.0069805,0.0069805 +185,35304.4,1.04762,1.21376,1.14731,0.70122,0.55712,0.62052,0.46201,0.96126,0.97162,1.12783,0.006964,0.006964,0.006964 +186,35561.6,1.04263,1.20865,1.14506,0.69972,0.55804,0.6209,0.46244,0.96073,0.97074,1.12742,0.0069475,0.0069475,0.0069475 +187,35819.4,1.04498,1.20876,1.14365,0.70317,0.55667,0.62119,0.46273,0.96031,0.96989,1.12694,0.006931,0.006931,0.006931 +188,36076.6,1.04935,1.21795,1.14923,0.70497,0.55558,0.62141,0.46284,0.96001,0.96932,1.1265,0.0069145,0.0069145,0.0069145 +189,36333.1,1.04292,1.20537,1.14564,0.70549,0.55597,0.62194,0.46324,0.95981,0.96839,1.1263,0.006898,0.006898,0.006898 +190,36590.8,1.04137,1.19862,1.14194,0.70424,0.55679,0.622,0.46299,0.95958,0.96762,1.12613,0.0068815,0.0068815,0.0068815 +191,36848.1,1.04552,1.20192,1.14613,0.70292,0.55758,0.62239,0.46333,0.9593,0.96673,1.12587,0.006865,0.006865,0.006865 +192,37105.5,1.04538,1.20288,1.14475,0.70111,0.55887,0.62259,0.46349,0.95906,0.96594,1.12547,0.0068485,0.0068485,0.0068485 +193,37363,1.03609,1.19689,1.14148,0.7007,0.55938,0.62259,0.46383,0.95878,0.96552,1.12515,0.006832,0.006832,0.006832 +194,37620.3,1.04255,1.20398,1.14294,0.69899,0.56052,0.62301,0.46396,0.95855,0.9648,1.12494,0.0068155,0.0068155,0.0068155 +195,37878.6,1.04331,1.20156,1.14355,0.69815,0.56142,0.62324,0.46428,0.95805,0.96403,1.12454,0.006799,0.006799,0.006799 +196,38135.8,1.0434,1.20707,1.14569,0.69639,0.56222,0.62387,0.46496,0.95753,0.96328,1.12423,0.0067825,0.0067825,0.0067825 +197,38393.2,1.04123,1.20051,1.14071,0.69776,0.56252,0.62405,0.46519,0.95743,0.96263,1.12398,0.006766,0.006766,0.006766 +198,38650.8,1.04172,1.20792,1.14673,0.7,0.56202,0.62421,0.46532,0.95706,0.96207,1.12383,0.0067495,0.0067495,0.0067495 +199,38908.2,1.04558,1.21024,1.14724,0.69997,0.56224,0.62433,0.46547,0.95722,0.96133,1.12386,0.006733,0.006733,0.006733 +200,39166,1.03807,1.20195,1.14196,0.69751,0.56379,0.6247,0.46578,0.95707,0.96079,1.12367,0.0067165,0.0067165,0.0067165 +201,39422.6,1.04845,1.20938,1.14761,0.69997,0.56252,0.62503,0.46611,0.95679,0.96004,1.12342,0.0067,0.0067,0.0067 +202,39679.4,1.03327,1.19892,1.13973,0.69902,0.56375,0.62538,0.46637,0.9565,0.95917,1.12308,0.0066835,0.0066835,0.0066835 +203,39936.9,1.04121,1.20087,1.14416,0.69745,0.56503,0.62538,0.46663,0.95601,0.95844,1.12251,0.006667,0.006667,0.006667 +204,40193.8,1.03829,1.19675,1.14213,0.69749,0.56557,0.62549,0.46661,0.95567,0.95792,1.1221,0.0066505,0.0066505,0.0066505 +205,40451.7,1.03087,1.20369,1.14026,0.69507,0.5677,0.62572,0.46688,0.95533,0.95726,1.12176,0.006634,0.006634,0.006634 +206,40709.4,1.03721,1.19541,1.13993,0.69279,0.56902,0.62599,0.46698,0.95516,0.95696,1.12153,0.0066175,0.0066175,0.0066175 +207,40966.5,1.03594,1.19256,1.14022,0.69063,0.57055,0.62612,0.46744,0.95485,0.95642,1.12122,0.006601,0.006601,0.006601 +208,41223.9,1.04329,1.19481,1.14002,0.69594,0.56784,0.62614,0.46754,0.9548,0.9557,1.12114,0.0065845,0.0065845,0.0065845 +209,41481.7,1.03656,1.19466,1.14024,0.69374,0.56938,0.6264,0.46765,0.95464,0.95526,1.12092,0.006568,0.006568,0.006568 +210,41738.7,1.0375,1.20456,1.14384,0.69052,0.57189,0.62697,0.46795,0.9546,0.95476,1.12068,0.0065515,0.0065515,0.0065515 +211,41996.5,1.03924,1.19446,1.14215,0.69052,0.57271,0.62726,0.46801,0.95444,0.95448,1.12045,0.006535,0.006535,0.006535 +212,42253.1,1.03282,1.1861,1.1389,0.68891,0.57388,0.62726,0.46794,0.95409,0.95392,1.12004,0.0065185,0.0065185,0.0065185 +213,42510.2,1.0385,1.19341,1.13917,0.68912,0.57479,0.62757,0.46782,0.9538,0.95314,1.11976,0.006502,0.006502,0.006502 +214,42767.9,1.03902,1.19184,1.13862,0.69173,0.57438,0.62808,0.46808,0.9537,0.95267,1.11952,0.0064855,0.0064855,0.0064855 +215,43025.2,1.03218,1.18615,1.13706,0.69028,0.57465,0.62816,0.46832,0.95355,0.95225,1.11944,0.006469,0.006469,0.006469 +216,43282.2,1.03226,1.18664,1.13798,0.69319,0.57247,0.62839,0.46861,0.9538,0.95194,1.11948,0.0064525,0.0064525,0.0064525 +217,43540.3,1.03226,1.18626,1.13556,0.69143,0.57312,0.62869,0.46851,0.9539,0.95167,1.11943,0.006436,0.006436,0.006436 +218,43798.6,1.03182,1.19212,1.13851,0.69212,0.57362,0.6291,0.46883,0.95413,0.95115,1.11942,0.0064195,0.0064195,0.0064195 +219,44056.4,1.0313,1.18896,1.13954,0.68966,0.57577,0.62951,0.46914,0.95411,0.95071,1.11924,0.006403,0.006403,0.006403 +220,44314.4,1.02834,1.17642,1.13161,0.69085,0.57523,0.62982,0.46958,0.95404,0.95032,1.11913,0.0063865,0.0063865,0.0063865 +221,44571.7,1.03152,1.1905,1.14011,0.6928,0.57428,0.63025,0.46997,0.95378,0.94968,1.11878,0.00637,0.00637,0.00637 +222,44829,1.02977,1.1859,1.13481,0.69236,0.57503,0.63054,0.4703,0.95377,0.94898,1.11853,0.0063535,0.0063535,0.0063535 +223,45086.3,1.03317,1.1859,1.13729,0.68968,0.57653,0.63055,0.47013,0.95373,0.94884,1.11844,0.006337,0.006337,0.006337 +224,45343.9,1.03382,1.18694,1.13574,0.69196,0.57563,0.63086,0.47068,0.95353,0.94885,1.11816,0.0063205,0.0063205,0.0063205 +225,45601.3,1.03365,1.18056,1.13698,0.68751,0.57799,0.63107,0.47093,0.95324,0.94831,1.11773,0.006304,0.006304,0.006304 +226,45859,1.03005,1.19217,1.13653,0.68772,0.57777,0.63127,0.47141,0.95302,0.94781,1.11747,0.0062875,0.0062875,0.0062875 +227,46117.3,1.02243,1.17819,1.13656,0.6902,0.57755,0.63173,0.47147,0.95279,0.94725,1.11724,0.006271,0.006271,0.006271 +228,46373.6,1.03003,1.17859,1.13522,0.69427,0.57593,0.63191,0.47159,0.95254,0.94673,1.11705,0.0062545,0.0062545,0.0062545 +229,46631.1,1.02652,1.17497,1.13367,0.69544,0.57732,0.63229,0.472,0.95237,0.94616,1.11686,0.006238,0.006238,0.006238 +230,46887.7,1.02969,1.17377,1.13399,0.69622,0.57717,0.63276,0.4724,0.95219,0.94534,1.11664,0.0062215,0.0062215,0.0062215 +231,47144.5,1.03009,1.17552,1.13452,0.69457,0.5778,0.63302,0.47254,0.9521,0.94464,1.11646,0.006205,0.006205,0.006205 +232,47402.1,1.03577,1.17574,1.13609,0.69512,0.57774,0.6333,0.47268,0.95187,0.94429,1.11615,0.0061885,0.0061885,0.0061885 +233,47659.7,1.02609,1.17715,1.13622,0.69577,0.57771,0.63337,0.47274,0.95154,0.94383,1.11591,0.006172,0.006172,0.006172 +234,47917.3,1.02834,1.17294,1.13544,0.69644,0.57786,0.63372,0.47263,0.95117,0.9433,1.11563,0.0061555,0.0061555,0.0061555 +235,48174.9,1.02472,1.17733,1.13174,0.69461,0.57957,0.63407,0.47311,0.95109,0.94279,1.11553,0.006139,0.006139,0.006139 +236,48432.6,1.02905,1.17365,1.13463,0.69564,0.57958,0.63423,0.47348,0.95077,0.94227,1.11524,0.0061225,0.0061225,0.0061225 +237,48690.3,1.02954,1.17262,1.13408,0.69488,0.58036,0.63434,0.47358,0.95066,0.94146,1.11496,0.006106,0.006106,0.006106 +238,48948.5,1.02582,1.17702,1.13667,0.69582,0.57993,0.6347,0.47375,0.95048,0.94075,1.11478,0.0060895,0.0060895,0.0060895 +239,49206.2,1.03116,1.18202,1.13705,0.69412,0.5809,0.63475,0.47376,0.95033,0.94087,1.11463,0.006073,0.006073,0.006073 +240,49463.9,1.03308,1.17869,1.13295,0.69635,0.58027,0.635,0.47394,0.94987,0.94014,1.11409,0.0060565,0.0060565,0.0060565 +241,49721.3,1.02908,1.17583,1.13354,0.69525,0.58021,0.63507,0.47399,0.94966,0.93977,1.11391,0.00604,0.00604,0.00604 +242,49979.6,1.02686,1.16891,1.13362,0.69499,0.58011,0.63539,0.47425,0.94977,0.93926,1.11403,0.0060235,0.0060235,0.0060235 +243,50237.6,1.02967,1.16734,1.13246,0.69533,0.58109,0.63569,0.47452,0.94929,0.93869,1.11363,0.006007,0.006007,0.006007 +244,50495.7,1.02669,1.17598,1.13301,0.69464,0.58083,0.63582,0.47487,0.94927,0.93853,1.11359,0.0059905,0.0059905,0.0059905 +245,50753.1,1.02876,1.17725,1.13395,0.69552,0.5806,0.63647,0.47524,0.94918,0.93851,1.11352,0.005974,0.005974,0.005974 +246,51010.3,1.02992,1.17754,1.1355,0.69672,0.57998,0.63653,0.47523,0.94905,0.93822,1.1133,0.0059575,0.0059575,0.0059575 +247,51268.4,1.03169,1.17354,1.13543,0.69876,0.57923,0.63666,0.47547,0.94891,0.93805,1.11312,0.005941,0.005941,0.005941 +248,51526.2,1.03039,1.17038,1.13619,0.69799,0.57947,0.63687,0.47533,0.94902,0.93788,1.11307,0.0059245,0.0059245,0.0059245 +249,51783.6,1.02457,1.16743,1.13118,0.69943,0.57828,0.63717,0.47554,0.94886,0.93777,1.11277,0.005908,0.005908,0.005908 +250,52041.4,1.02829,1.16379,1.1317,0.70255,0.57709,0.63719,0.47586,0.94863,0.93727,1.11248,0.0058915,0.0058915,0.0058915 +251,52298.4,1.02676,1.16845,1.13399,0.70705,0.57497,0.63734,0.4758,0.94829,0.93671,1.11213,0.005875,0.005875,0.005875 +252,52555.9,1.02788,1.17172,1.13342,0.7051,0.57714,0.63734,0.47608,0.948,0.93635,1.11189,0.0058585,0.0058585,0.0058585 +253,52813.4,1.02388,1.16572,1.13087,0.70442,0.57768,0.63746,0.47603,0.9479,0.93592,1.11168,0.005842,0.005842,0.005842 +254,53070.8,1.02231,1.164,1.13125,0.70357,0.57814,0.63795,0.47639,0.94766,0.9352,1.11155,0.0058255,0.0058255,0.0058255 +255,53328.5,1.02229,1.16449,1.1307,0.70453,0.57803,0.6385,0.4768,0.94728,0.93491,1.11124,0.005809,0.005809,0.005809 +256,53586.1,1.02176,1.16635,1.13023,0.70256,0.57926,0.6389,0.47691,0.94707,0.93459,1.11095,0.0057925,0.0057925,0.0057925 +257,53843.6,1.02977,1.16162,1.13064,0.70002,0.58087,0.63876,0.47685,0.94712,0.93424,1.11095,0.005776,0.005776,0.005776 +258,54101.3,1.02968,1.17139,1.13416,0.69659,0.58324,0.63895,0.47698,0.94689,0.93421,1.1107,0.0057595,0.0057595,0.0057595 +259,54359,1.01955,1.16072,1.13058,0.69716,0.5833,0.63924,0.47722,0.94681,0.93398,1.11061,0.005743,0.005743,0.005743 +260,54616,1.02116,1.16267,1.13166,0.69646,0.58402,0.63949,0.47724,0.9469,0.93346,1.11055,0.0057265,0.0057265,0.0057265 +261,54874.1,1.02314,1.15949,1.13088,0.69587,0.58415,0.63973,0.47752,0.94666,0.93336,1.11045,0.00571,0.00571,0.00571 +262,55131.5,1.0213,1.15736,1.12649,0.7015,0.58282,0.6401,0.47786,0.94653,0.933,1.11037,0.0056935,0.0056935,0.0056935 +263,55388.7,1.02252,1.15897,1.12877,0.70285,0.58126,0.64038,0.47805,0.94647,0.93267,1.11025,0.005677,0.005677,0.005677 +264,55646.2,1.01828,1.15236,1.12981,0.70581,0.58007,0.64046,0.47821,0.94634,0.93235,1.11026,0.0056605,0.0056605,0.0056605 +265,55903.3,1.01643,1.15078,1.12661,0.70541,0.57958,0.64053,0.47812,0.94622,0.93202,1.11001,0.005644,0.005644,0.005644 +266,56160.2,1.01755,1.14788,1.12664,0.70406,0.57851,0.64046,0.47821,0.94594,0.93178,1.10978,0.0056275,0.0056275,0.0056275 +267,56417.3,1.02123,1.14971,1.12707,0.70363,0.57897,0.64071,0.47847,0.94581,0.9315,1.10949,0.005611,0.005611,0.005611 +268,56675.3,1.0198,1.15368,1.12652,0.7024,0.58081,0.64109,0.47872,0.94565,0.93116,1.10945,0.0055945,0.0055945,0.0055945 +269,56933.5,1.02343,1.15412,1.12799,0.70698,0.57795,0.64132,0.4789,0.94534,0.93071,1.10919,0.005578,0.005578,0.005578 +270,57191.4,1.01161,1.14851,1.1255,0.69306,0.57946,0.64157,0.47945,0.94506,0.92992,1.10876,0.0055615,0.0055615,0.0055615 +271,57448.3,1.01517,1.15701,1.12847,0.69409,0.5809,0.64193,0.47946,0.94481,0.92997,1.10862,0.005545,0.005545,0.005545 +272,57705.8,1.01652,1.15602,1.12925,0.69351,0.58153,0.64192,0.47973,0.9446,0.92946,1.10851,0.0055285,0.0055285,0.0055285 +273,57963.3,1.01967,1.15592,1.12828,0.69108,0.58397,0.64194,0.47983,0.94434,0.92893,1.10828,0.005512,0.005512,0.005512 +274,58219.8,1.01467,1.14315,1.12502,0.69502,0.58103,0.64207,0.47997,0.94449,0.92875,1.10828,0.0054955,0.0054955,0.0054955 +275,58476.5,1.01137,1.14126,1.12303,0.69292,0.58261,0.64268,0.48017,0.9446,0.92823,1.10832,0.005479,0.005479,0.005479 +276,58733.9,1.02272,1.15129,1.12751,0.69922,0.5795,0.6427,0.48026,0.9447,0.92786,1.10846,0.0054625,0.0054625,0.0054625 +277,58990.8,1.014,1.14344,1.12366,0.69329,0.58395,0.64271,0.48013,0.94426,0.92788,1.1081,0.005446,0.005446,0.005446 +278,59248,1.0189,1.15275,1.12812,0.69674,0.58171,0.64276,0.48019,0.94405,0.92744,1.10789,0.0054295,0.0054295,0.0054295 +279,59505.6,1.01653,1.14355,1.12356,0.6945,0.5825,0.64276,0.48014,0.9439,0.9274,1.10774,0.005413,0.005413,0.005413 +280,59762.7,1.0143,1.1504,1.12369,0.69903,0.58146,0.64315,0.48039,0.94394,0.9271,1.10765,0.0053965,0.0053965,0.0053965 +281,60019.7,1.01701,1.14489,1.1258,0.69812,0.58204,0.64322,0.48063,0.94391,0.92696,1.10754,0.00538,0.00538,0.00538 +282,60276.2,1.01355,1.14692,1.12635,0.69828,0.58282,0.64339,0.48085,0.94362,0.92659,1.10733,0.0053635,0.0053635,0.0053635 +283,60533.4,1.0107,1.14072,1.12222,0.69746,0.58404,0.64351,0.48079,0.9435,0.92613,1.10724,0.005347,0.005347,0.005347 +284,60790.6,1.01536,1.14441,1.12461,0.69555,0.58526,0.64383,0.48114,0.94332,0.92599,1.10694,0.0053305,0.0053305,0.0053305 +285,61048.1,1.01233,1.14142,1.12486,0.69704,0.58409,0.64411,0.48129,0.94317,0.92553,1.10689,0.005314,0.005314,0.005314 +286,61305.5,1.01227,1.14065,1.1231,0.69845,0.58428,0.64426,0.48159,0.94301,0.92518,1.10679,0.0052975,0.0052975,0.0052975 +287,61562.2,1.01211,1.13638,1.12281,0.69759,0.58519,0.64425,0.48182,0.94267,0.92445,1.10672,0.005281,0.005281,0.005281 +288,61819,1.01002,1.14313,1.12515,0.69642,0.58463,0.64445,0.48199,0.9424,0.92406,1.10655,0.0052645,0.0052645,0.0052645 +289,62075.8,1.01112,1.13418,1.12115,0.69672,0.58519,0.64472,0.48212,0.9421,0.92369,1.10635,0.005248,0.005248,0.005248 +290,62332.5,1.01436,1.14915,1.12733,0.69625,0.58573,0.64465,0.48215,0.9422,0.92317,1.10635,0.0052315,0.0052315,0.0052315 +291,62589.5,1.01248,1.14213,1.12484,0.69541,0.58606,0.64483,0.48226,0.94189,0.92291,1.10613,0.005215,0.005215,0.005215 +292,62846.6,1.01065,1.13621,1.12403,0.6961,0.58545,0.64514,0.4826,0.94175,0.92238,1.10584,0.0051985,0.0051985,0.0051985 +293,63102.9,1.00763,1.13532,1.12049,0.69522,0.58583,0.64525,0.48249,0.94173,0.92223,1.10577,0.005182,0.005182,0.005182 +294,63359.2,1.01014,1.13671,1.12405,0.69559,0.58579,0.6453,0.4825,0.94181,0.9221,1.10572,0.0051655,0.0051655,0.0051655 +295,63616.4,1.00812,1.12839,1.12109,0.69556,0.58671,0.6455,0.48272,0.94174,0.92201,1.10559,0.005149,0.005149,0.005149 +296,63873.8,1.01017,1.13382,1.12266,0.69687,0.58589,0.64572,0.48275,0.94168,0.92134,1.10557,0.0051325,0.0051325,0.0051325 +297,64131.1,1.01262,1.13637,1.12069,0.69505,0.58612,0.64568,0.48265,0.94141,0.92071,1.10522,0.005116,0.005116,0.005116 +298,64388.3,1.01188,1.13341,1.1212,0.69632,0.58544,0.64589,0.48298,0.94107,0.92026,1.10486,0.0050995,0.0050995,0.0050995 +299,64644.7,1.01435,1.12334,1.12064,0.69737,0.58521,0.64653,0.48324,0.94125,0.91999,1.105,0.005083,0.005083,0.005083 +300,64901.2,1.00952,1.13871,1.12287,0.69785,0.58737,0.64673,0.4835,0.94114,0.91971,1.10499,0.0050665,0.0050665,0.0050665 +301,65158.3,1.0053,1.12192,1.12091,0.69408,0.58954,0.64652,0.48373,0.94079,0.91926,1.10461,0.00505,0.00505,0.00505 +302,65415.3,1.00744,1.12303,1.1189,0.69628,0.58871,0.64673,0.48397,0.9405,0.91901,1.10441,0.0050335,0.0050335,0.0050335 +303,65671.8,1.00506,1.12238,1.11869,0.69406,0.58999,0.64713,0.48417,0.94025,0.91863,1.10406,0.005017,0.005017,0.005017 +304,65928.9,1.01309,1.12992,1.12311,0.69783,0.5883,0.64742,0.4843,0.94021,0.9183,1.10407,0.0050005,0.0050005,0.0050005 +305,66185.5,1.00669,1.12352,1.11999,0.70231,0.58328,0.64722,0.48424,0.9398,0.91812,1.10364,0.004984,0.004984,0.004984 +306,66441.9,1.00515,1.11662,1.11599,0.70395,0.58322,0.64735,0.48451,0.93959,0.9178,1.10336,0.0049675,0.0049675,0.0049675 +307,66699.3,1.00284,1.11842,1.11867,0.70817,0.58101,0.64752,0.48455,0.93953,0.91751,1.10336,0.004951,0.004951,0.004951 +308,66956.9,1.00434,1.12538,1.1188,0.71151,0.58132,0.64762,0.48492,0.93949,0.9173,1.10333,0.0049345,0.0049345,0.0049345 +309,67214,1.00632,1.12348,1.1166,0.70519,0.58401,0.64767,0.48496,0.93945,0.91696,1.10323,0.004918,0.004918,0.004918 +310,67471.4,1.00342,1.12264,1.11719,0.71019,0.58233,0.64799,0.4852,0.93913,0.91656,1.10301,0.0049015,0.0049015,0.0049015 +311,67729.5,1.00294,1.11964,1.11821,0.71108,0.58199,0.64813,0.48531,0.93911,0.91629,1.10284,0.004885,0.004885,0.004885 +312,67986.7,1.00842,1.1204,1.11674,0.70781,0.58382,0.64854,0.48554,0.93921,0.91595,1.10289,0.0048685,0.0048685,0.0048685 +313,68242.8,1.00495,1.12109,1.11949,0.70793,0.58423,0.64866,0.48562,0.93944,0.91571,1.10288,0.004852,0.004852,0.004852 +314,68499.5,1.00811,1.1218,1.11953,0.70577,0.58555,0.64877,0.4858,0.93926,0.9153,1.10273,0.0048355,0.0048355,0.0048355 +315,68756.8,0.99966,1.11721,1.11574,0.70511,0.58667,0.64885,0.48607,0.93901,0.91492,1.10242,0.004819,0.004819,0.004819 +316,69013.9,1.00315,1.12138,1.11662,0.70408,0.58669,0.64912,0.48604,0.93878,0.91453,1.10222,0.0048025,0.0048025,0.0048025 +317,69270.9,1.00645,1.12441,1.11939,0.70418,0.58727,0.64919,0.48628,0.93874,0.91413,1.10217,0.004786,0.004786,0.004786 +318,69528.4,0.99974,1.11851,1.11815,0.69613,0.59236,0.64946,0.48627,0.93856,0.91362,1.1019,0.0047695,0.0047695,0.0047695 +319,69785.8,1.00332,1.11867,1.11787,0.69632,0.59255,0.64967,0.48646,0.93843,0.9133,1.10172,0.004753,0.004753,0.004753 +320,70043.9,0.99846,1.10819,1.11393,0.69667,0.59152,0.64987,0.48653,0.93809,0.91284,1.10153,0.0047365,0.0047365,0.0047365 +321,70301.6,1.00586,1.11272,1.1194,0.70023,0.58805,0.64977,0.48646,0.93801,0.91249,1.1014,0.00472,0.00472,0.00472 +322,70558.6,0.99458,1.09721,1.11304,0.70241,0.58697,0.64988,0.48659,0.93791,0.91212,1.10136,0.0047035,0.0047035,0.0047035 +323,70816,1.00297,1.11286,1.11565,0.70431,0.58642,0.64992,0.48673,0.93769,0.91192,1.10113,0.004687,0.004687,0.004687 +324,71073.3,0.99815,1.11585,1.11503,0.70336,0.58737,0.65039,0.48708,0.93741,0.91142,1.10105,0.0046705,0.0046705,0.0046705 +325,71331,1.00018,1.10984,1.11495,0.69812,0.59122,0.64999,0.48658,0.93746,0.91125,1.10106,0.004654,0.004654,0.004654 +326,71587.8,0.99731,1.11526,1.11835,0.69639,0.5937,0.65043,0.48719,0.93724,0.91107,1.10089,0.0046375,0.0046375,0.0046375 +327,71845.5,0.99816,1.11848,1.11635,0.69433,0.59524,0.65095,0.48747,0.9372,0.91089,1.10085,0.004621,0.004621,0.004621 +328,72103,0.9951,1.10802,1.11229,0.69466,0.5951,0.65102,0.48761,0.93732,0.91054,1.10076,0.0046045,0.0046045,0.0046045 +329,72360.4,0.99405,1.10014,1.11085,0.69589,0.59432,0.65103,0.48773,0.93714,0.91014,1.10055,0.004588,0.004588,0.004588 +330,72617.4,0.99688,1.09913,1.11223,0.6935,0.59608,0.65128,0.48783,0.93715,0.90948,1.10054,0.0045715,0.0045715,0.0045715 +331,72875,1.00082,1.11216,1.11689,0.69046,0.59709,0.65165,0.48813,0.93702,0.90912,1.10016,0.004555,0.004555,0.004555 +332,73133.1,0.99613,1.11912,1.11634,0.69067,0.59744,0.6518,0.48815,0.93658,0.90887,1.09983,0.0045385,0.0045385,0.0045385 +333,73390.1,0.99908,1.09836,1.11173,0.69017,0.59805,0.65196,0.48821,0.93644,0.9086,1.09975,0.004522,0.004522,0.004522 +334,73647.7,0.99431,1.1058,1.11384,0.68955,0.59901,0.65226,0.48844,0.93617,0.90837,1.09942,0.0045055,0.0045055,0.0045055 +335,73905.6,0.99221,1.10108,1.11289,0.68822,0.60037,0.65294,0.48915,0.93643,0.90802,1.09957,0.004489,0.004489,0.004489 +336,74162.1,0.99556,1.09554,1.11073,0.69126,0.59937,0.65293,0.48898,0.9362,0.9076,1.09922,0.0044725,0.0044725,0.0044725 +337,74419,0.99265,1.10164,1.11443,0.68828,0.60126,0.65295,0.48897,0.93616,0.90709,1.09907,0.004456,0.004456,0.004456 +338,74676.5,0.9941,1.1031,1.11513,0.6901,0.60021,0.65254,0.48913,0.9358,0.90678,1.09879,0.0044395,0.0044395,0.0044395 +339,74934.1,0.99249,1.09775,1.10985,0.69045,0.60052,0.65282,0.48919,0.93569,0.9064,1.09865,0.004423,0.004423,0.004423 +340,75191.4,0.99467,1.09673,1.11086,0.69166,0.6003,0.65318,0.48942,0.93512,0.90604,1.09836,0.0044065,0.0044065,0.0044065 +341,75448.3,0.99791,1.10253,1.11569,0.69068,0.60171,0.6542,0.48987,0.93507,0.90554,1.09839,0.00439,0.00439,0.00439 +342,75705.9,0.98869,1.10163,1.11137,0.69215,0.60101,0.65428,0.49012,0.93486,0.90529,1.09811,0.0043735,0.0043735,0.0043735 +343,75963.3,0.99007,1.08805,1.11009,0.69577,0.59799,0.65428,0.49011,0.93481,0.90487,1.09795,0.004357,0.004357,0.004357 +344,76220.6,0.99538,1.09904,1.11138,0.69782,0.59693,0.65448,0.49023,0.93481,0.90477,1.09792,0.0043405,0.0043405,0.0043405 +345,76477.9,0.99514,1.10244,1.11529,0.69587,0.59992,0.65423,0.48999,0.9346,0.90444,1.09778,0.004324,0.004324,0.004324 +346,76735.6,0.98984,1.09551,1.10963,0.69522,0.60174,0.65425,0.49021,0.93454,0.90382,1.09778,0.0043075,0.0043075,0.0043075 +347,76993,0.99582,1.09959,1.11448,0.69807,0.59972,0.6543,0.49084,0.93438,0.90329,1.09768,0.004291,0.004291,0.004291 +348,77250.4,0.99223,1.09033,1.11021,0.69751,0.60086,0.65472,0.49075,0.9344,0.90307,1.09761,0.0042745,0.0042745,0.0042745 +349,77508.2,0.99375,1.09035,1.11041,0.69767,0.60171,0.65503,0.49099,0.9345,0.90262,1.09763,0.004258,0.004258,0.004258 +350,77765.1,0.99224,1.09089,1.10787,0.69798,0.60105,0.65524,0.49132,0.93421,0.90217,1.0974,0.0042415,0.0042415,0.0042415 +351,78022.1,0.99027,1.08999,1.10958,0.69751,0.60057,0.65519,0.49154,0.93379,0.90191,1.09695,0.004225,0.004225,0.004225 +352,78279.4,0.99116,1.09477,1.11179,0.69989,0.59987,0.65571,0.49195,0.93349,0.90127,1.09671,0.0042085,0.0042085,0.0042085 +353,78536.2,0.99337,1.09062,1.11039,0.70074,0.6003,0.65574,0.49186,0.93344,0.9012,1.09666,0.004192,0.004192,0.004192 +354,78793.2,0.9927,1.0865,1.10853,0.70274,0.59942,0.65622,0.49231,0.93315,0.90114,1.09637,0.0041755,0.0041755,0.0041755 +355,79050.7,0.99261,1.09566,1.10838,0.70218,0.59985,0.65601,0.49232,0.93279,0.90112,1.09604,0.004159,0.004159,0.004159 +356,79307.8,0.99344,1.09519,1.10818,0.70144,0.60109,0.65599,0.49241,0.93274,0.90093,1.09588,0.0041425,0.0041425,0.0041425 +357,79564.6,0.99237,1.08708,1.10936,0.70237,0.59986,0.65612,0.49279,0.93243,0.90079,1.09563,0.004126,0.004126,0.004126 +358,79821.4,0.98997,1.08485,1.10448,0.70394,0.59854,0.65683,0.49315,0.93226,0.90034,1.09547,0.0041095,0.0041095,0.0041095 +359,80078.7,0.98976,1.08651,1.10907,0.70388,0.59858,0.65715,0.49325,0.93223,0.90044,1.09526,0.004093,0.004093,0.004093 +360,80335.8,0.98688,1.08217,1.10574,0.70403,0.5983,0.65736,0.49321,0.93232,0.89982,1.09537,0.0040765,0.0040765,0.0040765 +361,80593.9,0.99012,1.08486,1.10981,0.70261,0.59918,0.65756,0.49345,0.93215,0.89938,1.09518,0.00406,0.00406,0.00406 +362,80851.8,0.99059,1.08681,1.10771,0.70112,0.60092,0.65791,0.49351,0.93181,0.89921,1.09488,0.0040435,0.0040435,0.0040435 +363,81109.4,0.98623,1.08123,1.10672,0.69949,0.6027,0.65809,0.49364,0.93167,0.89893,1.0948,0.004027,0.004027,0.004027 +364,81367.1,0.98581,1.08473,1.10841,0.69729,0.60327,0.65781,0.49374,0.93167,0.89869,1.09465,0.0040105,0.0040105,0.0040105 +365,81624.2,0.98736,1.0773,1.10895,0.69843,0.60386,0.65841,0.49394,0.93158,0.89852,1.0945,0.003994,0.003994,0.003994 +366,81881.4,0.98393,1.07906,1.10413,0.69982,0.60286,0.65856,0.49407,0.93145,0.89806,1.09445,0.0039775,0.0039775,0.0039775 +367,82138.5,0.98352,1.08174,1.10693,0.69909,0.60308,0.65858,0.49412,0.93138,0.89752,1.09427,0.003961,0.003961,0.003961 +368,82396.3,0.98587,1.08275,1.1083,0.69808,0.60363,0.65849,0.49409,0.93107,0.89704,1.09411,0.0039445,0.0039445,0.0039445 +369,82653.8,0.98363,1.06925,1.10065,0.69768,0.60362,0.65852,0.49435,0.93128,0.89691,1.09409,0.003928,0.003928,0.003928 +370,82911.3,0.98579,1.0791,1.10739,0.69765,0.60402,0.65872,0.49442,0.93123,0.8968,1.09391,0.0039115,0.0039115,0.0039115 +371,83168.8,0.98765,1.07564,1.10705,0.69907,0.60398,0.65864,0.49444,0.9313,0.89666,1.09368,0.003895,0.003895,0.003895 +372,83425.7,0.98537,1.07874,1.10365,0.70192,0.60145,0.65885,0.4947,0.93113,0.89646,1.09343,0.0038785,0.0038785,0.0038785 +373,83683.7,0.98244,1.07385,1.10442,0.7016,0.60184,0.65882,0.49445,0.931,0.89594,1.09326,0.003862,0.003862,0.003862 +374,83941.5,0.98379,1.0698,1.10541,0.70411,0.60059,0.65899,0.49471,0.93089,0.89575,1.09308,0.0038455,0.0038455,0.0038455 +375,84198.8,0.98559,1.07222,1.10424,0.70071,0.60529,0.65922,0.49513,0.93089,0.89524,1.09301,0.003829,0.003829,0.003829 +376,84455.6,0.98021,1.06908,1.10247,0.70383,0.60281,0.65941,0.49501,0.93093,0.89512,1.09302,0.0038125,0.0038125,0.0038125 +377,84713.2,0.97858,1.07256,1.10334,0.70056,0.60599,0.65966,0.495,0.93082,0.89491,1.09273,0.003796,0.003796,0.003796 +378,84971.5,0.98158,1.06743,1.10284,0.70266,0.60485,0.66007,0.49518,0.9307,0.89489,1.09259,0.0037795,0.0037795,0.0037795 +379,85228.6,0.97761,1.07104,1.10269,0.70266,0.60515,0.66015,0.4953,0.93034,0.89473,1.09233,0.003763,0.003763,0.003763 +380,85486.6,0.98881,1.06899,1.10623,0.7008,0.60499,0.66014,0.49532,0.93015,0.89476,1.09222,0.0037465,0.0037465,0.0037465 +381,85743.8,0.97874,1.06578,1.10271,0.70128,0.60498,0.65985,0.49518,0.92963,0.89458,1.09197,0.00373,0.00373,0.00373 +382,86001.2,0.98049,1.06515,1.10277,0.70461,0.60306,0.66019,0.49543,0.92924,0.89409,1.09173,0.0037135,0.0037135,0.0037135 +383,86258.3,0.98117,1.06583,1.10364,0.70476,0.60415,0.66011,0.49523,0.92873,0.89371,1.09148,0.003697,0.003697,0.003697 +384,86515.5,0.98277,1.07432,1.10427,0.70685,0.60268,0.66033,0.49564,0.9282,0.89353,1.09096,0.0036805,0.0036805,0.0036805 +385,86773.1,0.98329,1.06436,1.1008,0.70002,0.60701,0.66081,0.49597,0.92769,0.89313,1.09043,0.003664,0.003664,0.003664 +386,87030.5,0.97815,1.05933,1.10131,0.70019,0.60648,0.66066,0.49577,0.92767,0.89253,1.09029,0.0036475,0.0036475,0.0036475 +387,87288.1,0.97652,1.05696,1.1005,0.69931,0.60621,0.66099,0.4961,0.92739,0.89269,1.09004,0.003631,0.003631,0.003631 +388,87545.6,0.97624,1.05745,1.10126,0.69833,0.6074,0.66082,0.49619,0.92731,0.89226,1.09006,0.0036145,0.0036145,0.0036145 +389,87803.3,0.97987,1.05993,1.10205,0.6997,0.60801,0.66137,0.49628,0.92733,0.89195,1.09001,0.003598,0.003598,0.003598 +390,88060.6,0.97295,1.06486,1.10096,0.69956,0.60905,0.66128,0.4964,0.9272,0.89185,1.09002,0.0035815,0.0035815,0.0035815 +391,88317.4,0.9731,1.05616,1.09978,0.70015,0.60887,0.66128,0.49643,0.92725,0.89149,1.08991,0.003565,0.003565,0.003565 +392,88574.8,0.97706,1.05901,1.09874,0.70126,0.60699,0.6616,0.49671,0.92701,0.89125,1.08975,0.0035485,0.0035485,0.0035485 +393,88832.1,0.97554,1.05838,1.10017,0.70121,0.60917,0.6621,0.49719,0.92678,0.89075,1.08952,0.003532,0.003532,0.003532 +394,89089.5,0.97828,1.05791,1.10101,0.70318,0.60801,0.66217,0.49727,0.92664,0.89022,1.08927,0.0035155,0.0035155,0.0035155 +395,89347.2,0.97938,1.05673,1.09884,0.70281,0.60927,0.66206,0.49737,0.92663,0.88981,1.08905,0.003499,0.003499,0.003499 +396,89604.7,0.97455,1.05243,1.09963,0.70187,0.60872,0.662,0.49737,0.92626,0.88943,1.0888,0.0034825,0.0034825,0.0034825 +397,89861.9,0.96985,1.05872,1.09772,0.70385,0.60762,0.66217,0.49759,0.92624,0.8891,1.08871,0.003466,0.003466,0.003466 +398,90120.1,0.97893,1.05401,1.09769,0.70375,0.6079,0.66178,0.49734,0.92622,0.88869,1.08861,0.0034495,0.0034495,0.0034495 +399,90377.9,0.973,1.0441,1.09617,0.70138,0.60924,0.66171,0.49725,0.92603,0.88833,1.0885,0.003433,0.003433,0.003433 +400,90635.7,0.97687,1.05337,1.09994,0.70346,0.60836,0.66254,0.4979,0.926,0.88793,1.08837,0.0034165,0.0034165,0.0034165 +401,90893.6,0.97061,1.04453,1.09623,0.70484,0.6078,0.66299,0.49807,0.92568,0.88712,1.08807,0.0034,0.0034,0.0034 +402,91151.2,0.97331,1.05356,1.09858,0.7062,0.6072,0.66292,0.49819,0.92547,0.88676,1.08795,0.0033835,0.0033835,0.0033835 +403,91408.4,0.97607,1.04658,1.10053,0.70683,0.60651,0.66311,0.49855,0.92534,0.8864,1.08806,0.003367,0.003367,0.003367 +404,91665.6,0.97246,1.04915,1.09657,0.70829,0.6057,0.66316,0.49842,0.92512,0.88601,1.08789,0.0033505,0.0033505,0.0033505 +405,91923.2,0.9727,1.05196,1.09746,0.70599,0.60762,0.66381,0.49878,0.92497,0.88557,1.08782,0.003334,0.003334,0.003334 +406,92180.4,0.97138,1.04222,1.0988,0.70693,0.60798,0.66415,0.49858,0.92474,0.88541,1.08748,0.0033175,0.0033175,0.0033175 +407,92438.5,0.96875,1.04255,1.09644,0.70871,0.60653,0.66381,0.49856,0.9248,0.88507,1.08749,0.003301,0.003301,0.003301 +408,92695.9,0.97127,1.04304,1.09481,0.7083,0.60654,0.66356,0.49856,0.92454,0.88484,1.08736,0.0032845,0.0032845,0.0032845 +409,92953.3,0.96878,1.04099,1.09332,0.71196,0.60445,0.66398,0.49901,0.92429,0.88441,1.08718,0.003268,0.003268,0.003268 +410,93210.5,0.96919,1.03942,1.09314,0.70763,0.60763,0.66421,0.49895,0.92414,0.88428,1.08711,0.0032515,0.0032515,0.0032515 +411,93467.7,0.9743,1.03858,1.09535,0.70732,0.60868,0.66403,0.49888,0.92388,0.88405,1.0869,0.003235,0.003235,0.003235 +412,93724.8,0.97215,1.04685,1.09635,0.70705,0.60835,0.66402,0.49894,0.92375,0.88342,1.08675,0.0032185,0.0032185,0.0032185 +413,93981.7,0.97781,1.04749,1.09854,0.70552,0.61006,0.66442,0.49903,0.92364,0.88324,1.0866,0.003202,0.003202,0.003202 +414,94239,0.96985,1.04336,1.09667,0.70699,0.6087,0.66426,0.49894,0.92326,0.88297,1.0863,0.0031855,0.0031855,0.0031855 +415,94496.3,0.96852,1.04391,1.09666,0.70734,0.6096,0.66401,0.49899,0.92318,0.8828,1.08625,0.003169,0.003169,0.003169 +416,94753.2,0.96821,1.03099,1.09425,0.70349,0.61176,0.66424,0.49916,0.92318,0.88233,1.08625,0.0031525,0.0031525,0.0031525 +417,95010.9,0.96457,1.03452,1.09384,0.70117,0.61328,0.66427,0.49934,0.92334,0.88206,1.08627,0.003136,0.003136,0.003136 +418,95268.6,0.96508,1.04088,1.09618,0.70651,0.60969,0.6645,0.49993,0.92305,0.88156,1.08605,0.0031195,0.0031195,0.0031195 +419,95525.4,0.9639,1.03888,1.09367,0.70777,0.60825,0.66459,0.4998,0.92296,0.88125,1.08592,0.003103,0.003103,0.003103 +420,95783,0.96912,1.0338,1.09281,0.7096,0.60797,0.66477,0.49975,0.92304,0.88081,1.08579,0.0030865,0.0030865,0.0030865 +421,96040.3,0.96572,1.03852,1.09566,0.71062,0.60693,0.66548,0.50046,0.92278,0.88035,1.08553,0.00307,0.00307,0.00307 +422,96297.5,0.96718,1.03741,1.09617,0.7133,0.60615,0.66596,0.50052,0.92278,0.88008,1.0854,0.0030535,0.0030535,0.0030535 +423,96554.8,0.96103,1.02685,1.09164,0.71161,0.60714,0.66599,0.50078,0.92268,0.87965,1.08529,0.003037,0.003037,0.003037 +424,96811.7,0.96469,1.03793,1.09409,0.71145,0.60746,0.66576,0.50043,0.92244,0.87921,1.0851,0.0030205,0.0030205,0.0030205 +425,97069.3,0.96832,1.02947,1.09413,0.7103,0.60852,0.66611,0.50063,0.92215,0.87875,1.08495,0.003004,0.003004,0.003004 +426,97327.5,0.97071,1.03172,1.09455,0.70989,0.60854,0.66608,0.50073,0.92195,0.87823,1.08489,0.0029875,0.0029875,0.0029875 +427,97585.1,0.96633,1.02991,1.09132,0.71604,0.60553,0.66588,0.50056,0.92171,0.87781,1.08472,0.002971,0.002971,0.002971 +428,97842.1,0.96544,1.02611,1.09304,0.71595,0.60626,0.66615,0.50107,0.92165,0.87761,1.08462,0.0029545,0.0029545,0.0029545 +429,98099.7,0.95812,1.02587,1.09367,0.71337,0.60816,0.66658,0.50119,0.92172,0.87748,1.08451,0.002938,0.002938,0.002938 +430,98357.2,0.96119,1.03029,1.09151,0.7085,0.6108,0.66664,0.50137,0.92167,0.87711,1.08437,0.0029215,0.0029215,0.0029215 +431,98614.3,0.95863,1.0233,1.09135,0.71034,0.61042,0.66683,0.50137,0.92142,0.87664,1.08402,0.002905,0.002905,0.002905 +432,98872.2,0.95444,1.02499,1.08987,0.71034,0.61064,0.66702,0.50153,0.92138,0.87626,1.08406,0.0028885,0.0028885,0.0028885 +433,99130.5,0.96417,1.03265,1.09291,0.71241,0.6103,0.66716,0.50172,0.92145,0.87601,1.08404,0.002872,0.002872,0.002872 +434,99387.8,0.96653,1.02463,1.09045,0.71427,0.60942,0.6673,0.50185,0.9212,0.87589,1.08376,0.0028555,0.0028555,0.0028555 +435,99645.3,0.96147,1.0233,1.09014,0.714,0.60893,0.66735,0.50203,0.92137,0.87594,1.08382,0.002839,0.002839,0.002839 +436,99902.7,0.9561,1.02119,1.08908,0.71217,0.61014,0.66746,0.50219,0.92129,0.87571,1.0838,0.0028225,0.0028225,0.0028225 +437,100160,0.95888,1.01687,1.08932,0.71377,0.60826,0.66748,0.50227,0.92107,0.87554,1.08365,0.002806,0.002806,0.002806 +438,100417,0.95473,1.01564,1.08979,0.71339,0.60893,0.66796,0.50265,0.92083,0.87504,1.08339,0.0027895,0.0027895,0.0027895 +439,100675,0.95671,1.02052,1.08878,0.71125,0.61048,0.66837,0.50262,0.92063,0.8749,1.08321,0.002773,0.002773,0.002773 +440,100932,0.9544,1.01121,1.08861,0.71094,0.6131,0.66859,0.50299,0.92049,0.87455,1.08309,0.0027565,0.0027565,0.0027565 +441,101189,0.95546,1.01752,1.0893,0.71577,0.61026,0.66851,0.50305,0.92041,0.87432,1.08312,0.00274,0.00274,0.00274 +442,101447,0.9625,1.01565,1.09044,0.71479,0.61273,0.66884,0.50326,0.92018,0.87422,1.08294,0.0027235,0.0027235,0.0027235 +443,101704,0.95547,1.00421,1.08687,0.71759,0.6107,0.66867,0.50341,0.92002,0.8741,1.08284,0.002707,0.002707,0.002707 +444,101962,0.96214,1.01319,1.08969,0.71713,0.6115,0.6688,0.50361,0.91996,0.8741,1.08275,0.0026905,0.0026905,0.0026905 +445,102220,0.96046,1.01023,1.09009,0.71706,0.61181,0.66907,0.5038,0.91974,0.87381,1.08257,0.002674,0.002674,0.002674 +446,102477,0.95222,1.01034,1.08915,0.71796,0.61151,0.66911,0.50395,0.9197,0.87358,1.08248,0.0026575,0.0026575,0.0026575 +447,102734,0.95013,1.01143,1.08653,0.7183,0.61147,0.6696,0.5043,0.91964,0.8736,1.08258,0.002641,0.002641,0.002641 +448,102992,0.95626,1.00541,1.09109,0.71529,0.61387,0.66968,0.5043,0.91956,0.87342,1.08256,0.0026245,0.0026245,0.0026245 +449,103249,0.95049,1.00689,1.08606,0.71556,0.61287,0.66984,0.5046,0.9193,0.87309,1.08242,0.002608,0.002608,0.002608 +450,103507,0.95071,1.00477,1.08563,0.71307,0.61405,0.6703,0.50497,0.91896,0.87276,1.08221,0.0025915,0.0025915,0.0025915 +451,103765,0.95554,1.00386,1.086,0.7131,0.61358,0.67037,0.50498,0.919,0.87233,1.08216,0.002575,0.002575,0.002575 +452,104022,0.95586,1.00556,1.08802,0.71044,0.61377,0.67053,0.50524,0.91879,0.87204,1.08187,0.0025585,0.0025585,0.0025585 +453,104279,0.95406,0.99986,1.08564,0.71204,0.6134,0.67069,0.5054,0.91908,0.87178,1.08203,0.002542,0.002542,0.002542 +454,104536,0.94915,1.00386,1.08636,0.7097,0.61465,0.67052,0.50516,0.91909,0.87141,1.08201,0.0025255,0.0025255,0.0025255 +455,104793,0.9511,1.0017,1.08407,0.70555,0.61656,0.67083,0.50558,0.91893,0.87113,1.08189,0.002509,0.002509,0.002509 +456,105051,0.94622,0.99949,1.08209,0.70586,0.61732,0.6712,0.50555,0.9187,0.87074,1.08159,0.0024925,0.0024925,0.0024925 +457,105309,0.95291,0.9985,1.08153,0.70904,0.61592,0.67145,0.50572,0.91852,0.86983,1.08157,0.002476,0.002476,0.002476 +458,105566,0.9485,0.99861,1.08615,0.71225,0.61634,0.67148,0.5056,0.9184,0.86945,1.08144,0.0024595,0.0024595,0.0024595 +459,105823,0.95054,1.00439,1.08419,0.7106,0.61785,0.67176,0.50608,0.9183,0.86897,1.08148,0.002443,0.002443,0.002443 +460,106080,0.94693,0.987,1.08186,0.71134,0.61952,0.67165,0.50618,0.91828,0.86866,1.0814,0.0024265,0.0024265,0.0024265 +461,106338,0.94757,0.98883,1.08043,0.71157,0.61923,0.67189,0.5061,0.91828,0.8683,1.0812,0.00241,0.00241,0.00241 +462,106595,0.94616,0.99424,1.0821,0.70982,0.62042,0.67211,0.50639,0.91829,0.86793,1.08116,0.0023935,0.0023935,0.0023935 +463,106853,0.94409,0.99341,1.08244,0.71272,0.62086,0.67249,0.50689,0.91816,0.86745,1.08096,0.002377,0.002377,0.002377 +464,107111,0.93839,0.982,1.07787,0.71355,0.62084,0.67277,0.50705,0.91805,0.86701,1.08088,0.0023605,0.0023605,0.0023605 +465,107368,0.94902,0.99716,1.08827,0.71724,0.61881,0.67323,0.50732,0.91772,0.86675,1.08074,0.002344,0.002344,0.002344 +466,107625,0.94835,0.99247,1.08282,0.71743,0.61844,0.67318,0.50735,0.91746,0.86624,1.08057,0.0023275,0.0023275,0.0023275 +467,107883,0.94631,0.9865,1.08048,0.71653,0.61916,0.67323,0.50746,0.91742,0.86577,1.08027,0.002311,0.002311,0.002311 +468,108140,0.94297,0.98774,1.07895,0.71465,0.6194,0.67355,0.50784,0.91734,0.86513,1.08003,0.0022945,0.0022945,0.0022945 +469,108397,0.94598,0.9918,1.08216,0.71539,0.61961,0.67386,0.50798,0.91698,0.86464,1.07981,0.002278,0.002278,0.002278 +470,108655,0.94754,0.99042,1.08322,0.71166,0.62092,0.67363,0.50794,0.9169,0.86426,1.07963,0.0022615,0.0022615,0.0022615 +471,108911,0.95015,0.99509,1.082,0.71138,0.62094,0.67398,0.50824,0.91686,0.86404,1.07951,0.002245,0.002245,0.002245 +472,109168,0.94358,0.98311,1.07776,0.71129,0.62183,0.67393,0.50829,0.91677,0.86369,1.07941,0.0022285,0.0022285,0.0022285 +473,109425,0.94033,0.9804,1.07567,0.71064,0.62311,0.67382,0.5083,0.9164,0.86333,1.07916,0.002212,0.002212,0.002212 +474,109683,0.94106,0.98576,1.07989,0.71045,0.62373,0.6742,0.50848,0.91588,0.86266,1.0788,0.0021955,0.0021955,0.0021955 +475,109939,0.94218,0.97965,1.07733,0.71135,0.62258,0.6742,0.50876,0.91553,0.86248,1.07849,0.002179,0.002179,0.002179 +476,110196,0.93846,0.97964,1.07909,0.71077,0.62384,0.67477,0.50882,0.91546,0.86193,1.07827,0.0021625,0.0021625,0.0021625 +477,110453,0.94311,0.98171,1.0779,0.71363,0.62133,0.67499,0.50904,0.91512,0.86178,1.07794,0.002146,0.002146,0.002146 +478,110710,0.94416,0.9808,1.07763,0.71537,0.62124,0.67529,0.5091,0.91498,0.86143,1.07782,0.0021295,0.0021295,0.0021295 +479,110968,0.93731,0.97156,1.07755,0.71191,0.62073,0.67513,0.50909,0.91491,0.86101,1.07783,0.002113,0.002113,0.002113 +480,111226,0.93971,0.97778,1.07868,0.71101,0.6211,0.67537,0.50932,0.91467,0.86034,1.0777,0.0020965,0.0020965,0.0020965 +481,111483,0.93944,0.97182,1.0763,0.71243,0.62016,0.67578,0.50962,0.91455,0.8598,1.07749,0.00208,0.00208,0.00208 +482,111740,0.9417,0.97225,1.07744,0.71526,0.61919,0.67581,0.5094,0.91418,0.85943,1.07713,0.0020635,0.0020635,0.0020635 +483,111997,0.94261,0.97604,1.07806,0.71532,0.61871,0.67611,0.50957,0.91423,0.8592,1.077,0.002047,0.002047,0.002047 +484,112254,0.93871,0.97854,1.07583,0.71795,0.61687,0.67626,0.50992,0.91446,0.85869,1.07722,0.0020305,0.0020305,0.0020305 +485,112512,0.94334,0.96663,1.07553,0.71448,0.61903,0.67674,0.5101,0.91429,0.85821,1.07704,0.002014,0.002014,0.002014 +486,112769,0.93911,0.96989,1.07554,0.71395,0.6202,0.67713,0.51045,0.91405,0.85778,1.0766,0.0019975,0.0019975,0.0019975 +487,113026,0.93771,0.9626,1.07543,0.71264,0.62099,0.67756,0.51059,0.91396,0.85744,1.07633,0.001981,0.001981,0.001981 +488,113283,0.93306,0.96662,1.07597,0.70738,0.625,0.6779,0.51121,0.91397,0.85723,1.07625,0.0019645,0.0019645,0.0019645 +489,113541,0.9341,0.96686,1.07509,0.71204,0.62237,0.67812,0.51163,0.91369,0.85687,1.07605,0.001948,0.001948,0.001948 +490,113798,0.93786,0.96748,1.07654,0.71461,0.62035,0.67812,0.51192,0.91331,0.85655,1.07581,0.0019315,0.0019315,0.0019315 +491,114055,0.93467,0.96178,1.07176,0.71299,0.62199,0.6784,0.51204,0.91305,0.85605,1.07556,0.001915,0.001915,0.001915 +492,114313,0.93207,0.96193,1.07231,0.71234,0.62268,0.67844,0.51203,0.91258,0.85586,1.07526,0.0018985,0.0018985,0.0018985 +493,114571,0.93685,0.95718,1.07437,0.70748,0.62634,0.67866,0.51243,0.91247,0.85547,1.07512,0.001882,0.001882,0.001882 +494,114827,0.9319,0.95829,1.07255,0.70951,0.62539,0.67861,0.51264,0.9124,0.85492,1.0751,0.0018655,0.0018655,0.0018655 +495,115085,0.93356,0.95956,1.07271,0.69991,0.6305,0.6789,0.51292,0.91235,0.85436,1.07495,0.001849,0.001849,0.001849 +496,115342,0.92597,0.95341,1.06956,0.7039,0.62968,0.67926,0.51291,0.91192,0.85427,1.07465,0.0018325,0.0018325,0.0018325 +497,115599,0.92624,0.95094,1.07157,0.71088,0.62478,0.67949,0.51285,0.91181,0.85365,1.07443,0.001816,0.001816,0.001816 +498,115856,0.92819,0.94869,1.06857,0.70999,0.62575,0.67948,0.51296,0.91191,0.85331,1.0745,0.0017995,0.0017995,0.0017995 +499,116113,0.9342,0.95382,1.07051,0.71342,0.62284,0.67972,0.51321,0.91187,0.85317,1.07446,0.001783,0.001783,0.001783 +500,116370,0.93616,0.95873,1.0749,0.7118,0.624,0.68001,0.51353,0.91179,0.85275,1.07436,0.0017665,0.0017665,0.0017665 +501,116628,0.92708,0.95565,1.07227,0.7114,0.62548,0.67988,0.51347,0.91168,0.85223,1.07422,0.00175,0.00175,0.00175 +502,116885,0.9262,0.94914,1.0689,0.70837,0.6283,0.67991,0.51409,0.91141,0.85181,1.07383,0.0017335,0.0017335,0.0017335 +503,117143,0.92108,0.94419,1.06762,0.70809,0.63142,0.68029,0.51434,0.9112,0.85141,1.07371,0.001717,0.001717,0.001717 +504,117400,0.93054,0.95612,1.07104,0.70563,0.63026,0.68022,0.51427,0.91106,0.85109,1.07369,0.0017005,0.0017005,0.0017005 +505,117658,0.91848,0.93882,1.06565,0.70749,0.62784,0.68032,0.51421,0.9109,0.85093,1.07355,0.001684,0.001684,0.001684 +506,117914,0.92403,0.94589,1.06928,0.70624,0.62847,0.68043,0.51434,0.91103,0.85055,1.07356,0.0016675,0.0016675,0.0016675 +507,118172,0.92617,0.94509,1.06871,0.70648,0.62946,0.6807,0.51453,0.91093,0.85008,1.07347,0.001651,0.001651,0.001651 +508,118429,0.92461,0.93929,1.06617,0.70935,0.62743,0.68064,0.51457,0.91061,0.84989,1.07312,0.0016345,0.0016345,0.0016345 +509,118686,0.92476,0.94694,1.06846,0.70763,0.62997,0.68098,0.51486,0.91046,0.84942,1.07292,0.001618,0.001618,0.001618 +510,118943,0.92525,0.93804,1.06792,0.7073,0.63055,0.68107,0.51468,0.91049,0.84897,1.07276,0.0016015,0.0016015,0.0016015 +511,119201,0.92481,0.93878,1.06881,0.70885,0.63064,0.68136,0.515,0.91051,0.8487,1.07279,0.001585,0.001585,0.001585 +512,119458,0.92371,0.93835,1.06916,0.71083,0.62959,0.68174,0.51507,0.91025,0.8485,1.07249,0.0015685,0.0015685,0.0015685 +513,119716,0.92696,0.934,1.06516,0.7099,0.63096,0.68186,0.51524,0.91027,0.84823,1.07254,0.001552,0.001552,0.001552 +514,119973,0.91917,0.93299,1.06796,0.70884,0.63,0.68194,0.51554,0.91006,0.84811,1.0723,0.0015355,0.0015355,0.0015355 +515,120231,0.92396,0.93839,1.0654,0.71074,0.62912,0.68217,0.51574,0.90997,0.84795,1.07221,0.001519,0.001519,0.001519 +516,120488,0.91766,0.92568,1.06288,0.70967,0.62959,0.68245,0.51572,0.91028,0.84767,1.07233,0.0015025,0.0015025,0.0015025 +517,120745,0.92245,0.93112,1.06645,0.72283,0.62286,0.68305,0.51626,0.90987,0.84745,1.07197,0.001486,0.001486,0.001486 +518,121002,0.91845,0.93152,1.06471,0.72242,0.62287,0.68309,0.51631,0.90997,0.84705,1.07181,0.0014695,0.0014695,0.0014695 +519,121259,0.91958,0.92457,1.06365,0.72068,0.62363,0.68318,0.5164,0.90966,0.84679,1.07151,0.001453,0.001453,0.001453 +520,121516,0.91433,0.92312,1.06245,0.71031,0.62996,0.68357,0.51642,0.90949,0.84661,1.07131,0.0014365,0.0014365,0.0014365 +521,121774,0.91427,0.9236,1.06444,0.72692,0.61984,0.68357,0.51655,0.9094,0.84648,1.07135,0.00142,0.00142,0.00142 +522,122031,0.91206,0.91771,1.06242,0.72782,0.6218,0.68404,0.51682,0.90938,0.84609,1.07131,0.0014035,0.0014035,0.0014035 +523,122289,0.91806,0.92441,1.06435,0.7245,0.62318,0.6843,0.51712,0.90937,0.84597,1.07126,0.001387,0.001387,0.001387 +524,122547,0.91673,0.91825,1.06447,0.71935,0.62724,0.68405,0.51705,0.90908,0.84586,1.07107,0.0013705,0.0013705,0.0013705 +525,122804,0.9129,0.91178,1.0628,0.72097,0.62622,0.684,0.51702,0.90893,0.84573,1.07088,0.001354,0.001354,0.001354 +526,123062,0.91549,0.9159,1.06162,0.72742,0.62382,0.68437,0.51713,0.90887,0.8457,1.07079,0.0013375,0.0013375,0.0013375 +527,123319,0.9191,0.91541,1.06383,0.72508,0.62592,0.68448,0.5173,0.90861,0.84549,1.07069,0.001321,0.001321,0.001321 +528,123576,0.9137,0.9163,1.06207,0.72573,0.62529,0.68458,0.51715,0.90858,0.84533,1.07062,0.0013045,0.0013045,0.0013045 +529,123833,0.91609,0.91063,1.06088,0.73009,0.62147,0.68455,0.51735,0.90795,0.84519,1.07018,0.001288,0.001288,0.001288 +530,124091,0.91756,0.91538,1.06024,0.72616,0.62388,0.68464,0.51732,0.90791,0.84516,1.06996,0.0012715,0.0012715,0.0012715 +531,124348,0.90864,0.90522,1.05907,0.72839,0.62302,0.68458,0.51736,0.90774,0.84522,1.06984,0.001255,0.001255,0.001255 +532,124605,0.91763,0.91468,1.06015,0.73024,0.62183,0.68503,0.51796,0.9076,0.84508,1.0697,0.0012385,0.0012385,0.0012385 +533,124861,0.90844,0.90146,1.05947,0.73334,0.61976,0.685,0.51786,0.90737,0.84481,1.06948,0.001222,0.001222,0.001222 +534,125118,0.91028,0.90423,1.06072,0.73198,0.62084,0.68539,0.51807,0.90705,0.84449,1.06917,0.0012055,0.0012055,0.0012055 +535,125375,0.90759,0.90119,1.05672,0.73408,0.62031,0.68542,0.51827,0.90688,0.84383,1.06897,0.001189,0.001189,0.001189 +536,125632,0.90488,0.89625,1.05515,0.73127,0.6216,0.68603,0.5189,0.90657,0.84358,1.06877,0.0011725,0.0011725,0.0011725 +537,125890,0.90756,0.89639,1.05447,0.73212,0.62087,0.68588,0.51879,0.90665,0.84337,1.06869,0.001156,0.001156,0.001156 +538,126147,0.90281,0.89674,1.05583,0.72934,0.62167,0.68607,0.51902,0.90629,0.84306,1.06838,0.0011395,0.0011395,0.0011395 +539,126404,0.90122,0.89424,1.05361,0.73023,0.62184,0.68632,0.51929,0.90621,0.84254,1.06829,0.001123,0.001123,0.001123 +540,126661,0.90554,0.8865,1.05613,0.73165,0.62089,0.68648,0.51943,0.90597,0.84206,1.06808,0.0011065,0.0011065,0.0011065 +541,126918,0.90146,0.8903,1.05188,0.73217,0.6208,0.68649,0.51964,0.90594,0.84186,1.06799,0.00109,0.00109,0.00109 +542,127176,0.89988,0.89077,1.05283,0.73279,0.62059,0.68659,0.51962,0.906,0.8417,1.06792,0.0010735,0.0010735,0.0010735 +543,127433,0.89799,0.89055,1.05267,0.7308,0.6221,0.68649,0.51947,0.90607,0.84153,1.06786,0.001057,0.001057,0.001057 +544,127690,0.90134,0.88118,1.0512,0.73286,0.62121,0.6864,0.51951,0.90597,0.84118,1.06763,0.0010405,0.0010405,0.0010405 +545,127946,0.89808,0.88415,1.05121,0.73434,0.62147,0.68652,0.51965,0.90581,0.84097,1.06739,0.001024,0.001024,0.001024 +546,128203,0.89299,0.88261,1.04908,0.73514,0.62135,0.6867,0.51962,0.90576,0.84095,1.06728,0.0010075,0.0010075,0.0010075 +547,128460,0.90417,0.88882,1.0545,0.73459,0.6214,0.68674,0.51988,0.90565,0.84068,1.06707,0.000991,0.000991,0.000991 +548,128718,0.89888,0.87938,1.05218,0.73284,0.62157,0.68661,0.52004,0.90563,0.84048,1.06689,0.0009745,0.0009745,0.0009745 +549,128975,0.89574,0.88089,1.0515,0.73334,0.62143,0.68658,0.51998,0.90564,0.84021,1.06679,0.000958,0.000958,0.000958 +550,129232,0.89502,0.87917,1.05117,0.72787,0.62427,0.68671,0.52013,0.90563,0.83978,1.06675,0.0009415,0.0009415,0.0009415 +551,129488,0.89369,0.87026,1.04781,0.72693,0.62498,0.68692,0.52016,0.90553,0.83972,1.0667,0.000925,0.000925,0.000925 +552,129746,0.89634,0.86681,1.04758,0.73109,0.62347,0.68689,0.52017,0.90544,0.83966,1.06651,0.0009085,0.0009085,0.0009085 +553,130002,0.89407,0.87037,1.05039,0.72855,0.6245,0.68704,0.52053,0.90501,0.83953,1.06626,0.000892,0.000892,0.000892 +554,130259,0.88615,0.86163,1.04493,0.72863,0.62524,0.68708,0.5207,0.90466,0.83943,1.06596,0.0008755,0.0008755,0.0008755 +555,130516,0.89364,0.86881,1.04901,0.72886,0.62535,0.6871,0.52075,0.90428,0.83939,1.06586,0.000859,0.000859,0.000859 +556,130773,0.88525,0.8635,1.04499,0.73049,0.62487,0.68735,0.52074,0.90427,0.83918,1.06575,0.0008425,0.0008425,0.0008425 +557,131029,0.88149,0.85834,1.04435,0.7316,0.62441,0.68729,0.52088,0.90408,0.839,1.06558,0.000826,0.000826,0.000826 +558,131286,0.8853,0.85626,1.04279,0.73288,0.62459,0.68725,0.52094,0.90419,0.83856,1.06565,0.0008095,0.0008095,0.0008095 +559,131543,0.88632,0.85213,1.04268,0.73429,0.62404,0.6873,0.52121,0.9039,0.83838,1.06548,0.000793,0.000793,0.000793 +560,131801,0.88169,0.85545,1.04396,0.73102,0.62652,0.68759,0.52137,0.90377,0.83809,1.06532,0.0007765,0.0007765,0.0007765 +561,132058,0.88483,0.84945,1.04508,0.73699,0.62269,0.68768,0.52139,0.90332,0.83772,1.06511,0.00076,0.00076,0.00076 +562,132315,0.88434,0.85097,1.04493,0.73473,0.62393,0.68768,0.52156,0.90323,0.83772,1.06505,0.0007435,0.0007435,0.0007435 +563,132572,0.883,0.84784,1.04176,0.73282,0.62609,0.68777,0.52173,0.9029,0.83744,1.06478,0.000727,0.000727,0.000727 +564,132829,0.87757,0.84393,1.04094,0.73448,0.62533,0.688,0.5218,0.9024,0.83712,1.0644,0.0007105,0.0007105,0.0007105 +565,133086,0.88467,0.84812,1.0428,0.73824,0.62335,0.68797,0.52172,0.90237,0.83697,1.06432,0.000694,0.000694,0.000694 +566,133342,0.87915,0.84306,1.04159,0.73711,0.6241,0.68799,0.5216,0.90221,0.83674,1.06419,0.0006775,0.0006775,0.0006775 +567,133599,0.87415,0.83407,1.03817,0.73628,0.62497,0.68816,0.52172,0.90205,0.83648,1.064,0.000661,0.000661,0.000661 +568,133856,0.87745,0.84048,1.04151,0.73479,0.62552,0.68825,0.52199,0.902,0.8364,1.06383,0.0006445,0.0006445,0.0006445 +569,134113,0.87497,0.8368,1.03863,0.73457,0.62532,0.6884,0.5221,0.90202,0.83618,1.06374,0.000628,0.000628,0.000628 +570,134369,0.87447,0.8361,1.03734,0.72982,0.62753,0.68854,0.52217,0.90174,0.83562,1.06353,0.0006115,0.0006115,0.0006115 +571,134626,0.86811,0.82549,1.03592,0.73042,0.62703,0.68874,0.52229,0.90166,0.83548,1.06338,0.000595,0.000595,0.000595 +572,134884,0.86687,0.82639,1.03457,0.73467,0.62588,0.68899,0.5223,0.90154,0.83547,1.06327,0.0005785,0.0005785,0.0005785 +573,135141,0.8687,0.8196,1.03269,0.73662,0.62452,0.68914,0.52268,0.90133,0.8353,1.06314,0.000562,0.000562,0.000562 +574,135398,0.86988,0.82756,1.03558,0.73507,0.62711,0.6893,0.52285,0.90113,0.83507,1.06301,0.0005455,0.0005455,0.0005455 +575,135655,0.87413,0.82387,1.03553,0.73875,0.62549,0.68952,0.52311,0.90117,0.83508,1.06301,0.000529,0.000529,0.000529 +576,135911,0.87094,0.82012,1.03652,0.73467,0.62724,0.68959,0.52323,0.90116,0.83491,1.06308,0.0005125,0.0005125,0.0005125 +577,136168,0.87379,0.82507,1.03618,0.73566,0.6276,0.68977,0.52334,0.90128,0.8347,1.0631,0.000496,0.000496,0.000496 +578,136425,0.86966,0.8129,1.03392,0.73834,0.62636,0.68986,0.52344,0.90123,0.83459,1.063,0.0004795,0.0004795,0.0004795 +579,136682,0.86449,0.81483,1.03283,0.73812,0.62703,0.68967,0.523,0.90121,0.83456,1.06294,0.000463,0.000463,0.000463 +580,136938,0.86479,0.812,1.03384,0.73371,0.62958,0.68963,0.52307,0.90127,0.83429,1.06291,0.0004465,0.0004465,0.0004465 +581,137195,0.86753,0.81024,1.03337,0.73519,0.62874,0.68978,0.52322,0.90112,0.83431,1.06282,0.00043,0.00043,0.00043 +582,137452,0.86546,0.80765,1.03335,0.73383,0.6295,0.68995,0.52338,0.90097,0.83429,1.06274,0.0004135,0.0004135,0.0004135 +583,137709,0.85909,0.7975,1.02921,0.73317,0.63043,0.69002,0.52345,0.90102,0.83381,1.06267,0.000397,0.000397,0.000397 +584,137966,0.85282,0.79207,1.02931,0.7338,0.63032,0.6901,0.52341,0.9012,0.83385,1.06279,0.0003805,0.0003805,0.0003805 +585,138223,0.85409,0.79165,1.0283,0.73253,0.63082,0.69027,0.52346,0.90118,0.83357,1.06267,0.000364,0.000364,0.000364 +586,138480,0.8607,0.79747,1.02888,0.73252,0.63094,0.6905,0.52369,0.90118,0.83359,1.06255,0.0003475,0.0003475,0.0003475 +587,138737,0.85479,0.79192,1.02855,0.73249,0.63152,0.69034,0.5239,0.90123,0.8336,1.06252,0.000331,0.000331,0.000331 +588,138994,0.85357,0.7878,1.02483,0.73198,0.63125,0.69034,0.52375,0.90121,0.83363,1.06244,0.0003145,0.0003145,0.0003145 +589,139252,0.85693,0.79041,1.02785,0.72948,0.63248,0.69074,0.52425,0.90107,0.83367,1.06235,0.000298,0.000298,0.000298 +590,139509,0.84959,0.78487,1.0257,0.73303,0.63304,0.69106,0.52438,0.90074,0.83347,1.06208,0.0002815,0.0002815,0.0002815 +591,139754,0.87804,0.76644,1.06396,0.72912,0.63283,0.69071,0.52411,0.90042,0.83354,1.06178,0.000265,0.000265,0.000265 +592,139994,0.87384,0.7536,1.05573,0.7308,0.63288,0.6911,0.52459,0.90004,0.83359,1.06143,0.0002485,0.0002485,0.0002485 +593,140235,0.87286,0.741,1.06019,0.73077,0.63418,0.69144,0.52491,0.89972,0.8334,1.06116,0.000232,0.000232,0.000232 +594,140475,0.86756,0.73707,1.05611,0.7328,0.63278,0.69141,0.52473,0.89953,0.83359,1.06088,0.0002155,0.0002155,0.0002155 +595,140716,0.86653,0.72878,1.05182,0.73596,0.63272,0.69149,0.525,0.89911,0.83388,1.06063,0.000199,0.000199,0.000199 +596,140957,0.86511,0.72675,1.05365,0.73565,0.63171,0.69129,0.52507,0.89904,0.83419,1.06052,0.0001825,0.0001825,0.0001825 +597,141198,0.8536,0.71153,1.04852,0.73749,0.63052,0.69137,0.52513,0.89893,0.83428,1.06032,0.000166,0.000166,0.000166 +598,141439,0.85893,0.7114,1.04775,0.73724,0.63303,0.69174,0.52546,0.89887,0.83423,1.06009,0.0001495,0.0001495,0.0001495 +599,141679,0.84472,0.70191,1.0456,0.7371,0.63272,0.69213,0.52543,0.89872,0.83454,1.05991,0.000133,0.000133,0.000133 +600,141920,0.84947,0.69589,1.04373,0.73621,0.63395,0.69235,0.52551,0.89852,0.83459,1.05968,0.0001165,0.0001165,0.0001165 diff --git a/logs/yolov12n.csv b/logs/yolov12n.csv new file mode 100644 index 0000000000000000000000000000000000000000..50d7c2482a073413d28cb0c86bf2d93647619157 --- /dev/null +++ b/logs/yolov12n.csv @@ -0,0 +1,601 @@ +epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2 +1,272.58,3.67409,5.7997,4.23095,0.00057,0.0187,0.0003,0.0001,3.57714,inf,6.54829,0.00332613,0.00332613,0.00332613 +2,396.246,2.85012,4.83548,3.13866,0.00371,0.12499,0.00571,0.00247,2.28931,3.89322,2.506,0.00664848,0.00664848,0.00664848 +3,524.345,1.98716,3.73223,2.05368,0.21313,0.07748,0.03184,0.01589,1.94887,3.30362,2.00147,0.00995982,0.00995982,0.00995982 +4,648.993,1.70712,3.12,1.75855,0.21209,0.13322,0.08979,0.05135,1.70553,2.70397,1.75597,0.0099505,0.0099505,0.0099505 +5,773.66,1.58303,2.74548,1.62897,0.25343,0.18529,0.13167,0.07811,1.60019,2.45253,1.66144,0.009934,0.009934,0.009934 +6,897.947,1.50906,2.56036,1.55871,0.31814,0.21029,0.17251,0.10573,1.54417,2.27229,1.59234,0.0099175,0.0099175,0.0099175 +7,1023.04,1.4602,2.41213,1.51228,0.34664,0.23039,0.20026,0.12586,1.50192,2.17113,1.55348,0.009901,0.009901,0.009901 +8,1145.74,1.42874,2.29651,1.47605,0.36393,0.25364,0.22907,0.1463,1.45941,2.05214,1.51413,0.0098845,0.0098845,0.0098845 +9,1269.33,1.40195,2.2387,1.45389,0.4092,0.26556,0.25236,0.16157,1.43955,1.97699,1.48588,0.009868,0.009868,0.009868 +10,1392.4,1.3883,2.17318,1.43168,0.41058,0.28172,0.26641,0.17299,1.40492,1.90929,1.45227,0.0098515,0.0098515,0.0098515 +11,1516.05,1.37281,2.11616,1.41624,0.44304,0.29692,0.28647,0.18715,1.37779,1.83621,1.42472,0.009835,0.009835,0.009835 +12,1640.02,1.36088,2.06849,1.40156,0.4621,0.30431,0.30352,0.20025,1.36048,1.77983,1.41147,0.0098185,0.0098185,0.0098185 +13,1762.74,1.34455,2.03859,1.38702,0.47398,0.31457,0.31435,0.20899,1.34247,1.73726,1.39526,0.009802,0.009802,0.009802 +14,1884.78,1.32681,1.98622,1.37114,0.48333,0.32229,0.32799,0.21855,1.3283,1.70222,1.37447,0.0097855,0.0097855,0.0097855 +15,2007.66,1.31964,1.96594,1.3658,0.4928,0.33489,0.34398,0.23035,1.31245,1.66353,1.36128,0.009769,0.009769,0.009769 +16,2129.88,1.32025,1.95014,1.36237,0.51128,0.33969,0.34917,0.23366,1.30166,1.63592,1.34933,0.0097525,0.0097525,0.0097525 +17,2252.27,1.30771,1.92565,1.35007,0.5166,0.34903,0.36174,0.24305,1.29012,1.60427,1.33992,0.009736,0.009736,0.009736 +18,2374.07,1.3037,1.90358,1.34608,0.50867,0.35642,0.36837,0.24875,1.27707,1.57863,1.32586,0.0097195,0.0097195,0.0097195 +19,2498.06,1.29775,1.88557,1.33715,0.51806,0.36104,0.37546,0.25466,1.26967,1.55993,1.32158,0.009703,0.009703,0.009703 +20,2620.42,1.28324,1.86384,1.32838,0.52087,0.37345,0.38518,0.26157,1.26105,1.53643,1.31098,0.0096865,0.0096865,0.0096865 +21,2742.42,1.28602,1.85137,1.32984,0.53165,0.37351,0.38879,0.26442,1.25247,1.5221,1.30371,0.00967,0.00967,0.00967 +22,2865.15,1.28063,1.84304,1.325,0.53354,0.37951,0.39443,0.26836,1.24961,1.50702,1.30023,0.0096535,0.0096535,0.0096535 +23,2987.67,1.2715,1.81947,1.31374,0.53277,0.38094,0.39924,0.2719,1.24219,1.49464,1.29416,0.009637,0.009637,0.009637 +24,3110.09,1.26939,1.79577,1.3128,0.53646,0.38668,0.40301,0.27565,1.2387,1.48106,1.28824,0.0096205,0.0096205,0.0096205 +25,3233.73,1.26147,1.78797,1.30806,0.54045,0.39266,0.40902,0.2799,1.23307,1.46968,1.2834,0.009604,0.009604,0.009604 +26,3356.49,1.26706,1.78851,1.30714,0.53949,0.39549,0.41137,0.28204,1.22771,1.45666,1.27885,0.0095875,0.0095875,0.0095875 +27,3480.42,1.26905,1.78259,1.30829,0.5444,0.39755,0.41569,0.2851,1.22272,1.44797,1.27608,0.009571,0.009571,0.009571 +28,3604.26,1.2626,1.77191,1.30391,0.53966,0.40385,0.41827,0.28708,1.22068,1.4412,1.2737,0.0095545,0.0095545,0.0095545 +29,3726.91,1.25625,1.77257,1.30285,0.54677,0.401,0.42103,0.28951,1.21701,1.43378,1.27075,0.009538,0.009538,0.009538 +30,3851.69,1.25306,1.7488,1.29884,0.54795,0.40404,0.42353,0.29122,1.21408,1.42671,1.26855,0.0095215,0.0095215,0.0095215 +31,3975.21,1.25142,1.73668,1.29264,0.55645,0.40276,0.42535,0.29318,1.21139,1.42112,1.2656,0.009505,0.009505,0.009505 +32,4098.71,1.24245,1.74128,1.29329,0.55466,0.40694,0.42731,0.29495,1.20951,1.41642,1.26343,0.0094885,0.0094885,0.0094885 +33,4222.14,1.24877,1.73454,1.29099,0.5571,0.40554,0.42864,0.29588,1.20724,1.41169,1.2611,0.009472,0.009472,0.009472 +34,4345.35,1.24448,1.73073,1.2888,0.55539,0.40942,0.4297,0.29699,1.20611,1.40742,1.25976,0.0094555,0.0094555,0.0094555 +35,4467.71,1.24478,1.72061,1.28738,0.55502,0.41213,0.43094,0.29796,1.2048,1.40389,1.25875,0.009439,0.009439,0.009439 +36,4589.77,1.24442,1.70576,1.28674,0.56158,0.41123,0.43206,0.2988,1.20372,1.40075,1.2576,0.0094225,0.0094225,0.0094225 +37,4711.06,1.24551,1.70334,1.28846,0.56322,0.41192,0.43346,0.29965,1.20273,1.39812,1.25669,0.009406,0.009406,0.009406 +38,4832.49,1.23285,1.70076,1.28393,0.56502,0.41103,0.43421,0.3002,1.20123,1.3956,1.25529,0.0093895,0.0093895,0.0093895 +39,4954.24,1.23386,1.69595,1.27936,0.56592,0.41082,0.43433,0.30063,1.20039,1.3934,1.25438,0.009373,0.009373,0.009373 +40,5075.89,1.23987,1.69035,1.28007,0.56781,0.4098,0.43489,0.30128,1.19979,1.39205,1.25382,0.0093565,0.0093565,0.0093565 +41,5198.9,1.23975,1.68687,1.27859,0.56346,0.41182,0.4351,0.30164,1.19936,1.39091,1.25297,0.00934,0.00934,0.00934 +42,5320.62,1.24007,1.69227,1.27853,0.56339,0.41174,0.43524,0.30199,1.19906,1.39016,1.25263,0.0093235,0.0093235,0.0093235 +43,5443.13,1.23329,1.68237,1.2736,0.56467,0.41137,0.43517,0.30198,1.19854,1.38936,1.25203,0.009307,0.009307,0.009307 +44,5565.81,1.22179,1.67425,1.27075,0.56687,0.4117,0.43573,0.30235,1.19782,1.38875,1.25142,0.0092905,0.0092905,0.0092905 +45,5688.33,1.22498,1.66824,1.27022,0.56812,0.41251,0.43623,0.30298,1.1973,1.38815,1.25082,0.009274,0.009274,0.009274 +46,5810.52,1.23694,1.66832,1.27469,0.56896,0.41201,0.43659,0.30331,1.19725,1.3884,1.25057,0.0092575,0.0092575,0.0092575 +47,5932.35,1.23324,1.66502,1.2752,0.57004,0.41089,0.43655,0.30354,1.19688,1.38838,1.25027,0.009241,0.009241,0.009241 +48,6054.92,1.22776,1.65586,1.26948,0.57396,0.40929,0.43667,0.30365,1.19656,1.38887,1.24979,0.0092245,0.0092245,0.0092245 +49,6176.64,1.22131,1.65263,1.26626,0.5747,0.40748,0.43654,0.3039,1.1961,1.39006,1.24937,0.009208,0.009208,0.009208 +50,6298.33,1.22546,1.65956,1.27083,0.57176,0.40878,0.43664,0.3039,1.196,1.39132,1.24912,0.0091915,0.0091915,0.0091915 +51,6418.39,1.21957,1.6573,1.26721,0.56817,0.40952,0.43619,0.30399,1.19614,1.39275,1.24909,0.009175,0.009175,0.009175 +52,6540.88,1.22509,1.65023,1.26978,0.56795,0.41045,0.43599,0.30373,1.19603,1.39484,1.24893,0.0091585,0.0091585,0.0091585 +53,6663.3,1.2193,1.63781,1.2654,0.5734,0.40837,0.43596,0.30356,1.19584,1.39631,1.24879,0.009142,0.009142,0.009142 +54,6786.3,1.22147,1.63226,1.26424,0.5744,0.40891,0.43581,0.30325,1.19575,1.39865,1.24848,0.0091255,0.0091255,0.0091255 +55,6908.45,1.22109,1.63761,1.26534,0.57395,0.41193,0.43574,0.30328,1.19553,1.4006,1.24811,0.009109,0.009109,0.009109 +56,7029.5,1.21706,1.62964,1.26134,0.57197,0.41143,0.43522,0.30307,1.19544,1.40294,1.24815,0.0090925,0.0090925,0.0090925 +57,7151.06,1.21272,1.63046,1.2605,0.57126,0.40953,0.43522,0.30333,1.19541,1.40505,1.24806,0.009076,0.009076,0.009076 +58,7272.67,1.21693,1.64155,1.26545,0.57177,0.40819,0.43493,0.30313,1.19531,1.40772,1.24801,0.0090595,0.0090595,0.0090595 +59,7394.42,1.21695,1.62813,1.26102,0.57195,0.4072,0.43455,0.30292,1.19542,1.41051,1.24818,0.009043,0.009043,0.009043 +60,7515.94,1.22172,1.62466,1.26031,0.57096,0.40696,0.43434,0.30284,1.19547,1.41321,1.24821,0.0090265,0.0090265,0.0090265 +61,7636.42,1.20662,1.62169,1.25365,0.56578,0.40719,0.43379,0.30279,1.19549,1.41625,1.24819,0.00901,0.00901,0.00901 +62,7758.63,1.20784,1.6167,1.25777,0.56882,0.40597,0.43434,0.30311,1.19546,1.41893,1.24823,0.0089935,0.0089935,0.0089935 +63,7879.3,1.21117,1.61531,1.2579,0.56879,0.40463,0.43407,0.30286,1.19574,1.42179,1.24833,0.008977,0.008977,0.008977 +64,8000.7,1.21345,1.62689,1.26061,0.57029,0.40367,0.43364,0.30279,1.19574,1.42473,1.24838,0.0089605,0.0089605,0.0089605 +65,8122.5,1.21502,1.60843,1.25982,0.57019,0.40228,0.43328,0.30253,1.19602,1.4277,1.24875,0.008944,0.008944,0.008944 +66,8243.49,1.21069,1.60756,1.25767,0.57053,0.4022,0.43283,0.30219,1.19606,1.43024,1.24882,0.0089275,0.0089275,0.0089275 +67,8364.98,1.20927,1.60328,1.2523,0.57126,0.40166,0.43239,0.30217,1.19587,1.4327,1.24878,0.008911,0.008911,0.008911 +68,8485.91,1.20996,1.60326,1.25613,0.57241,0.40286,0.43214,0.30229,1.19585,1.43536,1.24876,0.0088945,0.0088945,0.0088945 +69,8608.01,1.20014,1.60472,1.25315,0.5712,0.40284,0.43199,0.30236,1.19577,1.43832,1.24888,0.008878,0.008878,0.008878 +70,8731.58,1.21112,1.61368,1.25542,0.57347,0.40158,0.43159,0.30211,1.19583,1.44088,1.24905,0.0088615,0.0088615,0.0088615 +71,8853.96,1.2102,1.60073,1.25736,0.57184,0.40152,0.43149,0.30199,1.19558,1.44276,1.24896,0.008845,0.008845,0.008845 +72,8976.11,1.20622,1.60148,1.25376,0.57715,0.40012,0.43172,0.30196,1.19546,1.44477,1.24886,0.0088285,0.0088285,0.0088285 +73,9095.96,1.20769,1.59232,1.2494,0.57702,0.39991,0.43144,0.30184,1.19548,1.44671,1.24891,0.008812,0.008812,0.008812 +74,9216.22,1.2026,1.58846,1.25171,0.57741,0.40047,0.43143,0.30182,1.19541,1.44824,1.24871,0.0087955,0.0087955,0.0087955 +75,9338.49,1.20354,1.58269,1.24962,0.57788,0.39994,0.4316,0.30199,1.19527,1.4497,1.24846,0.008779,0.008779,0.008779 +76,9458.96,1.20608,1.59478,1.25144,0.57917,0.39919,0.43154,0.30192,1.19516,1.45146,1.24834,0.0087625,0.0087625,0.0087625 +77,9580.97,1.20109,1.57932,1.2474,0.58163,0.3988,0.43168,0.30214,1.19485,1.45278,1.24816,0.008746,0.008746,0.008746 +78,9701.64,1.20216,1.59453,1.25294,0.58251,0.39818,0.43154,0.30216,1.19482,1.45452,1.2481,0.0087295,0.0087295,0.0087295 +79,9822.72,1.20267,1.58613,1.24932,0.58177,0.39795,0.43159,0.3025,1.19436,1.45578,1.24783,0.008713,0.008713,0.008713 +80,9945.19,1.19826,1.57358,1.24577,0.58279,0.39869,0.43192,0.30255,1.19413,1.45651,1.24756,0.0086965,0.0086965,0.0086965 +81,10068.5,1.20266,1.57659,1.24788,0.58395,0.39831,0.43204,0.30276,1.19361,1.45645,1.24702,0.00868,0.00868,0.00868 +82,10188.2,1.199,1.58759,1.24845,0.58435,0.39885,0.4329,0.30328,1.19322,1.45663,1.24674,0.0086635,0.0086635,0.0086635 +83,10310,1.19617,1.58604,1.24405,0.5878,0.3973,0.43283,0.30339,1.19285,1.45715,1.24631,0.008647,0.008647,0.008647 +84,10432.5,1.19873,1.57489,1.2451,0.58605,0.39752,0.433,0.30364,1.19255,1.45707,1.24578,0.0086305,0.0086305,0.0086305 +85,10555.2,1.19767,1.57722,1.24754,0.58514,0.39763,0.43363,0.30409,1.1917,1.45702,1.24527,0.008614,0.008614,0.008614 +86,10676.8,1.19377,1.57251,1.24561,0.58718,0.39647,0.43427,0.30459,1.19108,1.45722,1.24485,0.0085975,0.0085975,0.0085975 +87,10797.9,1.19896,1.57918,1.24613,0.5886,0.39712,0.43475,0.30508,1.19054,1.45654,1.2445,0.008581,0.008581,0.008581 +88,10918.4,1.19238,1.57004,1.24318,0.59124,0.39668,0.43551,0.30566,1.18979,1.45554,1.24385,0.0085645,0.0085645,0.0085645 +89,11039.4,1.19504,1.57131,1.24522,0.5859,0.39839,0.43612,0.30628,1.18892,1.45381,1.24334,0.008548,0.008548,0.008548 +90,11161.3,1.19322,1.57555,1.24296,0.58511,0.39983,0.43675,0.30684,1.188,1.45236,1.24246,0.0085315,0.0085315,0.0085315 +91,11283.1,1.19526,1.56912,1.24645,0.58865,0.39927,0.43779,0.30759,1.18711,1.45089,1.24174,0.008515,0.008515,0.008515 +92,11403.6,1.19931,1.57098,1.24534,0.58434,0.40053,0.43844,0.30817,1.18627,1.44893,1.24104,0.0084985,0.0084985,0.0084985 +93,11525.5,1.19611,1.56027,1.24286,0.5882,0.39994,0.43932,0.30857,1.18521,1.44655,1.24025,0.008482,0.008482,0.008482 +94,11647.7,1.2013,1.57633,1.24541,0.58643,0.40171,0.44008,0.30929,1.18443,1.44437,1.23969,0.0084655,0.0084655,0.0084655 +95,11769.3,1.19463,1.56404,1.24176,0.58746,0.40256,0.44099,0.30995,1.18327,1.4425,1.23859,0.008449,0.008449,0.008449 +96,11891.2,1.20026,1.56659,1.24619,0.58813,0.40356,0.44191,0.31046,1.18264,1.44007,1.23798,0.0084325,0.0084325,0.0084325 +97,12014.3,1.1986,1.55227,1.24436,0.58852,0.40362,0.44278,0.31113,1.18187,1.43765,1.23731,0.008416,0.008416,0.008416 +98,12137.3,1.19095,1.55893,1.23896,0.58856,0.40487,0.4435,0.31184,1.18098,1.43502,1.23645,0.0083995,0.0083995,0.0083995 +99,12258.1,1.19567,1.55785,1.24301,0.58881,0.40582,0.44413,0.31256,1.18033,1.43207,1.2359,0.008383,0.008383,0.008383 +100,12379.6,1.19441,1.55319,1.24144,0.58879,0.40677,0.4452,0.31334,1.17946,1.42931,1.23522,0.0083665,0.0083665,0.0083665 +101,12501.5,1.20127,1.56106,1.24525,0.58618,0.40866,0.44621,0.31397,1.17876,1.42626,1.23451,0.00835,0.00835,0.00835 +102,12623.3,1.19219,1.55762,1.24187,0.58532,0.40996,0.44713,0.31467,1.17808,1.42328,1.23372,0.0083335,0.0083335,0.0083335 +103,12745.1,1.1808,1.54888,1.23777,0.58656,0.41062,0.44816,0.31526,1.17728,1.42012,1.23288,0.008317,0.008317,0.008317 +104,12866.4,1.18618,1.55737,1.23795,0.58972,0.41152,0.44932,0.31605,1.1766,1.4178,1.23231,0.0083005,0.0083005,0.0083005 +105,12988.8,1.19158,1.55084,1.24218,0.59122,0.41201,0.4503,0.31691,1.17614,1.41464,1.23175,0.008284,0.008284,0.008284 +106,13111.9,1.18817,1.55242,1.23836,0.59078,0.41311,0.45125,0.31751,1.17528,1.41163,1.23085,0.0082675,0.0082675,0.0082675 +107,13233.9,1.18886,1.54818,1.23925,0.59404,0.41333,0.45168,0.3179,1.17475,1.40887,1.2302,0.008251,0.008251,0.008251 +108,13355.6,1.19245,1.54029,1.2397,0.59433,0.41392,0.45231,0.31855,1.17394,1.40572,1.22945,0.0082345,0.0082345,0.0082345 +109,13475,1.18858,1.54837,1.24021,0.5952,0.41441,0.45327,0.31927,1.17316,1.40227,1.22857,0.008218,0.008218,0.008218 +110,13597.2,1.18351,1.54788,1.23743,0.5963,0.41431,0.45404,0.32002,1.17234,1.39926,1.22766,0.0082015,0.0082015,0.0082015 +111,13720.9,1.18614,1.54388,1.23602,0.59637,0.41577,0.45513,0.3206,1.17147,1.39625,1.22673,0.008185,0.008185,0.008185 +112,13842.7,1.18707,1.54827,1.23784,0.59798,0.41686,0.45616,0.32124,1.17073,1.39302,1.22584,0.0081685,0.0081685,0.0081685 +113,13963.7,1.18712,1.55793,1.2379,0.59664,0.41865,0.4569,0.3218,1.17028,1.38962,1.22528,0.008152,0.008152,0.008152 +114,14087.1,1.18993,1.55231,1.23765,0.59513,0.42042,0.45777,0.32262,1.16952,1.38629,1.22464,0.0081355,0.0081355,0.0081355 +115,14209.6,1.19311,1.54315,1.23768,0.59669,0.42152,0.45845,0.32332,1.16868,1.38301,1.22379,0.008119,0.008119,0.008119 +116,14332.2,1.17678,1.53148,1.23397,0.60208,0.42171,0.45908,0.32396,1.16805,1.37995,1.22312,0.0081025,0.0081025,0.0081025 +117,14453.9,1.18599,1.53533,1.23814,0.60045,0.42374,0.46001,0.32435,1.16725,1.37667,1.22225,0.008086,0.008086,0.008086 +118,14575.6,1.18177,1.53878,1.23615,0.60254,0.42283,0.46065,0.32508,1.16678,1.37366,1.22175,0.0080695,0.0080695,0.0080695 +119,14697.5,1.18596,1.54636,1.23511,0.6008,0.42403,0.46131,0.32561,1.16623,1.37084,1.22107,0.008053,0.008053,0.008053 +120,14820.2,1.18476,1.53437,1.23613,0.60477,0.42352,0.46226,0.32626,1.16541,1.36778,1.22027,0.0080365,0.0080365,0.0080365 +121,14943.2,1.17731,1.53816,1.23286,0.6039,0.42453,0.46327,0.3271,1.16486,1.36461,1.21966,0.00802,0.00802,0.00802 +122,15064.9,1.1801,1.53728,1.23708,0.60597,0.42411,0.46437,0.32778,1.16414,1.36138,1.21896,0.0080035,0.0080035,0.0080035 +123,15188.1,1.18451,1.5339,1.23645,0.60644,0.42457,0.46516,0.32855,1.16364,1.35857,1.21854,0.007987,0.007987,0.007987 +124,15310.5,1.18872,1.54779,1.23803,0.60381,0.42618,0.46561,0.32902,1.16317,1.35556,1.21786,0.0079705,0.0079705,0.0079705 +125,15432,1.18783,1.52745,1.23766,0.60196,0.42897,0.4665,0.32953,1.16256,1.35235,1.2173,0.007954,0.007954,0.007954 +126,15553.7,1.18097,1.53143,1.23403,0.60496,0.42883,0.46746,0.33022,1.16193,1.34963,1.21678,0.0079375,0.0079375,0.0079375 +127,15676.7,1.17854,1.53911,1.2323,0.60274,0.4303,0.46862,0.33097,1.16146,1.34703,1.21616,0.007921,0.007921,0.007921 +128,15798.6,1.1867,1.54508,1.2356,0.60403,0.43084,0.46928,0.33162,1.16088,1.34397,1.21558,0.0079045,0.0079045,0.0079045 +129,15920.3,1.18148,1.53145,1.2342,0.60626,0.43103,0.46978,0.33224,1.16023,1.34134,1.2149,0.007888,0.007888,0.007888 +130,16041.9,1.18232,1.52662,1.23126,0.60329,0.43257,0.47054,0.3331,1.15954,1.33876,1.21417,0.0078715,0.0078715,0.0078715 +131,16163.9,1.17675,1.53403,1.23193,0.60509,0.43351,0.47119,0.33357,1.15882,1.33583,1.21344,0.007855,0.007855,0.007855 +132,16283.4,1.18211,1.5325,1.23358,0.60562,0.43358,0.47184,0.33385,1.15819,1.33325,1.21275,0.0078385,0.0078385,0.0078385 +133,16405.1,1.17521,1.52675,1.23037,0.60512,0.43444,0.47243,0.33437,1.15736,1.33048,1.21198,0.007822,0.007822,0.007822 +134,16525.4,1.189,1.53215,1.23541,0.60331,0.43601,0.47286,0.3349,1.15685,1.32799,1.21155,0.0078055,0.0078055,0.0078055 +135,16646,1.18175,1.527,1.23171,0.60344,0.43696,0.47354,0.33541,1.15588,1.32579,1.21061,0.007789,0.007789,0.007789 +136,16767.5,1.18022,1.51587,1.23345,0.60572,0.43639,0.47406,0.33585,1.15541,1.32311,1.21015,0.0077725,0.0077725,0.0077725 +137,16889.3,1.18706,1.53397,1.237,0.60399,0.43742,0.47503,0.33644,1.15477,1.32065,1.20957,0.007756,0.007756,0.007756 +138,17009.8,1.18075,1.52121,1.23116,0.60649,0.43769,0.47617,0.33723,1.15417,1.31811,1.20909,0.0077395,0.0077395,0.0077395 +139,17132.9,1.17823,1.52327,1.23298,0.60503,0.43881,0.47683,0.33761,1.1535,1.31597,1.20867,0.007723,0.007723,0.007723 +140,17255.6,1.18107,1.51469,1.2324,0.60282,0.44047,0.47773,0.33819,1.15288,1.31335,1.20797,0.0077065,0.0077065,0.0077065 +141,17375.8,1.17881,1.51979,1.22973,0.60368,0.4403,0.47834,0.33892,1.15228,1.31041,1.20735,0.00769,0.00769,0.00769 +142,17498,1.17287,1.51357,1.23162,0.60509,0.44013,0.47897,0.33922,1.15161,1.3078,1.20677,0.0076735,0.0076735,0.0076735 +143,17619.1,1.17693,1.51502,1.2296,0.60564,0.4409,0.47947,0.33964,1.15111,1.3054,1.20619,0.007657,0.007657,0.007657 +144,17741.9,1.17463,1.50986,1.22759,0.60567,0.44082,0.47989,0.34019,1.15063,1.30351,1.2056,0.0076405,0.0076405,0.0076405 +145,17863.3,1.17663,1.5146,1.23201,0.60162,0.44281,0.48062,0.3406,1.15035,1.30142,1.20525,0.007624,0.007624,0.007624 +146,17984.7,1.17106,1.51126,1.2279,0.59985,0.44432,0.48103,0.34088,1.14986,1.29927,1.20482,0.0076075,0.0076075,0.0076075 +147,18107.3,1.17615,1.52005,1.23093,0.60155,0.44521,0.48172,0.34121,1.1492,1.29661,1.20435,0.007591,0.007591,0.007591 +148,18228.5,1.18754,1.5093,1.23406,0.60204,0.44431,0.48226,0.34147,1.14891,1.29471,1.20409,0.0075745,0.0075745,0.0075745 +149,18351.9,1.17355,1.51854,1.22843,0.60299,0.44399,0.48279,0.34209,1.1486,1.29307,1.20369,0.007558,0.007558,0.007558 +150,18473.9,1.18019,1.51467,1.23035,0.60247,0.44506,0.48342,0.34245,1.14834,1.29093,1.20341,0.0075415,0.0075415,0.0075415 +151,18594.6,1.17079,1.49948,1.2253,0.60269,0.44578,0.48413,0.34286,1.14817,1.28927,1.20314,0.007525,0.007525,0.007525 +152,18718,1.17355,1.50297,1.22486,0.60371,0.44589,0.48458,0.34317,1.14789,1.28741,1.20274,0.0075085,0.0075085,0.0075085 +153,18840.5,1.18212,1.51562,1.23121,0.603,0.44666,0.48503,0.3436,1.14758,1.28579,1.20238,0.007492,0.007492,0.007492 +154,18961.5,1.17125,1.50897,1.22914,0.60265,0.44701,0.48557,0.34408,1.14724,1.28395,1.20198,0.0074755,0.0074755,0.0074755 +155,19083.9,1.1763,1.51191,1.22817,0.60523,0.44639,0.48606,0.34459,1.14693,1.28192,1.20157,0.007459,0.007459,0.007459 +156,19205.5,1.1786,1.51402,1.22853,0.60364,0.44777,0.4865,0.34517,1.14655,1.28019,1.20129,0.0074425,0.0074425,0.0074425 +157,19327.6,1.17149,1.50387,1.22585,0.60509,0.44782,0.48713,0.34533,1.14589,1.27861,1.20077,0.007426,0.007426,0.007426 +158,19449.2,1.16756,1.49053,1.21934,0.6076,0.44794,0.48774,0.34575,1.14533,1.27679,1.20023,0.0074095,0.0074095,0.0074095 +159,19571.4,1.17029,1.49716,1.22635,0.60719,0.44803,0.48806,0.34605,1.14493,1.27507,1.19989,0.007393,0.007393,0.007393 +160,19694.5,1.17212,1.49642,1.2248,0.61059,0.44759,0.48864,0.34659,1.14453,1.2736,1.19938,0.0073765,0.0073765,0.0073765 +161,19816.6,1.17288,1.51074,1.22579,0.61156,0.44786,0.48896,0.3468,1.14411,1.27245,1.19901,0.00736,0.00736,0.00736 +162,19938.3,1.16983,1.49706,1.22438,0.60928,0.44906,0.4892,0.34721,1.14402,1.2713,1.19886,0.0073435,0.0073435,0.0073435 +163,20060.7,1.17158,1.51516,1.22668,0.61422,0.44748,0.48968,0.34755,1.14375,1.26959,1.19861,0.007327,0.007327,0.007327 +164,20182.9,1.16892,1.50401,1.22605,0.61044,0.45012,0.49004,0.34775,1.14313,1.26818,1.19816,0.0073105,0.0073105,0.0073105 +165,20306.6,1.16679,1.49467,1.22165,0.60324,0.45287,0.49053,0.34827,1.14278,1.26663,1.19769,0.007294,0.007294,0.007294 +166,20427.2,1.17083,1.49917,1.22336,0.60199,0.45411,0.49096,0.34861,1.14245,1.26533,1.19731,0.0072775,0.0072775,0.0072775 +167,20547.9,1.17663,1.49465,1.22495,0.60431,0.453,0.49093,0.34862,1.14203,1.26417,1.19709,0.007261,0.007261,0.007261 +168,20669.2,1.16766,1.50041,1.22368,0.6022,0.45491,0.49148,0.34876,1.14165,1.26279,1.19677,0.0072445,0.0072445,0.0072445 +169,20792.4,1.16441,1.50072,1.22221,0.60336,0.45544,0.49169,0.34909,1.14116,1.26167,1.19635,0.007228,0.007228,0.007228 +170,20913.8,1.1668,1.50087,1.22457,0.60368,0.45553,0.49225,0.34936,1.1412,1.26056,1.19629,0.0072115,0.0072115,0.0072115 +171,21035.1,1.16852,1.49419,1.22237,0.60471,0.4561,0.49259,0.34986,1.14093,1.2593,1.1961,0.007195,0.007195,0.007195 +172,21157.6,1.17073,1.49018,1.22381,0.60477,0.45649,0.49312,0.35004,1.14061,1.25822,1.19591,0.0071785,0.0071785,0.0071785 +173,21280.4,1.17131,1.49567,1.22341,0.60714,0.45578,0.49366,0.35036,1.14063,1.25761,1.19597,0.007162,0.007162,0.007162 +174,21402.1,1.16395,1.48794,1.22146,0.604,0.45786,0.49408,0.35071,1.14024,1.25646,1.19559,0.0071455,0.0071455,0.0071455 +175,21525.8,1.16383,1.49168,1.22166,0.6036,0.45861,0.4944,0.35106,1.14003,1.25551,1.19536,0.007129,0.007129,0.007129 +176,21648.8,1.17552,1.49129,1.22453,0.60517,0.45817,0.49501,0.35146,1.13959,1.2542,1.19512,0.0071125,0.0071125,0.0071125 +177,21771.4,1.17042,1.49677,1.22545,0.60722,0.45752,0.49491,0.35176,1.13937,1.25331,1.19484,0.007096,0.007096,0.007096 +178,21893.2,1.16542,1.49618,1.22475,0.607,0.4584,0.49541,0.35236,1.13922,1.25239,1.19452,0.0070795,0.0070795,0.0070795 +179,22014.4,1.17028,1.49335,1.2249,0.60812,0.45851,0.49598,0.35274,1.13883,1.25112,1.19425,0.007063,0.007063,0.007063 +180,22136.6,1.16038,1.48225,1.22248,0.60888,0.45881,0.49646,0.35293,1.13851,1.25018,1.19392,0.0070465,0.0070465,0.0070465 +181,22260.4,1.1704,1.49107,1.22346,0.61045,0.45813,0.49668,0.35299,1.13827,1.24916,1.19383,0.00703,0.00703,0.00703 +182,22384.2,1.16535,1.4826,1.22034,0.60662,0.46056,0.49687,0.35326,1.13783,1.24783,1.19348,0.0070135,0.0070135,0.0070135 +183,22507.6,1.16819,1.48799,1.21977,0.60904,0.45991,0.49713,0.35325,1.13759,1.24668,1.19314,0.006997,0.006997,0.006997 +184,22629.7,1.16597,1.49235,1.22334,0.60826,0.46064,0.49775,0.35353,1.13737,1.24531,1.1929,0.0069805,0.0069805,0.0069805 +185,22751.8,1.16614,1.48531,1.22204,0.61004,0.45989,0.49814,0.35368,1.13721,1.24474,1.19272,0.006964,0.006964,0.006964 +186,22874.2,1.16232,1.48583,1.22098,0.60792,0.46093,0.49823,0.35397,1.13672,1.24364,1.19239,0.0069475,0.0069475,0.0069475 +187,22995.2,1.16154,1.48541,1.22009,0.60623,0.46092,0.49822,0.35425,1.13641,1.24238,1.19207,0.006931,0.006931,0.006931 +188,23115.9,1.16353,1.48441,1.22354,0.60654,0.46199,0.49874,0.35459,1.1359,1.24152,1.19173,0.0069145,0.0069145,0.0069145 +189,23237,1.15932,1.47695,1.22068,0.60855,0.46186,0.49919,0.35495,1.13583,1.24056,1.19165,0.006898,0.006898,0.006898 +190,23358.5,1.16589,1.48396,1.22127,0.60885,0.46166,0.49929,0.35508,1.13546,1.23954,1.19135,0.0068815,0.0068815,0.0068815 +191,23481.4,1.17104,1.48002,1.22356,0.60769,0.46358,0.50004,0.35532,1.13539,1.2386,1.19125,0.006865,0.006865,0.006865 +192,23603.6,1.17294,1.4925,1.22464,0.60902,0.46228,0.49999,0.35535,1.13527,1.23764,1.19114,0.0068485,0.0068485,0.0068485 +193,23725.7,1.16616,1.48037,1.22419,0.60981,0.46283,0.50028,0.35549,1.13519,1.23685,1.19098,0.006832,0.006832,0.006832 +194,23846.6,1.16773,1.48498,1.21928,0.61087,0.46395,0.50076,0.3557,1.13498,1.23623,1.19075,0.0068155,0.0068155,0.0068155 +195,23968.3,1.16187,1.476,1.21577,0.61107,0.46383,0.50076,0.35594,1.1347,1.23538,1.19046,0.006799,0.006799,0.006799 +196,24090.6,1.16186,1.48511,1.22023,0.61093,0.46312,0.50101,0.35617,1.1346,1.23462,1.19022,0.0067825,0.0067825,0.0067825 +197,24213.1,1.16952,1.48677,1.22129,0.61006,0.46477,0.50115,0.35647,1.13422,1.23397,1.18993,0.006766,0.006766,0.006766 +198,24334.5,1.1584,1.47642,1.21795,0.61005,0.46581,0.50139,0.35659,1.13425,1.2331,1.18984,0.0067495,0.0067495,0.0067495 +199,24457.3,1.1677,1.47968,1.21956,0.60625,0.46698,0.50152,0.35695,1.13402,1.23222,1.18967,0.006733,0.006733,0.006733 +200,24578.6,1.16075,1.48808,1.21867,0.60726,0.46687,0.50191,0.35709,1.13328,1.23146,1.18919,0.0067165,0.0067165,0.0067165 +201,24700.7,1.1631,1.4836,1.2208,0.60671,0.46686,0.50196,0.35715,1.13269,1.23049,1.18873,0.0067,0.0067,0.0067 +202,24821.4,1.16021,1.47028,1.21839,0.60689,0.46773,0.50225,0.35733,1.13235,1.22956,1.18844,0.0066835,0.0066835,0.0066835 +203,24942.2,1.16818,1.48032,1.22379,0.6097,0.46745,0.50267,0.35752,1.13185,1.22858,1.18818,0.006667,0.006667,0.006667 +204,25064.8,1.16258,1.47052,1.21936,0.61367,0.46586,0.50271,0.35763,1.13148,1.22815,1.18791,0.0066505,0.0066505,0.0066505 +205,25186.6,1.15348,1.47451,1.21541,0.61376,0.4666,0.50299,0.35772,1.13137,1.22726,1.18781,0.006634,0.006634,0.006634 +206,25308.4,1.16116,1.47818,1.2187,0.61468,0.46686,0.5032,0.3578,1.13122,1.22635,1.18768,0.0066175,0.0066175,0.0066175 +207,25431.3,1.156,1.46688,1.21665,0.61512,0.46625,0.50373,0.35806,1.13112,1.22545,1.18753,0.006601,0.006601,0.006601 +208,25553.3,1.16161,1.4693,1.21883,0.61328,0.46757,0.50414,0.35863,1.13093,1.22471,1.18742,0.0065845,0.0065845,0.0065845 +209,25677,1.16028,1.46781,1.22123,0.61443,0.4697,0.50451,0.3589,1.13093,1.22412,1.18737,0.006568,0.006568,0.006568 +210,25799.9,1.1617,1.47092,1.21825,0.61448,0.47024,0.50457,0.359,1.13053,1.22326,1.18707,0.0065515,0.0065515,0.0065515 +211,25920,1.16172,1.46807,1.21834,0.61223,0.47215,0.50497,0.35916,1.13041,1.22261,1.18695,0.006535,0.006535,0.006535 +212,26041.6,1.1583,1.45894,1.21416,0.61144,0.47266,0.50508,0.35952,1.13024,1.22201,1.18688,0.0065185,0.0065185,0.0065185 +213,26164.1,1.15832,1.465,1.21645,0.61268,0.47162,0.50534,0.35969,1.13011,1.22111,1.18672,0.006502,0.006502,0.006502 +214,26284.5,1.16271,1.47272,1.21903,0.60954,0.47373,0.50527,0.35998,1.12984,1.22027,1.18652,0.0064855,0.0064855,0.0064855 +215,26406.4,1.1538,1.46152,1.21371,0.60974,0.47412,0.50555,0.36004,1.12948,1.2193,1.18627,0.006469,0.006469,0.006469 +216,26529.3,1.15707,1.46324,1.21743,0.6124,0.47409,0.50547,0.36003,1.12956,1.21886,1.18619,0.0064525,0.0064525,0.0064525 +217,26651.3,1.15743,1.45941,1.2136,0.61178,0.47429,0.50562,0.36003,1.12923,1.21815,1.18596,0.006436,0.006436,0.006436 +218,26774.4,1.15813,1.46462,1.21433,0.61153,0.47473,0.50606,0.36016,1.12867,1.2173,1.18558,0.0064195,0.0064195,0.0064195 +219,26898.2,1.1565,1.46567,1.21623,0.60962,0.47546,0.50634,0.3604,1.12836,1.21703,1.18532,0.006403,0.006403,0.006403 +220,27020.5,1.15573,1.45987,1.2122,0.60999,0.47576,0.50645,0.36045,1.12784,1.21616,1.18493,0.0063865,0.0063865,0.0063865 +221,27142.9,1.15577,1.46553,1.21496,0.61265,0.47451,0.50665,0.36046,1.12773,1.21548,1.18469,0.00637,0.00637,0.00637 +222,27265.7,1.15498,1.45663,1.21415,0.6134,0.47431,0.50708,0.36088,1.12752,1.21494,1.18447,0.0063535,0.0063535,0.0063535 +223,27386.7,1.15991,1.45966,1.21521,0.61437,0.47463,0.50727,0.36107,1.12722,1.2143,1.18418,0.006337,0.006337,0.006337 +224,27508.3,1.15696,1.46322,1.21234,0.6143,0.47427,0.50768,0.36114,1.1272,1.21376,1.18395,0.0063205,0.0063205,0.0063205 +225,27631.3,1.15691,1.46507,1.21566,0.61238,0.47517,0.50785,0.36141,1.12705,1.21314,1.18377,0.006304,0.006304,0.006304 +226,27752.9,1.15458,1.46888,1.21348,0.61127,0.47592,0.50809,0.36169,1.12695,1.21256,1.18358,0.0062875,0.0062875,0.0062875 +227,27875.4,1.14883,1.45668,1.21275,0.61202,0.47589,0.50803,0.36167,1.12674,1.21221,1.1833,0.006271,0.006271,0.006271 +228,27998.2,1.16237,1.45634,1.21692,0.61455,0.47512,0.50834,0.36173,1.12663,1.21155,1.18313,0.0062545,0.0062545,0.0062545 +229,28120.5,1.15235,1.44679,1.21153,0.61113,0.47649,0.50865,0.36189,1.12646,1.21099,1.18293,0.006238,0.006238,0.006238 +230,28242.8,1.15549,1.45083,1.21036,0.6148,0.47539,0.50879,0.36188,1.12614,1.21038,1.18264,0.0062215,0.0062215,0.0062215 +231,28364.6,1.15294,1.45228,1.21412,0.61729,0.47463,0.50898,0.36203,1.12646,1.20979,1.18263,0.006205,0.006205,0.006205 +232,28484.9,1.15678,1.45168,1.21299,0.62126,0.47444,0.50917,0.36217,1.12637,1.20903,1.18255,0.0061885,0.0061885,0.0061885 +233,28606.1,1.1528,1.44867,1.21122,0.61989,0.47532,0.50956,0.36227,1.12617,1.20848,1.18253,0.006172,0.006172,0.006172 +234,28727.9,1.15224,1.44526,1.21099,0.6191,0.47528,0.50952,0.36234,1.12595,1.20781,1.1822,0.0061555,0.0061555,0.0061555 +235,28850.6,1.14946,1.45854,1.20941,0.61854,0.47747,0.51014,0.36279,1.12577,1.20737,1.182,0.006139,0.006139,0.006139 +236,28973.5,1.16183,1.45517,1.21328,0.62091,0.47653,0.51034,0.3629,1.12543,1.20661,1.18158,0.0061225,0.0061225,0.0061225 +237,29096,1.15413,1.45213,1.21261,0.61682,0.47685,0.51033,0.36282,1.12537,1.20623,1.1815,0.006106,0.006106,0.006106 +238,29216.8,1.15607,1.46,1.21426,0.62197,0.47523,0.51069,0.36299,1.12495,1.20584,1.18123,0.0060895,0.0060895,0.0060895 +239,29339,1.15367,1.45309,1.2146,0.61841,0.47553,0.51074,0.36304,1.12464,1.20506,1.18095,0.006073,0.006073,0.006073 +240,29460.4,1.15787,1.44761,1.2107,0.62691,0.47274,0.51099,0.36321,1.12428,1.20462,1.18061,0.0060565,0.0060565,0.0060565 +241,29581.5,1.14903,1.45092,1.2122,0.62374,0.47336,0.51108,0.36345,1.12406,1.20399,1.18051,0.00604,0.00604,0.00604 +242,29704.1,1.15645,1.44669,1.21269,0.62435,0.47352,0.51135,0.36345,1.12384,1.20369,1.1803,0.0060235,0.0060235,0.0060235 +243,29826.5,1.15887,1.45198,1.21282,0.62397,0.47368,0.51174,0.36353,1.1238,1.20342,1.1802,0.006007,0.006007,0.006007 +244,29948.1,1.14847,1.44544,1.20981,0.6252,0.47458,0.51168,0.3636,1.12345,1.20294,1.17987,0.0059905,0.0059905,0.0059905 +245,30071,1.15926,1.45295,1.21251,0.62363,0.47508,0.51204,0.36375,1.12338,1.20247,1.17968,0.005974,0.005974,0.005974 +246,30192.8,1.15503,1.44675,1.21418,0.62364,0.47549,0.51204,0.36372,1.12328,1.20191,1.17962,0.0059575,0.0059575,0.0059575 +247,30315.3,1.15195,1.44593,1.20971,0.62368,0.47464,0.51236,0.36373,1.12325,1.20136,1.17959,0.005941,0.005941,0.005941 +248,30437.9,1.1511,1.44973,1.21229,0.62462,0.47372,0.51246,0.36392,1.1231,1.20092,1.17945,0.0059245,0.0059245,0.0059245 +249,30560.9,1.14567,1.4502,1.20933,0.62599,0.474,0.51279,0.36415,1.12273,1.20083,1.17912,0.005908,0.005908,0.005908 +250,30683,1.15206,1.44252,1.21118,0.62543,0.47504,0.51316,0.36423,1.12258,1.20074,1.17894,0.0058915,0.0058915,0.0058915 +251,30805.8,1.15248,1.45008,1.2128,0.62526,0.47482,0.51289,0.36435,1.12237,1.20019,1.17877,0.005875,0.005875,0.005875 +252,30926.9,1.14886,1.44554,1.20854,0.62576,0.47511,0.51333,0.36463,1.12188,1.19961,1.17846,0.0058585,0.0058585,0.0058585 +253,31049,1.14653,1.4465,1.21093,0.62648,0.47407,0.51341,0.36472,1.12212,1.19918,1.17864,0.005842,0.005842,0.005842 +254,31172.3,1.14668,1.43547,1.20928,0.62411,0.47539,0.5135,0.36467,1.12191,1.19843,1.17846,0.0058255,0.0058255,0.0058255 +255,31293.9,1.1535,1.44363,1.2107,0.6288,0.47262,0.5135,0.3649,1.12176,1.19805,1.17825,0.005809,0.005809,0.005809 +256,31415.3,1.14864,1.44737,1.21154,0.6253,0.47456,0.51365,0.36484,1.12148,1.19762,1.17806,0.0057925,0.0057925,0.0057925 +257,31537.8,1.1552,1.43833,1.20974,0.62722,0.47396,0.51383,0.36501,1.12128,1.19718,1.17783,0.005776,0.005776,0.005776 +258,31658.9,1.15876,1.44563,1.21083,0.62659,0.47491,0.51405,0.36529,1.12118,1.1969,1.1777,0.0057595,0.0057595,0.0057595 +259,31781.4,1.15249,1.445,1.2076,0.62795,0.47451,0.51415,0.36522,1.12144,1.1966,1.17786,0.005743,0.005743,0.005743 +260,31902.5,1.13892,1.43158,1.20237,0.62756,0.47571,0.51448,0.36542,1.12138,1.19612,1.17776,0.0057265,0.0057265,0.0057265 +261,32025.1,1.15259,1.43886,1.20735,0.62509,0.47685,0.51484,0.36565,1.12113,1.19531,1.17746,0.00571,0.00571,0.00571 +262,32148,1.14994,1.44462,1.20997,0.62596,0.4766,0.51516,0.36571,1.12117,1.19515,1.17743,0.0056935,0.0056935,0.0056935 +263,32269.4,1.15312,1.445,1.20846,0.62619,0.47613,0.51506,0.36572,1.12117,1.19467,1.1773,0.005677,0.005677,0.005677 +264,32392,1.14397,1.43747,1.2073,0.62541,0.47634,0.5152,0.36609,1.12087,1.1939,1.17715,0.0056605,0.0056605,0.0056605 +265,32515.2,1.14867,1.43804,1.20832,0.62576,0.47702,0.51549,0.36628,1.12044,1.1935,1.17687,0.005644,0.005644,0.005644 +266,32637.5,1.14921,1.43529,1.20791,0.62289,0.47832,0.51566,0.36646,1.12034,1.19323,1.17669,0.0056275,0.0056275,0.0056275 +267,32758.4,1.14976,1.43382,1.20777,0.62404,0.47772,0.51558,0.36652,1.12026,1.19286,1.17662,0.005611,0.005611,0.005611 +268,32881.4,1.14447,1.4293,1.20764,0.62577,0.47685,0.5158,0.3669,1.12031,1.19259,1.17647,0.0055945,0.0055945,0.0055945 +269,33001.6,1.14492,1.42795,1.20575,0.62692,0.47779,0.51614,0.36725,1.11987,1.19209,1.17611,0.005578,0.005578,0.005578 +270,33123.8,1.13725,1.4369,1.204,0.62743,0.47637,0.51644,0.36753,1.11968,1.19172,1.17592,0.0055615,0.0055615,0.0055615 +271,33245.8,1.14209,1.42608,1.20508,0.62582,0.47755,0.51682,0.36768,1.11925,1.1908,1.17543,0.005545,0.005545,0.005545 +272,33365.5,1.14202,1.43078,1.20539,0.63068,0.47666,0.51706,0.36799,1.11909,1.19023,1.17526,0.0055285,0.0055285,0.0055285 +273,33489.1,1.14617,1.4332,1.20547,0.62635,0.4788,0.5174,0.36836,1.11901,1.19002,1.17518,0.005512,0.005512,0.005512 +274,33612.6,1.1508,1.43186,1.20713,0.62539,0.47996,0.51743,0.36833,1.11899,1.18964,1.17504,0.0054955,0.0054955,0.0054955 +275,33734.3,1.14165,1.42282,1.20296,0.62427,0.48108,0.51773,0.36833,1.11874,1.18968,1.17486,0.005479,0.005479,0.005479 +276,33855.9,1.15227,1.43719,1.20808,0.62716,0.47954,0.51792,0.36858,1.11848,1.18916,1.17479,0.0054625,0.0054625,0.0054625 +277,33979,1.1506,1.43395,1.20718,0.62533,0.48099,0.51832,0.36869,1.11861,1.18869,1.17485,0.005446,0.005446,0.005446 +278,34101.2,1.13528,1.4248,1.20126,0.62501,0.48125,0.51861,0.36887,1.11838,1.18805,1.17471,0.0054295,0.0054295,0.0054295 +279,34224.6,1.14285,1.42097,1.2038,0.62055,0.48303,0.51872,0.3688,1.11843,1.18799,1.17463,0.005413,0.005413,0.005413 +280,34344.3,1.14136,1.42075,1.20491,0.61939,0.48369,0.51889,0.36887,1.1184,1.18746,1.17473,0.0053965,0.0053965,0.0053965 +281,34467.2,1.15119,1.43092,1.20673,0.61912,0.48414,0.51922,0.369,1.11849,1.18689,1.17469,0.00538,0.00538,0.00538 +282,34589.7,1.13937,1.41884,1.20365,0.62021,0.48434,0.51957,0.36933,1.1183,1.18651,1.17459,0.0053635,0.0053635,0.0053635 +283,34713,1.14871,1.42153,1.20523,0.62243,0.48367,0.51998,0.36965,1.1183,1.18606,1.17448,0.005347,0.005347,0.005347 +284,34835.2,1.14401,1.42774,1.20389,0.62176,0.48351,0.52003,0.36973,1.11835,1.18548,1.17429,0.0053305,0.0053305,0.0053305 +285,34957.3,1.15205,1.42987,1.20615,0.62171,0.48261,0.51994,0.36978,1.11855,1.18529,1.17429,0.005314,0.005314,0.005314 +286,35078.1,1.13822,1.40917,1.19815,0.62155,0.48299,0.52018,0.36981,1.11841,1.18456,1.174,0.0052975,0.0052975,0.0052975 +287,35200.8,1.13423,1.40652,1.19916,0.6219,0.48258,0.52048,0.36997,1.11796,1.18429,1.17364,0.005281,0.005281,0.005281 +288,35322.9,1.13678,1.42941,1.20383,0.62258,0.48262,0.52054,0.37005,1.11776,1.18418,1.17351,0.0052645,0.0052645,0.0052645 +289,35446.1,1.14222,1.42592,1.20436,0.62685,0.48057,0.52076,0.37007,1.11776,1.18381,1.17346,0.005248,0.005248,0.005248 +290,35568.3,1.14047,1.41462,1.20024,0.62633,0.48171,0.52093,0.37014,1.11754,1.18331,1.17317,0.0052315,0.0052315,0.0052315 +291,35691.2,1.13849,1.41023,1.20078,0.62988,0.47969,0.52116,0.37049,1.1177,1.18302,1.17316,0.005215,0.005215,0.005215 +292,35812.6,1.13917,1.42771,1.20456,0.62488,0.48242,0.52141,0.37072,1.11739,1.18268,1.1728,0.0051985,0.0051985,0.0051985 +293,35933.9,1.13472,1.42209,1.20208,0.62929,0.47961,0.52159,0.37086,1.11725,1.18218,1.17271,0.005182,0.005182,0.005182 +294,36054.9,1.14068,1.4162,1.20347,0.62596,0.48308,0.52186,0.37082,1.11731,1.18214,1.17255,0.0051655,0.0051655,0.0051655 +295,36176.5,1.14238,1.4131,1.20257,0.62464,0.48333,0.5219,0.37111,1.11719,1.1817,1.17237,0.005149,0.005149,0.005149 +296,36298.3,1.14358,1.41835,1.20333,0.6272,0.48029,0.52179,0.37084,1.11704,1.18145,1.17234,0.0051325,0.0051325,0.0051325 +297,36419.4,1.14227,1.42181,1.20257,0.62674,0.48123,0.52187,0.37091,1.11709,1.18119,1.17229,0.005116,0.005116,0.005116 +298,36540.2,1.1363,1.4127,1.20005,0.62801,0.48166,0.52226,0.37104,1.11704,1.18079,1.17218,0.0050995,0.0050995,0.0050995 +299,36661.6,1.14618,1.41524,1.20622,0.62525,0.48186,0.52204,0.37126,1.11702,1.18062,1.17207,0.005083,0.005083,0.005083 +300,36784.4,1.13965,1.41777,1.20048,0.62549,0.48214,0.52217,0.3712,1.11705,1.18022,1.1721,0.0050665,0.0050665,0.0050665 +301,36907,1.14113,1.41401,1.20102,0.62546,0.48228,0.52236,0.37153,1.11678,1.18003,1.17189,0.00505,0.00505,0.00505 +302,37028.3,1.13843,1.40746,1.20256,0.62776,0.48202,0.52249,0.37164,1.11687,1.17952,1.17177,0.0050335,0.0050335,0.0050335 +303,37150,1.13547,1.41273,1.20169,0.62408,0.48438,0.52272,0.37173,1.11681,1.17906,1.17162,0.005017,0.005017,0.005017 +304,37271.4,1.13766,1.41346,1.20145,0.62599,0.48389,0.52288,0.37174,1.1166,1.17859,1.17143,0.0050005,0.0050005,0.0050005 +305,37393.2,1.14284,1.41598,1.20182,0.62507,0.48414,0.52282,0.3719,1.11636,1.17809,1.17126,0.004984,0.004984,0.004984 +306,37515.2,1.13577,1.41307,1.19989,0.6259,0.4849,0.52323,0.37219,1.11623,1.17772,1.17131,0.0049675,0.0049675,0.0049675 +307,37636.6,1.13482,1.41074,1.19862,0.62444,0.48503,0.5234,0.37239,1.11629,1.17733,1.17122,0.004951,0.004951,0.004951 +308,37759.9,1.13951,1.40839,1.19827,0.62814,0.48441,0.52345,0.37233,1.11612,1.17693,1.17099,0.0049345,0.0049345,0.0049345 +309,37881.8,1.14039,1.41317,1.19865,0.62816,0.48413,0.52358,0.37259,1.11572,1.17644,1.17071,0.004918,0.004918,0.004918 +310,38003.9,1.13493,1.41438,1.19891,0.62762,0.48547,0.5241,0.37259,1.11591,1.17593,1.17062,0.0049015,0.0049015,0.0049015 +311,38127.2,1.13452,1.4025,1.19865,0.62561,0.48658,0.5242,0.37281,1.11584,1.17557,1.17046,0.004885,0.004885,0.004885 +312,38250.2,1.13704,1.40101,1.19729,0.62643,0.48631,0.52427,0.37294,1.11551,1.1756,1.17012,0.0048685,0.0048685,0.0048685 +313,38373.9,1.13657,1.40058,1.19785,0.62741,0.48628,0.52438,0.37321,1.11533,1.17519,1.16999,0.004852,0.004852,0.004852 +314,38496,1.13884,1.40785,1.19942,0.62662,0.48703,0.52437,0.37341,1.11517,1.17477,1.16989,0.0048355,0.0048355,0.0048355 +315,38617.8,1.13536,1.40436,1.19685,0.62803,0.48545,0.52438,0.37348,1.11482,1.17418,1.16964,0.004819,0.004819,0.004819 +316,38739.4,1.1334,1.40157,1.19522,0.62968,0.48461,0.52468,0.37376,1.11436,1.17378,1.16917,0.0048025,0.0048025,0.0048025 +317,38860.7,1.13981,1.40101,1.19721,0.63,0.48471,0.52506,0.37418,1.11412,1.17321,1.16885,0.004786,0.004786,0.004786 +318,38983.4,1.12988,1.39365,1.19601,0.63097,0.48402,0.52527,0.37414,1.11406,1.1726,1.16868,0.0047695,0.0047695,0.0047695 +319,39106.7,1.13633,1.39697,1.19742,0.63151,0.48421,0.5254,0.37413,1.11382,1.17241,1.16841,0.004753,0.004753,0.004753 +320,39227.9,1.13255,1.39938,1.19653,0.63272,0.48383,0.52555,0.37409,1.11372,1.17208,1.16826,0.0047365,0.0047365,0.0047365 +321,39349.3,1.1388,1.40281,1.19982,0.63496,0.48238,0.52583,0.37453,1.11355,1.17181,1.16829,0.00472,0.00472,0.00472 +322,39471.3,1.13306,1.39656,1.19822,0.63422,0.48324,0.52588,0.37464,1.11353,1.17157,1.16808,0.0047035,0.0047035,0.0047035 +323,39593.1,1.13459,1.39716,1.19605,0.63097,0.48381,0.52597,0.37441,1.11358,1.17128,1.16807,0.004687,0.004687,0.004687 +324,39714.6,1.13332,1.39609,1.19443,0.63195,0.48319,0.52607,0.3747,1.11327,1.17117,1.16788,0.0046705,0.0046705,0.0046705 +325,39836.1,1.13191,1.39813,1.19582,0.63323,0.48227,0.52619,0.37469,1.1131,1.17076,1.16778,0.004654,0.004654,0.004654 +326,39957.8,1.13323,1.40812,1.19998,0.6324,0.48355,0.52643,0.37483,1.11288,1.17043,1.16765,0.0046375,0.0046375,0.0046375 +327,40079.1,1.12451,1.40888,1.19653,0.6345,0.48289,0.52637,0.37492,1.11261,1.17009,1.16746,0.004621,0.004621,0.004621 +328,40200.5,1.12981,1.3984,1.19668,0.63358,0.48266,0.52643,0.37511,1.11252,1.16995,1.16741,0.0046045,0.0046045,0.0046045 +329,40322.9,1.12629,1.38532,1.1923,0.63862,0.48028,0.52633,0.37505,1.11256,1.1696,1.16731,0.004588,0.004588,0.004588 +330,40442.2,1.13411,1.38836,1.19423,0.63615,0.48151,0.52653,0.37529,1.11242,1.16944,1.16711,0.0045715,0.0045715,0.0045715 +331,40562.7,1.13074,1.38847,1.19511,0.63795,0.4808,0.52667,0.37537,1.11218,1.16909,1.167,0.004555,0.004555,0.004555 +332,40685.1,1.12984,1.39602,1.19467,0.63799,0.48016,0.52672,0.37556,1.11211,1.16885,1.1669,0.0045385,0.0045385,0.0045385 +333,40808.3,1.13347,1.39759,1.19658,0.6391,0.4799,0.5268,0.3757,1.11171,1.16851,1.16661,0.004522,0.004522,0.004522 +334,40930.3,1.1236,1.39425,1.19287,0.63904,0.48093,0.5271,0.37548,1.11143,1.16835,1.16636,0.0045055,0.0045055,0.0045055 +335,41058.5,1.12788,1.3869,1.19291,0.64045,0.48044,0.52707,0.37576,1.11136,1.16801,1.16635,0.004489,0.004489,0.004489 +336,41186.4,1.13174,1.38914,1.19368,0.63923,0.48101,0.5272,0.37601,1.11114,1.1677,1.16621,0.0044725,0.0044725,0.0044725 +337,41314.2,1.12872,1.38663,1.19431,0.64236,0.48053,0.52729,0.37598,1.11111,1.16753,1.16618,0.004456,0.004456,0.004456 +338,41442.3,1.12865,1.39012,1.19365,0.64463,0.47973,0.52718,0.37611,1.11099,1.16746,1.16601,0.0044395,0.0044395,0.0044395 +339,41570.1,1.13205,1.38495,1.1952,0.64603,0.48001,0.52742,0.37613,1.11084,1.16709,1.16575,0.004423,0.004423,0.004423 +340,41699,1.13001,1.385,1.19518,0.64644,0.48018,0.52764,0.37625,1.11082,1.16677,1.16572,0.0044065,0.0044065,0.0044065 +341,41825.5,1.13339,1.39333,1.19654,0.64595,0.4803,0.52758,0.37631,1.11076,1.16659,1.16571,0.00439,0.00439,0.00439 +342,41954.1,1.13166,1.38991,1.19423,0.64533,0.48014,0.52773,0.3766,1.11055,1.16634,1.16562,0.0043735,0.0043735,0.0043735 +343,42081.1,1.12984,1.38069,1.19304,0.64367,0.48105,0.52756,0.37664,1.11051,1.16593,1.16546,0.004357,0.004357,0.004357 +344,42207.7,1.13685,1.39235,1.19641,0.64297,0.48192,0.52777,0.37673,1.11066,1.16565,1.16565,0.0043405,0.0043405,0.0043405 +345,42333.9,1.13197,1.38535,1.19252,0.645,0.48168,0.5277,0.3767,1.11055,1.16512,1.16557,0.004324,0.004324,0.004324 +346,42462.5,1.12267,1.38578,1.19173,0.64105,0.48337,0.528,0.377,1.11052,1.16466,1.16552,0.0043075,0.0043075,0.0043075 +347,42588.4,1.1315,1.38939,1.19351,0.63856,0.48498,0.52839,0.37718,1.11034,1.16422,1.16557,0.004291,0.004291,0.004291 +348,42715.2,1.1298,1.3737,1.19316,0.63618,0.4867,0.52888,0.37733,1.1102,1.16383,1.16551,0.0042745,0.0042745,0.0042745 +349,42841.6,1.12887,1.38555,1.19237,0.63551,0.48647,0.52867,0.37764,1.10988,1.16332,1.1652,0.004258,0.004258,0.004258 +350,42969.8,1.12698,1.38786,1.19202,0.63616,0.48683,0.52869,0.3775,1.1096,1.16284,1.16513,0.0042415,0.0042415,0.0042415 +351,43096.4,1.13186,1.38128,1.19271,0.63562,0.48791,0.52881,0.37765,1.10951,1.16226,1.16494,0.004225,0.004225,0.004225 +352,43222.2,1.13052,1.38271,1.1958,0.63649,0.48734,0.52882,0.37779,1.10957,1.16191,1.16491,0.0042085,0.0042085,0.0042085 +353,43349,1.12531,1.37478,1.19149,0.63509,0.48732,0.52875,0.37768,1.10938,1.16144,1.16479,0.004192,0.004192,0.004192 +354,43476.6,1.12839,1.36722,1.19315,0.63798,0.48676,0.52904,0.37802,1.10962,1.16099,1.16487,0.0041755,0.0041755,0.0041755 +355,43604,1.1225,1.37733,1.18909,0.63866,0.48745,0.52934,0.37799,1.10906,1.16073,1.16444,0.004159,0.004159,0.004159 +356,43730.2,1.12413,1.37934,1.18992,0.64112,0.48589,0.52921,0.37809,1.10907,1.16027,1.16457,0.0041425,0.0041425,0.0041425 +357,43856.5,1.12642,1.37163,1.19082,0.64293,0.4854,0.52931,0.37825,1.10903,1.16018,1.16458,0.004126,0.004126,0.004126 +358,43980.1,1.125,1.37305,1.18839,0.64076,0.48614,0.52895,0.37781,1.10892,1.16004,1.16449,0.0041095,0.0041095,0.0041095 +359,44104.2,1.12529,1.37666,1.19122,0.63971,0.48634,0.52885,0.37786,1.10867,1.15966,1.16427,0.004093,0.004093,0.004093 +360,44230.6,1.12659,1.37287,1.19066,0.64003,0.48661,0.52914,0.37794,1.10845,1.15971,1.16398,0.0040765,0.0040765,0.0040765 +361,44355.7,1.12364,1.37019,1.18943,0.63981,0.48815,0.52988,0.3782,1.10844,1.15938,1.16393,0.00406,0.00406,0.00406 +362,44481.7,1.12499,1.37333,1.19059,0.63777,0.48913,0.5299,0.37842,1.10818,1.15905,1.16375,0.0040435,0.0040435,0.0040435 +363,44606.5,1.12051,1.36284,1.18526,0.63767,0.4891,0.52985,0.37846,1.10811,1.15851,1.16358,0.004027,0.004027,0.004027 +364,44732.7,1.12263,1.37152,1.19166,0.63903,0.4889,0.53012,0.3787,1.10791,1.15776,1.16341,0.0040105,0.0040105,0.0040105 +365,44858.6,1.11826,1.36873,1.18773,0.63988,0.48859,0.53025,0.37889,1.10776,1.15783,1.16327,0.003994,0.003994,0.003994 +366,44984.3,1.1217,1.36137,1.18842,0.64151,0.48861,0.53054,0.37908,1.10761,1.15737,1.16329,0.0039775,0.0039775,0.0039775 +367,45110.7,1.11637,1.36184,1.18618,0.63986,0.48914,0.53036,0.37914,1.10731,1.15702,1.16314,0.003961,0.003961,0.003961 +368,45236.8,1.12542,1.37212,1.18984,0.64171,0.48851,0.53048,0.37909,1.10741,1.15692,1.16304,0.0039445,0.0039445,0.0039445 +369,45361.5,1.11586,1.35413,1.18575,0.64314,0.48714,0.5306,0.37931,1.10761,1.15665,1.16296,0.003928,0.003928,0.003928 +370,45487.4,1.12192,1.36459,1.18575,0.6418,0.48856,0.53073,0.37961,1.10738,1.15634,1.16264,0.0039115,0.0039115,0.0039115 +371,45614.7,1.11962,1.36154,1.18821,0.6409,0.48768,0.5311,0.37978,1.1071,1.15618,1.16245,0.003895,0.003895,0.003895 +372,45739.1,1.12318,1.36313,1.18765,0.64421,0.48577,0.53132,0.37975,1.1069,1.15592,1.16232,0.0038785,0.0038785,0.0038785 +373,45865.8,1.12097,1.36959,1.19014,0.64466,0.48665,0.53163,0.37977,1.10672,1.15557,1.1621,0.003862,0.003862,0.003862 +374,45992.4,1.12509,1.36481,1.18921,0.64339,0.48709,0.53187,0.37991,1.10679,1.1554,1.16216,0.0038455,0.0038455,0.0038455 +375,46119.9,1.12229,1.36632,1.18734,0.64439,0.48619,0.5322,0.38009,1.10685,1.15495,1.16207,0.003829,0.003829,0.003829 +376,46247.6,1.12301,1.36683,1.18949,0.6437,0.48737,0.53255,0.38016,1.10675,1.15482,1.16203,0.0038125,0.0038125,0.0038125 +377,46373,1.11311,1.36885,1.18667,0.64572,0.48679,0.53268,0.38049,1.10644,1.15435,1.16182,0.003796,0.003796,0.003796 +378,46499.2,1.11875,1.36476,1.18685,0.64523,0.48699,0.53287,0.38058,1.10639,1.15398,1.16171,0.0037795,0.0037795,0.0037795 +379,46624.4,1.11318,1.36784,1.18429,0.64615,0.48774,0.53303,0.38083,1.10659,1.15344,1.16168,0.003763,0.003763,0.003763 +380,46750.1,1.12772,1.37473,1.18987,0.64242,0.48996,0.53319,0.38089,1.10648,1.15293,1.16139,0.0037465,0.0037465,0.0037465 +381,46875.2,1.11257,1.35529,1.18372,0.64262,0.49007,0.5336,0.38118,1.10635,1.15256,1.16126,0.00373,0.00373,0.00373 +382,47000.8,1.12111,1.35787,1.1869,0.63947,0.49089,0.53366,0.38119,1.1063,1.15239,1.16127,0.0037135,0.0037135,0.0037135 +383,47126.1,1.12381,1.36047,1.19006,0.63897,0.49084,0.53384,0.38104,1.10618,1.15213,1.16105,0.003697,0.003697,0.003697 +384,47252.7,1.12278,1.3575,1.1857,0.64085,0.49014,0.53391,0.38131,1.10617,1.15156,1.16102,0.0036805,0.0036805,0.0036805 +385,47378.8,1.12442,1.3627,1.18782,0.63968,0.49178,0.53392,0.38131,1.1061,1.15121,1.16081,0.003664,0.003664,0.003664 +386,47505.3,1.11577,1.35597,1.1854,0.64016,0.49095,0.53405,0.38138,1.10571,1.15082,1.16052,0.0036475,0.0036475,0.0036475 +387,47630.4,1.11707,1.35683,1.18557,0.637,0.4921,0.53417,0.38151,1.10548,1.15051,1.16036,0.003631,0.003631,0.003631 +388,47756,1.11571,1.35301,1.18468,0.6386,0.4913,0.5344,0.38171,1.10534,1.15023,1.16017,0.0036145,0.0036145,0.0036145 +389,47880.5,1.12114,1.34786,1.18443,0.63891,0.49146,0.5343,0.3817,1.10521,1.15025,1.16005,0.003598,0.003598,0.003598 +390,48006.1,1.1107,1.35184,1.18265,0.63903,0.49125,0.53426,0.38166,1.10517,1.15,1.16003,0.0035815,0.0035815,0.0035815 +391,48132.2,1.11104,1.35255,1.18356,0.63854,0.49175,0.53444,0.38193,1.10504,1.14959,1.15987,0.003565,0.003565,0.003565 +392,48258.5,1.12038,1.35572,1.1851,0.63951,0.49235,0.53439,0.38177,1.10512,1.14926,1.15983,0.0035485,0.0035485,0.0035485 +393,48383.8,1.11649,1.35451,1.18412,0.64171,0.49166,0.53456,0.38203,1.10524,1.14908,1.15985,0.003532,0.003532,0.003532 +394,48509.5,1.11774,1.34355,1.1859,0.63993,0.4919,0.53461,0.382,1.10518,1.14895,1.15971,0.0035155,0.0035155,0.0035155 +395,48635.2,1.11723,1.34836,1.18308,0.64141,0.49284,0.53487,0.38216,1.10497,1.14855,1.1595,0.003499,0.003499,0.003499 +396,48762,1.11741,1.35791,1.18502,0.63942,0.49387,0.53487,0.38237,1.10484,1.14824,1.15939,0.0034825,0.0034825,0.0034825 +397,48885,1.11574,1.35147,1.18186,0.63852,0.49532,0.53499,0.38241,1.10496,1.148,1.15935,0.003466,0.003466,0.003466 +398,49010.7,1.12278,1.34984,1.18404,0.64013,0.49451,0.53526,0.38268,1.10486,1.14774,1.15922,0.0034495,0.0034495,0.0034495 +399,49136.9,1.11283,1.34449,1.18382,0.64114,0.49323,0.53527,0.38244,1.10465,1.14746,1.15907,0.003433,0.003433,0.003433 +400,49259.6,1.11194,1.34474,1.18224,0.64045,0.49345,0.53534,0.38242,1.1048,1.14736,1.1593,0.0034165,0.0034165,0.0034165 +401,49384.6,1.11144,1.33513,1.18032,0.64254,0.49318,0.53566,0.38255,1.10483,1.14701,1.15929,0.0034,0.0034,0.0034 +402,49508.2,1.11051,1.35154,1.18185,0.64233,0.49224,0.53562,0.38254,1.10487,1.14663,1.15924,0.0033835,0.0033835,0.0033835 +403,49631,1.11345,1.34522,1.18287,0.64196,0.49188,0.53572,0.38278,1.10469,1.14633,1.15914,0.003367,0.003367,0.003367 +404,49758.3,1.11074,1.34545,1.18178,0.63943,0.49416,0.53589,0.38295,1.10471,1.14593,1.15912,0.0033505,0.0033505,0.0033505 +405,49883.3,1.11841,1.34788,1.18405,0.64266,0.49272,0.53612,0.38306,1.1046,1.14581,1.15918,0.003334,0.003334,0.003334 +406,50006.8,1.11228,1.33868,1.18274,0.64275,0.49157,0.53622,0.38303,1.1045,1.14546,1.15915,0.0033175,0.0033175,0.0033175 +407,50132.7,1.10713,1.33925,1.17833,0.64049,0.49394,0.53656,0.38328,1.1045,1.14511,1.15916,0.003301,0.003301,0.003301 +408,50257.1,1.1074,1.33673,1.17746,0.64221,0.49232,0.53622,0.38294,1.10456,1.14491,1.15901,0.0032845,0.0032845,0.0032845 +409,50383.1,1.10632,1.33727,1.17624,0.64027,0.49404,0.53635,0.38325,1.10435,1.14477,1.1587,0.003268,0.003268,0.003268 +410,50510.2,1.10605,1.341,1.17959,0.63885,0.49462,0.53643,0.38338,1.10431,1.14451,1.15871,0.0032515,0.0032515,0.0032515 +411,50636.4,1.11751,1.33712,1.17981,0.63818,0.49393,0.53652,0.38308,1.10427,1.14445,1.15852,0.003235,0.003235,0.003235 +412,50761.3,1.11747,1.34257,1.18024,0.63662,0.4948,0.53668,0.38347,1.10416,1.14428,1.15858,0.0032185,0.0032185,0.0032185 +413,50888.4,1.11574,1.34779,1.18147,0.63674,0.49548,0.53645,0.38355,1.10409,1.14435,1.1584,0.003202,0.003202,0.003202 +414,51013.1,1.11126,1.34181,1.18018,0.63701,0.49575,0.53653,0.38356,1.10392,1.14418,1.15815,0.0031855,0.0031855,0.0031855 +415,51139,1.1117,1.34272,1.18203,0.63863,0.49566,0.53684,0.38392,1.10383,1.14388,1.15804,0.003169,0.003169,0.003169 +416,51265.2,1.10782,1.33474,1.17645,0.63686,0.49635,0.53681,0.38394,1.1039,1.14378,1.15802,0.0031525,0.0031525,0.0031525 +417,51392.4,1.11484,1.3361,1.18181,0.63953,0.49539,0.53707,0.384,1.10395,1.14341,1.15806,0.003136,0.003136,0.003136 +418,51518.4,1.10658,1.34342,1.18077,0.64084,0.49514,0.53731,0.38416,1.10399,1.14306,1.15804,0.0031195,0.0031195,0.0031195 +419,51645.7,1.10635,1.32831,1.17525,0.63746,0.4959,0.53746,0.38425,1.10384,1.1428,1.15781,0.003103,0.003103,0.003103 +420,51772.4,1.11255,1.32894,1.17917,0.63849,0.49518,0.53786,0.38456,1.10373,1.14266,1.15752,0.0030865,0.0030865,0.0030865 +421,51896.7,1.10859,1.33312,1.1809,0.63747,0.49631,0.53794,0.38491,1.10352,1.14237,1.15733,0.00307,0.00307,0.00307 +422,52022.8,1.10707,1.33103,1.1786,0.63961,0.49654,0.53858,0.3849,1.10319,1.14203,1.15709,0.0030535,0.0030535,0.0030535 +423,52147.6,1.10698,1.32849,1.17646,0.64102,0.495,0.53849,0.38504,1.10293,1.14174,1.15677,0.003037,0.003037,0.003037 +424,52272.7,1.09733,1.32435,1.17692,0.64371,0.49351,0.53859,0.38501,1.10278,1.14132,1.15663,0.0030205,0.0030205,0.0030205 +425,52398.2,1.10562,1.32036,1.17623,0.64355,0.49505,0.53856,0.3849,1.10263,1.14079,1.15663,0.003004,0.003004,0.003004 +426,52523.1,1.11334,1.33137,1.17899,0.64376,0.49525,0.53916,0.38511,1.10233,1.14043,1.1564,0.0029875,0.0029875,0.0029875 +427,52648.4,1.10903,1.32926,1.17797,0.64485,0.49483,0.53931,0.38511,1.1021,1.1402,1.15628,0.002971,0.002971,0.002971 +428,52772.8,1.09901,1.31011,1.17436,0.6429,0.49546,0.53943,0.38551,1.10186,1.13996,1.15617,0.0029545,0.0029545,0.0029545 +429,52895,1.10452,1.31854,1.17757,0.64417,0.49477,0.53965,0.3856,1.10175,1.13974,1.15613,0.002938,0.002938,0.002938 +430,53021.4,1.10682,1.32884,1.1776,0.64645,0.49517,0.53995,0.38572,1.10158,1.13943,1.15592,0.0029215,0.0029215,0.0029215 +431,53147.5,1.10329,1.33031,1.17916,0.64724,0.4945,0.53963,0.38548,1.10156,1.13925,1.1559,0.002905,0.002905,0.002905 +432,53272.9,1.10194,1.32759,1.1753,0.64801,0.49464,0.53984,0.38576,1.10125,1.13892,1.15575,0.0028885,0.0028885,0.0028885 +433,53399.3,1.1043,1.33122,1.17666,0.65064,0.49393,0.54011,0.38591,1.10125,1.139,1.15575,0.002872,0.002872,0.002872 +434,53524.8,1.10941,1.33472,1.17867,0.65128,0.49352,0.53986,0.38583,1.10128,1.13901,1.1556,0.0028555,0.0028555,0.0028555 +435,53651,1.10597,1.31779,1.17495,0.65054,0.49488,0.53994,0.3859,1.10158,1.13868,1.15578,0.002839,0.002839,0.002839 +436,53776.4,1.0991,1.31966,1.17387,0.65221,0.49267,0.53975,0.38598,1.10151,1.13865,1.15555,0.0028225,0.0028225,0.0028225 +437,53901.5,1.11288,1.33036,1.17742,0.65244,0.49315,0.53992,0.38606,1.10117,1.13859,1.15513,0.002806,0.002806,0.002806 +438,54027.2,1.10263,1.31631,1.17384,0.65431,0.49315,0.54007,0.38628,1.10145,1.13851,1.15518,0.0027895,0.0027895,0.0027895 +439,54151,1.09671,1.31959,1.17341,0.6553,0.49283,0.54036,0.38643,1.10122,1.13859,1.15505,0.002773,0.002773,0.002773 +440,54276.2,1.10329,1.31398,1.17123,0.6529,0.4935,0.5402,0.38641,1.10139,1.13839,1.15521,0.0027565,0.0027565,0.0027565 +441,54402.7,1.10235,1.31973,1.17493,0.65164,0.49373,0.54031,0.38656,1.10125,1.13826,1.15511,0.00274,0.00274,0.00274 +442,54529.5,1.10173,1.3083,1.1745,0.64993,0.49556,0.54035,0.38661,1.10109,1.13802,1.15499,0.0027235,0.0027235,0.0027235 +443,54653,1.10373,1.31153,1.17409,0.64832,0.49649,0.54069,0.38689,1.10115,1.13783,1.155,0.002707,0.002707,0.002707 +444,54779.7,1.10251,1.31197,1.17134,0.64955,0.49613,0.54066,0.38679,1.10101,1.13737,1.15488,0.0026905,0.0026905,0.0026905 +445,54903.9,1.10628,1.31211,1.17663,0.6514,0.49509,0.54082,0.38695,1.10077,1.13724,1.15467,0.002674,0.002674,0.002674 +446,55030.5,1.09198,1.30427,1.16934,0.65413,0.49324,0.54105,0.38729,1.10076,1.13667,1.15456,0.0026575,0.0026575,0.0026575 +447,55154.3,1.09973,1.31084,1.17314,0.65425,0.49407,0.54131,0.38742,1.10053,1.13641,1.15439,0.002641,0.002641,0.002641 +448,55281.4,1.09289,1.30618,1.17266,0.65171,0.495,0.54117,0.3875,1.10024,1.13587,1.15418,0.0026245,0.0026245,0.0026245 +449,55405.5,1.09481,1.3114,1.17218,0.65066,0.49599,0.54165,0.38752,1.09997,1.13546,1.15397,0.002608,0.002608,0.002608 +450,55531.8,1.09562,1.3055,1.17278,0.65114,0.49642,0.54171,0.38781,1.09982,1.1349,1.154,0.0025915,0.0025915,0.0025915 +451,55657.7,1.10089,1.30984,1.17225,0.64961,0.49731,0.54179,0.38782,1.09973,1.13452,1.15379,0.002575,0.002575,0.002575 +452,55784.7,1.10151,1.31042,1.17321,0.64645,0.49793,0.54166,0.38799,1.09957,1.1341,1.15371,0.0025585,0.0025585,0.0025585 +453,55911.6,1.1016,1.30944,1.17294,0.64825,0.4973,0.54171,0.38804,1.09955,1.13388,1.15364,0.002542,0.002542,0.002542 +454,56038.1,1.1015,1.30515,1.17205,0.65081,0.4971,0.54195,0.38824,1.0991,1.13345,1.15344,0.0025255,0.0025255,0.0025255 +455,56164.8,1.10156,1.30808,1.17067,0.65297,0.49679,0.54241,0.38848,1.09916,1.133,1.15346,0.002509,0.002509,0.002509 +456,56291.4,1.08926,1.30302,1.1668,0.65334,0.49619,0.54255,0.38846,1.0989,1.13278,1.15308,0.0024925,0.0024925,0.0024925 +457,56416.2,1.1034,1.30844,1.17281,0.65291,0.49695,0.54257,0.38843,1.09891,1.1324,1.15313,0.002476,0.002476,0.002476 +458,56540.5,1.09733,1.30623,1.17025,0.65299,0.49769,0.54269,0.38855,1.09869,1.13225,1.15281,0.0024595,0.0024595,0.0024595 +459,56667.4,1.09649,1.30426,1.16946,0.6542,0.49717,0.54283,0.38867,1.09848,1.13196,1.1525,0.002443,0.002443,0.002443 +460,56792.5,1.09159,1.29053,1.16707,0.65383,0.49676,0.54275,0.38874,1.09855,1.13179,1.15238,0.0024265,0.0024265,0.0024265 +461,56919,1.09839,1.29972,1.16945,0.65001,0.49852,0.54294,0.38866,1.09856,1.1312,1.15235,0.00241,0.00241,0.00241 +462,57044,1.09409,1.30309,1.16897,0.65031,0.4988,0.54332,0.38885,1.09825,1.13094,1.15211,0.0023935,0.0023935,0.0023935 +463,57170.2,1.09622,1.30218,1.17083,0.64903,0.49971,0.54347,0.38915,1.09785,1.13037,1.15188,0.002377,0.002377,0.002377 +464,57296.6,1.09214,1.29586,1.16668,0.64986,0.49907,0.54351,0.38921,1.09789,1.13025,1.15181,0.0023605,0.0023605,0.0023605 +465,57423.2,1.09237,1.2921,1.16893,0.65107,0.49855,0.5436,0.38933,1.09772,1.12982,1.15148,0.002344,0.002344,0.002344 +466,57549.7,1.09283,1.29465,1.16835,0.65,0.49839,0.54378,0.38945,1.09777,1.12933,1.15155,0.0023275,0.0023275,0.0023275 +467,57673.7,1.09612,1.29256,1.16883,0.64924,0.49922,0.5441,0.38962,1.09774,1.12903,1.15133,0.002311,0.002311,0.002311 +468,57799.4,1.09045,1.29409,1.16578,0.64795,0.49944,0.54396,0.38963,1.09749,1.12885,1.15115,0.0022945,0.0022945,0.0022945 +469,57925.5,1.09097,1.29032,1.16856,0.64509,0.50149,0.54415,0.38965,1.09754,1.1288,1.15105,0.002278,0.002278,0.002278 +470,58051,1.09818,1.30006,1.17277,0.64326,0.50225,0.5444,0.38974,1.0975,1.12877,1.15095,0.0022615,0.0022615,0.0022615 +471,58175.8,1.09027,1.29383,1.16699,0.6448,0.50051,0.54436,0.38959,1.09735,1.12832,1.15086,0.002245,0.002245,0.002245 +472,58302.9,1.09622,1.29695,1.16857,0.64802,0.49911,0.54422,0.38959,1.09725,1.12777,1.15069,0.0022285,0.0022285,0.0022285 +473,58430.1,1.08586,1.28448,1.16311,0.64827,0.50008,0.54444,0.38981,1.09699,1.12774,1.15046,0.002212,0.002212,0.002212 +474,58555.8,1.08566,1.28671,1.1675,0.64733,0.49976,0.54429,0.38984,1.09681,1.12757,1.15037,0.0021955,0.0021955,0.0021955 +475,58678.6,1.08457,1.28172,1.16376,0.64508,0.50085,0.54427,0.39006,1.09701,1.12715,1.15038,0.002179,0.002179,0.002179 +476,58805.1,1.09538,1.28443,1.16727,0.63868,0.50538,0.54462,0.39023,1.09683,1.1268,1.15033,0.0021625,0.0021625,0.0021625 +477,58928.4,1.0872,1.28612,1.16488,0.64004,0.50619,0.54489,0.39058,1.09648,1.1262,1.1501,0.002146,0.002146,0.002146 +478,59052.4,1.0901,1.28417,1.16418,0.64123,0.50541,0.54507,0.39059,1.09648,1.12579,1.15006,0.0021295,0.0021295,0.0021295 +479,59178.1,1.0922,1.28111,1.16629,0.64298,0.5044,0.54543,0.39093,1.09647,1.12563,1.14993,0.002113,0.002113,0.002113 +480,59302.9,1.09309,1.28778,1.16587,0.64771,0.50124,0.54538,0.39075,1.09621,1.12546,1.14985,0.0020965,0.0020965,0.0020965 +481,59429.6,1.09148,1.28633,1.16444,0.65025,0.49998,0.54529,0.39071,1.09614,1.12514,1.14979,0.00208,0.00208,0.00208 +482,59556.5,1.08785,1.27713,1.16485,0.65175,0.49896,0.54586,0.39091,1.09587,1.12494,1.14963,0.0020635,0.0020635,0.0020635 +483,59682.2,1.09754,1.28455,1.16707,0.65344,0.49785,0.54597,0.39133,1.09585,1.12479,1.1496,0.002047,0.002047,0.002047 +484,59808.2,1.08698,1.28122,1.16166,0.65392,0.49777,0.54592,0.39119,1.0959,1.12483,1.14957,0.0020305,0.0020305,0.0020305 +485,59934.5,1.09264,1.28211,1.16518,0.65065,0.50043,0.54579,0.39125,1.09559,1.12473,1.14941,0.002014,0.002014,0.002014 +486,60060.3,1.07736,1.26915,1.16206,0.64813,0.50117,0.54583,0.39124,1.09545,1.12486,1.14937,0.0019975,0.0019975,0.0019975 +487,60185.6,1.08558,1.27282,1.1638,0.64974,0.50048,0.54636,0.3917,1.09534,1.12428,1.14931,0.001981,0.001981,0.001981 +488,60310.6,1.08224,1.27685,1.16285,0.64917,0.50062,0.54695,0.39218,1.09511,1.12394,1.14909,0.0019645,0.0019645,0.0019645 +489,60435.2,1.08698,1.27419,1.1645,0.6526,0.50066,0.54699,0.39243,1.09509,1.12356,1.14902,0.001948,0.001948,0.001948 +490,60560.6,1.08588,1.27896,1.16346,0.65214,0.5015,0.54722,0.39255,1.09499,1.12317,1.14879,0.0019315,0.0019315,0.0019315 +491,60686,1.08686,1.26369,1.16205,0.65197,0.50155,0.54758,0.39284,1.09477,1.12286,1.14861,0.001915,0.001915,0.001915 +492,60811.6,1.08338,1.28028,1.16509,0.65233,0.50177,0.54774,0.39281,1.09469,1.12237,1.14862,0.0018985,0.0018985,0.0018985 +493,60937.9,1.08877,1.26721,1.16473,0.65183,0.50189,0.54782,0.39283,1.09483,1.12212,1.14862,0.001882,0.001882,0.001882 +494,61064,1.08438,1.26826,1.16209,0.65322,0.50139,0.54793,0.39314,1.0947,1.12206,1.14845,0.0018655,0.0018655,0.0018655 +495,61190.4,1.08101,1.26741,1.15913,0.65404,0.50138,0.5482,0.39288,1.09422,1.12192,1.14798,0.001849,0.001849,0.001849 +496,61314.3,1.0803,1.26575,1.16052,0.65137,0.50254,0.54824,0.39291,1.09415,1.1217,1.14785,0.0018325,0.0018325,0.0018325 +497,61439.4,1.08109,1.26323,1.16322,0.65348,0.5021,0.54855,0.39301,1.09399,1.12153,1.14768,0.001816,0.001816,0.001816 +498,61563.9,1.08003,1.26138,1.1594,0.65475,0.50249,0.54829,0.39295,1.09395,1.12123,1.14769,0.0017995,0.0017995,0.0017995 +499,61691,1.08239,1.26506,1.15983,0.65245,0.50385,0.54882,0.39313,1.09385,1.12083,1.14748,0.001783,0.001783,0.001783 +500,61816,1.08866,1.26225,1.16204,0.65317,0.50386,0.54887,0.39318,1.09397,1.12095,1.14748,0.0017665,0.0017665,0.0017665 +501,61941.7,1.08324,1.26557,1.16059,0.65697,0.50159,0.54891,0.39339,1.09381,1.12064,1.14738,0.00175,0.00175,0.00175 +502,62066.8,1.08194,1.25959,1.16076,0.65476,0.50224,0.54902,0.39348,1.09373,1.12032,1.14727,0.0017335,0.0017335,0.0017335 +503,62189.3,1.07787,1.25588,1.15781,0.65479,0.5019,0.54894,0.39365,1.09329,1.12,1.14681,0.001717,0.001717,0.001717 +504,62315.6,1.08015,1.26183,1.15942,0.65789,0.50024,0.54898,0.39388,1.09343,1.11956,1.1469,0.0017005,0.0017005,0.0017005 +505,62441.6,1.07601,1.26641,1.15962,0.65723,0.50001,0.54899,0.39395,1.09372,1.1192,1.14703,0.001684,0.001684,0.001684 +506,62565.8,1.06971,1.2436,1.15638,0.65487,0.5015,0.54921,0.39401,1.09369,1.11906,1.14678,0.0016675,0.0016675,0.0016675 +507,62691.9,1.07695,1.25647,1.15729,0.65812,0.49993,0.54936,0.39433,1.09359,1.11887,1.14665,0.001651,0.001651,0.001651 +508,62816.4,1.07455,1.25454,1.15456,0.65526,0.50126,0.54936,0.39435,1.09335,1.11849,1.1465,0.0016345,0.0016345,0.0016345 +509,62940.3,1.08413,1.26929,1.16098,0.6553,0.5009,0.54924,0.39425,1.0933,1.11853,1.14665,0.001618,0.001618,0.001618 +510,63066.3,1.08144,1.2481,1.15837,0.65512,0.5008,0.54967,0.39458,1.09315,1.11833,1.14658,0.0016015,0.0016015,0.0016015 +511,63192.6,1.07948,1.25525,1.15755,0.6568,0.5003,0.54996,0.39478,1.09289,1.11803,1.1465,0.001585,0.001585,0.001585 +512,63315.4,1.07824,1.24988,1.15667,0.65477,0.50069,0.54991,0.39469,1.09289,1.11777,1.1464,0.0015685,0.0015685,0.0015685 +513,63439.9,1.08178,1.24782,1.15557,0.65828,0.50002,0.55032,0.39505,1.09243,1.11754,1.14603,0.001552,0.001552,0.001552 +514,63565.3,1.06965,1.24735,1.15574,0.66105,0.49924,0.55028,0.39499,1.09227,1.11718,1.14595,0.0015355,0.0015355,0.0015355 +515,63691.3,1.07406,1.24856,1.15202,0.66088,0.49988,0.55036,0.39507,1.09226,1.11674,1.14588,0.001519,0.001519,0.001519 +516,63818.6,1.0689,1.245,1.15273,0.66457,0.49698,0.55043,0.3951,1.09219,1.11638,1.14569,0.0015025,0.0015025,0.0015025 +517,63942.5,1.07194,1.24086,1.15307,0.66085,0.49992,0.55079,0.39545,1.09218,1.11588,1.14565,0.001486,0.001486,0.001486 +518,64065.2,1.08105,1.24432,1.15641,0.66021,0.5009,0.55088,0.39544,1.09199,1.11554,1.14551,0.0014695,0.0014695,0.0014695 +519,64191.2,1.07524,1.24329,1.15468,0.66326,0.49929,0.55083,0.39558,1.09159,1.11553,1.1452,0.001453,0.001453,0.001453 +520,64316.1,1.06785,1.23638,1.1504,0.66257,0.50012,0.55079,0.39558,1.09141,1.11523,1.14501,0.0014365,0.0014365,0.0014365 +521,64441.4,1.06446,1.23103,1.15079,0.66291,0.49996,0.55104,0.39561,1.09141,1.11506,1.14497,0.00142,0.00142,0.00142 +522,64567.5,1.07253,1.23342,1.15154,0.66454,0.49857,0.55116,0.3956,1.09121,1.11498,1.14481,0.0014035,0.0014035,0.0014035 +523,64693.5,1.07398,1.24487,1.15462,0.66874,0.49739,0.55114,0.39564,1.0913,1.11469,1.14486,0.001387,0.001387,0.001387 +524,64816.6,1.06708,1.23073,1.14882,0.66604,0.49818,0.55135,0.39553,1.09124,1.11432,1.14479,0.0013705,0.0013705,0.0013705 +525,64941.9,1.06522,1.23139,1.14903,0.67093,0.49721,0.55158,0.39565,1.09117,1.11402,1.14463,0.001354,0.001354,0.001354 +526,65068.2,1.06845,1.2277,1.15084,0.67018,0.49774,0.55156,0.39577,1.09089,1.1138,1.1445,0.0013375,0.0013375,0.0013375 +527,65193.8,1.06803,1.23291,1.15196,0.66892,0.49812,0.5515,0.39587,1.09085,1.11375,1.14456,0.001321,0.001321,0.001321 +528,65318.2,1.06901,1.22958,1.15132,0.67119,0.49756,0.55154,0.39594,1.09053,1.11348,1.14426,0.0013045,0.0013045,0.0013045 +529,65443,1.06934,1.22042,1.14798,0.67051,0.49819,0.55158,0.39592,1.09039,1.11306,1.14407,0.001288,0.001288,0.001288 +530,65567.6,1.06566,1.22529,1.14875,0.67206,0.49814,0.55185,0.39589,1.09032,1.1128,1.1439,0.0012715,0.0012715,0.0012715 +531,65692.8,1.06298,1.22556,1.15106,0.67324,0.49735,0.55177,0.39595,1.09028,1.11252,1.14383,0.001255,0.001255,0.001255 +532,65817.7,1.06829,1.22887,1.14997,0.67415,0.49736,0.55208,0.39606,1.09036,1.11255,1.14388,0.0012385,0.0012385,0.0012385 +533,65943.2,1.0641,1.22502,1.14637,0.67308,0.49742,0.5519,0.39604,1.09028,1.11229,1.14378,0.001222,0.001222,0.001222 +534,66067.1,1.06199,1.21143,1.14915,0.66733,0.4991,0.55238,0.39624,1.09002,1.11194,1.14357,0.0012055,0.0012055,0.0012055 +535,66193,1.06044,1.21831,1.14588,0.66913,0.49868,0.5524,0.39638,1.09,1.11181,1.1436,0.001189,0.001189,0.001189 +536,66318.4,1.06631,1.22225,1.1507,0.66918,0.49856,0.55253,0.39647,1.08986,1.11134,1.14346,0.0011725,0.0011725,0.0011725 +537,66445.4,1.06767,1.21979,1.1486,0.66722,0.49955,0.55283,0.39654,1.08961,1.11087,1.14338,0.001156,0.001156,0.001156 +538,66570.3,1.06396,1.21607,1.1474,0.66928,0.49925,0.55294,0.39681,1.08944,1.11058,1.14314,0.0011395,0.0011395,0.0011395 +539,66694.8,1.06061,1.21028,1.14546,0.66799,0.49901,0.55307,0.39684,1.08937,1.11033,1.14307,0.001123,0.001123,0.001123 +540,66822,1.06346,1.20631,1.14662,0.66775,0.49986,0.55315,0.39691,1.0893,1.11006,1.14277,0.0011065,0.0011065,0.0011065 +541,66947.4,1.05598,1.204,1.14225,0.66418,0.50168,0.55354,0.3972,1.08923,1.10953,1.14266,0.00109,0.00109,0.00109 +542,67073.5,1.0625,1.21562,1.1463,0.66511,0.50156,0.55359,0.39741,1.08895,1.10928,1.14247,0.0010735,0.0010735,0.0010735 +543,67199.1,1.0576,1.21123,1.14794,0.6658,0.50153,0.55404,0.39753,1.08896,1.10894,1.14245,0.001057,0.001057,0.001057 +544,67323.5,1.06099,1.2053,1.14529,0.66594,0.5013,0.55432,0.39763,1.08879,1.10864,1.14241,0.0010405,0.0010405,0.0010405 +545,67449.6,1.04789,1.19889,1.14195,0.66698,0.50099,0.5545,0.39778,1.08867,1.10869,1.14234,0.001024,0.001024,0.001024 +546,67576.2,1.05139,1.20362,1.14435,0.66721,0.50126,0.5545,0.39773,1.08869,1.1087,1.14237,0.0010075,0.0010075,0.0010075 +547,67700.7,1.05449,1.20099,1.14094,0.66523,0.50147,0.55465,0.39775,1.08853,1.10858,1.14224,0.000991,0.000991,0.000991 +548,67823.6,1.06023,1.19785,1.1418,0.66416,0.5021,0.55456,0.39789,1.08812,1.10858,1.14202,0.0009745,0.0009745,0.0009745 +549,67950.6,1.05065,1.19806,1.14318,0.67135,0.49963,0.55458,0.39787,1.08777,1.10853,1.14171,0.000958,0.000958,0.000958 +550,68076,1.06016,1.19807,1.14332,0.67105,0.4996,0.55484,0.39804,1.08747,1.10819,1.14146,0.0009415,0.0009415,0.0009415 +551,68201.4,1.05317,1.19799,1.14034,0.67208,0.49912,0.55518,0.39831,1.08743,1.10789,1.14141,0.000925,0.000925,0.000925 +552,68327,1.0563,1.19395,1.1411,0.67085,0.50049,0.55533,0.39843,1.08722,1.10747,1.14126,0.0009085,0.0009085,0.0009085 +553,68452.6,1.05206,1.1889,1.14163,0.67389,0.49939,0.55536,0.39816,1.08719,1.10728,1.14121,0.000892,0.000892,0.000892 +554,68579.7,1.04861,1.18607,1.14116,0.67001,0.50106,0.55538,0.39849,1.08694,1.10724,1.14093,0.0008755,0.0008755,0.0008755 +555,68704.3,1.05392,1.19297,1.14074,0.67067,0.50144,0.55581,0.39842,1.08678,1.10696,1.14081,0.000859,0.000859,0.000859 +556,68827,1.05944,1.19607,1.14224,0.67046,0.50085,0.55581,0.39875,1.08693,1.10676,1.14085,0.0008425,0.0008425,0.0008425 +557,68952,1.0459,1.18891,1.1374,0.6726,0.50026,0.55578,0.39865,1.08666,1.10671,1.14067,0.000826,0.000826,0.000826 +558,69077.5,1.0466,1.18861,1.13713,0.67054,0.50157,0.55582,0.39863,1.08629,1.10647,1.1405,0.0008095,0.0008095,0.0008095 +559,69203.1,1.05063,1.17965,1.13718,0.67121,0.50215,0.55596,0.39898,1.08615,1.10611,1.14017,0.000793,0.000793,0.000793 +560,69329.5,1.05047,1.1814,1.13808,0.67189,0.50204,0.55616,0.39894,1.08597,1.1056,1.13997,0.0007765,0.0007765,0.0007765 +561,69455.5,1.04356,1.17658,1.13583,0.67183,0.50188,0.55644,0.39898,1.0858,1.10545,1.13965,0.00076,0.00076,0.00076 +562,69580.9,1.05075,1.18716,1.13892,0.67321,0.50169,0.55643,0.39896,1.08556,1.10498,1.13946,0.0007435,0.0007435,0.0007435 +563,69706.2,1.049,1.17805,1.14006,0.67284,0.50238,0.5568,0.39949,1.08524,1.10465,1.13937,0.000727,0.000727,0.000727 +564,69831.3,1.04095,1.17354,1.13469,0.67402,0.50163,0.55665,0.39943,1.08505,1.10437,1.13914,0.0007105,0.0007105,0.0007105 +565,69957,1.04492,1.17502,1.13531,0.6746,0.50199,0.55654,0.39923,1.08508,1.10431,1.13911,0.000694,0.000694,0.000694 +566,70083.2,1.04569,1.16636,1.13515,0.671,0.50356,0.55691,0.39995,1.08483,1.10403,1.13903,0.0006775,0.0006775,0.0006775 +567,70208.7,1.04265,1.16736,1.13424,0.67021,0.50356,0.55705,0.40022,1.08481,1.10369,1.13898,0.000661,0.000661,0.000661 +568,70334.3,1.04309,1.16817,1.13521,0.67225,0.50233,0.55734,0.40034,1.08458,1.1036,1.13871,0.0006445,0.0006445,0.0006445 +569,70459.8,1.03967,1.16353,1.13281,0.67229,0.50357,0.55752,0.40034,1.08435,1.10344,1.13845,0.000628,0.000628,0.000628 +570,70585.2,1.0406,1.1665,1.13403,0.67364,0.50313,0.55759,0.40038,1.0841,1.10313,1.13823,0.0006115,0.0006115,0.0006115 +571,70711.5,1.04359,1.15938,1.13381,0.67461,0.50177,0.55747,0.4005,1.08396,1.1028,1.1381,0.000595,0.000595,0.000595 +572,70837.6,1.03826,1.15759,1.13364,0.67749,0.50123,0.55759,0.40053,1.08423,1.10245,1.1382,0.0005785,0.0005785,0.0005785 +573,70962.7,1.03329,1.14625,1.12792,0.67667,0.50152,0.55772,0.40076,1.08432,1.1023,1.13809,0.000562,0.000562,0.000562 +574,71086.9,1.04278,1.15469,1.13211,0.6765,0.50148,0.55776,0.40086,1.08444,1.10249,1.13815,0.0005455,0.0005455,0.0005455 +575,71214.2,1.04579,1.16019,1.13233,0.67603,0.50181,0.55786,0.40076,1.08441,1.10244,1.13806,0.000529,0.000529,0.000529 +576,71337.9,1.03917,1.15702,1.13627,0.67463,0.5022,0.558,0.40071,1.08436,1.10263,1.13792,0.0005125,0.0005125,0.0005125 +577,71463.6,1.03657,1.15097,1.12985,0.67455,0.50218,0.55807,0.40068,1.08436,1.10251,1.13789,0.000496,0.000496,0.000496 +578,71588.6,1.0381,1.15386,1.13296,0.67157,0.50351,0.55829,0.40093,1.08452,1.10237,1.13787,0.0004795,0.0004795,0.0004795 +579,71715.1,1.03384,1.14506,1.12921,0.66664,0.50592,0.55824,0.40117,1.08451,1.10202,1.13773,0.000463,0.000463,0.000463 +580,71841.1,1.03118,1.14051,1.13,0.6623,0.50781,0.55852,0.40121,1.08455,1.10187,1.13768,0.0004465,0.0004465,0.0004465 +581,71967.7,1.03173,1.14934,1.12796,0.66591,0.50633,0.55859,0.40126,1.08455,1.10137,1.13758,0.00043,0.00043,0.00043 +582,72092.4,1.0359,1.14792,1.12882,0.67087,0.50445,0.55866,0.40131,1.08469,1.10128,1.13764,0.0004135,0.0004135,0.0004135 +583,72218.8,1.03807,1.14836,1.13049,0.6674,0.50641,0.55883,0.40171,1.08481,1.10125,1.13764,0.000397,0.000397,0.000397 +584,72343.3,1.02687,1.13229,1.12546,0.67627,0.50494,0.55912,0.40192,1.08489,1.1013,1.13755,0.0003805,0.0003805,0.0003805 +585,72469,1.02666,1.12528,1.12325,0.67583,0.5054,0.55909,0.40197,1.08486,1.10114,1.1375,0.000364,0.000364,0.000364 +586,72595.3,1.03546,1.13239,1.12657,0.67212,0.50675,0.559,0.40198,1.08495,1.10099,1.13753,0.0003475,0.0003475,0.0003475 +587,72721.5,1.02616,1.12633,1.12432,0.67156,0.50826,0.55947,0.4022,1.08505,1.10064,1.13743,0.000331,0.000331,0.000331 +588,72847.6,1.03188,1.13387,1.12625,0.67168,0.50863,0.55978,0.40232,1.08516,1.10052,1.13753,0.0003145,0.0003145,0.0003145 +589,72973.9,1.03119,1.12946,1.12504,0.6691,0.50914,0.55974,0.40239,1.08521,1.10026,1.13739,0.000298,0.000298,0.000298 +590,73098.2,1.02542,1.1214,1.12088,0.67026,0.5091,0.55979,0.4022,1.08541,1.10034,1.13751,0.0002815,0.0002815,0.0002815 +591,73218.9,1.06738,1.07422,1.14765,0.66937,0.51037,0.56016,0.40243,1.08534,1.09998,1.13745,0.000265,0.000265,0.000265 +592,73334.6,1.06064,1.05421,1.14115,0.66759,0.51181,0.56057,0.40262,1.08496,1.09987,1.13718,0.0002485,0.0002485,0.0002485 +593,73452.6,1.06054,1.04288,1.1426,0.668,0.51124,0.56074,0.40283,1.08463,1.09955,1.13694,0.000232,0.000232,0.000232 +594,73567.6,1.05331,1.03945,1.13886,0.66459,0.51148,0.56091,0.40325,1.08449,1.09938,1.13676,0.0002155,0.0002155,0.0002155 +595,73686.5,1.05427,1.03467,1.13456,0.66472,0.51107,0.56086,0.40333,1.08417,1.09921,1.13639,0.000199,0.000199,0.000199 +596,73799.9,1.05756,1.03332,1.13949,0.66279,0.51179,0.56103,0.40328,1.0839,1.09899,1.13625,0.0001825,0.0001825,0.0001825 +597,73915.2,1.0475,1.02038,1.13695,0.66361,0.51185,0.56115,0.40348,1.08381,1.09883,1.13619,0.000166,0.000166,0.000166 +598,74033.5,1.05262,1.02638,1.13282,0.66274,0.51234,0.56124,0.40382,1.08354,1.09848,1.13606,0.0001495,0.0001495,0.0001495 +599,74148.2,1.04192,1.01589,1.13243,0.66339,0.51225,0.5616,0.4037,1.08323,1.0982,1.13583,0.000133,0.000133,0.000133 +600,74261.3,1.0428,1.00773,1.13181,0.66324,0.51282,0.5618,0.40385,1.08294,1.09802,1.13559,0.0001165,0.0001165,0.0001165 diff --git a/logs/yolov12s.csv b/logs/yolov12s.csv new file mode 100644 index 0000000000000000000000000000000000000000..54520d4095e19325d21f4f39e4d6fa70717bd690 --- /dev/null +++ b/logs/yolov12s.csv @@ -0,0 +1,601 @@ +epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2 +1,249.887,3.70137,5.81363,4.22953,0.00163,0.00952,0.0009,0.00025,3.53785,5.31773,4.13032,0.00332613,0.00332613,0.00332613 +2,501.11,2.64304,4.41195,2.82565,0.30284,0.03563,0.01618,0.00761,2.10517,3.52868,2.35791,0.00664848,0.00664848,0.00664848 +3,749.063,1.90854,3.3877,1.98903,0.25375,0.10235,0.05763,0.03131,1.78302,2.9205,1.98156,0.00995982,0.00995982,0.00995982 +4,996.51,1.68071,2.86168,1.74047,0.26972,0.17118,0.12427,0.0736,1.60207,2.52902,1.77566,0.0099505,0.0099505,0.0099505 +5,1242.93,1.56951,2.54159,1.62493,0.31269,0.20974,0.17463,0.10777,1.50752,2.23793,1.66167,0.009934,0.009934,0.009934 +6,1489.44,1.50069,2.35649,1.56037,0.37566,0.24536,0.22047,0.13786,1.45538,2.07674,1.59936,0.0099175,0.0099175,0.0099175 +7,1735.89,1.45552,2.22158,1.51502,0.4062,0.27014,0.25713,0.16496,1.40674,1.9515,1.55853,0.009901,0.009901,0.009901 +8,1981.91,1.41964,2.12166,1.47359,0.42024,0.29363,0.28275,0.18261,1.36245,1.84421,1.51109,0.0098845,0.0098845,0.0098845 +9,2227.87,1.39455,2.05542,1.44872,0.45331,0.31614,0.31506,0.20693,1.32734,1.75566,1.46997,0.009868,0.009868,0.009868 +10,2474.46,1.37643,1.99282,1.42542,0.48396,0.32215,0.33207,0.22058,1.29108,1.69556,1.43945,0.0098515,0.0098515,0.0098515 +11,2719.81,1.3539,1.94121,1.40428,0.47537,0.3411,0.34806,0.23411,1.27399,1.64165,1.41143,0.009835,0.009835,0.009835 +12,2965.68,1.34416,1.90253,1.39224,0.5007,0.35463,0.36553,0.24586,1.25277,1.58593,1.39236,0.0098185,0.0098185,0.0098185 +13,3211.88,1.32734,1.87446,1.37705,0.51136,0.36832,0.38059,0.25842,1.23569,1.5345,1.37212,0.009802,0.009802,0.009802 +14,3457.88,1.31472,1.83374,1.36301,0.52078,0.38262,0.39726,0.27065,1.21787,1.49779,1.35451,0.0097855,0.0097855,0.0097855 +15,3704.5,1.30358,1.81092,1.35098,0.51456,0.39141,0.40598,0.2782,1.20211,1.47073,1.33558,0.009769,0.009769,0.009769 +16,3950.28,1.29748,1.78896,1.34317,0.51709,0.40522,0.41896,0.28815,1.19542,1.4337,1.32779,0.0097525,0.0097525,0.0097525 +17,4195.72,1.29389,1.76653,1.33712,0.53748,0.40107,0.42709,0.29553,1.18469,1.40849,1.31373,0.009736,0.009736,0.009736 +18,4441.71,1.27948,1.75191,1.32847,0.54765,0.41264,0.43654,0.30252,1.17086,1.39027,1.30233,0.0097195,0.0097195,0.0097195 +19,4687.29,1.27012,1.71815,1.31962,0.55383,0.42007,0.44361,0.30791,1.15838,1.36878,1.29127,0.009703,0.009703,0.009703 +20,4932.85,1.27232,1.71936,1.31728,0.56057,0.4255,0.4482,0.31261,1.1532,1.34986,1.28092,0.0096865,0.0096865,0.0096865 +21,5178.6,1.26444,1.70154,1.30989,0.58132,0.42135,0.45363,0.31798,1.14492,1.33567,1.27432,0.00967,0.00967,0.00967 +22,5424.28,1.26156,1.6927,1.30429,0.57502,0.43122,0.46174,0.32462,1.1385,1.31787,1.26692,0.0096535,0.0096535,0.0096535 +23,5670.84,1.25558,1.68488,1.30217,0.57289,0.43814,0.46689,0.32824,1.13486,1.30394,1.26269,0.009637,0.009637,0.009637 +24,5917.69,1.24785,1.65836,1.29333,0.55037,0.44884,0.47091,0.33171,1.12665,1.29362,1.25635,0.0096205,0.0096205,0.0096205 +25,6163.9,1.25243,1.6562,1.29246,0.58436,0.4463,0.47726,0.3363,1.1245,1.28359,1.25237,0.009604,0.009604,0.009604 +26,6409.17,1.24462,1.64478,1.28554,0.58216,0.44582,0.47884,0.33854,1.12016,1.2755,1.24755,0.0095875,0.0095875,0.0095875 +27,6655.15,1.24834,1.64286,1.28398,0.57403,0.45474,0.48381,0.34208,1.11556,1.26437,1.24207,0.009571,0.009571,0.009571 +28,6901.64,1.2424,1.6295,1.2819,0.58523,0.45115,0.48662,0.34539,1.11129,1.25765,1.23804,0.0095545,0.0095545,0.0095545 +29,7147.79,1.23585,1.61471,1.27817,0.59016,0.45293,0.48919,0.34768,1.10791,1.25126,1.23474,0.009538,0.009538,0.009538 +30,7393.29,1.23341,1.61371,1.27233,0.6039,0.45533,0.49308,0.35027,1.10563,1.2463,1.23188,0.0095215,0.0095215,0.0095215 +31,7639.01,1.23167,1.61079,1.27372,0.58884,0.46274,0.49473,0.35139,1.10463,1.23976,1.23094,0.009505,0.009505,0.009505 +32,7885.35,1.22641,1.599,1.27156,0.59897,0.46181,0.49713,0.35355,1.10284,1.23455,1.22861,0.0094885,0.0094885,0.0094885 +33,8130.89,1.22944,1.58984,1.2667,0.61292,0.4575,0.49869,0.35492,1.10092,1.23049,1.2267,0.009472,0.009472,0.009472 +34,8376.74,1.22578,1.59046,1.26305,0.61628,0.45949,0.50116,0.35658,1.09805,1.22768,1.22513,0.0094555,0.0094555,0.0094555 +35,8624.29,1.22671,1.58924,1.26512,0.61538,0.46058,0.50256,0.35841,1.09705,1.22353,1.22342,0.009439,0.009439,0.009439 +36,8872.6,1.21457,1.57134,1.25976,0.62002,0.46007,0.50378,0.35916,1.09553,1.22012,1.2215,0.0094225,0.0094225,0.0094225 +37,9118.31,1.22155,1.57175,1.26156,0.62163,0.46024,0.50496,0.36029,1.09423,1.2169,1.21999,0.009406,0.009406,0.009406 +38,9363.82,1.21726,1.56184,1.25618,0.62148,0.4617,0.506,0.36155,1.0927,1.21406,1.2181,0.0093895,0.0093895,0.0093895 +39,9609.71,1.21609,1.56347,1.2545,0.62325,0.46078,0.50671,0.36204,1.09218,1.21184,1.21747,0.009373,0.009373,0.009373 +40,9855.49,1.21357,1.54706,1.25165,0.62226,0.46207,0.50708,0.36258,1.0917,1.21083,1.21674,0.0093565,0.0093565,0.0093565 +41,10100.4,1.21218,1.55484,1.25123,0.62965,0.46033,0.50773,0.36328,1.09104,1.2095,1.216,0.00934,0.00934,0.00934 +42,10345.4,1.21463,1.55137,1.25138,0.63031,0.46108,0.50851,0.36401,1.09011,1.20825,1.21494,0.0093235,0.0093235,0.0093235 +43,10590.9,1.20861,1.53976,1.24812,0.62838,0.4615,0.50896,0.36439,1.08954,1.20754,1.21424,0.009307,0.009307,0.009307 +44,10836.2,1.19794,1.53415,1.24441,0.62687,0.46174,0.50925,0.3649,1.08884,1.20702,1.2136,0.0092905,0.0092905,0.0092905 +45,11081.3,1.2049,1.5364,1.24358,0.62844,0.46089,0.5096,0.36519,1.08855,1.20623,1.21331,0.009274,0.009274,0.009274 +46,11326.4,1.20565,1.52632,1.24328,0.6277,0.46174,0.50993,0.36546,1.08816,1.20599,1.21266,0.0092575,0.0092575,0.0092575 +47,11571.7,1.20609,1.5288,1.24468,0.62683,0.46142,0.51018,0.36574,1.08768,1.20531,1.21213,0.009241,0.009241,0.009241 +48,11817.4,1.19732,1.51277,1.2384,0.62823,0.46148,0.51051,0.36636,1.08709,1.20532,1.21151,0.0092245,0.0092245,0.0092245 +49,12062.7,1.20148,1.52051,1.24131,0.62572,0.46224,0.51082,0.36647,1.08646,1.20569,1.21085,0.009208,0.009208,0.009208 +50,12307.7,1.20269,1.52531,1.2434,0.62765,0.462,0.51097,0.36686,1.08581,1.20643,1.21022,0.0091915,0.0091915,0.0091915 +51,12552.7,1.20394,1.52772,1.24542,0.62492,0.46231,0.51093,0.36677,1.08513,1.20713,1.2097,0.009175,0.009175,0.009175 +52,12797.8,1.20099,1.51468,1.24015,0.62751,0.46144,0.51179,0.36728,1.08472,1.2079,1.20928,0.0091585,0.0091585,0.0091585 +53,13043.2,1.19383,1.50838,1.23655,0.62726,0.46214,0.51228,0.36761,1.08416,1.20848,1.20885,0.009142,0.009142,0.009142 +54,247.508,1.19281,1.5119,1.24063,0.62622,0.46352,0.51307,0.36839,1.08321,1.20579,1.20808,0.0091255,0.0091255,0.0091255 +55,490.117,1.20581,1.53381,1.24572,0.62771,0.46379,0.5135,0.3687,1.08278,1.20369,1.20742,0.009109,0.009109,0.009109 +56,733.803,1.20874,1.53755,1.24984,0.62737,0.464,0.51416,0.36911,1.08234,1.20173,1.20688,0.0090925,0.0090925,0.0090925 +57,977.318,1.20008,1.52932,1.24274,0.63121,0.46394,0.51451,0.3694,1.08182,1.2,1.2063,0.009076,0.009076,0.009076 +58,1226.6,1.2074,1.52766,1.24403,0.63354,0.46286,0.51468,0.36969,1.08122,1.1986,1.20594,0.0090595,0.0090595,0.0090595 +59,1473.65,1.19963,1.5302,1.24289,0.63615,0.46266,0.51529,0.37015,1.08066,1.1974,1.20516,0.009043,0.009043,0.009043 +60,1722.25,1.19846,1.50805,1.23892,0.6353,0.46288,0.51578,0.37043,1.08008,1.19622,1.20439,0.0090265,0.0090265,0.0090265 +61,1969.13,1.19867,1.51009,1.23822,0.6328,0.46423,0.51641,0.37097,1.07955,1.19524,1.20389,0.00901,0.00901,0.00901 +62,2216.85,1.19769,1.51055,1.23773,0.6308,0.46465,0.51684,0.37131,1.07911,1.19426,1.20336,0.0089935,0.0089935,0.0089935 +63,2465.33,1.19321,1.50445,1.23528,0.62155,0.4694,0.51734,0.37151,1.07836,1.19354,1.20267,0.008977,0.008977,0.008977 +64,2712.73,1.19643,1.50797,1.23542,0.62271,0.46948,0.5176,0.37172,1.07843,1.19326,1.20252,0.0089605,0.0089605,0.0089605 +65,2962.58,1.18888,1.49372,1.23173,0.62032,0.47111,0.51768,0.37199,1.0781,1.19288,1.20221,0.008944,0.008944,0.008944 +66,3208.5,1.19265,1.48966,1.23402,0.62284,0.47074,0.51791,0.37233,1.07804,1.19289,1.202,0.0089275,0.0089275,0.0089275 +67,3456.83,1.1899,1.49742,1.23389,0.62523,0.4707,0.51811,0.37249,1.07794,1.1927,1.20194,0.008911,0.008911,0.008911 +68,3702.25,1.19233,1.49175,1.23566,0.62672,0.47071,0.51841,0.37273,1.07749,1.19286,1.20154,0.0088945,0.0088945,0.0088945 +69,3946.25,1.18643,1.48737,1.22947,0.6241,0.47227,0.51884,0.37313,1.07697,1.19341,1.20098,0.008878,0.008878,0.008878 +70,4191.01,1.18982,1.49241,1.23152,0.6233,0.47231,0.51858,0.37292,1.07649,1.19352,1.20059,0.0088615,0.0088615,0.0088615 +71,4435.91,1.18306,1.48213,1.22632,0.6207,0.47292,0.51853,0.37346,1.07632,1.19376,1.20006,0.008845,0.008845,0.008845 +72,4680.23,1.18346,1.48386,1.22891,0.62016,0.47376,0.51882,0.37354,1.07588,1.19436,1.19968,0.0088285,0.0088285,0.0088285 +73,4924.47,1.18668,1.47794,1.22877,0.62176,0.47385,0.51858,0.37339,1.07575,1.19539,1.19932,0.008812,0.008812,0.008812 +74,5168.82,1.18131,1.4717,1.22682,0.62245,0.47335,0.51877,0.37368,1.07555,1.19638,1.19914,0.0087955,0.0087955,0.0087955 +75,5413.49,1.18667,1.47278,1.22773,0.62107,0.47312,0.51869,0.37404,1.07537,1.19719,1.19887,0.008779,0.008779,0.008779 +76,5658.4,1.18272,1.4752,1.22717,0.62097,0.47326,0.51882,0.37429,1.07517,1.19808,1.19863,0.0087625,0.0087625,0.0087625 +77,5903.59,1.1833,1.46629,1.22278,0.6228,0.47287,0.51898,0.37434,1.07492,1.19886,1.19842,0.008746,0.008746,0.008746 +78,6148.43,1.18277,1.46914,1.22499,0.62201,0.4724,0.5189,0.37456,1.0749,1.19968,1.19815,0.0087295,0.0087295,0.0087295 +79,6393.15,1.17944,1.46891,1.22397,0.62243,0.47285,0.51872,0.37481,1.07464,1.20088,1.19785,0.008713,0.008713,0.008713 +80,6638.63,1.17667,1.46113,1.22375,0.62439,0.47283,0.51875,0.37468,1.07431,1.20184,1.19775,0.0086965,0.0086965,0.0086965 +81,6883.62,1.17819,1.4624,1.22486,0.62297,0.47426,0.51881,0.37471,1.07392,1.20282,1.19756,0.00868,0.00868,0.00868 +82,7129.52,1.17589,1.45764,1.22117,0.62207,0.47494,0.51901,0.37508,1.07413,1.20374,1.19748,0.0086635,0.0086635,0.0086635 +83,7375.63,1.17756,1.45886,1.22196,0.62233,0.47493,0.51912,0.37519,1.07402,1.20479,1.19731,0.008647,0.008647,0.008647 +84,7620.66,1.17167,1.45795,1.22037,0.62434,0.47533,0.51933,0.37521,1.07385,1.20599,1.19706,0.0086305,0.0086305,0.0086305 +85,7865.71,1.17576,1.45486,1.22176,0.62074,0.47646,0.51941,0.37528,1.07378,1.2067,1.19709,0.008614,0.008614,0.008614 +86,8110.61,1.17369,1.44837,1.21772,0.61814,0.47912,0.5201,0.37572,1.07342,1.20748,1.19685,0.0085975,0.0085975,0.0085975 +87,8355.76,1.17872,1.4461,1.21766,0.62113,0.477,0.51981,0.3757,1.07317,1.20793,1.19659,0.008581,0.008581,0.008581 +88,8600.32,1.17536,1.45603,1.21732,0.62097,0.47844,0.52024,0.37601,1.07294,1.20876,1.19627,0.0085645,0.0085645,0.0085645 +89,8845.44,1.17378,1.44397,1.21757,0.62123,0.47829,0.52025,0.37605,1.0727,1.20968,1.19606,0.008548,0.008548,0.008548 +90,9091.68,1.16973,1.44679,1.21777,0.62534,0.47715,0.52039,0.37623,1.07265,1.21033,1.19591,0.0085315,0.0085315,0.0085315 +91,9336.94,1.17136,1.44822,1.21663,0.62482,0.47773,0.52088,0.37639,1.07261,1.21069,1.19572,0.008515,0.008515,0.008515 +92,9582.42,1.1753,1.44267,1.21731,0.62326,0.47843,0.52136,0.37705,1.0721,1.21087,1.19543,0.0084985,0.0084985,0.0084985 +93,9826.24,1.17297,1.43479,1.21448,0.62696,0.47833,0.52152,0.37698,1.07197,1.21133,1.1951,0.008482,0.008482,0.008482 +94,10067.8,1.17005,1.44006,1.21458,0.62641,0.47923,0.52188,0.3773,1.07155,1.21099,1.19458,0.0084655,0.0084655,0.0084655 +95,10310.8,1.16686,1.44151,1.21447,0.62796,0.47849,0.52222,0.37773,1.07135,1.21093,1.19429,0.008449,0.008449,0.008449 +96,10556.1,1.17282,1.44024,1.21638,0.62968,0.47853,0.52249,0.3781,1.07104,1.21075,1.19405,0.0084325,0.0084325,0.0084325 +97,10801.9,1.17668,1.43986,1.2141,0.63216,0.47793,0.52285,0.37855,1.07089,1.21052,1.19362,0.008416,0.008416,0.008416 +98,11045.8,1.16964,1.43934,1.21384,0.63348,0.47824,0.52306,0.37885,1.07058,1.21014,1.19322,0.0083995,0.0083995,0.0083995 +99,11291,1.16881,1.42854,1.21137,0.63359,0.4783,0.52363,0.37937,1.07019,1.2089,1.19265,0.008383,0.008383,0.008383 +100,11537.6,1.17086,1.43053,1.20997,0.63704,0.47552,0.52434,0.37979,1.06953,1.20784,1.19213,0.0083665,0.0083665,0.0083665 +101,11783.2,1.16846,1.42479,1.21006,0.63296,0.47909,0.52485,0.38032,1.06895,1.20705,1.1915,0.00835,0.00835,0.00835 +102,12028.3,1.16382,1.43452,1.21176,0.63842,0.47578,0.5253,0.38048,1.06866,1.20609,1.19099,0.0083335,0.0083335,0.0083335 +103,12273.8,1.16727,1.43295,1.21115,0.63229,0.48018,0.52576,0.38102,1.06839,1.20474,1.19065,0.008317,0.008317,0.008317 +104,12517.6,1.16657,1.43044,1.21041,0.63481,0.47951,0.52627,0.38142,1.06783,1.2033,1.19004,0.0083005,0.0083005,0.0083005 +105,12760.4,1.1685,1.42669,1.21222,0.63641,0.48097,0.52693,0.3821,1.06758,1.20207,1.18955,0.008284,0.008284,0.008284 +106,13004.3,1.16255,1.42004,1.20737,0.63833,0.48011,0.52728,0.38252,1.06696,1.20116,1.18898,0.0082675,0.0082675,0.0082675 +107,13248.1,1.16756,1.42225,1.20895,0.63859,0.48051,0.52772,0.38312,1.06644,1.19986,1.1884,0.008251,0.008251,0.008251 +108,13491.8,1.16766,1.42532,1.21118,0.63993,0.48144,0.52792,0.38331,1.06594,1.19831,1.18791,0.0082345,0.0082345,0.0082345 +109,13735.5,1.16509,1.42169,1.21002,0.64168,0.48087,0.52864,0.38399,1.06567,1.197,1.18742,0.008218,0.008218,0.008218 +110,13979.3,1.1638,1.42027,1.20894,0.64077,0.48273,0.52918,0.38441,1.06506,1.19542,1.18672,0.0082015,0.0082015,0.0082015 +111,14222.6,1.16356,1.41684,1.20823,0.64171,0.48236,0.52976,0.38481,1.0645,1.19382,1.18608,0.008185,0.008185,0.008185 +112,14466.2,1.16076,1.41795,1.2077,0.6405,0.48334,0.53052,0.38529,1.06412,1.19206,1.18549,0.0081685,0.0081685,0.0081685 +113,14710.4,1.1604,1.41572,1.2078,0.63854,0.48452,0.53114,0.38584,1.06349,1.1902,1.18482,0.008152,0.008152,0.008152 +114,14955.2,1.16017,1.41946,1.20882,0.64088,0.48419,0.53195,0.38627,1.06305,1.1888,1.18428,0.0081355,0.0081355,0.0081355 +115,15199.6,1.16676,1.42128,1.21047,0.64237,0.48285,0.53251,0.38692,1.06251,1.1869,1.18372,0.008119,0.008119,0.008119 +116,15443.8,1.16101,1.41495,1.20657,0.64268,0.48362,0.53313,0.38733,1.06169,1.18513,1.183,0.0081025,0.0081025,0.0081025 +117,15689.4,1.15878,1.40828,1.20511,0.64389,0.48482,0.53383,0.388,1.06113,1.18338,1.18234,0.008086,0.008086,0.008086 +118,15933.3,1.15975,1.40304,1.20403,0.64462,0.48529,0.53483,0.38885,1.0606,1.18111,1.18173,0.0080695,0.0080695,0.0080695 +119,16177.1,1.16201,1.41423,1.20621,0.64285,0.48702,0.53549,0.38923,1.05977,1.17942,1.18099,0.008053,0.008053,0.008053 +120,16420.9,1.15499,1.40131,1.20207,0.64458,0.48773,0.53609,0.38968,1.0591,1.17752,1.18042,0.0080365,0.0080365,0.0080365 +121,16664.4,1.1582,1.40774,1.2036,0.64376,0.48975,0.5365,0.39004,1.05865,1.17552,1.17985,0.00802,0.00802,0.00802 +122,16907.8,1.15633,1.40659,1.20413,0.64679,0.48758,0.53699,0.39068,1.05759,1.1739,1.17907,0.0080035,0.0080035,0.0080035 +123,17151,1.16008,1.40355,1.20466,0.64737,0.48734,0.53755,0.39119,1.05714,1.17218,1.17857,0.007987,0.007987,0.007987 +124,17394.2,1.1605,1.40758,1.20375,0.65094,0.48752,0.53873,0.39192,1.05666,1.16999,1.17816,0.0079705,0.0079705,0.0079705 +125,17638.4,1.15966,1.40415,1.20528,0.65272,0.48852,0.5395,0.39254,1.05619,1.16816,1.17764,0.007954,0.007954,0.007954 +126,17881.8,1.15742,1.40197,1.20225,0.64992,0.48928,0.54,0.39328,1.05545,1.16616,1.17693,0.0079375,0.0079375,0.0079375 +127,18126.4,1.15253,1.39853,1.19848,0.65078,0.48918,0.54058,0.39355,1.05476,1.16472,1.17603,0.007921,0.007921,0.007921 +128,18370,1.15921,1.4032,1.2024,0.65107,0.49,0.54095,0.39394,1.05413,1.16269,1.17529,0.0079045,0.0079045,0.0079045 +129,18614.7,1.15891,1.40203,1.20174,0.65422,0.49035,0.54163,0.3946,1.05341,1.16079,1.17459,0.007888,0.007888,0.007888 +130,18859.3,1.15913,1.40143,1.20314,0.65378,0.49168,0.54178,0.39511,1.05284,1.1591,1.17406,0.0078715,0.0078715,0.0078715 +131,19104.4,1.15072,1.39499,1.19994,0.65015,0.49447,0.54289,0.39593,1.05205,1.15692,1.17338,0.007855,0.007855,0.007855 +132,19348,1.15051,1.39069,1.19972,0.65039,0.49571,0.54379,0.39647,1.05147,1.15454,1.17288,0.0078385,0.0078385,0.0078385 +133,19592.2,1.14997,1.39202,1.20077,0.65239,0.49549,0.54434,0.3969,1.05122,1.1526,1.17248,0.007822,0.007822,0.007822 +134,19836.4,1.15746,1.40101,1.20294,0.6525,0.49661,0.54505,0.39726,1.05097,1.15093,1.17196,0.0078055,0.0078055,0.0078055 +135,20080.4,1.15959,1.40287,1.20303,0.65468,0.49683,0.5456,0.39783,1.05025,1.14904,1.17136,0.007789,0.007789,0.007789 +136,20324.3,1.15435,1.4025,1.20157,0.6544,0.49822,0.54611,0.3983,1.04962,1.14699,1.17067,0.0077725,0.0077725,0.0077725 +137,20568,1.16072,1.39897,1.20239,0.65412,0.49803,0.54648,0.39874,1.04926,1.14542,1.17005,0.007756,0.007756,0.007756 +138,20811.6,1.1515,1.39099,1.19968,0.6576,0.49816,0.54726,0.39933,1.04886,1.14367,1.16963,0.0077395,0.0077395,0.0077395 +139,21054.9,1.15296,1.39291,1.20025,0.65777,0.49953,0.54795,0.39981,1.04829,1.14201,1.16893,0.007723,0.007723,0.007723 +140,21297.6,1.15097,1.38978,1.20018,0.65661,0.50032,0.54861,0.40049,1.0478,1.1398,1.16821,0.0077065,0.0077065,0.0077065 +141,21541.4,1.1501,1.39093,1.20067,0.65972,0.50007,0.54942,0.40106,1.04717,1.13736,1.16755,0.00769,0.00769,0.00769 +142,21784.8,1.15752,1.38305,1.20171,0.66096,0.49978,0.54998,0.40142,1.04638,1.13532,1.16683,0.0076735,0.0076735,0.0076735 +143,22028.5,1.15639,1.38675,1.2011,0.66086,0.49983,0.55049,0.40197,1.04563,1.13311,1.16605,0.007657,0.007657,0.007657 +144,22272.6,1.15288,1.38043,1.19748,0.66128,0.50029,0.55121,0.40251,1.045,1.13099,1.16549,0.0076405,0.0076405,0.0076405 +145,22517,1.15168,1.39227,1.1998,0.6572,0.50212,0.55175,0.403,1.04462,1.12904,1.1649,0.007624,0.007624,0.007624 +146,22760.6,1.14874,1.38314,1.20097,0.65792,0.50161,0.55242,0.40354,1.0441,1.12718,1.16454,0.0076075,0.0076075,0.0076075 +147,23006.5,1.14933,1.38779,1.19973,0.65902,0.50255,0.55303,0.40402,1.04379,1.1255,1.16417,0.007591,0.007591,0.007591 +148,23250.8,1.1486,1.38281,1.19687,0.65818,0.50299,0.55388,0.40458,1.04355,1.12346,1.16386,0.0075745,0.0075745,0.0075745 +149,23494.5,1.14743,1.38472,1.19775,0.65888,0.50378,0.55422,0.40492,1.0433,1.12218,1.16354,0.007558,0.007558,0.007558 +150,23738.8,1.14793,1.38057,1.19738,0.66034,0.50361,0.55476,0.40516,1.04278,1.12047,1.16302,0.0075415,0.0075415,0.0075415 +151,23982.7,1.14911,1.37724,1.19482,0.65787,0.50533,0.55563,0.40583,1.04238,1.11888,1.16254,0.007525,0.007525,0.007525 +152,24226.2,1.15018,1.38172,1.1982,0.65601,0.5069,0.55628,0.40618,1.04202,1.11743,1.16219,0.0075085,0.0075085,0.0075085 +153,24470.6,1.15006,1.38028,1.19829,0.65792,0.50644,0.55691,0.4066,1.04177,1.11575,1.16183,0.007492,0.007492,0.007492 +154,24715,1.1521,1.38208,1.19945,0.65704,0.50894,0.55749,0.40705,1.04148,1.114,1.1615,0.0074755,0.0074755,0.0074755 +155,24958.6,1.14906,1.37955,1.19543,0.65735,0.50897,0.55795,0.40751,1.04124,1.11241,1.16113,0.007459,0.007459,0.007459 +156,25202.5,1.14707,1.3776,1.19589,0.65622,0.51014,0.55854,0.40808,1.04143,1.11051,1.16126,0.0074425,0.0074425,0.0074425 +157,25446.4,1.14858,1.37158,1.19533,0.65471,0.51095,0.55896,0.40842,1.04093,1.10915,1.16081,0.007426,0.007426,0.007426 +158,25690.2,1.1456,1.37241,1.19265,0.65186,0.51213,0.55934,0.40854,1.04025,1.10748,1.16016,0.0074095,0.0074095,0.0074095 +159,25935.9,1.14938,1.3777,1.19565,0.65266,0.5122,0.56002,0.40907,1.03963,1.1059,1.15953,0.007393,0.007393,0.007393 +160,26179.4,1.14248,1.36921,1.19251,0.648,0.5156,0.56054,0.40937,1.03931,1.10447,1.15896,0.0073765,0.0073765,0.0073765 +161,26423.4,1.14653,1.37822,1.19563,0.64548,0.51727,0.56096,0.40991,1.03913,1.10313,1.1587,0.00736,0.00736,0.00736 +162,26667.5,1.14345,1.37596,1.19405,0.64798,0.51626,0.56165,0.41054,1.03861,1.10157,1.15837,0.0073435,0.0073435,0.0073435 +163,26911.2,1.1446,1.38447,1.19687,0.64609,0.51706,0.56233,0.41089,1.0383,1.1002,1.15794,0.007327,0.007327,0.007327 +164,27155.1,1.14231,1.36255,1.19142,0.64793,0.51605,0.56255,0.41116,1.03802,1.09849,1.15768,0.0073105,0.0073105,0.0073105 +165,27400,1.14188,1.36527,1.19252,0.64756,0.51716,0.563,0.41167,1.0377,1.09675,1.15741,0.007294,0.007294,0.007294 +166,27645.1,1.14548,1.36143,1.19152,0.64761,0.51728,0.56368,0.41202,1.03746,1.09519,1.15707,0.0072775,0.0072775,0.0072775 +167,27890,1.14489,1.36552,1.19413,0.64655,0.51822,0.56392,0.4124,1.03707,1.09338,1.15688,0.007261,0.007261,0.007261 +168,28134.8,1.13999,1.36757,1.19135,0.64513,0.52059,0.56443,0.41287,1.03679,1.09171,1.1565,0.0072445,0.0072445,0.0072445 +169,28379.4,1.14232,1.37151,1.19428,0.64723,0.52012,0.56499,0.41304,1.03671,1.09031,1.15626,0.007228,0.007228,0.007228 +170,28624,1.1397,1.36357,1.19154,0.64643,0.52084,0.56542,0.41321,1.03644,1.08923,1.15595,0.0072115,0.0072115,0.0072115 +171,28868.3,1.14294,1.36325,1.19264,0.64296,0.52395,0.5664,0.4139,1.0361,1.08809,1.15574,0.007195,0.007195,0.007195 +172,29113.5,1.14344,1.36377,1.19142,0.64543,0.52255,0.56686,0.41425,1.03587,1.08645,1.15551,0.0071785,0.0071785,0.0071785 +173,29356.9,1.14197,1.36239,1.19089,0.64923,0.52078,0.56702,0.41457,1.0355,1.08515,1.15518,0.007162,0.007162,0.007162 +174,29603.3,1.13809,1.35566,1.19157,0.65126,0.52052,0.56756,0.41506,1.03531,1.08375,1.15483,0.0071455,0.0071455,0.0071455 +175,29846.9,1.14232,1.36127,1.19225,0.64659,0.52207,0.56801,0.41529,1.035,1.08249,1.15433,0.007129,0.007129,0.007129 +176,30090.8,1.14158,1.36024,1.18962,0.64268,0.52453,0.56849,0.41571,1.03465,1.08107,1.15392,0.0071125,0.0071125,0.0071125 +177,30334,1.1387,1.36568,1.19238,0.64429,0.52415,0.56862,0.41593,1.03423,1.08027,1.15341,0.007096,0.007096,0.007096 +178,30578,1.14358,1.35873,1.19037,0.64475,0.52562,0.56904,0.41611,1.0342,1.07948,1.15319,0.0070795,0.0070795,0.0070795 +179,30821.8,1.14945,1.35935,1.19269,0.646,0.52642,0.56956,0.41663,1.03382,1.07854,1.15273,0.007063,0.007063,0.007063 +180,31065.1,1.13811,1.35648,1.19134,0.64314,0.52816,0.56985,0.41691,1.0335,1.07732,1.15242,0.0070465,0.0070465,0.0070465 +181,31308.4,1.14099,1.35915,1.18809,0.64341,0.52784,0.57003,0.41731,1.03342,1.07665,1.15224,0.00703,0.00703,0.00703 +182,31551.8,1.1396,1.35049,1.18936,0.64339,0.5303,0.5702,0.4173,1.03346,1.07563,1.15214,0.0070135,0.0070135,0.0070135 +183,31796.1,1.13817,1.35073,1.18794,0.64317,0.53046,0.57048,0.41749,1.03308,1.07476,1.15168,0.006997,0.006997,0.006997 +184,32040.7,1.13968,1.35377,1.18947,0.6424,0.53206,0.57104,0.41765,1.0331,1.0736,1.1517,0.0069805,0.0069805,0.0069805 +185,32285,1.14277,1.35189,1.1874,0.64152,0.53251,0.57143,0.41805,1.03281,1.07274,1.15121,0.006964,0.006964,0.006964 +186,32528.6,1.13391,1.35234,1.18812,0.64508,0.5312,0.57171,0.41815,1.03257,1.07162,1.15098,0.0069475,0.0069475,0.0069475 +187,32773.6,1.13684,1.35315,1.18755,0.64308,0.53352,0.57213,0.4185,1.03217,1.07092,1.15065,0.006931,0.006931,0.006931 +188,33017,1.13897,1.35331,1.18816,0.64497,0.53277,0.57255,0.4187,1.03187,1.07015,1.15026,0.0069145,0.0069145,0.0069145 +189,33260.7,1.13409,1.35193,1.18627,0.64603,0.53271,0.57279,0.41906,1.03171,1.06942,1.14989,0.006898,0.006898,0.006898 +190,33504.9,1.12926,1.34417,1.18369,0.6504,0.53009,0.57287,0.4191,1.03138,1.06858,1.14953,0.0068815,0.0068815,0.0068815 +191,33749.8,1.1428,1.35352,1.19381,0.65284,0.52893,0.57316,0.41962,1.03121,1.06787,1.14934,0.006865,0.006865,0.006865 +192,33993.3,1.13609,1.34535,1.18687,0.65112,0.53,0.57375,0.41981,1.03119,1.06721,1.14929,0.0068485,0.0068485,0.0068485 +193,34237.3,1.13406,1.34228,1.18636,0.65276,0.52921,0.57418,0.42021,1.03083,1.06651,1.14891,0.006832,0.006832,0.006832 +194,34481.7,1.13395,1.34588,1.18501,0.65164,0.5298,0.57443,0.42039,1.03041,1.06589,1.14851,0.0068155,0.0068155,0.0068155 +195,34725.8,1.13247,1.33483,1.18345,0.64888,0.53022,0.57446,0.42034,1.03017,1.06509,1.14827,0.006799,0.006799,0.006799 +196,34970.1,1.13236,1.34398,1.18439,0.64684,0.53191,0.57475,0.42052,1.03005,1.06469,1.14804,0.0067825,0.0067825,0.0067825 +197,35213.5,1.13723,1.33863,1.18389,0.64971,0.53073,0.57488,0.42085,1.0297,1.06391,1.14756,0.006766,0.006766,0.006766 +198,35457.4,1.12949,1.34602,1.18508,0.64759,0.53181,0.57506,0.42086,1.02936,1.06292,1.14727,0.0067495,0.0067495,0.0067495 +199,35701.7,1.13204,1.33777,1.18507,0.64849,0.53207,0.57535,0.42124,1.02922,1.0623,1.14718,0.006733,0.006733,0.006733 +200,35946.5,1.13391,1.34234,1.18482,0.65735,0.52827,0.57594,0.42154,1.029,1.06173,1.14689,0.0067165,0.0067165,0.0067165 +201,36191.7,1.13588,1.33939,1.18533,0.65599,0.52941,0.57612,0.42172,1.02876,1.06115,1.14666,0.0067,0.0067,0.0067 +202,36435.6,1.12942,1.33743,1.18462,0.6547,0.53033,0.57634,0.42208,1.02853,1.0608,1.14639,0.0066835,0.0066835,0.0066835 +203,36680.1,1.13305,1.33622,1.18461,0.65783,0.52971,0.57692,0.42219,1.02826,1.06002,1.14608,0.006667,0.006667,0.006667 +204,36923.9,1.13588,1.34143,1.18593,0.65805,0.5299,0.57715,0.42238,1.02807,1.05888,1.14591,0.0066505,0.0066505,0.0066505 +205,37168.5,1.1304,1.34333,1.18491,0.65776,0.5301,0.57731,0.42271,1.02781,1.05788,1.14551,0.006634,0.006634,0.006634 +206,37413.2,1.13588,1.33995,1.18775,0.65618,0.53056,0.57769,0.42272,1.02751,1.0573,1.14526,0.0066175,0.0066175,0.0066175 +207,37657.2,1.12612,1.33041,1.18321,0.65734,0.53013,0.57772,0.42284,1.02747,1.05666,1.14523,0.006601,0.006601,0.006601 +208,37900.6,1.13225,1.33297,1.18313,0.65503,0.5311,0.57807,0.42304,1.02742,1.05617,1.14512,0.0065845,0.0065845,0.0065845 +209,38144,1.13059,1.34079,1.18495,0.65684,0.53141,0.5782,0.42319,1.02729,1.05552,1.14498,0.006568,0.006568,0.006568 +210,38387.7,1.13094,1.34006,1.18632,0.65438,0.53335,0.57858,0.42358,1.02712,1.05511,1.14471,0.0065515,0.0065515,0.0065515 +211,38631.7,1.13071,1.33355,1.18333,0.66145,0.53355,0.57912,0.42399,1.0268,1.05435,1.14454,0.006535,0.006535,0.006535 +212,38877.2,1.13356,1.32926,1.18272,0.66386,0.53234,0.57969,0.42422,1.02649,1.05361,1.14411,0.0065185,0.0065185,0.0065185 +213,39122.7,1.13423,1.33734,1.1835,0.66476,0.53143,0.57992,0.42444,1.02632,1.05289,1.14375,0.006502,0.006502,0.006502 +214,39367.5,1.13045,1.33665,1.1816,0.66445,0.53228,0.58016,0.42471,1.02589,1.05243,1.14345,0.0064855,0.0064855,0.0064855 +215,39611.2,1.12665,1.33107,1.18072,0.66576,0.53191,0.58011,0.42467,1.0258,1.05156,1.14324,0.006469,0.006469,0.006469 +216,39854.5,1.13546,1.32921,1.1819,0.66633,0.53285,0.58067,0.42513,1.02549,1.05094,1.14291,0.0064525,0.0064525,0.0064525 +217,40098.4,1.13,1.33488,1.18096,0.66826,0.53225,0.58086,0.42551,1.0251,1.05039,1.14254,0.006436,0.006436,0.006436 +218,40342.2,1.12631,1.32367,1.18061,0.6628,0.53292,0.58073,0.42548,1.02479,1.04972,1.14217,0.0064195,0.0064195,0.0064195 +219,40586.5,1.12526,1.32546,1.17857,0.66954,0.53184,0.58133,0.42596,1.02435,1.04909,1.14177,0.006403,0.006403,0.006403 +220,40831.2,1.12729,1.32761,1.18017,0.67123,0.53225,0.58136,0.42614,1.02411,1.04853,1.14135,0.0063865,0.0063865,0.0063865 +221,41075.2,1.12778,1.32871,1.18087,0.67333,0.53117,0.58158,0.42621,1.02394,1.04795,1.14108,0.00637,0.00637,0.00637 +222,41320.5,1.12919,1.32892,1.18184,0.67224,0.53241,0.58205,0.42662,1.02364,1.04741,1.1406,0.0063535,0.0063535,0.0063535 +223,41565.3,1.12779,1.32657,1.1819,0.67173,0.53259,0.58232,0.42665,1.02328,1.04685,1.14048,0.006337,0.006337,0.006337 +224,41808.4,1.12785,1.32722,1.18007,0.67111,0.53251,0.58241,0.42668,1.02297,1.04623,1.14009,0.0063205,0.0063205,0.0063205 +225,42052.8,1.12574,1.32023,1.17979,0.6715,0.53336,0.58277,0.42702,1.02288,1.04547,1.13994,0.006304,0.006304,0.006304 +226,42297.1,1.13103,1.328,1.18101,0.67127,0.53385,0.5831,0.42729,1.02258,1.04486,1.13959,0.0062875,0.0062875,0.0062875 +227,42540.4,1.12675,1.32474,1.18171,0.67045,0.53417,0.58355,0.4275,1.022,1.04445,1.13913,0.006271,0.006271,0.006271 +228,42784.1,1.12662,1.31871,1.17891,0.66928,0.53464,0.58322,0.42777,1.02197,1.0435,1.13906,0.0062545,0.0062545,0.0062545 +229,43027.2,1.12308,1.31444,1.17485,0.66051,0.53688,0.58329,0.42775,1.02189,1.04261,1.13894,0.006238,0.006238,0.006238 +230,43271,1.12433,1.31542,1.17481,0.66538,0.53531,0.58381,0.4281,1.02157,1.04186,1.13869,0.0062215,0.0062215,0.0062215 +231,43514.4,1.12273,1.31511,1.17675,0.66576,0.53534,0.58382,0.42816,1.0213,1.04111,1.13851,0.006205,0.006205,0.006205 +232,43758,1.13069,1.31808,1.17806,0.66367,0.53687,0.58429,0.42862,1.02108,1.04025,1.13821,0.0061885,0.0061885,0.0061885 +233,44001.4,1.12099,1.31103,1.17798,0.66818,0.5357,0.58471,0.42886,1.02117,1.04001,1.13813,0.006172,0.006172,0.006172 +234,44246.4,1.12564,1.31156,1.17606,0.66391,0.53807,0.58496,0.42874,1.02116,1.03957,1.13812,0.0061555,0.0061555,0.0061555 +235,44490.7,1.12203,1.31345,1.17706,0.66639,0.5379,0.58537,0.42893,1.02089,1.03895,1.13784,0.006139,0.006139,0.006139 +236,44735.7,1.12891,1.31184,1.17849,0.66828,0.53805,0.58574,0.42901,1.02097,1.03823,1.13785,0.0061225,0.0061225,0.0061225 +237,44980.7,1.12092,1.31508,1.17678,0.67099,0.53713,0.58628,0.42944,1.02071,1.03772,1.1375,0.006106,0.006106,0.006106 +238,45224.9,1.1235,1.31175,1.17713,0.67693,0.53498,0.58636,0.42954,1.02066,1.03717,1.13742,0.0060895,0.0060895,0.0060895 +239,45468,1.12549,1.31421,1.17737,0.6775,0.53414,0.5865,0.4295,1.02045,1.03657,1.1371,0.006073,0.006073,0.006073 +240,45711.9,1.12348,1.31495,1.17607,0.67216,0.53766,0.58682,0.42994,1.02024,1.03598,1.13691,0.0060565,0.0060565,0.0060565 +241,45955.6,1.11986,1.31023,1.17725,0.67539,0.53587,0.58679,0.43016,1.02034,1.03548,1.13673,0.00604,0.00604,0.00604 +242,46199.3,1.12042,1.31552,1.17816,0.67844,0.53421,0.58717,0.43014,1.0199,1.03496,1.13635,0.0060235,0.0060235,0.0060235 +243,46442.6,1.12792,1.3115,1.17978,0.67474,0.53623,0.58738,0.43038,1.0199,1.03441,1.13629,0.006007,0.006007,0.006007 +244,46686.2,1.12433,1.31203,1.17714,0.67455,0.53612,0.58768,0.43065,1.01954,1.03411,1.13596,0.0059905,0.0059905,0.0059905 +245,46930,1.12566,1.31655,1.17726,0.66437,0.53997,0.58825,0.43083,1.01925,1.0337,1.13569,0.005974,0.005974,0.005974 +246,47174,1.1235,1.3089,1.17619,0.66514,0.53901,0.58813,0.43105,1.01906,1.03314,1.13548,0.0059575,0.0059575,0.0059575 +247,47418.1,1.11777,1.31064,1.17552,0.66623,0.5389,0.5886,0.43142,1.01873,1.03267,1.13523,0.005941,0.005941,0.005941 +248,47662.2,1.12115,1.31387,1.17698,0.66659,0.5397,0.58867,0.43145,1.01869,1.03246,1.13506,0.0059245,0.0059245,0.0059245 +249,47906.6,1.11886,1.30786,1.17593,0.66398,0.54088,0.58866,0.43149,1.01838,1.03251,1.13489,0.005908,0.005908,0.005908 +250,48149.9,1.12203,1.30485,1.17503,0.66742,0.54015,0.58884,0.43181,1.01802,1.03178,1.13474,0.0058915,0.0058915,0.0058915 +251,48393,1.12055,1.30421,1.17426,0.67541,0.54017,0.58946,0.43234,1.01794,1.03158,1.13464,0.005875,0.005875,0.005875 +252,48635.8,1.11803,1.30316,1.17452,0.67592,0.54005,0.58984,0.43274,1.01785,1.03136,1.13446,0.0058585,0.0058585,0.0058585 +253,48878.7,1.12347,1.3092,1.17495,0.67893,0.53933,0.59016,0.4329,1.01766,1.03103,1.13426,0.005842,0.005842,0.005842 +254,49122.3,1.11648,1.30326,1.17495,0.67577,0.54075,0.58992,0.43288,1.01745,1.03085,1.13426,0.0058255,0.0058255,0.0058255 +255,49365.5,1.11872,1.2984,1.17185,0.67353,0.54046,0.59033,0.43324,1.01734,1.03044,1.13402,0.005809,0.005809,0.005809 +256,49609.3,1.11756,1.30866,1.17465,0.66774,0.54526,0.59033,0.43352,1.01719,1.03007,1.1339,0.0057925,0.0057925,0.0057925 +257,49853,1.1246,1.30115,1.17362,0.66799,0.5456,0.59039,0.43351,1.01683,1.02956,1.13363,0.005776,0.005776,0.005776 +258,50096.8,1.12396,1.30201,1.17303,0.6667,0.54548,0.59048,0.43387,1.01661,1.02897,1.13347,0.0057595,0.0057595,0.0057595 +259,50340.4,1.1145,1.30149,1.17217,0.66812,0.54564,0.59089,0.4342,1.01628,1.02854,1.13321,0.005743,0.005743,0.005743 +260,50584.1,1.11506,1.30144,1.17062,0.66507,0.54594,0.59063,0.43419,1.01614,1.02797,1.13305,0.0057265,0.0057265,0.0057265 +261,50827.6,1.11377,1.29402,1.17069,0.66532,0.5461,0.59087,0.43438,1.01588,1.0275,1.13278,0.00571,0.00571,0.00571 +262,51070.4,1.11785,1.29984,1.17285,0.66765,0.54591,0.59104,0.43446,1.01563,1.02716,1.13262,0.0056935,0.0056935,0.0056935 +263,51313.4,1.11923,1.3004,1.17466,0.67057,0.54386,0.59159,0.4347,1.01528,1.02671,1.13224,0.005677,0.005677,0.005677 +264,51556.1,1.11271,1.29652,1.17283,0.67191,0.54453,0.59192,0.43488,1.0151,1.0265,1.13228,0.0056605,0.0056605,0.0056605 +265,51799.5,1.11656,1.2931,1.16908,0.67346,0.54485,0.59233,0.43504,1.01467,1.02614,1.13203,0.005644,0.005644,0.005644 +266,52042.5,1.11119,1.2992,1.17182,0.67397,0.54356,0.59245,0.43542,1.01451,1.02573,1.13195,0.0056275,0.0056275,0.0056275 +267,52285.4,1.11594,1.29326,1.17106,0.67213,0.54508,0.59284,0.43565,1.01416,1.02498,1.13172,0.005611,0.005611,0.005611 +268,52528.4,1.11229,1.29473,1.17072,0.6716,0.54533,0.59301,0.43592,1.01395,1.02444,1.13155,0.0055945,0.0055945,0.0055945 +269,52771.7,1.11406,1.29698,1.17107,0.67056,0.5457,0.59309,0.43609,1.01368,1.02431,1.13129,0.005578,0.005578,0.005578 +270,53015.1,1.10766,1.28888,1.16879,0.67264,0.54388,0.59297,0.43624,1.01359,1.02418,1.13116,0.0055615,0.0055615,0.0055615 +271,53258.3,1.10976,1.29229,1.16973,0.67178,0.54496,0.59306,0.43602,1.0131,1.02398,1.13073,0.005545,0.005545,0.005545 +272,53502.5,1.11459,1.29418,1.17241,0.67055,0.54656,0.59334,0.43635,1.01286,1.02386,1.13061,0.0055285,0.0055285,0.0055285 +273,53746.4,1.1131,1.29251,1.17122,0.66728,0.54854,0.59328,0.43648,1.01257,1.02337,1.13051,0.005512,0.005512,0.005512 +274,53990.3,1.11568,1.29536,1.17085,0.66917,0.54755,0.59343,0.43659,1.0125,1.02286,1.13041,0.0054955,0.0054955,0.0054955 +275,54234.8,1.10975,1.28199,1.16899,0.67203,0.54507,0.59363,0.43674,1.01208,1.02244,1.13011,0.005479,0.005479,0.005479 +276,54478.1,1.11665,1.29355,1.17132,0.67271,0.54483,0.59379,0.43686,1.01193,1.0221,1.12979,0.0054625,0.0054625,0.0054625 +277,54721.2,1.11301,1.29454,1.16989,0.67365,0.54545,0.59402,0.43691,1.01142,1.02169,1.12942,0.005446,0.005446,0.005446 +278,54964.4,1.11252,1.28784,1.1683,0.67409,0.54553,0.59428,0.43725,1.01125,1.02113,1.1293,0.0054295,0.0054295,0.0054295 +279,55207.9,1.11196,1.28122,1.16721,0.67608,0.54425,0.59461,0.43755,1.01092,1.02056,1.12904,0.005413,0.005413,0.005413 +280,55451.2,1.11119,1.28438,1.16771,0.67836,0.5432,0.59466,0.43759,1.01073,1.02047,1.12872,0.0053965,0.0053965,0.0053965 +281,55694.7,1.11511,1.28743,1.17013,0.67551,0.54541,0.59485,0.43773,1.01036,1.01971,1.12833,0.00538,0.00538,0.00538 +282,55937.6,1.11022,1.27761,1.16719,0.67709,0.54469,0.59435,0.4374,1.01041,1.01939,1.12818,0.0053635,0.0053635,0.0053635 +283,56181.3,1.11121,1.27869,1.16772,0.67695,0.54455,0.59461,0.43748,1.0103,1.01906,1.12802,0.005347,0.005347,0.005347 +284,56424.7,1.1094,1.28113,1.16676,0.67866,0.54299,0.59487,0.43772,1.01018,1.01899,1.12776,0.0053305,0.0053305,0.0053305 +285,56668.4,1.11298,1.2849,1.16801,0.67884,0.54286,0.5953,0.43798,1.01005,1.01851,1.12765,0.005314,0.005314,0.005314 +286,56911.9,1.1091,1.28268,1.16795,0.67768,0.54385,0.59549,0.43826,1.00993,1.01809,1.12739,0.0052975,0.0052975,0.0052975 +287,57155.4,1.10664,1.28113,1.16663,0.6802,0.54315,0.59586,0.43834,1.01014,1.01809,1.12749,0.005281,0.005281,0.005281 +288,57399.2,1.10631,1.28473,1.16642,0.68205,0.54263,0.59568,0.43833,1.01023,1.01785,1.12749,0.0052645,0.0052645,0.0052645 +289,57642.1,1.10894,1.27927,1.16621,0.6783,0.54553,0.59605,0.43879,1.01009,1.01747,1.12747,0.005248,0.005248,0.005248 +290,57885.5,1.10382,1.27394,1.16345,0.67621,0.5468,0.59667,0.43903,1.01006,1.01681,1.1273,0.0052315,0.0052315,0.0052315 +291,58129,1.10813,1.27516,1.16764,0.6841,0.54329,0.59672,0.43918,1.01031,1.01629,1.12739,0.005215,0.005215,0.005215 +292,58373.4,1.10684,1.27673,1.1668,0.68171,0.54503,0.5968,0.43947,1.0103,1.01596,1.12714,0.0051985,0.0051985,0.0051985 +293,58617.9,1.10658,1.27603,1.16517,0.68148,0.54541,0.59691,0.43975,1.01023,1.01547,1.12695,0.005182,0.005182,0.005182 +294,58862.5,1.10868,1.2752,1.16571,0.68427,0.54392,0.59732,0.43986,1.01023,1.01516,1.12688,0.0051655,0.0051655,0.0051655 +295,59105.3,1.10683,1.2773,1.16526,0.68197,0.54419,0.59758,0.44018,1.01024,1.01491,1.12679,0.005149,0.005149,0.005149 +296,59349.6,1.10769,1.27843,1.16643,0.68779,0.54425,0.59786,0.44059,1.01013,1.01445,1.12668,0.0051325,0.0051325,0.0051325 +297,59592.6,1.10741,1.27903,1.16439,0.68659,0.54455,0.59826,0.44051,1.00989,1.01398,1.12648,0.005116,0.005116,0.005116 +298,59834.9,1.10848,1.2763,1.16592,0.68538,0.54519,0.59834,0.44063,1.00937,1.01338,1.12606,0.0050995,0.0050995,0.0050995 +299,60078.2,1.10527,1.27304,1.16496,0.68561,0.5449,0.59853,0.44063,1.00924,1.01312,1.1259,0.005083,0.005083,0.005083 +300,60321.2,1.10577,1.27238,1.16297,0.68227,0.54679,0.59879,0.44057,1.00881,1.01243,1.12549,0.0050665,0.0050665,0.0050665 +301,60564.4,1.10521,1.26456,1.16378,0.68088,0.54782,0.59882,0.44061,1.00863,1.01209,1.12526,0.00505,0.00505,0.00505 +302,60808.2,1.10511,1.26939,1.16314,0.68601,0.54585,0.5993,0.44095,1.00822,1.01158,1.12485,0.0050335,0.0050335,0.0050335 +303,61052.9,1.11048,1.26864,1.16694,0.68411,0.54587,0.59974,0.44112,1.00811,1.01113,1.12495,0.005017,0.005017,0.005017 +304,61298.5,1.10192,1.26767,1.16286,0.68039,0.54633,0.60008,0.44152,1.00782,1.01055,1.12467,0.0050005,0.0050005,0.0050005 +305,61543.1,1.10003,1.26141,1.16102,0.68261,0.54568,0.60011,0.4417,1.00731,1.01007,1.12422,0.004984,0.004984,0.004984 +306,61787.4,1.10401,1.26602,1.16419,0.68328,0.54533,0.60029,0.44174,1.00711,1.00948,1.12403,0.0049675,0.0049675,0.0049675 +307,62030.7,1.10355,1.26923,1.16571,0.68214,0.54606,0.60029,0.44179,1.00714,1.00924,1.1239,0.004951,0.004951,0.004951 +308,62274.4,1.10141,1.26417,1.16291,0.68111,0.54694,0.6004,0.44186,1.00729,1.00872,1.12391,0.0049345,0.0049345,0.0049345 +309,62517.8,1.10356,1.26669,1.16199,0.68346,0.54599,0.60067,0.44205,1.007,1.00827,1.12362,0.004918,0.004918,0.004918 +310,62760.9,1.09897,1.26741,1.16188,0.68395,0.54527,0.60104,0.44206,1.00716,1.00802,1.12358,0.0049015,0.0049015,0.0049015 +311,63005,1.09958,1.26263,1.16004,0.68209,0.54599,0.60099,0.44214,1.00717,1.00792,1.12334,0.004885,0.004885,0.004885 +312,63248.6,1.09986,1.25709,1.16052,0.68216,0.5464,0.60128,0.44252,1.00703,1.00748,1.12323,0.0048685,0.0048685,0.0048685 +313,63492,1.09854,1.25717,1.16083,0.68798,0.54343,0.60145,0.4424,1.0073,1.00703,1.1233,0.004852,0.004852,0.004852 +314,63736,1.10245,1.25939,1.16231,0.68453,0.54504,0.60168,0.44241,1.00731,1.00671,1.12313,0.0048355,0.0048355,0.0048355 +315,63979.7,1.10088,1.25868,1.16094,0.68427,0.54566,0.60187,0.44263,1.00672,1.00635,1.12289,0.004819,0.004819,0.004819 +316,64222.6,1.09872,1.25723,1.16201,0.68607,0.54408,0.60152,0.44256,1.00647,1.00597,1.12282,0.0048025,0.0048025,0.0048025 +317,64466.5,1.09722,1.25795,1.16126,0.68768,0.54369,0.60167,0.44296,1.00632,1.00537,1.12277,0.004786,0.004786,0.004786 +318,64709.5,1.09345,1.24812,1.1592,0.68717,0.54402,0.60187,0.44304,1.00631,1.00542,1.12265,0.0047695,0.0047695,0.0047695 +319,64953.3,1.09529,1.24855,1.15938,0.68792,0.5442,0.60236,0.44336,1.006,1.00479,1.12244,0.004753,0.004753,0.004753 +320,65196.9,1.10421,1.25833,1.16165,0.68582,0.54573,0.60238,0.44331,1.00601,1.00443,1.12234,0.0047365,0.0047365,0.0047365 +321,65440.4,1.10467,1.25787,1.16191,0.68528,0.54593,0.60301,0.444,1.00567,1.00403,1.12209,0.00472,0.00472,0.00472 +322,65684.2,1.10088,1.25421,1.16061,0.68385,0.54671,0.60303,0.44417,1.00554,1.00349,1.12206,0.0047035,0.0047035,0.0047035 +323,65928.7,1.0948,1.24954,1.15646,0.6859,0.5467,0.6033,0.44449,1.00541,1.00309,1.12197,0.004687,0.004687,0.004687 +324,66172.8,1.097,1.24674,1.15702,0.68585,0.54691,0.60341,0.44456,1.00526,1.0027,1.12196,0.0046705,0.0046705,0.0046705 +325,66418.7,1.09671,1.24828,1.15876,0.68605,0.54796,0.60375,0.44459,1.00517,1.00257,1.12175,0.004654,0.004654,0.004654 +326,66662.8,1.09596,1.25728,1.16012,0.68611,0.54757,0.60363,0.4448,1.00491,1.00218,1.1216,0.0046375,0.0046375,0.0046375 +327,66906.6,1.09527,1.2483,1.15823,0.68787,0.54672,0.6039,0.44472,1.00484,1.00159,1.12152,0.004621,0.004621,0.004621 +328,67149.9,1.09873,1.24898,1.15742,0.69261,0.54421,0.6042,0.44488,1.0048,1.00141,1.12132,0.0046045,0.0046045,0.0046045 +329,67393.4,1.10133,1.25105,1.15968,0.69355,0.54476,0.60455,0.44525,1.00457,1.00096,1.12098,0.004588,0.004588,0.004588 +330,67636.9,1.09891,1.24437,1.15756,0.69536,0.54409,0.60464,0.44543,1.00421,1.00081,1.12054,0.0045715,0.0045715,0.0045715 +331,67880.3,1.09532,1.24326,1.15585,0.69426,0.5458,0.60499,0.44574,1.00425,1.00025,1.12057,0.004555,0.004555,0.004555 +332,68124.1,1.09549,1.24714,1.15835,0.69966,0.54392,0.60535,0.44585,1.00381,0.99957,1.12025,0.0045385,0.0045385,0.0045385 +333,68367.5,1.09873,1.24165,1.15652,0.69853,0.54482,0.60549,0.44582,1.00341,0.99899,1.12005,0.004522,0.004522,0.004522 +334,68611.1,1.09281,1.24258,1.15573,0.69588,0.54665,0.60544,0.44595,1.00335,0.99819,1.11995,0.0045055,0.0045055,0.0045055 +335,68854.6,1.0961,1.24231,1.15716,0.69739,0.54636,0.60552,0.44636,1.00333,0.99806,1.11985,0.004489,0.004489,0.004489 +336,69098.5,1.09301,1.24551,1.15596,0.69998,0.54446,0.60559,0.44666,1.00334,0.99778,1.1197,0.0044725,0.0044725,0.0044725 +337,69342,1.09724,1.24046,1.15745,0.69726,0.54427,0.60576,0.44682,1.00289,0.99736,1.1193,0.004456,0.004456,0.004456 +338,69584.8,1.09507,1.24105,1.15382,0.69593,0.54528,0.606,0.44698,1.00246,0.99726,1.1189,0.0044395,0.0044395,0.0044395 +339,69828.8,1.08929,1.23304,1.15413,0.69351,0.54726,0.60626,0.44711,1.00227,0.9969,1.11882,0.004423,0.004423,0.004423 +340,70072.7,1.09568,1.24112,1.158,0.69556,0.54616,0.60652,0.4473,1.00224,0.99644,1.11868,0.0044065,0.0044065,0.0044065 +341,70315.5,1.09587,1.24973,1.15933,0.69397,0.5474,0.60667,0.44758,1.00208,0.9959,1.11857,0.00439,0.00439,0.00439 +342,70558.6,1.09638,1.23678,1.15873,0.68364,0.55462,0.60707,0.44783,1.00166,0.99538,1.11828,0.0043735,0.0043735,0.0043735 +343,70802.5,1.09632,1.23944,1.1577,0.68338,0.55612,0.60744,0.44827,1.00132,0.99504,1.11799,0.004357,0.004357,0.004357 +344,71046.5,1.09177,1.24025,1.15644,0.68116,0.55717,0.60718,0.44796,1.00114,0.99472,1.11795,0.0043405,0.0043405,0.0043405 +345,71289.7,1.09052,1.23275,1.15374,0.67978,0.55838,0.60717,0.44822,1.00095,0.99444,1.11782,0.004324,0.004324,0.004324 +346,71534,1.09444,1.23291,1.15439,0.67629,0.55991,0.60705,0.44837,1.00078,0.99396,1.11765,0.0043075,0.0043075,0.0043075 +347,71777.5,1.09407,1.24265,1.15587,0.68185,0.55579,0.60717,0.44845,1.00043,0.99357,1.11739,0.004291,0.004291,0.004291 +348,72022.1,1.08904,1.22953,1.15427,0.67802,0.55913,0.60727,0.44865,0.99989,0.9931,1.11713,0.0042745,0.0042745,0.0042745 +349,72267.3,1.09454,1.23362,1.15501,0.67751,0.56025,0.60753,0.44886,1.00013,0.99286,1.11717,0.004258,0.004258,0.004258 +350,72512.4,1.09208,1.23228,1.15518,0.67879,0.56005,0.60698,0.44871,1.00016,0.99256,1.11711,0.0042415,0.0042415,0.0042415 +351,72756.2,1.0937,1.23281,1.15449,0.67624,0.56109,0.60728,0.44882,1.00016,0.99218,1.11704,0.004225,0.004225,0.004225 +352,73001.4,1.09406,1.23205,1.15411,0.68156,0.55603,0.60746,0.44895,0.99998,0.9919,1.11689,0.0042085,0.0042085,0.0042085 +353,73246.7,1.08469,1.2249,1.15135,0.6821,0.55595,0.60756,0.44919,0.99986,0.99154,1.11678,0.004192,0.004192,0.004192 +354,73492.1,1.08936,1.22761,1.15372,0.68417,0.55557,0.60788,0.44926,0.9999,0.99134,1.11669,0.0041755,0.0041755,0.0041755 +355,73735.9,1.08731,1.22822,1.1517,0.68435,0.55521,0.60818,0.44959,0.99996,0.99104,1.11666,0.004159,0.004159,0.004159 +356,73979.7,1.08937,1.2334,1.15292,0.68876,0.55412,0.60833,0.44962,0.99944,0.99071,1.11641,0.0041425,0.0041425,0.0041425 +357,74223.4,1.08882,1.22266,1.1532,0.68826,0.55458,0.60849,0.44986,0.99936,0.99057,1.11609,0.004126,0.004126,0.004126 +358,74468,1.08839,1.22336,1.15092,0.68907,0.55452,0.60907,0.4503,0.99921,0.98996,1.11601,0.0041095,0.0041095,0.0041095 +359,74712.3,1.08887,1.22891,1.15395,0.68697,0.55619,0.60931,0.45064,0.99907,0.98958,1.11595,0.004093,0.004093,0.004093 +360,74957,1.08791,1.22509,1.15235,0.6857,0.55738,0.60986,0.45092,0.99855,0.98901,1.11555,0.0040765,0.0040765,0.0040765 +361,75200.9,1.08662,1.228,1.15136,0.68412,0.55845,0.61043,0.4512,0.99827,0.9887,1.11522,0.00406,0.00406,0.00406 +362,75444.4,1.0869,1.22681,1.15282,0.68384,0.55769,0.61039,0.45121,0.99806,0.98809,1.11503,0.0040435,0.0040435,0.0040435 +363,75687.8,1.08342,1.21506,1.14919,0.68289,0.55709,0.61023,0.45116,0.99799,0.98777,1.11485,0.004027,0.004027,0.004027 +364,75931.9,1.08426,1.21884,1.15167,0.68382,0.55741,0.61041,0.45135,0.99788,0.98753,1.11487,0.0040105,0.0040105,0.0040105 +365,76175.6,1.088,1.22076,1.15114,0.68335,0.55747,0.61049,0.45133,0.99766,0.98741,1.11471,0.003994,0.003994,0.003994 +366,76419.5,1.08788,1.21856,1.15223,0.68411,0.55613,0.61087,0.45171,0.99749,0.98703,1.11453,0.0039775,0.0039775,0.0039775 +367,76664.1,1.08404,1.21802,1.15081,0.6852,0.55639,0.6111,0.45171,0.99728,0.98645,1.11435,0.003961,0.003961,0.003961 +368,76907.4,1.08213,1.21748,1.14978,0.6855,0.55631,0.61133,0.45187,0.9973,0.98566,1.11437,0.0039445,0.0039445,0.0039445 +369,77150.9,1.08388,1.21469,1.15126,0.68878,0.5542,0.61115,0.45212,0.99722,0.98536,1.11426,0.003928,0.003928,0.003928 +370,77393.7,1.08322,1.22093,1.15069,0.68789,0.55458,0.61149,0.45226,0.99703,0.98493,1.11408,0.0039115,0.0039115,0.0039115 +371,77637,1.08408,1.21083,1.14944,0.68673,0.55503,0.61134,0.45244,0.9968,0.98474,1.11387,0.003895,0.003895,0.003895 +372,77879.9,1.08491,1.21902,1.15074,0.68863,0.55502,0.61149,0.45246,0.99639,0.98409,1.11375,0.0038785,0.0038785,0.0038785 +373,78123.4,1.08045,1.20748,1.14538,0.69239,0.55228,0.61159,0.45257,0.99621,0.98352,1.11368,0.003862,0.003862,0.003862 +374,78366.2,1.08048,1.20295,1.14583,0.6928,0.55258,0.61195,0.45263,0.99611,0.98298,1.11356,0.0038455,0.0038455,0.0038455 +375,78609,1.08503,1.21026,1.14617,0.6925,0.55332,0.61209,0.45277,0.9961,0.98257,1.11354,0.003829,0.003829,0.003829 +376,78852.7,1.0816,1.20924,1.14665,0.69313,0.5531,0.61214,0.45292,0.9963,0.98214,1.11358,0.0038125,0.0038125,0.0038125 +377,79096,1.07992,1.21363,1.14721,0.69277,0.55336,0.61261,0.45306,0.9963,0.98181,1.11346,0.003796,0.003796,0.003796 +378,79339.4,1.08012,1.20603,1.14957,0.68582,0.55722,0.61264,0.4532,0.99609,0.9813,1.11331,0.0037795,0.0037795,0.0037795 +379,79583,1.07687,1.20946,1.14765,0.68399,0.55834,0.6131,0.45357,0.9961,0.98093,1.11327,0.003763,0.003763,0.003763 +380,79826.5,1.08863,1.21459,1.15159,0.68333,0.55923,0.6131,0.45344,0.99633,0.98051,1.1133,0.0037465,0.0037465,0.0037465 +381,80069.3,1.0778,1.20417,1.14528,0.67956,0.56206,0.61313,0.4535,0.99635,0.98027,1.11336,0.00373,0.00373,0.00373 +382,80312.4,1.08151,1.20447,1.14631,0.67735,0.56336,0.61327,0.45349,0.99615,0.97992,1.11318,0.0037135,0.0037135,0.0037135 +383,80555,1.08646,1.21414,1.14853,0.67691,0.56365,0.6133,0.45343,0.99613,0.97982,1.11313,0.003697,0.003697,0.003697 +384,80798.4,1.08095,1.20271,1.1453,0.67514,0.56467,0.61246,0.45304,0.99585,0.97956,1.11305,0.0036805,0.0036805,0.0036805 +385,81041.4,1.0838,1.20709,1.14695,0.67644,0.56439,0.6127,0.45344,0.99566,0.97927,1.11298,0.003664,0.003664,0.003664 +386,81284.3,1.07867,1.20252,1.14703,0.67614,0.56471,0.61286,0.4536,0.99563,0.97912,1.11278,0.0036475,0.0036475,0.0036475 +387,81527.2,1.07494,1.19712,1.14383,0.6779,0.56439,0.61285,0.45355,0.9956,0.97894,1.11261,0.003631,0.003631,0.003631 +388,81770.2,1.07619,1.19679,1.14463,0.67783,0.56466,0.61292,0.45373,0.9956,0.97857,1.11255,0.0036145,0.0036145,0.0036145 +389,82013.5,1.08001,1.19572,1.14487,0.67797,0.56471,0.61338,0.4541,0.99527,0.9782,1.11241,0.003598,0.003598,0.003598 +390,82256.6,1.07584,1.20193,1.1433,0.68015,0.56284,0.6136,0.45432,0.99525,0.97788,1.11226,0.0035815,0.0035815,0.0035815 +391,82499.7,1.07362,1.19282,1.14214,0.68058,0.56319,0.6135,0.45404,0.99518,0.97758,1.11221,0.003565,0.003565,0.003565 +392,82743.3,1.07853,1.20061,1.1456,0.67924,0.56454,0.61374,0.45406,0.99511,0.97694,1.11209,0.0035485,0.0035485,0.0035485 +393,82986.2,1.07451,1.20096,1.1443,0.67494,0.56818,0.61386,0.45408,0.9951,0.97649,1.11206,0.003532,0.003532,0.003532 +394,83229,1.08186,1.19951,1.14558,0.68014,0.5649,0.61403,0.45409,0.99513,0.97566,1.11209,0.0035155,0.0035155,0.0035155 +395,83471.9,1.07931,1.19436,1.14635,0.68058,0.56503,0.61405,0.45416,0.99478,0.97571,1.11186,0.003499,0.003499,0.003499 +396,83715.2,1.0767,1.19657,1.14546,0.68141,0.56422,0.61431,0.45447,0.99469,0.97561,1.11168,0.0034825,0.0034825,0.0034825 +397,83958.2,1.07014,1.19075,1.14091,0.67824,0.56632,0.61525,0.45529,0.99438,0.97514,1.11151,0.003466,0.003466,0.003466 +398,84201.9,1.07557,1.19005,1.13976,0.67937,0.56575,0.61543,0.4555,0.99404,0.97494,1.11112,0.0034495,0.0034495,0.0034495 +399,84445.7,1.07368,1.18905,1.14212,0.67927,0.56622,0.61549,0.45555,0.99385,0.97453,1.11092,0.003433,0.003433,0.003433 +400,84688.9,1.07387,1.18657,1.13915,0.67605,0.56802,0.61532,0.45567,0.9937,0.97435,1.11071,0.0034165,0.0034165,0.0034165 +401,84932.2,1.07033,1.18431,1.14089,0.67524,0.56833,0.61551,0.45549,0.99366,0.97424,1.11056,0.0034,0.0034,0.0034 +402,85175.2,1.0774,1.19147,1.14419,0.67587,0.56832,0.61566,0.4558,0.9936,0.97396,1.11051,0.0033835,0.0033835,0.0033835 +403,85418.3,1.07215,1.18508,1.14086,0.67842,0.56703,0.6158,0.45578,0.99366,0.97354,1.11051,0.003367,0.003367,0.003367 +404,85661.2,1.07571,1.18924,1.1429,0.67471,0.57054,0.61622,0.4562,0.99348,0.97322,1.11025,0.0033505,0.0033505,0.0033505 +405,85904.2,1.07326,1.18714,1.13954,0.67669,0.56921,0.61631,0.45593,0.99321,0.97269,1.11002,0.003334,0.003334,0.003334 +406,86147.4,1.07706,1.18367,1.1421,0.67845,0.56837,0.61663,0.45604,0.99278,0.97223,1.10978,0.0033175,0.0033175,0.0033175 +407,86390.6,1.06858,1.18206,1.14012,0.67834,0.5686,0.6164,0.45631,0.99278,0.97213,1.10969,0.003301,0.003301,0.003301 +408,86633.7,1.06715,1.18076,1.13746,0.67714,0.56892,0.61659,0.45622,0.99279,0.97207,1.10951,0.0032845,0.0032845,0.0032845 +409,86876.1,1.06262,1.17682,1.13587,0.67984,0.56757,0.61649,0.45613,0.99257,0.97177,1.10916,0.003268,0.003268,0.003268 +410,87119.2,1.06816,1.17838,1.13666,0.68232,0.56613,0.61696,0.45621,0.99247,0.97157,1.10901,0.0032515,0.0032515,0.0032515 +411,87361.7,1.07284,1.18633,1.14089,0.68151,0.56742,0.61721,0.45631,0.99259,0.97134,1.10896,0.003235,0.003235,0.003235 +412,87604.7,1.07429,1.18255,1.14065,0.68409,0.566,0.6178,0.45686,0.99233,0.9711,1.10865,0.0032185,0.0032185,0.0032185 +413,87847.5,1.0683,1.17696,1.13984,0.68559,0.56526,0.61798,0.4572,0.99203,0.97054,1.10842,0.003202,0.003202,0.003202 +414,88090.4,1.07121,1.18067,1.13747,0.68361,0.56581,0.61765,0.45695,0.99171,0.97029,1.10799,0.0031855,0.0031855,0.0031855 +415,88333.6,1.07213,1.18289,1.14008,0.67622,0.57211,0.61793,0.45718,0.99141,0.97012,1.10781,0.003169,0.003169,0.003169 +416,88576.6,1.0662,1.17574,1.13782,0.68183,0.56937,0.61821,0.45731,0.99149,0.96963,1.10784,0.0031525,0.0031525,0.0031525 +417,88819.7,1.07074,1.17697,1.13943,0.6816,0.56895,0.61837,0.4574,0.99157,0.96925,1.10773,0.003136,0.003136,0.003136 +418,89062.2,1.07096,1.17299,1.13718,0.68266,0.56871,0.61852,0.45745,0.99158,0.96908,1.10763,0.0031195,0.0031195,0.0031195 +419,89305.8,1.07092,1.18126,1.13827,0.684,0.56813,0.6189,0.45777,0.99156,0.96857,1.10766,0.003103,0.003103,0.003103 +420,89549.1,1.07049,1.18065,1.13824,0.68247,0.56966,0.61907,0.45781,0.9914,0.96832,1.10745,0.0030865,0.0030865,0.0030865 +421,89792.5,1.0702,1.17625,1.13981,0.68227,0.57034,0.61878,0.4578,0.99125,0.96802,1.10735,0.00307,0.00307,0.00307 +422,90035.6,1.06039,1.16828,1.13537,0.6843,0.56927,0.61912,0.45817,0.99081,0.96758,1.10687,0.0030535,0.0030535,0.0030535 +423,90278.8,1.0618,1.16454,1.13434,0.68622,0.56903,0.61966,0.45835,0.99086,0.96705,1.1069,0.003037,0.003037,0.003037 +424,90522,1.06507,1.17385,1.13663,0.68504,0.57043,0.61976,0.45842,0.99089,0.96659,1.10696,0.0030205,0.0030205,0.0030205 +425,90765.1,1.06343,1.16059,1.13571,0.68573,0.57089,0.61993,0.4585,0.9909,0.96604,1.10696,0.003004,0.003004,0.003004 +426,91008.6,1.06742,1.16923,1.1346,0.68663,0.57041,0.62012,0.45854,0.99088,0.96584,1.10696,0.0029875,0.0029875,0.0029875 +427,91251.5,1.06561,1.16774,1.13389,0.68488,0.57165,0.62021,0.45888,0.99091,0.96558,1.10684,0.002971,0.002971,0.002971 +428,91494.7,1.06164,1.16273,1.1345,0.68548,0.57025,0.62033,0.45912,0.99091,0.96531,1.1067,0.0029545,0.0029545,0.0029545 +429,91737.8,1.06172,1.1632,1.13469,0.68749,0.5682,0.62061,0.45918,0.99057,0.96531,1.10638,0.002938,0.002938,0.002938 +430,91981.6,1.0688,1.1695,1.13684,0.6857,0.56912,0.62117,0.45968,0.99033,0.96497,1.10621,0.0029215,0.0029215,0.0029215 +431,92225.6,1.06042,1.1615,1.13364,0.68496,0.57017,0.6214,0.45978,0.99033,0.96471,1.10611,0.002905,0.002905,0.002905 +432,92469.2,1.05746,1.15938,1.1316,0.68726,0.57016,0.62139,0.45984,0.99017,0.9642,1.10598,0.0028885,0.0028885,0.0028885 +433,92713.5,1.06515,1.16214,1.13392,0.68489,0.57164,0.62142,0.46013,0.98966,0.96388,1.10571,0.002872,0.002872,0.002872 +434,92957.2,1.06351,1.15982,1.13296,0.68439,0.57187,0.62132,0.45993,0.98929,0.96342,1.10543,0.0028555,0.0028555,0.0028555 +435,93200.1,1.06093,1.16143,1.13411,0.68779,0.5708,0.62184,0.46027,0.9889,0.96329,1.10525,0.002839,0.002839,0.002839 +436,93442.4,1.06217,1.15978,1.13244,0.68477,0.57201,0.62193,0.46029,0.98881,0.96313,1.10491,0.0028225,0.0028225,0.0028225 +437,93685.3,1.05834,1.15719,1.13136,0.68551,0.57136,0.62167,0.46019,0.98866,0.96281,1.10479,0.002806,0.002806,0.002806 +438,93928.6,1.05578,1.15316,1.13108,0.68304,0.5723,0.62206,0.46067,0.98861,0.96243,1.10457,0.0027895,0.0027895,0.0027895 +439,94171.9,1.05903,1.15967,1.13308,0.68393,0.57154,0.62192,0.46036,0.98858,0.96197,1.10455,0.002773,0.002773,0.002773 +440,94415.1,1.06357,1.15541,1.13147,0.68401,0.57151,0.62194,0.46034,0.98814,0.96149,1.10407,0.0027565,0.0027565,0.0027565 +441,94658.3,1.0605,1.15219,1.13263,0.68862,0.56973,0.62202,0.46095,0.98802,0.96084,1.10407,0.00274,0.00274,0.00274 +442,94901.1,1.06221,1.15542,1.1313,0.68979,0.56952,0.62236,0.46111,0.98785,0.9606,1.10404,0.0027235,0.0027235,0.0027235 +443,95143.7,1.06232,1.1523,1.13153,0.69129,0.56915,0.62266,0.46115,0.98775,0.9603,1.1039,0.002707,0.002707,0.002707 +444,95386.7,1.05769,1.14555,1.13033,0.69064,0.56978,0.62293,0.46132,0.98774,0.96,1.10378,0.0026905,0.0026905,0.0026905 +445,95629.6,1.05801,1.14987,1.13118,0.69133,0.57017,0.6235,0.46151,0.98771,0.95946,1.1038,0.002674,0.002674,0.002674 +446,95872.2,1.05618,1.14957,1.13385,0.68909,0.57162,0.6236,0.46152,0.9876,0.95912,1.10369,0.0026575,0.0026575,0.0026575 +447,96116.1,1.05926,1.15019,1.13155,0.69198,0.57085,0.62346,0.46156,0.98744,0.95871,1.10342,0.002641,0.002641,0.002641 +448,96359.6,1.05734,1.15068,1.12943,0.69221,0.57105,0.62338,0.46168,0.98753,0.95811,1.10346,0.0026245,0.0026245,0.0026245 +449,96603.6,1.05588,1.14519,1.12909,0.69206,0.57144,0.62397,0.4621,0.98737,0.95781,1.10327,0.002608,0.002608,0.002608 +450,96847.9,1.05273,1.14351,1.12788,0.69189,0.57089,0.62361,0.46189,0.98707,0.95756,1.10296,0.0025915,0.0025915,0.0025915 +451,97092.2,1.05307,1.14562,1.12838,0.69264,0.57037,0.62394,0.46196,0.98698,0.95741,1.10289,0.002575,0.002575,0.002575 +452,97336.4,1.06069,1.14823,1.12976,0.69342,0.57009,0.62415,0.46194,0.98648,0.95721,1.10226,0.0025585,0.0025585,0.0025585 +453,97580.2,1.05991,1.14396,1.12966,0.69305,0.57171,0.62461,0.46221,0.98663,0.95669,1.10217,0.002542,0.002542,0.002542 +454,97823.3,1.0562,1.13669,1.12895,0.69508,0.57125,0.62503,0.46267,0.98646,0.95616,1.10207,0.0025255,0.0025255,0.0025255 +455,98066,1.05493,1.14486,1.12665,0.69229,0.57214,0.62526,0.46261,0.9865,0.95532,1.10212,0.002509,0.002509,0.002509 +456,98309.3,1.05243,1.1388,1.12687,0.6941,0.57348,0.62526,0.46284,0.98655,0.95513,1.10201,0.0024925,0.0024925,0.0024925 +457,98551.9,1.05265,1.13485,1.12581,0.69712,0.57239,0.62575,0.46321,0.98677,0.95432,1.10216,0.002476,0.002476,0.002476 +458,98795,1.05321,1.13495,1.12805,0.6958,0.57369,0.62515,0.46302,0.98666,0.95374,1.10216,0.0024595,0.0024595,0.0024595 +459,99037.6,1.05423,1.14059,1.12681,0.69722,0.57272,0.62576,0.46312,0.9865,0.95352,1.10212,0.002443,0.002443,0.002443 +460,99280.7,1.04997,1.13355,1.12612,0.69902,0.57106,0.62598,0.46327,0.98631,0.9531,1.10203,0.0024265,0.0024265,0.0024265 +461,99524.4,1.049,1.13046,1.12703,0.70088,0.56982,0.62618,0.46333,0.98632,0.95294,1.10204,0.00241,0.00241,0.00241 +462,99768.2,1.05259,1.13682,1.12586,0.69895,0.572,0.62646,0.46359,0.98619,0.95259,1.10186,0.0023935,0.0023935,0.0023935 +463,100012,1.05004,1.13267,1.12497,0.69696,0.57343,0.62679,0.4639,0.98593,0.95208,1.10189,0.002377,0.002377,0.002377 +464,100255,1.04922,1.13092,1.12584,0.69583,0.57278,0.62693,0.46419,0.9858,0.95179,1.10175,0.0023605,0.0023605,0.0023605 +465,100499,1.04848,1.12644,1.12601,0.69754,0.57389,0.62723,0.46429,0.9855,0.95171,1.10164,0.002344,0.002344,0.002344 +466,100742,1.04587,1.1305,1.12606,0.69651,0.57259,0.62735,0.46423,0.98556,0.95123,1.10159,0.0023275,0.0023275,0.0023275 +467,100984,1.05063,1.13171,1.12473,0.7005,0.57177,0.62744,0.46436,0.98555,0.95119,1.10134,0.002311,0.002311,0.002311 +468,101227,1.05016,1.1294,1.12596,0.69857,0.57363,0.62811,0.46474,0.98545,0.95093,1.10126,0.0022945,0.0022945,0.0022945 +469,101470,1.04949,1.13186,1.12827,0.70023,0.57403,0.62856,0.46509,0.98543,0.95052,1.1015,0.002278,0.002278,0.002278 +470,101713,1.05287,1.12562,1.12496,0.70045,0.57461,0.62864,0.46495,0.98535,0.95016,1.10144,0.0022615,0.0022615,0.0022615 +471,101957,1.04409,1.12448,1.12549,0.7012,0.57521,0.62861,0.46513,0.98527,0.94995,1.1014,0.002245,0.002245,0.002245 +472,102201,1.05367,1.12784,1.12649,0.70021,0.57605,0.62834,0.4651,0.98522,0.9496,1.10141,0.0022285,0.0022285,0.0022285 +473,102445,1.0466,1.12395,1.12441,0.70053,0.57624,0.62803,0.46502,0.98516,0.94948,1.10144,0.002212,0.002212,0.002212 +474,102688,1.04513,1.12683,1.12488,0.69988,0.57697,0.62829,0.46547,0.98493,0.94915,1.10123,0.0021955,0.0021955,0.0021955 +475,102933,1.04499,1.11565,1.12173,0.70311,0.57442,0.62823,0.46549,0.98457,0.9485,1.10109,0.002179,0.002179,0.002179 +476,103176,1.04097,1.11322,1.11878,0.70237,0.57534,0.62833,0.46564,0.98439,0.94794,1.10078,0.0021625,0.0021625,0.0021625 +477,103419,1.04425,1.1145,1.12324,0.70465,0.57429,0.62831,0.46571,0.98438,0.9476,1.10061,0.002146,0.002146,0.002146 +478,103662,1.04802,1.1266,1.12471,0.70655,0.5725,0.6283,0.46562,0.98407,0.94729,1.10048,0.0021295,0.0021295,0.0021295 +479,103905,1.04529,1.11514,1.12055,0.70719,0.57207,0.62851,0.46579,0.98412,0.94712,1.10048,0.002113,0.002113,0.002113 +480,104148,1.04525,1.1178,1.12262,0.70523,0.57402,0.62879,0.46576,0.98404,0.94709,1.10038,0.0020965,0.0020965,0.0020965 +481,104391,1.04233,1.1127,1.12171,0.70845,0.57109,0.62895,0.46587,0.98404,0.9471,1.10036,0.00208,0.00208,0.00208 +482,104633,1.04302,1.10428,1.12036,0.70195,0.57598,0.62922,0.4661,0.98402,0.94653,1.10021,0.0020635,0.0020635,0.0020635 +483,104876,1.0417,1.10998,1.11913,0.6994,0.57688,0.62921,0.46634,0.98406,0.94631,1.10019,0.002047,0.002047,0.002047 +484,105118,1.03776,1.10925,1.11985,0.69845,0.5771,0.62887,0.46589,0.98419,0.9462,1.09999,0.0020305,0.0020305,0.0020305 +485,105361,1.03637,1.10589,1.11714,0.69744,0.57824,0.62921,0.46622,0.98379,0.9462,1.09956,0.002014,0.002014,0.002014 +486,105603,1.03696,1.10412,1.11949,0.69673,0.57889,0.62961,0.46646,0.98369,0.9456,1.09945,0.0019975,0.0019975,0.0019975 +487,105846,1.04074,1.1044,1.11948,0.69542,0.57965,0.62955,0.46656,0.98377,0.94508,1.09945,0.001981,0.001981,0.001981 +488,106089,1.03462,1.10041,1.11585,0.69542,0.58082,0.62973,0.46691,0.98367,0.94455,1.09944,0.0019645,0.0019645,0.0019645 +489,106332,1.04065,1.10403,1.11809,0.69415,0.58215,0.6302,0.46771,0.98369,0.94421,1.09945,0.001948,0.001948,0.001948 +490,106574,1.03642,1.10132,1.11855,0.69518,0.58278,0.62988,0.46741,0.98352,0.94399,1.09931,0.0019315,0.0019315,0.0019315 +491,106817,1.03374,1.09379,1.11592,0.69666,0.58191,0.63031,0.46768,0.98326,0.94378,1.09924,0.001915,0.001915,0.001915 +492,107060,1.03677,1.09727,1.11858,0.69236,0.58374,0.63051,0.46777,0.98307,0.94351,1.09907,0.0018985,0.0018985,0.0018985 +493,107303,1.03478,1.09475,1.11626,0.69648,0.58101,0.63028,0.46753,0.98297,0.94356,1.09887,0.001882,0.001882,0.001882 +494,107547,1.03806,1.09578,1.11568,0.69573,0.58064,0.63047,0.46769,0.98255,0.94342,1.09854,0.0018655,0.0018655,0.0018655 +495,107790,1.03724,1.09672,1.11532,0.70026,0.58061,0.63068,0.4678,0.98239,0.94283,1.0984,0.001849,0.001849,0.001849 +496,108033,1.03209,1.09476,1.11645,0.69879,0.58228,0.63138,0.46831,0.98243,0.94262,1.0984,0.0018325,0.0018325,0.0018325 +497,108275,1.02931,1.09296,1.1149,0.70044,0.58212,0.63151,0.46853,0.98231,0.942,1.09818,0.001816,0.001816,0.001816 +498,108519,1.03811,1.0937,1.11727,0.70334,0.58143,0.63161,0.46871,0.98222,0.94171,1.09814,0.0017995,0.0017995,0.0017995 +499,108762,1.03611,1.09203,1.11437,0.70276,0.58114,0.63162,0.46898,0.98217,0.94162,1.09792,0.001783,0.001783,0.001783 +500,109006,1.03823,1.0872,1.11465,0.70264,0.58068,0.6318,0.46925,0.98224,0.94134,1.09798,0.0017665,0.0017665,0.0017665 +501,109249,1.03321,1.08628,1.11374,0.70455,0.57957,0.6319,0.46929,0.98212,0.941,1.09775,0.00175,0.00175,0.00175 +502,109492,1.03289,1.08675,1.11241,0.70834,0.57821,0.63211,0.46944,0.982,0.94075,1.09764,0.0017335,0.0017335,0.0017335 +503,109735,1.02994,1.08542,1.11213,0.70869,0.57856,0.63222,0.46955,0.98221,0.94043,1.09757,0.001717,0.001717,0.001717 +504,109978,1.02994,1.0838,1.11348,0.70834,0.57901,0.63193,0.46916,0.98222,0.94011,1.09755,0.0017005,0.0017005,0.0017005 +505,110221,1.03067,1.08202,1.11322,0.70855,0.57946,0.6323,0.46947,0.98247,0.93967,1.09762,0.001684,0.001684,0.001684 +506,110464,1.0269,1.07706,1.11143,0.70599,0.58114,0.63251,0.46946,0.98241,0.93937,1.09731,0.0016675,0.0016675,0.0016675 +507,110708,1.03047,1.07995,1.11215,0.70754,0.58053,0.63262,0.46959,0.98229,0.93877,1.09726,0.001651,0.001651,0.001651 +508,110950,1.02503,1.07942,1.11244,0.70876,0.5795,0.63259,0.46976,0.98241,0.93861,1.09725,0.0016345,0.0016345,0.0016345 +509,111193,1.03545,1.07957,1.11256,0.7088,0.58004,0.63291,0.46975,0.98223,0.93814,1.09698,0.001618,0.001618,0.001618 +510,111437,1.02809,1.07269,1.11008,0.70665,0.58143,0.63307,0.46981,0.98225,0.93765,1.09678,0.0016015,0.0016015,0.0016015 +511,111680,1.02546,1.07202,1.10866,0.70808,0.581,0.63319,0.46986,0.98208,0.93733,1.09662,0.001585,0.001585,0.001585 +512,111924,1.02937,1.07036,1.11051,0.70661,0.58242,0.63358,0.47007,0.98188,0.93689,1.09652,0.0015685,0.0015685,0.0015685 +513,112167,1.02852,1.07359,1.10819,0.70701,0.58292,0.63405,0.47046,0.98177,0.93634,1.09633,0.001552,0.001552,0.001552 +514,112411,1.02341,1.06551,1.10993,0.70872,0.58266,0.63416,0.47066,0.98173,0.93591,1.09636,0.0015355,0.0015355,0.0015355 +515,112655,1.02344,1.06604,1.10798,0.70758,0.58266,0.63421,0.47053,0.98148,0.93543,1.09611,0.001519,0.001519,0.001519 +516,112899,1.0252,1.06311,1.10741,0.70631,0.58335,0.63479,0.4711,0.98121,0.93515,1.0958,0.0015025,0.0015025,0.0015025 +517,113143,1.02262,1.06369,1.1057,0.70542,0.58374,0.63532,0.47162,0.98095,0.93457,1.09549,0.001486,0.001486,0.001486 +518,113385,1.02161,1.06124,1.10735,0.70412,0.58342,0.63563,0.47202,0.98067,0.93392,1.09521,0.0014695,0.0014695,0.0014695 +519,113628,1.02506,1.06438,1.10562,0.70564,0.58307,0.63593,0.47228,0.98051,0.93334,1.09505,0.001453,0.001453,0.001453 +520,113871,1.01896,1.05831,1.10829,0.70401,0.58352,0.63592,0.47255,0.98033,0.93321,1.09485,0.0014365,0.0014365,0.0014365 +521,114114,1.02176,1.06032,1.10504,0.70481,0.58275,0.63625,0.47258,0.9805,0.93276,1.09484,0.00142,0.00142,0.00142 +522,114357,1.02245,1.05774,1.10657,0.70367,0.58362,0.63648,0.47261,0.98057,0.93225,1.09483,0.0014035,0.0014035,0.0014035 +523,114599,1.01865,1.0557,1.10535,0.70523,0.58324,0.637,0.47293,0.98052,0.93216,1.09461,0.001387,0.001387,0.001387 +524,114842,1.01661,1.0563,1.10519,0.70539,0.5831,0.6372,0.47287,0.98061,0.93178,1.09445,0.0013705,0.0013705,0.0013705 +525,115084,1.01269,1.05091,1.10399,0.70743,0.58225,0.63732,0.47304,0.98055,0.93127,1.09443,0.001354,0.001354,0.001354 +526,115327,1.01876,1.05429,1.10405,0.70852,0.58129,0.63765,0.47339,0.9804,0.93084,1.09428,0.0013375,0.0013375,0.0013375 +527,115569,1.01835,1.05281,1.10364,0.70707,0.58298,0.63796,0.47341,0.9804,0.93055,1.09408,0.001321,0.001321,0.001321 +528,115812,1.0161,1.04894,1.10464,0.70723,0.58327,0.6381,0.47353,0.98021,0.9305,1.09399,0.0013045,0.0013045,0.0013045 +529,116054,1.01879,1.04654,1.10301,0.71076,0.58213,0.63831,0.47379,0.9797,0.93026,1.09371,0.001288,0.001288,0.001288 +530,116296,1.0132,1.04257,1.10305,0.70771,0.58468,0.63846,0.47385,0.9795,0.93,1.0934,0.0012715,0.0012715,0.0012715 +531,116540,1.01734,1.04487,1.1048,0.70773,0.58428,0.63856,0.4738,0.97952,0.92992,1.09344,0.001255,0.001255,0.001255 +532,116782,1.01424,1.03966,1.10036,0.70006,0.58185,0.63851,0.47383,0.97957,0.92994,1.09338,0.0012385,0.0012385,0.0012385 +533,117025,1.01207,1.03593,1.09886,0.7022,0.58091,0.63864,0.47383,0.97946,0.92939,1.09322,0.001222,0.001222,0.001222 +534,117267,1.01637,1.04308,1.10397,0.71392,0.58113,0.63881,0.47366,0.9793,0.92914,1.09318,0.0012055,0.0012055,0.0012055 +535,117510,1.01024,1.03543,1.09835,0.71495,0.58133,0.63905,0.47412,0.9792,0.92884,1.09297,0.001189,0.001189,0.001189 +536,117753,1.0193,1.04082,1.10426,0.71203,0.58414,0.6392,0.47406,0.97908,0.92845,1.09286,0.0011725,0.0011725,0.0011725 +537,117996,1.01584,1.03547,1.10211,0.71587,0.58223,0.63935,0.4743,0.9789,0.92818,1.09273,0.001156,0.001156,0.001156 +538,118238,1.00738,1.03822,1.10027,0.7124,0.58453,0.63936,0.47429,0.97875,0.92799,1.09249,0.0011395,0.0011395,0.0011395 +539,118480,1.01194,1.0313,1.10036,0.71424,0.58375,0.63959,0.47445,0.9785,0.92774,1.09229,0.001123,0.001123,0.001123 +540,118723,1.00982,1.02992,1.0975,0.71625,0.58275,0.63973,0.47489,0.97815,0.92738,1.09187,0.0011065,0.0011065,0.0011065 +541,118966,1.00811,1.02869,1.1001,0.71379,0.58492,0.63969,0.47461,0.97841,0.92726,1.09197,0.00109,0.00109,0.00109 +542,119208,1.01178,1.03117,1.09914,0.71522,0.58494,0.64014,0.47484,0.97819,0.92743,1.09163,0.0010735,0.0010735,0.0010735 +543,119451,1.00449,1.0223,1.09608,0.71641,0.58424,0.63994,0.47492,0.978,0.92744,1.09162,0.001057,0.001057,0.001057 +544,119693,1.00558,1.02435,1.09727,0.71772,0.58342,0.64011,0.47523,0.97763,0.92733,1.09126,0.0010405,0.0010405,0.0010405 +545,119935,1.00212,1.01777,1.09537,0.72117,0.58153,0.64019,0.47524,0.97744,0.92709,1.091,0.001024,0.001024,0.001024 +546,120178,1.00274,1.01987,1.09457,0.72177,0.58115,0.64022,0.47524,0.97729,0.92707,1.09081,0.0010075,0.0010075,0.0010075 +547,120420,1.00508,1.01699,1.09488,0.72109,0.5817,0.64038,0.47529,0.9771,0.92681,1.09068,0.000991,0.000991,0.000991 +548,120663,1.00133,1.01371,1.09607,0.7212,0.58149,0.64058,0.47553,0.97688,0.92694,1.09053,0.0009745,0.0009745,0.0009745 +549,120905,1.00477,1.01758,1.09519,0.71615,0.58458,0.64096,0.47585,0.97671,0.92699,1.09033,0.000958,0.000958,0.000958 +550,121148,0.99668,1.01026,1.09207,0.71694,0.58347,0.64082,0.47599,0.97636,0.92656,1.09012,0.0009415,0.0009415,0.0009415 +551,121391,0.99632,1.00957,1.09255,0.71567,0.58383,0.64126,0.47588,0.97631,0.92633,1.09002,0.000925,0.000925,0.000925 +552,121634,1.00015,1.0064,1.09534,0.7168,0.58326,0.64141,0.47602,0.97642,0.92616,1.09009,0.0009085,0.0009085,0.0009085 +553,121876,0.99796,1.00482,1.09317,0.71471,0.58483,0.64192,0.4763,0.97646,0.92581,1.09027,0.000892,0.000892,0.000892 +554,122119,0.99659,0.99899,1.09238,0.71774,0.5836,0.64167,0.47625,0.97638,0.92538,1.09038,0.0008755,0.0008755,0.0008755 +555,122362,0.99337,1.00103,1.091,0.71791,0.58348,0.64215,0.47638,0.97616,0.92534,1.09023,0.000859,0.000859,0.000859 +556,122604,0.99998,1.00144,1.09274,0.71972,0.58374,0.64268,0.47646,0.97593,0.92534,1.09003,0.0008425,0.0008425,0.0008425 +557,122846,0.98852,0.99253,1.08925,0.71881,0.58458,0.64279,0.47643,0.97615,0.92512,1.09024,0.000826,0.000826,0.000826 +558,123089,0.99172,0.99237,1.08919,0.72175,0.58284,0.64287,0.47655,0.9758,0.92504,1.08997,0.0008095,0.0008095,0.0008095 +559,123331,0.98832,0.98923,1.08622,0.72158,0.58312,0.64298,0.4767,0.97582,0.92467,1.09003,0.000793,0.000793,0.000793 +560,123574,0.99305,0.99299,1.08984,0.72439,0.58106,0.64321,0.47714,0.97578,0.9243,1.08997,0.0007765,0.0007765,0.0007765 +561,123816,0.99218,0.98495,1.08695,0.72336,0.58182,0.64337,0.47717,0.97573,0.92398,1.08997,0.00076,0.00076,0.00076 +562,124059,0.99009,0.98491,1.08901,0.72446,0.58108,0.64366,0.47713,0.97567,0.92385,1.08986,0.0007435,0.0007435,0.0007435 +563,124301,0.99108,0.98472,1.08629,0.71776,0.58559,0.6438,0.47723,0.97568,0.92378,1.08988,0.000727,0.000727,0.000727 +564,124544,0.98826,0.98007,1.08487,0.71443,0.58614,0.64376,0.4774,0.97539,0.92379,1.08953,0.0007105,0.0007105,0.0007105 +565,124787,0.99125,0.98529,1.08654,0.70966,0.58852,0.64374,0.4776,0.97544,0.92384,1.08948,0.000694,0.000694,0.000694 +566,125030,0.98708,0.97363,1.08409,0.69872,0.58774,0.64394,0.47793,0.97509,0.92354,1.08922,0.0006775,0.0006775,0.0006775 +567,125274,0.98538,0.97066,1.08327,0.69756,0.58802,0.64403,0.47806,0.9754,0.92363,1.08936,0.000661,0.000661,0.000661 +568,125518,0.98325,0.96737,1.08322,0.69905,0.5881,0.64403,0.47797,0.97547,0.92365,1.08945,0.0006445,0.0006445,0.0006445 +569,125761,0.98081,0.97295,1.08285,0.70112,0.58744,0.64398,0.47807,0.97538,0.92364,1.08926,0.000628,0.000628,0.000628 +570,126003,0.9772,0.96564,1.08226,0.70396,0.58662,0.64389,0.47811,0.9753,0.9237,1.08925,0.0006115,0.0006115,0.0006115 +571,126246,0.9834,0.96476,1.08298,0.70645,0.58555,0.64412,0.47816,0.97522,0.92338,1.08927,0.000595,0.000595,0.000595 +572,126490,0.98214,0.96506,1.08217,0.70468,0.58701,0.64427,0.47833,0.97515,0.92337,1.08906,0.0005785,0.0005785,0.0005785 +573,126734,0.97833,0.95529,1.0799,0.70554,0.58738,0.64405,0.47805,0.97501,0.92318,1.08892,0.000562,0.000562,0.000562 +574,126978,0.97651,0.95645,1.0791,0.70674,0.58611,0.64432,0.47843,0.97506,0.9229,1.08906,0.0005455,0.0005455,0.0005455 +575,127221,0.97936,0.95922,1.08073,0.70609,0.58705,0.64459,0.47863,0.97495,0.9228,1.08906,0.000529,0.000529,0.000529 +576,127465,0.9792,0.95435,1.07988,0.70442,0.58854,0.64454,0.47849,0.97483,0.92267,1.08874,0.0005125,0.0005125,0.0005125 +577,127708,0.97175,0.94756,1.07573,0.69811,0.59264,0.6444,0.4784,0.97511,0.92252,1.08888,0.000496,0.000496,0.000496 +578,127950,0.97559,0.94787,1.07699,0.69844,0.59262,0.64454,0.47847,0.97508,0.92223,1.08891,0.0004795,0.0004795,0.0004795 +579,128193,0.9688,0.94227,1.07742,0.6966,0.59301,0.64458,0.47844,0.9752,0.92197,1.08898,0.000463,0.000463,0.000463 +580,128436,0.97323,0.94231,1.07545,0.69677,0.59273,0.64476,0.47866,0.97519,0.92194,1.08897,0.0004465,0.0004465,0.0004465 +581,128678,0.97049,0.94019,1.07627,0.6998,0.59189,0.6446,0.47869,0.97488,0.92194,1.08869,0.00043,0.00043,0.00043 +582,128921,0.96636,0.93575,1.07325,0.70108,0.59114,0.64471,0.47871,0.97492,0.92177,1.08858,0.0004135,0.0004135,0.0004135 +583,129164,0.97257,0.93516,1.0739,0.70287,0.5908,0.64507,0.47877,0.97493,0.92146,1.08854,0.000397,0.000397,0.000397 +584,129406,0.96405,0.92943,1.07271,0.70502,0.58937,0.64481,0.47885,0.97493,0.92133,1.08861,0.0003805,0.0003805,0.0003805 +585,129649,0.96674,0.92486,1.07124,0.70722,0.58826,0.64492,0.47896,0.97478,0.92128,1.08831,0.000364,0.000364,0.000364 +586,129891,0.96621,0.92939,1.07187,0.70763,0.58729,0.6448,0.47891,0.97454,0.92131,1.08815,0.0003475,0.0003475,0.0003475 +587,130134,0.96266,0.92377,1.06983,0.70667,0.58807,0.64525,0.47918,0.97469,0.92115,1.08828,0.000331,0.000331,0.000331 +588,130376,0.95783,0.9165,1.06783,0.70664,0.5885,0.64522,0.47885,0.97478,0.92109,1.08826,0.0003145,0.0003145,0.0003145 +589,130620,0.95944,0.91679,1.06756,0.70556,0.58985,0.64511,0.47879,0.97473,0.92124,1.08825,0.000298,0.000298,0.000298 +590,130863,0.9609,0.91361,1.0675,0.70518,0.58973,0.64533,0.4788,0.97488,0.92115,1.08816,0.0002815,0.0002815,0.0002815 +591,131098,0.94799,0.84742,1.08649,0.70632,0.58926,0.64534,0.4789,0.97481,0.92143,1.08792,0.000265,0.000265,0.000265 +592,131331,0.94168,0.83085,1.08185,0.70456,0.5917,0.64548,0.47905,0.97475,0.9215,1.08758,0.0002485,0.0002485,0.0002485 +593,131565,0.94029,0.8265,1.08297,0.70501,0.59244,0.64568,0.47896,0.97465,0.92168,1.08735,0.000232,0.000232,0.000232 +594,131797,0.93413,0.81392,1.08039,0.70235,0.59373,0.64576,0.47905,0.97434,0.92178,1.08704,0.0002155,0.0002155,0.0002155 +595,132028,0.93433,0.809,1.07648,0.70354,0.59324,0.6461,0.47918,0.97447,0.92169,1.08692,0.000199,0.000199,0.000199 +596,132259,0.93404,0.8069,1.07917,0.70245,0.59382,0.646,0.47918,0.97441,0.92201,1.08697,0.0001825,0.0001825,0.0001825 +597,132490,0.92453,0.79685,1.07327,0.70321,0.59415,0.64592,0.47894,0.97449,0.92228,1.08705,0.000166,0.000166,0.000166 +598,132724,0.92831,0.7903,1.07276,0.70699,0.59081,0.64616,0.47899,0.97461,0.92306,1.08692,0.0001495,0.0001495,0.0001495 +599,132958,0.92208,0.78767,1.06982,0.70802,0.58939,0.64615,0.47899,0.97471,0.9234,1.08693,0.000133,0.000133,0.000133 +600,133191,0.91942,0.7776,1.0726,0.70844,0.58907,0.64601,0.47899,0.97493,0.92379,1.08718,0.0001165,0.0001165,0.0001165 diff --git a/logs/yolov12x.csv b/logs/yolov12x.csv new file mode 100644 index 0000000000000000000000000000000000000000..ff6a2a1c505ce809601491f6aab43e26cd163190 --- /dev/null +++ b/logs/yolov12x.csv @@ -0,0 +1,601 @@ +epoch,time,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),val/box_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2 +1,519.197,3.54187,5.39907,3.94662,0.00135,0.02527,0.00113,0.00037,2.88596,4.45246,3.35333,0.00332613,0.00332613,0.00332613 +2,1039.78,2.26566,3.79278,2.34889,0.20557,0.07287,0.03583,0.01831,1.92128,3.24893,2.16043,0.00664848,0.00664848,0.00664848 +3,1556.86,1.78658,3.06815,1.83309,0.24614,0.13364,0.09034,0.05312,1.62598,2.69211,1.83255,0.00995982,0.00995982,0.00995982 +4,2072.73,1.56497,2.58273,1.62293,0.30879,0.20758,0.17107,0.10605,1.43363,2.27481,1.65635,0.0099505,0.0099505,0.0099505 +5,2587.87,1.43536,2.2441,1.51289,0.37642,0.26308,0.24113,0.15775,1.33387,2.00528,1.56117,0.009934,0.009934,0.009934 +6,3103.71,1.35558,2.04958,1.45167,0.42429,0.29737,0.28436,0.18978,1.26977,1.82335,1.49348,0.0099175,0.0099175,0.0099175 +7,3617.05,1.30563,1.90673,1.41003,0.45979,0.32491,0.32429,0.21886,1.2298,1.70068,1.45036,0.009901,0.009901,0.009901 +8,4131.44,1.26857,1.81472,1.37504,0.46949,0.35387,0.35481,0.24208,1.18722,1.62273,1.416,0.0098845,0.0098845,0.0098845 +9,4646.49,1.23735,1.73146,1.35265,0.50558,0.36292,0.37985,0.26353,1.16471,1.55663,1.3882,0.009868,0.009868,0.009868 +10,5160.82,1.21154,1.66658,1.32902,0.51379,0.3941,0.40967,0.28558,1.13229,1.46027,1.36062,0.0098515,0.0098515,0.0098515 +11,5675.44,1.1941,1.63388,1.31402,0.53026,0.40181,0.42135,0.29435,1.11622,1.42328,1.34151,0.009835,0.009835,0.009835 +12,6189.95,1.17611,1.5856,1.29553,0.53939,0.42102,0.44145,0.31071,1.10179,1.36632,1.32868,0.0098185,0.0098185,0.0098185 +13,6702.71,1.16498,1.55847,1.28609,0.54834,0.43145,0.45573,0.32333,1.09025,1.33136,1.3153,0.009802,0.009802,0.009802 +14,7216.21,1.15052,1.51762,1.27086,0.56712,0.44174,0.46971,0.33287,1.07196,1.28876,1.29421,0.0097855,0.0097855,0.0097855 +15,7729.98,1.13863,1.50004,1.26264,0.57393,0.4455,0.47971,0.34316,1.05663,1.26032,1.28387,0.009769,0.009769,0.009769 +16,8243.71,1.13534,1.48283,1.26007,0.58461,0.45703,0.4924,0.35349,1.04528,1.22671,1.27165,0.0097525,0.0097525,0.0097525 +17,8758.73,1.12487,1.46143,1.24862,0.58888,0.46421,0.49804,0.35999,1.03979,1.21029,1.2656,0.009736,0.009736,0.009736 +18,9272.29,1.11643,1.43329,1.24251,0.59193,0.47442,0.5066,0.36505,1.03029,1.18424,1.25755,0.0097195,0.0097195,0.0097195 +19,9786.44,1.10631,1.4173,1.23572,0.61492,0.47463,0.51547,0.37358,1.02302,1.17187,1.24978,0.009703,0.009703,0.009703 +20,10301.1,1.09932,1.41067,1.23094,0.62253,0.48335,0.52436,0.38043,1.0186,1.14988,1.24516,0.0096865,0.0096865,0.0096865 +21,10816.3,1.09889,1.39934,1.2266,0.61884,0.48843,0.52903,0.38518,1.01018,1.13691,1.23654,0.00967,0.00967,0.00967 +22,11329.7,1.08903,1.3861,1.22041,0.61718,0.49661,0.53729,0.39191,1.00388,1.12673,1.23151,0.0096535,0.0096535,0.0096535 +23,11845.4,1.08583,1.37412,1.21556,0.6293,0.49529,0.53967,0.39457,1.0012,1.11215,1.22857,0.009637,0.009637,0.009637 +24,12359.2,1.08383,1.35927,1.21449,0.628,0.50289,0.54334,0.39682,0.9961,1.10021,1.22397,0.0096205,0.0096205,0.0096205 +25,12873.8,1.07419,1.3482,1.20669,0.63029,0.51138,0.55016,0.40314,0.99375,1.09061,1.22147,0.009604,0.009604,0.009604 +26,13388.2,1.07558,1.3432,1.20773,0.64691,0.50716,0.55524,0.40699,0.99048,1.08465,1.21653,0.0095875,0.0095875,0.0095875 +27,13902,1.07517,1.33847,1.20573,0.65858,0.50665,0.55913,0.41031,0.98839,1.07615,1.21479,0.009571,0.009571,0.009571 +28,14415.5,1.06935,1.32877,1.20234,0.65202,0.51252,0.56183,0.4131,0.98315,1.06965,1.2113,0.0095545,0.0095545,0.0095545 +29,14929.3,1.06222,1.32072,1.19883,0.65137,0.51323,0.56472,0.4162,0.97934,1.06411,1.20827,0.009538,0.009538,0.009538 +30,15443.1,1.06454,1.31837,1.20061,0.645,0.51868,0.56734,0.41825,0.97713,1.05758,1.20592,0.0095215,0.0095215,0.0095215 +31,15958.2,1.05657,1.30569,1.19365,0.64729,0.52038,0.56923,0.41917,0.97672,1.05265,1.20505,0.009505,0.009505,0.009505 +32,16472.7,1.05529,1.30733,1.19364,0.65713,0.51943,0.57028,0.42066,0.97574,1.04958,1.20383,0.0094885,0.0094885,0.0094885 +33,16988.2,1.05261,1.28473,1.18755,0.65963,0.52125,0.5723,0.42184,0.97419,1.04564,1.20241,0.009472,0.009472,0.009472 +34,17502.2,1.05511,1.29261,1.18909,0.6552,0.52306,0.57421,0.42348,0.97265,1.0425,1.20043,0.0094555,0.0094555,0.0094555 +35,18016.5,1.051,1.28983,1.18874,0.65736,0.52581,0.57607,0.42491,0.97171,1.03951,1.19918,0.009439,0.009439,0.009439 +36,18531.6,1.04787,1.28146,1.1831,0.6569,0.52805,0.5776,0.42569,0.97043,1.03606,1.19799,0.0094225,0.0094225,0.0094225 +37,19044.6,1.0444,1.26957,1.18092,0.66053,0.52653,0.57908,0.42699,0.96998,1.03371,1.1974,0.009406,0.009406,0.009406 +38,19559.3,1.04523,1.2675,1.18324,0.66108,0.52675,0.58028,0.42831,0.96888,1.03203,1.19607,0.0093895,0.0093895,0.0093895 +39,20074.6,1.03974,1.25846,1.17811,0.6603,0.52761,0.58106,0.42896,0.96798,1.03033,1.19504,0.009373,0.009373,0.009373 +40,20588.7,1.03609,1.25399,1.17536,0.6649,0.52624,0.58182,0.42988,0.96724,1.0286,1.19407,0.0093565,0.0093565,0.0093565 +41,21102.5,1.0362,1.25297,1.17383,0.66638,0.52686,0.58315,0.43106,0.96645,1.02712,1.19303,0.00934,0.00934,0.00934 +42,21616.2,1.03968,1.25236,1.17701,0.66965,0.52646,0.5837,0.43166,0.96578,1.02612,1.19231,0.0093235,0.0093235,0.0093235 +43,22130.7,1.03559,1.24917,1.17361,0.66947,0.52819,0.58449,0.4322,0.96521,1.02507,1.19173,0.009307,0.009307,0.009307 +44,22644.7,1.02846,1.23994,1.16984,0.66794,0.52882,0.5849,0.43259,0.9645,1.02457,1.19119,0.0092905,0.0092905,0.0092905 +45,23158.8,1.02691,1.23556,1.16866,0.66495,0.53059,0.58518,0.43292,0.96415,1.02387,1.1907,0.009274,0.009274,0.009274 +46,23673.5,1.03132,1.23385,1.16885,0.66288,0.53157,0.58578,0.43353,0.96364,1.02337,1.19024,0.0092575,0.0092575,0.0092575 +47,24188.3,1.02558,1.23528,1.16937,0.66335,0.53274,0.58609,0.43404,0.96326,1.02267,1.18983,0.009241,0.009241,0.009241 +48,24703,1.02296,1.22214,1.16471,0.6648,0.53243,0.58644,0.4343,0.96306,1.02205,1.18979,0.0092245,0.0092245,0.0092245 +49,25216.7,1.02073,1.22128,1.16507,0.66831,0.53143,0.5866,0.43459,0.96255,1.02206,1.18942,0.009208,0.009208,0.009208 +50,25731,1.02183,1.22089,1.16557,0.66852,0.53203,0.58691,0.43476,0.96206,1.02204,1.18891,0.0091915,0.0091915,0.0091915 +51,26245.2,1.02213,1.22026,1.16573,0.67191,0.53054,0.58702,0.43496,0.96181,1.02198,1.18869,0.009175,0.009175,0.009175 +52,26758.7,1.02122,1.21556,1.16354,0.66931,0.53117,0.58686,0.43512,0.96196,1.02241,1.1887,0.0091585,0.0091585,0.0091585 +53,27271.1,1.01563,1.21214,1.16468,0.66867,0.5313,0.58671,0.43537,0.96161,1.02325,1.1884,0.009142,0.009142,0.009142 +54,27783.6,1.01253,1.20259,1.15838,0.67038,0.53122,0.58676,0.43542,0.96131,1.02447,1.18808,0.0091255,0.0091255,0.0091255 +55,28296.5,1.01414,1.20459,1.15841,0.67286,0.52877,0.58668,0.43564,0.96155,1.02542,1.18819,0.009109,0.009109,0.009109 +56,28809.2,1.01043,1.19468,1.1581,0.67093,0.52858,0.58685,0.43556,0.96161,1.02671,1.18841,0.0090925,0.0090925,0.0090925 +57,29321.9,1.01141,1.1924,1.15543,0.669,0.52916,0.58715,0.43572,0.96163,1.02782,1.18851,0.009076,0.009076,0.009076 +58,29835.1,1.01087,1.19599,1.15699,0.66692,0.53268,0.587,0.43585,0.9616,1.0293,1.18858,0.0090595,0.0090595,0.0090595 +59,30348.1,1.0046,1.19077,1.15484,0.66985,0.53169,0.58701,0.43584,0.9617,1.03071,1.18893,0.009043,0.009043,0.009043 +60,30859.6,1.00777,1.18869,1.15366,0.66769,0.53251,0.58674,0.43581,0.96194,1.03283,1.18917,0.0090265,0.0090265,0.0090265 +61,31370.7,1.0071,1.18782,1.15381,0.66516,0.5339,0.58676,0.4358,0.9625,1.03523,1.18961,0.00901,0.00901,0.00901 +62,31883,1.00448,1.18366,1.15284,0.66346,0.5336,0.58674,0.43582,0.96271,1.03724,1.18981,0.0089935,0.0089935,0.0089935 +63,32396.2,1.00159,1.17664,1.14938,0.66235,0.53348,0.58649,0.43567,0.96305,1.03906,1.18997,0.008977,0.008977,0.008977 +64,32911.6,1.00633,1.18792,1.15075,0.66203,0.53259,0.58608,0.43559,0.96304,1.04113,1.19016,0.0089605,0.0089605,0.0089605 +65,33425.4,0.99816,1.17658,1.14869,0.66388,0.53159,0.5861,0.43545,0.96317,1.04294,1.19043,0.008944,0.008944,0.008944 +66,33939.7,1.00632,1.17596,1.15121,0.66699,0.52991,0.58612,0.43556,0.96337,1.04524,1.19076,0.0089275,0.0089275,0.0089275 +67,34453.2,1.00294,1.17829,1.14946,0.66885,0.52938,0.58612,0.43545,0.96351,1.04686,1.19103,0.008911,0.008911,0.008911 +68,34967.8,0.99944,1.17945,1.15014,0.66934,0.52872,0.58614,0.4353,0.96391,1.04876,1.1913,0.0088945,0.0088945,0.0088945 +69,35481,0.99988,1.17481,1.14835,0.6718,0.52683,0.58569,0.43518,0.96413,1.05046,1.19155,0.008878,0.008878,0.008878 +70,35995.4,0.99871,1.16537,1.14561,0.67027,0.52705,0.58531,0.43499,0.96407,1.05189,1.19165,0.0088615,0.0088615,0.0088615 +71,36508.6,0.99437,1.16295,1.14794,0.67278,0.52579,0.58528,0.43516,0.96384,1.05337,1.19175,0.008845,0.008845,0.008845 +72,37023.7,0.99383,1.16026,1.14389,0.67012,0.52753,0.5851,0.43478,0.96359,1.05504,1.19159,0.0088285,0.0088285,0.0088285 +73,37537.9,0.99872,1.16772,1.14907,0.67104,0.52642,0.58495,0.43464,0.96357,1.05715,1.19171,0.008812,0.008812,0.008812 +74,38051,0.99536,1.16189,1.14779,0.66991,0.52682,0.58507,0.43482,0.96319,1.05836,1.19149,0.0087955,0.0087955,0.0087955 +75,38563.9,0.99289,1.15642,1.14483,0.67038,0.52669,0.58485,0.43492,0.96236,1.05949,1.19097,0.008779,0.008779,0.008779 +76,39076.6,0.98982,1.1535,1.1428,0.67335,0.52534,0.58488,0.43488,0.9622,1.061,1.19084,0.0087625,0.0087625,0.0087625 +77,39589.4,0.99087,1.15265,1.14252,0.67127,0.52713,0.58499,0.43502,0.96268,1.06234,1.19122,0.008746,0.008746,0.008746 +78,40100.1,0.98943,1.15501,1.14393,0.67022,0.52782,0.58494,0.43515,0.96257,1.06395,1.19106,0.0087295,0.0087295,0.0087295 +79,40611.9,0.98891,1.15213,1.14258,0.67211,0.52747,0.58522,0.43554,0.96247,1.0652,1.19085,0.008713,0.008713,0.008713 +80,41123.4,0.99089,1.15194,1.14431,0.67333,0.52715,0.5853,0.43591,0.96217,1.06589,1.19068,0.0086965,0.0086965,0.0086965 +81,41635.4,0.98716,1.14848,1.1422,0.67439,0.52627,0.58557,0.43612,0.96096,1.06672,1.18971,0.00868,0.00868,0.00868 +82,42148.8,0.98561,1.14577,1.14194,0.67577,0.52588,0.58575,0.43617,0.96094,1.06701,1.18964,0.0086635,0.0086635,0.0086635 +83,42660.4,0.98607,1.14507,1.1407,0.67439,0.52696,0.58603,0.43654,0.96022,1.06735,1.18887,0.008647,0.008647,0.008647 +84,43171.5,0.98708,1.14574,1.14091,0.67512,0.52762,0.58635,0.43688,0.95988,1.06734,1.18831,0.0086305,0.0086305,0.0086305 +85,43683.5,0.9886,1.14708,1.14071,0.67299,0.5286,0.58673,0.43739,0.95953,1.06738,1.18791,0.008614,0.008614,0.008614 +86,44193.6,0.98393,1.14218,1.13912,0.67202,0.52997,0.58724,0.43795,0.95905,1.06722,1.18744,0.0085975,0.0085975,0.0085975 +87,44706.5,0.98416,1.13988,1.13871,0.67527,0.52989,0.5876,0.43831,0.95895,1.0667,1.18715,0.008581,0.008581,0.008581 +88,45218.8,0.98281,1.14194,1.13918,0.67362,0.53055,0.58768,0.43864,0.95868,1.06617,1.18675,0.0085645,0.0085645,0.0085645 +89,45731.5,0.98301,1.14151,1.13858,0.67379,0.53193,0.58825,0.43905,0.95827,1.06539,1.18626,0.008548,0.008548,0.008548 +90,46245,0.98081,1.1356,1.13651,0.67455,0.53166,0.58853,0.43962,0.95804,1.06483,1.186,0.0085315,0.0085315,0.0085315 +91,46757.1,0.97967,1.13294,1.13566,0.67585,0.53128,0.58868,0.43998,0.95731,1.06361,1.18524,0.008515,0.008515,0.008515 +92,47270.9,0.97796,1.13094,1.13705,0.67539,0.53259,0.58926,0.44057,0.95679,1.06227,1.18457,0.0084985,0.0084985,0.0084985 +93,47784.5,0.98279,1.13133,1.13662,0.67545,0.5321,0.58972,0.44139,0.95631,1.06094,1.18411,0.008482,0.008482,0.008482 +94,48296.9,0.98253,1.13063,1.1359,0.67577,0.5317,0.59012,0.44174,0.9557,1.05953,1.18347,0.0084655,0.0084655,0.0084655 +95,48810.2,0.97899,1.12829,1.13611,0.67566,0.5317,0.5908,0.44258,0.95518,1.0582,1.18292,0.008449,0.008449,0.008449 +96,49323,0.98116,1.13461,1.13854,0.67609,0.53005,0.59089,0.44284,0.95466,1.05668,1.18234,0.0084325,0.0084325,0.0084325 +97,49836.7,0.98205,1.13051,1.13514,0.67679,0.53081,0.59133,0.44318,0.95422,1.05479,1.18171,0.008416,0.008416,0.008416 +98,50349.6,0.97968,1.12794,1.13681,0.67598,0.53205,0.59191,0.4436,0.95368,1.05277,1.18111,0.0083995,0.0083995,0.0083995 +99,50863.2,0.97382,1.11973,1.13315,0.67765,0.53248,0.59266,0.44433,0.95304,1.051,1.18045,0.008383,0.008383,0.008383 +100,51378.5,0.98033,1.12578,1.13681,0.67526,0.53502,0.59344,0.44497,0.95211,1.04949,1.17976,0.0083665,0.0083665,0.0083665 +101,51891.1,0.97779,1.12237,1.13515,0.67482,0.53646,0.59396,0.4454,0.95162,1.0479,1.17918,0.00835,0.00835,0.00835 +102,52403.4,0.97653,1.1246,1.13375,0.67508,0.53756,0.59484,0.44599,0.9513,1.04617,1.17863,0.0083335,0.0083335,0.0083335 +103,52917.5,0.97514,1.12094,1.13205,0.67652,0.53701,0.59564,0.44667,0.95071,1.04414,1.17806,0.008317,0.008317,0.008317 +104,53432.2,0.97449,1.11983,1.1303,0.67634,0.53803,0.59642,0.44732,0.95,1.04198,1.17732,0.0083005,0.0083005,0.0083005 +105,53944.6,0.97163,1.11507,1.13142,0.67731,0.53808,0.59694,0.4478,0.94931,1.03999,1.17656,0.008284,0.008284,0.008284 +106,54457.8,0.97377,1.11771,1.1317,0.67511,0.5418,0.59768,0.4484,0.94883,1.03816,1.17581,0.0082675,0.0082675,0.0082675 +107,54971.6,0.97578,1.11962,1.13212,0.67556,0.54317,0.59822,0.44902,0.94806,1.03616,1.1749,0.008251,0.008251,0.008251 +108,55486.4,0.97566,1.11663,1.13219,0.67837,0.54286,0.59891,0.44964,0.94746,1.03394,1.17404,0.0082345,0.0082345,0.0082345 +109,55999,0.97451,1.1119,1.12966,0.67886,0.54279,0.59945,0.45019,0.94651,1.03209,1.17315,0.008218,0.008218,0.008218 +110,56510.5,0.96892,1.10897,1.12974,0.68022,0.54312,0.60007,0.45068,0.94571,1.03001,1.17231,0.0082015,0.0082015,0.0082015 +111,57023,0.96861,1.10718,1.12906,0.68035,0.54236,0.60071,0.4514,0.94493,1.02755,1.1714,0.008185,0.008185,0.008185 +112,57535.5,0.97086,1.11249,1.12921,0.68086,0.54293,0.60143,0.45201,0.94447,1.02565,1.17082,0.0081685,0.0081685,0.0081685 +113,58048.5,0.96757,1.10783,1.12725,0.67786,0.54517,0.60224,0.45267,0.94377,1.02363,1.17001,0.008152,0.008152,0.008152 +114,58562.3,0.97423,1.1146,1.13216,0.67842,0.54605,0.60274,0.45318,0.94301,1.0213,1.16914,0.0081355,0.0081355,0.0081355 +115,59075.8,0.97529,1.11375,1.13164,0.67692,0.54769,0.60326,0.45385,0.94226,1.01926,1.16838,0.008119,0.008119,0.008119 +116,59589.7,0.96704,1.10348,1.12932,0.67507,0.5481,0.6042,0.45467,0.94169,1.01676,1.16763,0.0081025,0.0081025,0.0081025 +117,60101.7,0.96628,1.1043,1.12754,0.6758,0.54912,0.60509,0.45525,0.94087,1.01442,1.16683,0.008086,0.008086,0.008086 +118,60615.2,0.96584,1.10573,1.12776,0.67649,0.5499,0.60589,0.45578,0.94044,1.01238,1.16617,0.0080695,0.0080695,0.0080695 +119,61129.6,0.96295,1.10221,1.12421,0.6782,0.55086,0.60653,0.45647,0.94,1.01054,1.16555,0.008053,0.008053,0.008053 +120,61642.1,0.96642,1.09952,1.12606,0.67883,0.55171,0.60741,0.45743,0.94001,1.00823,1.16528,0.0080365,0.0080365,0.0080365 +121,62156.3,0.96777,1.10488,1.12596,0.67971,0.55324,0.60805,0.4576,0.93945,1.0062,1.16447,0.00802,0.00802,0.00802 +122,62669.2,0.96639,1.10679,1.12759,0.68119,0.55334,0.60861,0.45815,0.93924,1.00431,1.16405,0.0080035,0.0080035,0.0080035 +123,63182.8,0.96028,1.09847,1.12553,0.67914,0.55583,0.60922,0.4587,0.93888,1.00228,1.16354,0.007987,0.007987,0.007987 +124,63696.4,0.96825,1.09742,1.12423,0.67934,0.55568,0.60999,0.45923,0.93843,1.00048,1.16295,0.0079705,0.0079705,0.0079705 +125,64209.1,0.96421,1.09555,1.12508,0.68028,0.55522,0.61066,0.45983,0.93767,0.99822,1.16211,0.007954,0.007954,0.007954 +126,64722.4,0.96334,1.09789,1.12401,0.67952,0.55687,0.61144,0.46057,0.93707,0.99599,1.16143,0.0079375,0.0079375,0.0079375 +127,65235.5,0.96051,1.09589,1.12236,0.68151,0.55819,0.61198,0.46098,0.93643,0.9939,1.16079,0.007921,0.007921,0.007921 +128,65749.3,0.96132,1.09442,1.12161,0.68559,0.5571,0.61358,0.46221,0.93534,0.99199,1.1597,0.0079045,0.0079045,0.0079045 +129,66261.8,0.96349,1.09581,1.12529,0.68552,0.55734,0.61419,0.46291,0.93527,0.99025,1.1595,0.007888,0.007888,0.007888 +130,66774.6,0.96357,1.09191,1.12399,0.68864,0.55683,0.61492,0.46351,0.93468,0.98851,1.15899,0.0078715,0.0078715,0.0078715 +131,67288,0.96465,1.09965,1.12634,0.69227,0.55662,0.6155,0.46385,0.93425,0.98692,1.1586,0.007855,0.007855,0.007855 +132,67801.7,0.95905,1.08893,1.12231,0.69316,0.55672,0.61603,0.46442,0.93386,0.98513,1.1582,0.0078385,0.0078385,0.0078385 +133,68314.8,0.95866,1.08922,1.12309,0.69481,0.55635,0.61694,0.46518,0.93358,0.98349,1.15785,0.007822,0.007822,0.007822 +134,68826.8,0.96015,1.08845,1.1219,0.69405,0.55714,0.61755,0.46571,0.93326,0.98141,1.15736,0.0078055,0.0078055,0.0078055 +135,69340.2,0.95847,1.08492,1.12114,0.69297,0.55825,0.61838,0.46613,0.93267,0.97904,1.15672,0.007789,0.007789,0.007789 +136,69854.1,0.95827,1.08368,1.12238,0.69559,0.55864,0.61904,0.46659,0.93201,0.97733,1.15605,0.0077725,0.0077725,0.0077725 +137,70366.3,0.9589,1.08838,1.12263,0.69526,0.55976,0.61923,0.46692,0.93139,0.97555,1.15535,0.007756,0.007756,0.007756 +138,70879.7,0.96073,1.08377,1.12146,0.69593,0.56001,0.61994,0.46746,0.93086,0.97351,1.15472,0.0077395,0.0077395,0.0077395 +139,71392.4,0.95745,1.0829,1.12386,0.6967,0.56073,0.62032,0.46797,0.93038,0.9718,1.15417,0.007723,0.007723,0.007723 +140,71905.9,0.95411,1.08448,1.12058,0.69059,0.56495,0.62092,0.46838,0.92981,0.97032,1.15355,0.0077065,0.0077065,0.0077065 +141,72419.5,0.95846,1.08642,1.1224,0.69187,0.5639,0.62152,0.46875,0.92948,0.96848,1.15321,0.00769,0.00769,0.00769 +142,72934.1,0.95979,1.0823,1.12114,0.69325,0.56358,0.62192,0.46919,0.92883,0.96711,1.1526,0.0076735,0.0076735,0.0076735 +143,73447.6,0.9572,1.07637,1.11987,0.69489,0.56348,0.62249,0.46978,0.9281,0.96511,1.15181,0.007657,0.007657,0.007657 +144,73962.2,0.95499,1.07725,1.11804,0.69731,0.56286,0.623,0.47017,0.92731,0.96344,1.15106,0.0076405,0.0076405,0.0076405 +145,74475.2,0.95378,1.07663,1.11805,0.69885,0.56212,0.62396,0.47066,0.92695,0.96178,1.15062,0.007624,0.007624,0.007624 +146,74988.3,0.95153,1.07003,1.11615,0.70034,0.56246,0.62413,0.47129,0.92628,0.96007,1.14983,0.0076075,0.0076075,0.0076075 +147,75503.1,0.95398,1.07974,1.11669,0.70211,0.56223,0.62461,0.47156,0.92562,0.95839,1.14915,0.007591,0.007591,0.007591 +148,76015.8,0.95797,1.08096,1.12008,0.70421,0.56087,0.62511,0.4721,0.92522,0.9567,1.14855,0.0075745,0.0075745,0.0075745 +149,76529.8,0.95661,1.0816,1.12131,0.70224,0.56359,0.62568,0.47267,0.92506,0.95514,1.14826,0.007558,0.007558,0.007558 +150,77040.7,0.95474,1.07945,1.11872,0.70104,0.56587,0.62627,0.47305,0.92438,0.95399,1.14769,0.0075415,0.0075415,0.0075415 +151,77554.1,0.95649,1.07355,1.11608,0.70111,0.56558,0.62674,0.47332,0.92427,0.95248,1.14742,0.007525,0.007525,0.007525 +152,78067.3,0.95514,1.0704,1.11652,0.70395,0.56447,0.62707,0.47372,0.92398,0.95098,1.14703,0.0075085,0.0075085,0.0075085 +153,78579.2,0.95352,1.07205,1.11617,0.70267,0.5657,0.62777,0.47419,0.92337,0.94947,1.14627,0.007492,0.007492,0.007492 +154,79092.8,0.95547,1.07721,1.11877,0.70041,0.56723,0.6284,0.47456,0.92281,0.94795,1.14576,0.0074755,0.0074755,0.0074755 +155,79605.2,0.95373,1.07392,1.11834,0.70106,0.56699,0.62896,0.47509,0.92239,0.94617,1.14528,0.007459,0.007459,0.007459 +156,80117.2,0.95288,1.06564,1.11634,0.70116,0.56702,0.62908,0.47545,0.92184,0.94496,1.14471,0.0074425,0.0074425,0.0074425 +157,80628.9,0.9514,1.06949,1.11602,0.70446,0.56582,0.62928,0.4758,0.92146,0.94366,1.14432,0.007426,0.007426,0.007426 +158,81141.5,0.95229,1.06731,1.1161,0.70529,0.56555,0.63004,0.47629,0.9213,0.94241,1.14392,0.0074095,0.0074095,0.0074095 +159,81655.4,0.94892,1.06497,1.11412,0.70547,0.56653,0.63048,0.47666,0.92105,0.94134,1.14347,0.007393,0.007393,0.007393 +160,82169.1,0.95294,1.06239,1.11446,0.70326,0.56839,0.63084,0.47706,0.92054,0.93995,1.14289,0.0073765,0.0073765,0.0073765 +161,82683.3,0.95074,1.06611,1.11364,0.7046,0.56733,0.63143,0.47763,0.92028,0.93878,1.14247,0.00736,0.00736,0.00736 +162,83198,0.95193,1.0736,1.11942,0.70414,0.56812,0.63132,0.47759,0.91975,0.93759,1.14207,0.0073435,0.0073435,0.0073435 +163,83713.2,0.95034,1.06813,1.11453,0.70611,0.56752,0.63191,0.4781,0.9195,0.93637,1.14177,0.007327,0.007327,0.007327 +164,84226.5,0.94862,1.0623,1.11492,0.70727,0.56751,0.63234,0.47863,0.91925,0.93498,1.14131,0.0073105,0.0073105,0.0073105 +165,84743.1,0.94717,1.06058,1.1134,0.70374,0.5689,0.6326,0.47871,0.91905,0.93393,1.14094,0.007294,0.007294,0.007294 +166,85259.4,0.94542,1.06106,1.11277,0.7049,0.5686,0.63311,0.47911,0.9187,0.93281,1.14059,0.0072775,0.0072775,0.0072775 +167,85771.8,0.94981,1.06026,1.11415,0.70733,0.56804,0.63341,0.4795,0.91839,0.93193,1.14011,0.007261,0.007261,0.007261 +168,86286.7,0.95049,1.06432,1.11604,0.70368,0.57037,0.63431,0.47982,0.91817,0.93075,1.13972,0.0072445,0.0072445,0.0072445 +169,86801.2,0.94749,1.05872,1.11308,0.70547,0.57071,0.63457,0.48011,0.91802,0.92965,1.13937,0.007228,0.007228,0.007228 +170,87313.7,0.94564,1.05896,1.1099,0.70887,0.56956,0.63489,0.48036,0.91801,0.92879,1.13901,0.0072115,0.0072115,0.0072115 +171,87827.8,0.94591,1.05475,1.11195,0.70991,0.57023,0.6351,0.48062,0.91768,0.92729,1.13868,0.007195,0.007195,0.007195 +172,88342.1,0.94737,1.05686,1.11327,0.7076,0.57228,0.63566,0.48094,0.91752,0.92639,1.13827,0.0071785,0.0071785,0.0071785 +173,88857.1,0.94555,1.0581,1.11232,0.70942,0.57207,0.6358,0.48086,0.91737,0.9254,1.13806,0.007162,0.007162,0.007162 +174,89376.4,0.94692,1.05319,1.11257,0.70906,0.57327,0.63623,0.48138,0.91709,0.9241,1.13789,0.0071455,0.0071455,0.0071455 +175,89892.3,0.94299,1.0514,1.11049,0.71019,0.57255,0.63653,0.48174,0.91667,0.92285,1.13749,0.007129,0.007129,0.007129 +176,90410.3,0.94657,1.05378,1.11039,0.70803,0.57379,0.63685,0.4819,0.91659,0.92175,1.13758,0.0071125,0.0071125,0.0071125 +177,90926.6,0.94634,1.054,1.11385,0.70941,0.57337,0.6375,0.48234,0.91645,0.92079,1.13723,0.007096,0.007096,0.007096 +178,91440.8,0.94708,1.05606,1.11345,0.70923,0.57383,0.63793,0.48282,0.91608,0.9198,1.13699,0.0070795,0.0070795,0.0070795 +179,91955.9,0.94241,1.05079,1.11165,0.71011,0.57379,0.63835,0.48304,0.91595,0.91864,1.13672,0.007063,0.007063,0.007063 +180,92470.7,0.9414,1.05016,1.11067,0.71184,0.57288,0.63855,0.48297,0.91592,0.91796,1.13658,0.0070465,0.0070465,0.0070465 +181,92984.9,0.94024,1.05089,1.1086,0.71306,0.57301,0.63943,0.48366,0.916,0.91694,1.13649,0.00703,0.00703,0.00703 +182,93502.7,0.94029,1.04796,1.10914,0.71299,0.57321,0.63984,0.48403,0.9158,0.91591,1.1361,0.0070135,0.0070135,0.0070135 +183,94016.9,0.94433,1.05013,1.10969,0.71296,0.57438,0.64023,0.48438,0.91579,0.91481,1.13577,0.006997,0.006997,0.006997 +184,94528.4,0.94368,1.04535,1.11035,0.71391,0.57434,0.64044,0.4847,0.91533,0.91426,1.13534,0.0069805,0.0069805,0.0069805 +185,95041.2,0.94569,1.04428,1.11072,0.7124,0.57593,0.64061,0.48482,0.9152,0.91366,1.13532,0.006964,0.006964,0.006964 +186,95555.2,0.93984,1.0505,1.11203,0.71337,0.57559,0.64122,0.48499,0.9149,0.91291,1.13503,0.0069475,0.0069475,0.0069475 +187,96071.2,0.94351,1.05048,1.10952,0.71404,0.57566,0.64172,0.48541,0.91494,0.91209,1.13493,0.006931,0.006931,0.006931 +188,96586.8,0.94359,1.04446,1.11162,0.71427,0.57595,0.64233,0.48585,0.91466,0.91128,1.13485,0.0069145,0.0069145,0.0069145 +189,97102.8,0.94318,1.04872,1.11043,0.71505,0.5761,0.64266,0.48584,0.91474,0.91069,1.13467,0.006898,0.006898,0.006898 +190,97617.5,0.93881,1.04714,1.10996,0.71308,0.57707,0.64284,0.48629,0.91436,0.90981,1.13423,0.0068815,0.0068815,0.0068815 +191,98131.7,0.94041,1.04269,1.10844,0.71271,0.57814,0.64313,0.48638,0.91395,0.90877,1.13405,0.006865,0.006865,0.006865 +192,98646,0.93755,1.03903,1.10645,0.7125,0.57931,0.64334,0.48679,0.9136,0.90759,1.13369,0.0068485,0.0068485,0.0068485 +193,99159.9,0.94176,1.04668,1.10796,0.7146,0.57798,0.64374,0.48715,0.91343,0.90674,1.13347,0.006832,0.006832,0.006832 +194,99672.2,0.94103,1.04598,1.10849,0.71536,0.5773,0.64407,0.48724,0.91317,0.90582,1.13312,0.0068155,0.0068155,0.0068155 +195,100186,0.94158,1.03781,1.10687,0.71489,0.57834,0.64457,0.4877,0.91296,0.90502,1.13278,0.006799,0.006799,0.006799 +196,100700,0.93502,1.03995,1.10575,0.71827,0.57613,0.64452,0.48767,0.91286,0.90444,1.13261,0.0067825,0.0067825,0.0067825 +197,101216,0.93744,1.03802,1.10667,0.71393,0.57862,0.64426,0.48759,0.91246,0.90401,1.13206,0.006766,0.006766,0.006766 +198,101728,0.93911,1.04028,1.10913,0.71607,0.5785,0.64458,0.4879,0.91236,0.90349,1.1321,0.0067495,0.0067495,0.0067495 +199,102243,0.9393,1.041,1.10974,0.71538,0.58005,0.64485,0.48804,0.9121,0.90282,1.1318,0.006733,0.006733,0.006733 +200,102756,0.93706,1.0384,1.10486,0.69761,0.58251,0.6454,0.48831,0.91195,0.90214,1.13136,0.0067165,0.0067165,0.0067165 +201,103270,0.93978,1.03742,1.10701,0.69878,0.58275,0.64619,0.48891,0.91192,0.90175,1.1313,0.0067,0.0067,0.0067 +202,103783,0.93859,1.03466,1.10721,0.70112,0.58186,0.64616,0.48894,0.9117,0.90137,1.13116,0.0066835,0.0066835,0.0066835 +203,104297,0.93649,1.03733,1.10789,0.69769,0.58377,0.64673,0.48919,0.91121,0.90065,1.13086,0.006667,0.006667,0.006667 +204,104810,0.93519,1.03866,1.10866,0.70209,0.5812,0.64728,0.48945,0.91099,0.90004,1.13073,0.0066505,0.0066505,0.0066505 +205,105324,0.93456,1.03695,1.106,0.69949,0.58249,0.64756,0.48983,0.91072,0.89998,1.13058,0.006634,0.006634,0.006634 +206,105837,0.93814,1.03273,1.10641,0.70051,0.58053,0.64701,0.48966,0.91052,0.89959,1.13047,0.0066175,0.0066175,0.0066175 +207,106350,0.93902,1.03535,1.10814,0.7021,0.58135,0.64744,0.49018,0.91023,0.89894,1.13041,0.006601,0.006601,0.006601 +208,106863,0.93054,1.02716,1.10398,0.70219,0.58144,0.64786,0.49044,0.90995,0.89826,1.13025,0.0065845,0.0065845,0.0065845 +209,107378,0.9322,1.02929,1.10438,0.70116,0.58246,0.6485,0.49084,0.90981,0.89755,1.12989,0.006568,0.006568,0.006568 +210,107894,0.93585,1.0343,1.10652,0.69594,0.58751,0.64848,0.49085,0.90944,0.89707,1.12968,0.0065515,0.0065515,0.0065515 +211,108408,0.93655,1.036,1.10766,0.69683,0.58779,0.64871,0.49087,0.90906,0.8964,1.12935,0.006535,0.006535,0.006535 +212,108922,0.93514,1.02592,1.10522,0.70039,0.58862,0.6493,0.49138,0.90901,0.89582,1.12937,0.0065185,0.0065185,0.0065185 +213,109436,0.93327,1.02989,1.10263,0.70119,0.5883,0.64969,0.49153,0.90911,0.89499,1.12927,0.006502,0.006502,0.006502 +214,109952,0.93174,1.02662,1.10301,0.70117,0.58956,0.64992,0.49178,0.90875,0.89485,1.12901,0.0064855,0.0064855,0.0064855 +215,110467,0.93034,1.0252,1.10393,0.70094,0.59055,0.65006,0.49224,0.90841,0.89461,1.12881,0.006469,0.006469,0.006469 +216,110983,0.93469,1.02347,1.10388,0.70048,0.59142,0.65011,0.49246,0.90816,0.89423,1.12864,0.0064525,0.0064525,0.0064525 +217,111498,0.93146,1.02176,1.09937,0.70166,0.59133,0.65055,0.49266,0.90779,0.89364,1.12812,0.006436,0.006436,0.006436 +218,112014,0.93692,1.03384,1.10406,0.7006,0.59237,0.65099,0.49331,0.90724,0.89301,1.12758,0.0064195,0.0064195,0.0064195 +219,112530,0.93083,1.02282,1.09995,0.69642,0.59462,0.65111,0.49328,0.90703,0.89254,1.127,0.006403,0.006403,0.006403 +220,113045,0.93732,1.03048,1.1043,0.69932,0.59313,0.65143,0.49325,0.90697,0.89171,1.12684,0.0063865,0.0063865,0.0063865 +221,113561,0.93078,1.0217,1.10145,0.696,0.59593,0.65174,0.49332,0.9068,0.89079,1.12655,0.00637,0.00637,0.00637 +222,114078,0.92976,1.02149,1.10239,0.69515,0.59718,0.65216,0.4937,0.90651,0.89003,1.12625,0.0063535,0.0063535,0.0063535 +223,114594,0.93261,1.02367,1.10267,0.69518,0.59779,0.65219,0.49375,0.90642,0.88948,1.12611,0.006337,0.006337,0.006337 +224,115109,0.93312,1.02329,1.10173,0.69743,0.59664,0.65273,0.49379,0.90585,0.88893,1.1257,0.0063205,0.0063205,0.0063205 +225,115625,0.92926,1.01739,1.10071,0.69669,0.59604,0.65305,0.49391,0.90559,0.88847,1.12553,0.006304,0.006304,0.006304 +226,116141,0.92804,1.01884,1.10062,0.69556,0.59569,0.65318,0.49416,0.90543,0.88791,1.12522,0.0062875,0.0062875,0.0062875 +227,116658,0.92715,1.01422,1.09907,0.6963,0.59616,0.65339,0.49426,0.90523,0.8874,1.12496,0.006271,0.006271,0.006271 +228,117173,0.92949,1.01551,1.1034,0.69718,0.59683,0.65372,0.49469,0.90523,0.88696,1.12508,0.0062545,0.0062545,0.0062545 +229,117689,0.93037,1.01408,1.10008,0.69772,0.59679,0.65391,0.49502,0.90522,0.88663,1.12502,0.006238,0.006238,0.006238 +230,118206,0.92761,1.01179,1.10107,0.69995,0.59573,0.65398,0.49486,0.90497,0.88629,1.12487,0.0062215,0.0062215,0.0062215 +231,118723,0.9257,1.0129,1.09892,0.70311,0.59394,0.65412,0.49513,0.90484,0.88599,1.1249,0.006205,0.006205,0.006205 +232,119239,0.92765,1.01088,1.10065,0.70559,0.59263,0.6542,0.49554,0.9048,0.88504,1.12483,0.0061885,0.0061885,0.0061885 +233,119754,0.92895,1.01723,1.10168,0.70332,0.5953,0.65454,0.49581,0.90445,0.88475,1.12465,0.006172,0.006172,0.006172 +234,120270,0.92661,1.01113,1.1008,0.70606,0.59448,0.65482,0.49563,0.90434,0.88433,1.1247,0.0061555,0.0061555,0.0061555 +235,120786,0.9285,1.01461,1.10028,0.70667,0.59402,0.65492,0.49612,0.9042,0.88366,1.12471,0.006139,0.006139,0.006139 +236,121302,0.92792,1.01218,1.09881,0.70531,0.59503,0.65493,0.49609,0.90425,0.88327,1.12451,0.0061225,0.0061225,0.0061225 +237,121818,0.9244,1.01593,1.09896,0.70506,0.59542,0.65538,0.4962,0.90382,0.88267,1.12422,0.006106,0.006106,0.006106 +238,122334,0.92273,1.00936,1.09852,0.70438,0.59684,0.65574,0.49661,0.9039,0.88228,1.12421,0.0060895,0.0060895,0.0060895 +239,122849,0.92444,1.01244,1.10052,0.70579,0.59621,0.65578,0.49671,0.90397,0.88205,1.12406,0.006073,0.006073,0.006073 +240,123365,0.92884,1.01408,1.10007,0.70708,0.59529,0.6558,0.49711,0.90359,0.88148,1.12383,0.0060565,0.0060565,0.0060565 +241,123881,0.92419,1.01025,1.09998,0.70432,0.596,0.65586,0.49733,0.9032,0.88131,1.1235,0.00604,0.00604,0.00604 +242,124397,0.92484,1.01041,1.0994,0.70465,0.59627,0.65592,0.4973,0.90313,0.88113,1.12339,0.0060235,0.0060235,0.0060235 +243,124915,0.92712,1.0085,1.10105,0.70222,0.59617,0.65592,0.49735,0.903,0.88061,1.12326,0.006007,0.006007,0.006007 +244,125431,0.92373,1.00857,1.09751,0.70257,0.59667,0.65611,0.49756,0.90258,0.88036,1.12273,0.0059905,0.0059905,0.0059905 +245,125948,0.9271,1.01346,1.09899,0.70168,0.59626,0.65621,0.4978,0.90219,0.87988,1.12232,0.005974,0.005974,0.005974 +246,126464,0.92052,1.00372,1.09827,0.70447,0.59438,0.65632,0.49785,0.90198,0.8796,1.12224,0.0059575,0.0059575,0.0059575 +247,126979,0.91702,1.00049,1.09668,0.70514,0.59427,0.65688,0.4984,0.90197,0.87903,1.12222,0.005941,0.005941,0.005941 +248,127496,0.92341,1.00705,1.09831,0.70372,0.59578,0.65725,0.4986,0.90168,0.87857,1.12191,0.0059245,0.0059245,0.0059245 +249,128011,0.91828,1.00151,1.09599,0.70369,0.59507,0.65747,0.49865,0.90116,0.87788,1.1216,0.005908,0.005908,0.005908 +250,128526,0.92126,1.00182,1.09885,0.70768,0.59171,0.65789,0.49905,0.90102,0.87737,1.12141,0.0058915,0.0058915,0.0058915 +251,129042,0.92214,1.00449,1.09827,0.70925,0.59144,0.65814,0.49932,0.90078,0.87734,1.12119,0.005875,0.005875,0.005875 +252,129556,0.9209,1.00448,1.09773,0.70276,0.59595,0.65792,0.49913,0.90067,0.87688,1.12092,0.0058585,0.0058585,0.0058585 +253,130072,0.92002,0.99751,1.09574,0.70306,0.59656,0.65816,0.4994,0.90029,0.87644,1.12053,0.005842,0.005842,0.005842 +254,130590,0.91798,0.99675,1.09432,0.7042,0.59565,0.65805,0.49962,0.90026,0.8764,1.12052,0.0058255,0.0058255,0.0058255 +255,131107,0.92047,0.99815,1.09587,0.70461,0.5965,0.65854,0.49976,0.90005,0.87627,1.12041,0.005809,0.005809,0.005809 +256,131624,0.91741,1.00109,1.09627,0.70683,0.59526,0.65879,0.50017,0.90007,0.87576,1.12034,0.0057925,0.0057925,0.0057925 +257,132140,0.91937,0.99634,1.09472,0.69271,0.60063,0.65857,0.50011,0.90003,0.87519,1.12036,0.005776,0.005776,0.005776 +258,132656,0.92016,1.00057,1.09437,0.69354,0.59963,0.65856,0.49998,0.89998,0.87495,1.12034,0.0057595,0.0057595,0.0057595 +259,133173,0.91537,0.99963,1.0955,0.69525,0.60019,0.6589,0.50027,0.89962,0.8745,1.12024,0.005743,0.005743,0.005743 +260,133689,0.91518,0.99253,1.09231,0.69435,0.60102,0.65922,0.50038,0.89927,0.87415,1.11993,0.0057265,0.0057265,0.0057265 +261,134205,0.91515,0.99325,1.09487,0.69425,0.60189,0.65929,0.50056,0.89887,0.87367,1.11984,0.00571,0.00571,0.00571 +262,134720,0.92079,0.9931,1.09661,0.69702,0.60199,0.65936,0.50077,0.8986,0.8735,1.11972,0.0056935,0.0056935,0.0056935 +263,135235,0.91707,0.99732,1.09516,0.69757,0.60242,0.65976,0.5009,0.89843,0.87319,1.11966,0.005677,0.005677,0.005677 +264,135751,0.91739,0.99674,1.09347,0.69792,0.60209,0.6601,0.50105,0.89806,0.87265,1.11927,0.0056605,0.0056605,0.0056605 +265,136268,0.91927,0.99705,1.09473,0.6969,0.60271,0.66022,0.50121,0.89818,0.87224,1.11927,0.005644,0.005644,0.005644 +266,136784,0.9142,0.9952,1.09423,0.70068,0.60117,0.66056,0.50146,0.89804,0.87196,1.11887,0.0056275,0.0056275,0.0056275 +267,137300,0.9158,0.99131,1.09367,0.70475,0.59868,0.66093,0.50149,0.89773,0.87165,1.1186,0.005611,0.005611,0.005611 +268,137815,0.9173,0.98815,1.09315,0.70466,0.59921,0.661,0.50139,0.89769,0.8714,1.11843,0.0055945,0.0055945,0.0055945 +269,138332,0.915,0.9976,1.09495,0.70129,0.60143,0.66094,0.50153,0.89756,0.87102,1.11847,0.005578,0.005578,0.005578 +270,138849,0.91556,0.99251,1.09255,0.70545,0.6001,0.66091,0.50183,0.89739,0.87047,1.11836,0.0055615,0.0055615,0.0055615 +271,139364,0.91241,0.98747,1.0893,0.70728,0.59825,0.6612,0.50176,0.89712,0.8701,1.11783,0.005545,0.005545,0.005545 +272,139879,0.91308,0.99007,1.09169,0.70487,0.59875,0.66141,0.5019,0.89669,0.87015,1.11749,0.0055285,0.0055285,0.0055285 +273,140395,0.91134,0.98641,1.09267,0.70546,0.59884,0.66153,0.50213,0.89661,0.8697,1.11748,0.005512,0.005512,0.005512 +274,140912,0.91405,0.98741,1.0938,0.70379,0.59897,0.66161,0.50235,0.89667,0.86954,1.11738,0.0054955,0.0054955,0.0054955 +275,141427,0.91421,0.9875,1.09099,0.69927,0.60305,0.66134,0.50219,0.89638,0.86931,1.11696,0.005479,0.005479,0.005479 +276,141943,0.91771,0.9906,1.09307,0.7012,0.60159,0.66143,0.50231,0.89619,0.86909,1.11675,0.0054625,0.0054625,0.0054625 +277,142459,0.91168,0.98271,1.09176,0.70243,0.60121,0.66176,0.50275,0.89622,0.86898,1.11681,0.005446,0.005446,0.005446 +278,142975,0.91628,0.99164,1.08997,0.70217,0.60212,0.66203,0.50254,0.89653,0.86859,1.11665,0.0054295,0.0054295,0.0054295 +279,143492,0.91302,0.97939,1.08947,0.69633,0.6071,0.66218,0.50288,0.8965,0.86782,1.11662,0.005413,0.005413,0.005413 +280,144008,0.91218,0.98474,1.08962,0.69816,0.6067,0.66237,0.50281,0.89648,0.86784,1.11648,0.0053965,0.0053965,0.0053965 +281,144524,0.91387,0.98492,1.09075,0.69905,0.60434,0.66234,0.50273,0.89648,0.86779,1.1164,0.00538,0.00538,0.00538 +282,145040,0.91186,0.97946,1.09055,0.6973,0.60513,0.66252,0.50289,0.89624,0.86745,1.11604,0.0053635,0.0053635,0.0053635 +283,145556,0.90988,0.98315,1.0892,0.69893,0.60585,0.6632,0.50344,0.89614,0.86734,1.11571,0.005347,0.005347,0.005347 +284,146071,0.91067,0.98013,1.09051,0.70171,0.60377,0.66337,0.50354,0.89598,0.86702,1.11562,0.0053305,0.0053305,0.0053305 +285,146588,0.91057,0.98001,1.09064,0.70081,0.6033,0.66337,0.50352,0.89564,0.86718,1.11531,0.005314,0.005314,0.005314 +286,147104,0.91239,0.98136,1.09065,0.70054,0.60365,0.66422,0.50467,0.8951,0.86667,1.11497,0.0052975,0.0052975,0.0052975 +287,147622,0.91157,0.97981,1.09103,0.6988,0.60487,0.66432,0.50449,0.89497,0.8662,1.11481,0.005281,0.005281,0.005281 +288,148138,0.90785,0.98604,1.09172,0.70273,0.60381,0.66462,0.50496,0.89489,0.86557,1.11492,0.0052645,0.0052645,0.0052645 +289,148655,0.90676,0.97552,1.0889,0.71436,0.59966,0.66504,0.50533,0.89516,0.8654,1.11509,0.005248,0.005248,0.005248 +290,149171,0.90433,0.97657,1.08907,0.7138,0.60018,0.66512,0.50565,0.89508,0.86491,1.11499,0.0052315,0.0052315,0.0052315 +291,149688,0.90793,0.97615,1.08905,0.71592,0.59881,0.66518,0.50569,0.89502,0.86463,1.11503,0.005215,0.005215,0.005215 +292,150205,0.90993,0.98133,1.09066,0.71239,0.6011,0.66592,0.50616,0.89497,0.86432,1.11498,0.0051985,0.0051985,0.0051985 +293,150721,0.90996,0.97387,1.08657,0.70996,0.60336,0.666,0.50619,0.89493,0.86391,1.11491,0.005182,0.005182,0.005182 +294,151238,0.90726,0.97314,1.08726,0.705,0.60505,0.66617,0.50632,0.89459,0.86362,1.11459,0.0051655,0.0051655,0.0051655 +295,151752,0.90726,0.97482,1.08806,0.70339,0.60508,0.66626,0.50657,0.89423,0.86353,1.11418,0.005149,0.005149,0.005149 +296,152269,0.90967,0.97916,1.09037,0.69785,0.60905,0.66652,0.5067,0.89423,0.86324,1.11421,0.0051325,0.0051325,0.0051325 +297,152785,0.90649,0.97798,1.08892,0.7004,0.60787,0.66638,0.50639,0.89381,0.86324,1.11409,0.005116,0.005116,0.005116 +298,153301,0.91,0.97496,1.08886,0.7022,0.60699,0.66641,0.50665,0.89384,0.8628,1.11418,0.0050995,0.0050995,0.0050995 +299,153817,0.90494,0.96663,1.08632,0.7021,0.60699,0.66658,0.50671,0.89364,0.86207,1.11396,0.005083,0.005083,0.005083 +300,154333,0.90605,0.97145,1.0878,0.70263,0.60748,0.66728,0.50748,0.89332,0.86171,1.11354,0.0050665,0.0050665,0.0050665 +301,154848,0.90565,0.96907,1.08849,0.70078,0.6098,0.66758,0.50744,0.89279,0.86151,1.11315,0.00505,0.00505,0.00505 +302,155364,0.9023,0.96516,1.08345,0.69892,0.61055,0.66751,0.50772,0.89251,0.86103,1.1129,0.0050335,0.0050335,0.0050335 +303,155880,0.90624,0.96396,1.08503,0.70291,0.61008,0.66748,0.50793,0.8925,0.8607,1.11258,0.005017,0.005017,0.005017 +304,156396,0.90354,0.96617,1.08595,0.70206,0.6109,0.66727,0.50767,0.89238,0.86018,1.11238,0.0050005,0.0050005,0.0050005 +305,156913,0.90124,0.96635,1.08453,0.70513,0.60835,0.66768,0.50771,0.89236,0.85998,1.11217,0.004984,0.004984,0.004984 +306,157429,0.90519,0.96577,1.08763,0.70345,0.60877,0.66708,0.507,0.89216,0.85964,1.11197,0.0049675,0.0049675,0.0049675 +307,157945,0.90293,0.96617,1.08526,0.70304,0.60927,0.66756,0.50739,0.89216,0.85938,1.11208,0.004951,0.004951,0.004951 +308,158461,0.90009,0.95929,1.08326,0.70277,0.60833,0.66786,0.50747,0.89208,0.85907,1.11197,0.0049345,0.0049345,0.0049345 +309,158977,0.90357,0.96223,1.08378,0.70742,0.60648,0.66814,0.50782,0.89208,0.85897,1.11187,0.004918,0.004918,0.004918 +310,159492,0.90341,0.96334,1.08538,0.70648,0.6065,0.66821,0.50809,0.89165,0.85903,1.11151,0.0049015,0.0049015,0.0049015 +311,160007,0.89956,0.96111,1.08515,0.70591,0.60642,0.66826,0.50799,0.89143,0.85873,1.11131,0.004885,0.004885,0.004885 +312,160523,0.89818,0.95527,1.08112,0.69911,0.61115,0.66857,0.50816,0.89156,0.85826,1.11133,0.0048685,0.0048685,0.0048685 +313,161039,0.90089,0.96264,1.08332,0.69997,0.60982,0.66829,0.50806,0.89131,0.85791,1.11095,0.004852,0.004852,0.004852 +314,161554,0.90123,0.95867,1.08258,0.69587,0.61331,0.66841,0.50823,0.89118,0.85791,1.11062,0.0048355,0.0048355,0.0048355 +315,162070,0.90254,0.95953,1.08408,0.69759,0.61272,0.66861,0.50847,0.89124,0.85766,1.11061,0.004819,0.004819,0.004819 +316,162587,0.90041,0.95937,1.08335,0.69003,0.61763,0.66886,0.50877,0.89115,0.85716,1.11041,0.0048025,0.0048025,0.0048025 +317,163104,0.89574,0.9567,1.0822,0.68874,0.61799,0.66932,0.50904,0.89114,0.85674,1.11025,0.004786,0.004786,0.004786 +318,163625,0.8947,0.95353,1.08253,0.68954,0.61815,0.66975,0.50931,0.89105,0.85657,1.11038,0.0047695,0.0047695,0.0047695 +319,164144,0.89969,0.9549,1.08396,0.69164,0.61851,0.66989,0.50937,0.89072,0.85621,1.11016,0.004753,0.004753,0.004753 +320,164662,0.90195,0.95913,1.08449,0.69609,0.61606,0.6697,0.5091,0.89046,0.85556,1.10995,0.0047365,0.0047365,0.0047365 +321,165178,0.89999,0.95253,1.08104,0.70268,0.61137,0.6697,0.5094,0.89048,0.85558,1.10977,0.00472,0.00472,0.00472 +322,165693,0.8975,0.94931,1.08147,0.6986,0.61523,0.67001,0.50961,0.89034,0.8554,1.10957,0.0047035,0.0047035,0.0047035 +323,166211,0.89852,0.95535,1.07946,0.70066,0.61256,0.66984,0.50975,0.89012,0.85509,1.10921,0.004687,0.004687,0.004687 +324,166729,0.89922,0.95453,1.08144,0.70209,0.61109,0.67011,0.51017,0.88993,0.85497,1.109,0.0046705,0.0046705,0.0046705 +325,167246,0.90416,0.95899,1.0844,0.70256,0.61092,0.6705,0.50997,0.88993,0.85476,1.1088,0.004654,0.004654,0.004654 +326,167764,0.89919,0.95586,1.08346,0.70361,0.61047,0.67031,0.51,0.88964,0.85467,1.10854,0.0046375,0.0046375,0.0046375 +327,168282,0.89696,0.95307,1.08226,0.70332,0.61057,0.67041,0.51002,0.88942,0.85462,1.10834,0.004621,0.004621,0.004621 +328,168799,0.89385,0.95159,1.07874,0.70378,0.61112,0.67048,0.51021,0.88907,0.85463,1.10802,0.0046045,0.0046045,0.0046045 +329,169316,0.89604,0.95335,1.08132,0.70394,0.61204,0.67084,0.51019,0.88925,0.85435,1.10801,0.004588,0.004588,0.004588 +330,169832,0.90058,0.94639,1.082,0.70281,0.61306,0.67093,0.51048,0.88923,0.85456,1.10802,0.0045715,0.0045715,0.0045715 +331,170348,0.89386,0.95193,1.08306,0.70129,0.6146,0.67116,0.51088,0.88929,0.85427,1.10802,0.004555,0.004555,0.004555 +332,170864,0.89279,0.94904,1.07901,0.70242,0.61453,0.67122,0.51086,0.88911,0.854,1.1079,0.0045385,0.0045385,0.0045385 +333,171380,0.89569,0.94969,1.08338,0.70163,0.61511,0.67158,0.51095,0.88913,0.85355,1.10796,0.004522,0.004522,0.004522 +334,171897,0.8931,0.94652,1.07908,0.70464,0.61275,0.67157,0.51119,0.88915,0.85334,1.10793,0.0045055,0.0045055,0.0045055 +335,172412,0.89353,0.94304,1.07957,0.69692,0.61863,0.6719,0.51149,0.88914,0.85354,1.1079,0.004489,0.004489,0.004489 +336,172929,0.89387,0.94238,1.07845,0.69579,0.61961,0.67179,0.51133,0.88901,0.85295,1.10771,0.0044725,0.0044725,0.0044725 +337,173443,0.89497,0.94467,1.08051,0.69285,0.62198,0.67226,0.5117,0.88877,0.85254,1.10759,0.004456,0.004456,0.004456 +338,173958,0.89082,0.93641,1.077,0.69258,0.62122,0.67191,0.51193,0.8886,0.85213,1.10741,0.0044395,0.0044395,0.0044395 +339,174473,0.8914,0.94548,1.07891,0.69482,0.61982,0.67229,0.51197,0.88859,0.85195,1.10746,0.004423,0.004423,0.004423 +340,174989,0.89111,0.93556,1.07706,0.69865,0.61809,0.67282,0.51219,0.88857,0.85125,1.10743,0.0044065,0.0044065,0.0044065 +341,175506,0.89384,0.94422,1.07971,0.70154,0.61671,0.67311,0.51242,0.88859,0.85089,1.10745,0.00439,0.00439,0.00439 +342,176022,0.89688,0.94597,1.08227,0.70185,0.61756,0.67335,0.51266,0.8884,0.85089,1.10743,0.0043735,0.0043735,0.0043735 +343,176537,0.89335,0.93635,1.07571,0.70641,0.61497,0.67357,0.51288,0.88802,0.85061,1.10702,0.004357,0.004357,0.004357 +344,177054,0.89366,0.94276,1.07849,0.7042,0.61589,0.67373,0.51308,0.88789,0.85029,1.10686,0.0043405,0.0043405,0.0043405 +345,177570,0.89006,0.94037,1.07859,0.70582,0.61506,0.67376,0.51309,0.88814,0.84996,1.10702,0.004324,0.004324,0.004324 +346,178086,0.89123,0.93628,1.0744,0.70057,0.61737,0.67387,0.51349,0.88815,0.84956,1.10696,0.0043075,0.0043075,0.0043075 +347,178602,0.89361,0.93743,1.07908,0.70257,0.61686,0.67406,0.51338,0.88803,0.84938,1.10695,0.004291,0.004291,0.004291 +348,179118,0.89238,0.9356,1.07692,0.70459,0.61576,0.67412,0.51356,0.88818,0.84876,1.10682,0.0042745,0.0042745,0.0042745 +349,179633,0.89196,0.93449,1.07566,0.71087,0.61129,0.67411,0.51362,0.88844,0.8485,1.10702,0.004258,0.004258,0.004258 +350,180148,0.88954,0.93676,1.07551,0.7131,0.6097,0.6744,0.51377,0.88835,0.84804,1.10692,0.0042415,0.0042415,0.0042415 +351,180665,0.88797,0.9316,1.07347,0.71247,0.60952,0.67449,0.51375,0.88855,0.84782,1.10689,0.004225,0.004225,0.004225 +352,181181,0.88588,0.93029,1.07364,0.71437,0.60913,0.67471,0.51389,0.8882,0.8475,1.10673,0.0042085,0.0042085,0.0042085 +353,181697,0.88996,0.93567,1.07516,0.71328,0.61067,0.67459,0.51387,0.888,0.8473,1.10651,0.004192,0.004192,0.004192 +354,182214,0.88337,0.92834,1.0756,0.7105,0.61274,0.67452,0.51396,0.88753,0.84691,1.10642,0.0041755,0.0041755,0.0041755 +355,182729,0.88282,0.92856,1.07227,0.71042,0.61268,0.67463,0.51385,0.8872,0.84635,1.10605,0.004159,0.004159,0.004159 +356,183243,0.89,0.93411,1.0759,0.71351,0.6117,0.67442,0.51387,0.88683,0.84595,1.10576,0.0041425,0.0041425,0.0041425 +357,183757,0.88826,0.93043,1.07528,0.71393,0.61056,0.67462,0.5141,0.88659,0.84561,1.1054,0.004126,0.004126,0.004126 +358,184270,0.88544,0.92758,1.07481,0.71534,0.61007,0.67466,0.51431,0.88622,0.84543,1.10524,0.0041095,0.0041095,0.0041095 +359,184785,0.88663,0.92894,1.0771,0.7152,0.60906,0.67456,0.51423,0.88548,0.84514,1.1047,0.004093,0.004093,0.004093 +360,185298,0.88527,0.93048,1.07572,0.70848,0.61439,0.67457,0.51436,0.88565,0.84473,1.1048,0.0040765,0.0040765,0.0040765 +361,185813,0.88346,0.92989,1.07583,0.70795,0.61502,0.67478,0.51448,0.88556,0.84425,1.10472,0.00406,0.00406,0.00406 +362,186326,0.88689,0.93006,1.07551,0.70724,0.61528,0.6748,0.51468,0.88532,0.84365,1.10452,0.0040435,0.0040435,0.0040435 +363,186841,0.88532,0.92577,1.07295,0.70856,0.61504,0.67521,0.51481,0.88543,0.84389,1.10457,0.004027,0.004027,0.004027 +364,187355,0.87985,0.92278,1.07376,0.70782,0.61442,0.67547,0.51494,0.88518,0.84369,1.10449,0.0040105,0.0040105,0.0040105 +365,187869,0.88448,0.92317,1.0711,0.7083,0.61446,0.6756,0.51518,0.88503,0.8434,1.10431,0.003994,0.003994,0.003994 +366,188385,0.88124,0.92089,1.07133,0.70331,0.61792,0.67596,0.51539,0.88508,0.84303,1.10428,0.0039775,0.0039775,0.0039775 +367,188898,0.88128,0.91714,1.07095,0.70366,0.618,0.67613,0.51558,0.88496,0.84258,1.10403,0.003961,0.003961,0.003961 +368,189412,0.88359,0.921,1.0728,0.70702,0.61638,0.67637,0.51581,0.88508,0.84224,1.10396,0.0039445,0.0039445,0.0039445 +369,189928,0.88432,0.9214,1.07395,0.70377,0.62012,0.67701,0.51598,0.88495,0.84189,1.10374,0.003928,0.003928,0.003928 +370,190443,0.88186,0.92098,1.07198,0.70471,0.62007,0.67757,0.51649,0.88501,0.84129,1.10376,0.0039115,0.0039115,0.0039115 +371,190957,0.88306,0.92176,1.07485,0.70449,0.62026,0.67743,0.51647,0.88476,0.84105,1.10357,0.003895,0.003895,0.003895 +372,191471,0.88168,0.91515,1.07169,0.70844,0.61919,0.67756,0.51655,0.8845,0.84083,1.10337,0.0038785,0.0038785,0.0038785 +373,191985,0.87935,0.9148,1.06923,0.70953,0.6194,0.67797,0.51698,0.88431,0.84046,1.10315,0.003862,0.003862,0.003862 +374,192499,0.88205,0.91848,1.06923,0.70898,0.61904,0.67826,0.51724,0.88427,0.84037,1.10278,0.0038455,0.0038455,0.0038455 +375,193013,0.88312,0.91814,1.0703,0.70543,0.62151,0.67856,0.51757,0.88408,0.84022,1.10261,0.003829,0.003829,0.003829 +376,193527,0.87982,0.91105,1.06951,0.70076,0.6243,0.67872,0.51737,0.88422,0.83997,1.10266,0.0038125,0.0038125,0.0038125 +377,194042,0.877,0.91156,1.06894,0.7147,0.61657,0.67875,0.51751,0.88413,0.83973,1.10251,0.003796,0.003796,0.003796 +378,194556,0.87869,0.9091,1.06886,0.70686,0.62108,0.67864,0.5175,0.88396,0.83961,1.10225,0.0037795,0.0037795,0.0037795 +379,195070,0.87844,0.91405,1.07144,0.71291,0.61766,0.67886,0.51782,0.88366,0.8394,1.10217,0.003763,0.003763,0.003763 +380,195583,0.88061,0.91526,1.0718,0.71566,0.61657,0.67863,0.51803,0.88339,0.83927,1.10185,0.0037465,0.0037465,0.0037465 +381,196099,0.87893,0.91468,1.06849,0.71259,0.61915,0.67932,0.51813,0.88335,0.83913,1.10159,0.00373,0.00373,0.00373 +382,196613,0.87865,0.91018,1.06978,0.71369,0.62051,0.67955,0.51828,0.88326,0.83872,1.10141,0.0037135,0.0037135,0.0037135 +383,197128,0.87697,0.90582,1.06846,0.71189,0.62194,0.67961,0.5184,0.88318,0.8386,1.10122,0.003697,0.003697,0.003697 +384,197643,0.88089,0.9111,1.06956,0.7136,0.62092,0.67976,0.51853,0.88322,0.83835,1.1012,0.0036805,0.0036805,0.0036805 +385,198158,0.88065,0.90962,1.06815,0.71424,0.62104,0.68005,0.51883,0.88303,0.83767,1.10104,0.003664,0.003664,0.003664 +386,198672,0.87715,0.90569,1.06928,0.71514,0.62209,0.68024,0.51913,0.8829,0.83747,1.10092,0.0036475,0.0036475,0.0036475 +387,199185,0.87669,0.90884,1.06934,0.71343,0.62305,0.68026,0.5193,0.8828,0.83721,1.10068,0.003631,0.003631,0.003631 +388,199699,0.87563,0.90676,1.07033,0.71037,0.6243,0.68042,0.51952,0.88263,0.83708,1.1007,0.0036145,0.0036145,0.0036145 +389,200212,0.87954,0.9052,1.06936,0.71046,0.62422,0.68085,0.51987,0.88245,0.83662,1.10053,0.003598,0.003598,0.003598 +390,200726,0.87551,0.91018,1.06721,0.70386,0.62881,0.68118,0.52021,0.88232,0.83672,1.10036,0.0035815,0.0035815,0.0035815 +391,201241,0.87187,0.90129,1.06537,0.70384,0.63076,0.68137,0.52032,0.88193,0.83658,1.10019,0.003565,0.003565,0.003565 +392,201755,0.87181,0.90365,1.06609,0.70222,0.63409,0.68172,0.52068,0.88185,0.83646,1.1001,0.0035485,0.0035485,0.0035485 +393,202269,0.87532,0.90622,1.0678,0.69957,0.63465,0.68114,0.52037,0.88182,0.83645,1.09996,0.003532,0.003532,0.003532 +394,202784,0.87798,0.9016,1.06849,0.7,0.63315,0.6813,0.52079,0.88192,0.83635,1.10011,0.0035155,0.0035155,0.0035155 +395,203298,0.87456,0.90208,1.06715,0.70545,0.63088,0.68146,0.52081,0.88151,0.83624,1.09983,0.003499,0.003499,0.003499 +396,203813,0.8717,0.90162,1.06437,0.70217,0.63262,0.68155,0.52094,0.88109,0.83606,1.09939,0.0034825,0.0034825,0.0034825 +397,204326,0.87319,0.89684,1.06452,0.70316,0.63096,0.6816,0.52103,0.88111,0.83583,1.0993,0.003466,0.003466,0.003466 +398,204841,0.8716,0.8911,1.06319,0.71082,0.62742,0.68181,0.52109,0.88099,0.83528,1.09912,0.0034495,0.0034495,0.0034495 +399,205356,0.87428,0.89546,1.067,0.71379,0.62558,0.68193,0.52122,0.88071,0.83455,1.09901,0.003433,0.003433,0.003433 +400,205870,0.87338,0.89574,1.06604,0.71167,0.62701,0.68198,0.52138,0.88063,0.83375,1.09887,0.0034165,0.0034165,0.0034165 +401,206383,0.8671,0.89043,1.06205,0.7083,0.62988,0.68287,0.52198,0.88061,0.83307,1.09873,0.0034,0.0034,0.0034 +402,206899,0.87135,0.89667,1.06604,0.70872,0.63038,0.68304,0.52192,0.88044,0.83279,1.09861,0.0033835,0.0033835,0.0033835 +403,207414,0.87002,0.89248,1.06493,0.71056,0.62878,0.68305,0.52199,0.88044,0.83239,1.09848,0.003367,0.003367,0.003367 +404,207930,0.87211,0.89564,1.06589,0.71091,0.6282,0.6834,0.52223,0.88042,0.83234,1.09848,0.0033505,0.0033505,0.0033505 +405,208445,0.87057,0.88879,1.06444,0.71199,0.62809,0.68374,0.52265,0.88025,0.83206,1.09849,0.003334,0.003334,0.003334 +406,208958,0.8702,0.89238,1.0665,0.71233,0.62787,0.68374,0.52286,0.88005,0.83198,1.09845,0.0033175,0.0033175,0.0033175 +407,209473,0.86758,0.89175,1.06365,0.71565,0.62648,0.68393,0.52309,0.88007,0.83148,1.09849,0.003301,0.003301,0.003301 +408,209987,0.86717,0.88816,1.06189,0.71994,0.62562,0.68434,0.52336,0.88,0.83142,1.09828,0.0032845,0.0032845,0.0032845 +409,210502,0.86806,0.89287,1.06369,0.71977,0.62561,0.68438,0.52351,0.88,0.83087,1.09813,0.003268,0.003268,0.003268 +410,211016,0.86585,0.88439,1.06145,0.72123,0.62407,0.68447,0.52339,0.8796,0.83088,1.09782,0.0032515,0.0032515,0.0032515 +411,211531,0.86834,0.88857,1.06254,0.72386,0.62278,0.68481,0.52365,0.87938,0.83057,1.09755,0.003235,0.003235,0.003235 +412,212046,0.86759,0.88896,1.06367,0.72408,0.62342,0.68489,0.52406,0.87901,0.83011,1.09757,0.0032185,0.0032185,0.0032185 +413,212559,0.86662,0.88567,1.06202,0.72246,0.62483,0.68505,0.52434,0.87887,0.82996,1.09763,0.003202,0.003202,0.003202 +414,213073,0.86626,0.88935,1.06371,0.71706,0.62678,0.68509,0.52457,0.87877,0.82936,1.09753,0.0031855,0.0031855,0.0031855 +415,213588,0.87202,0.88881,1.06453,0.7163,0.62752,0.68543,0.52472,0.87861,0.82951,1.0976,0.003169,0.003169,0.003169 +416,214102,0.86667,0.88534,1.06131,0.71511,0.62816,0.68552,0.52495,0.8784,0.82899,1.09743,0.0031525,0.0031525,0.0031525 +417,214616,0.86801,0.88254,1.06224,0.71168,0.63032,0.68551,0.5247,0.87825,0.82897,1.0973,0.003136,0.003136,0.003136 +418,215131,0.86238,0.88081,1.05874,0.7036,0.63473,0.68571,0.52474,0.87828,0.82842,1.09704,0.0031195,0.0031195,0.0031195 +419,215644,0.86536,0.88087,1.06156,0.70503,0.63456,0.68577,0.52486,0.87821,0.82791,1.09698,0.003103,0.003103,0.003103 +420,216157,0.86434,0.87537,1.06017,0.70469,0.63484,0.68607,0.52514,0.87782,0.82773,1.09669,0.0030865,0.0030865,0.0030865 +421,216670,0.85986,0.87667,1.05982,0.70642,0.63545,0.68629,0.52532,0.87783,0.82738,1.09655,0.00307,0.00307,0.00307 +422,217184,0.8638,0.87844,1.06193,0.70331,0.63635,0.68646,0.52538,0.87781,0.82692,1.09657,0.0030535,0.0030535,0.0030535 +423,217697,0.86511,0.87811,1.06169,0.70734,0.63661,0.68733,0.52588,0.87759,0.82654,1.09659,0.003037,0.003037,0.003037 +424,218212,0.8625,0.87985,1.05891,0.71129,0.63241,0.68745,0.52565,0.87722,0.82603,1.0962,0.0030205,0.0030205,0.0030205 +425,218726,0.8619,0.8754,1.05959,0.70618,0.63591,0.68782,0.52605,0.87704,0.82569,1.09609,0.003004,0.003004,0.003004 +426,219241,0.86392,0.8714,1.05909,0.71691,0.63021,0.68798,0.52625,0.87678,0.82572,1.09581,0.0029875,0.0029875,0.0029875 +427,219755,0.85996,0.87286,1.06011,0.71521,0.63195,0.68782,0.52627,0.87646,0.82536,1.09561,0.002971,0.002971,0.002971 +428,220270,0.86312,0.87502,1.06127,0.71572,0.63151,0.68764,0.52627,0.87621,0.82527,1.09548,0.0029545,0.0029545,0.0029545 +429,220786,0.85947,0.86885,1.05881,0.7126,0.63406,0.6878,0.52649,0.87617,0.82482,1.09546,0.002938,0.002938,0.002938 +430,221300,0.85793,0.87004,1.05779,0.71403,0.63334,0.68843,0.52702,0.87589,0.82448,1.09519,0.0029215,0.0029215,0.0029215 +431,221816,0.85619,0.87585,1.06136,0.71971,0.63071,0.68858,0.52718,0.87538,0.82431,1.09501,0.002905,0.002905,0.002905 +432,222331,0.85831,0.86669,1.05707,0.72504,0.62762,0.689,0.52739,0.87512,0.82389,1.09471,0.0028885,0.0028885,0.0028885 +433,222847,0.85997,0.86549,1.05483,0.72081,0.63038,0.68922,0.52768,0.87505,0.82393,1.09441,0.002872,0.002872,0.002872 +434,223362,0.8562,0.86463,1.05723,0.723,0.62906,0.68961,0.52778,0.87474,0.82362,1.09436,0.0028555,0.0028555,0.0028555 +435,223876,0.86006,0.8624,1.0563,0.72229,0.62936,0.68974,0.5279,0.87449,0.82326,1.09408,0.002839,0.002839,0.002839 +436,224391,0.85369,0.86602,1.05517,0.72437,0.62864,0.68968,0.52799,0.87432,0.82305,1.09377,0.0028225,0.0028225,0.0028225 +437,224907,0.85543,0.8626,1.05463,0.71923,0.63132,0.69011,0.52846,0.87432,0.82316,1.09359,0.002806,0.002806,0.002806 +438,225423,0.85549,0.86693,1.05482,0.71845,0.63291,0.69022,0.52848,0.87426,0.82338,1.0933,0.0027895,0.0027895,0.0027895 +439,225938,0.85351,0.86447,1.0547,0.71936,0.63247,0.69018,0.52859,0.8743,0.82341,1.09327,0.002773,0.002773,0.002773 +440,226451,0.85827,0.863,1.05536,0.7218,0.63249,0.69042,0.52881,0.87404,0.82327,1.09293,0.0027565,0.0027565,0.0027565 +441,226968,0.85228,0.85462,1.0535,0.7182,0.63284,0.69037,0.52891,0.87438,0.82294,1.09313,0.00274,0.00274,0.00274 +442,227485,0.86074,0.86053,1.05602,0.7208,0.63091,0.69044,0.52878,0.87453,0.82252,1.09305,0.0027235,0.0027235,0.0027235 +443,227998,0.85437,0.85869,1.05413,0.72148,0.63108,0.69073,0.52905,0.87426,0.82243,1.09272,0.002707,0.002707,0.002707 +444,228513,0.8509,0.85417,1.05337,0.72359,0.63116,0.69099,0.52946,0.87403,0.82209,1.0925,0.0026905,0.0026905,0.0026905 +445,229028,0.85639,0.8565,1.05578,0.72099,0.63307,0.69096,0.52957,0.87368,0.82186,1.09214,0.002674,0.002674,0.002674 +446,229542,0.84914,0.85556,1.05409,0.72089,0.6328,0.69111,0.52967,0.87348,0.82154,1.09182,0.0026575,0.0026575,0.0026575 +447,230057,0.8516,0.85308,1.05378,0.72499,0.62966,0.69108,0.52949,0.87297,0.82131,1.09147,0.002641,0.002641,0.002641 +448,230571,0.85009,0.85753,1.05466,0.72206,0.63143,0.69111,0.52963,0.87308,0.82102,1.09138,0.0026245,0.0026245,0.0026245 +449,231085,0.85094,0.85409,1.05099,0.73566,0.62785,0.69174,0.52988,0.87265,0.8204,1.0911,0.002608,0.002608,0.002608 +450,231602,0.84863,0.85087,1.05043,0.73372,0.62916,0.69204,0.53021,0.87268,0.81981,1.09097,0.0025915,0.0025915,0.0025915 +451,232117,0.84942,0.85687,1.05336,0.73203,0.63048,0.69204,0.53024,0.87264,0.81928,1.09078,0.002575,0.002575,0.002575 +452,232633,0.85205,0.85653,1.05415,0.7315,0.63113,0.69245,0.53057,0.87249,0.81945,1.09055,0.0025585,0.0025585,0.0025585 +453,233148,0.84993,0.85167,1.05321,0.72672,0.63441,0.69249,0.53078,0.87227,0.8192,1.09035,0.002542,0.002542,0.002542 +454,233663,0.84927,0.84683,1.05139,0.72788,0.63386,0.69274,0.53095,0.87226,0.81897,1.0902,0.0025255,0.0025255,0.0025255 +455,234178,0.84837,0.84859,1.05134,0.72642,0.63584,0.69296,0.53119,0.87206,0.81863,1.08994,0.002509,0.002509,0.002509 +456,234695,0.84799,0.84619,1.05137,0.72764,0.6351,0.69294,0.5312,0.87164,0.81805,1.08973,0.0024925,0.0024925,0.0024925 +457,235211,0.84937,0.84571,1.0501,0.73028,0.63422,0.69326,0.5312,0.87136,0.81769,1.08949,0.002476,0.002476,0.002476 +458,235725,0.84447,0.84019,1.0493,0.7329,0.63233,0.6933,0.53146,0.87107,0.81718,1.0893,0.0024595,0.0024595,0.0024595 +459,236241,0.84581,0.84343,1.05079,0.7338,0.63181,0.69344,0.53149,0.87071,0.81681,1.08917,0.002443,0.002443,0.002443 +460,236755,0.84394,0.83584,1.04827,0.73553,0.63145,0.69402,0.53194,0.87043,0.81649,1.08887,0.0024265,0.0024265,0.0024265 +461,237269,0.84347,0.84219,1.04967,0.7359,0.63102,0.69426,0.53235,0.87041,0.81593,1.08877,0.00241,0.00241,0.00241 +462,237785,0.84689,0.84172,1.0482,0.73428,0.63307,0.69456,0.53269,0.87038,0.81561,1.08861,0.0023935,0.0023935,0.0023935 +463,238301,0.84373,0.84388,1.05003,0.73697,0.63097,0.69453,0.53256,0.87045,0.81475,1.08862,0.002377,0.002377,0.002377 +464,238816,0.84367,0.83269,1.04728,0.7379,0.63056,0.69485,0.53295,0.87036,0.81442,1.08843,0.0023605,0.0023605,0.0023605 +465,239331,0.84487,0.839,1.04895,0.73921,0.62972,0.69488,0.53291,0.86995,0.81412,1.0881,0.002344,0.002344,0.002344 +466,239846,0.84272,0.83791,1.04755,0.73702,0.63162,0.69512,0.53288,0.86965,0.81416,1.08784,0.0023275,0.0023275,0.0023275 +467,240363,0.8424,0.83654,1.04887,0.7368,0.63169,0.69509,0.53305,0.86973,0.81392,1.08781,0.002311,0.002311,0.002311 +468,240878,0.8433,0.83557,1.04696,0.7379,0.63122,0.69551,0.53338,0.86925,0.81373,1.08753,0.0022945,0.0022945,0.0022945 +469,241394,0.84398,0.836,1.04893,0.73611,0.63321,0.69545,0.53349,0.86933,0.81352,1.08759,0.002278,0.002278,0.002278 +470,241909,0.84412,0.83397,1.04799,0.73613,0.6329,0.69556,0.53358,0.86888,0.81335,1.08724,0.0022615,0.0022615,0.0022615 +471,242424,0.84451,0.83464,1.04873,0.73979,0.6303,0.69576,0.53376,0.86894,0.81332,1.0873,0.002245,0.002245,0.002245 +472,242939,0.83859,0.82974,1.04644,0.74072,0.63146,0.6961,0.53411,0.8688,0.81315,1.08718,0.0022285,0.0022285,0.0022285 +473,243454,0.83458,0.82955,1.04489,0.74093,0.63096,0.69641,0.5342,0.86881,0.81292,1.08723,0.002212,0.002212,0.002212 +474,243970,0.83827,0.83124,1.04554,0.73997,0.63188,0.69653,0.53448,0.86886,0.81275,1.08706,0.0021955,0.0021955,0.0021955 +475,244486,0.83886,0.82302,1.04399,0.74256,0.6299,0.69659,0.5348,0.86875,0.81264,1.08684,0.002179,0.002179,0.002179 +476,244999,0.83806,0.82593,1.04512,0.74281,0.63043,0.69706,0.53485,0.86839,0.8122,1.08651,0.0021625,0.0021625,0.0021625 +477,245515,0.83356,0.82421,1.04191,0.73898,0.63429,0.69737,0.53528,0.86798,0.81164,1.086,0.002146,0.002146,0.002146 +478,246030,0.8398,0.82551,1.04572,0.74037,0.63305,0.69741,0.53536,0.86766,0.81125,1.08572,0.0021295,0.0021295,0.0021295 +479,246546,0.84096,0.82441,1.045,0.73641,0.63699,0.6974,0.5356,0.86747,0.81123,1.08554,0.002113,0.002113,0.002113 +480,247062,0.83666,0.82503,1.04352,0.73651,0.63616,0.69742,0.53579,0.86728,0.81079,1.08532,0.0020965,0.0020965,0.0020965 +481,247578,0.8368,0.82124,1.04407,0.73382,0.63815,0.69769,0.53579,0.86699,0.81037,1.08504,0.00208,0.00208,0.00208 +482,248092,0.8341,0.81644,1.04303,0.73345,0.63792,0.69773,0.53605,0.86679,0.81028,1.08473,0.0020635,0.0020635,0.0020635 +483,248608,0.83493,0.81573,1.04305,0.7353,0.63743,0.69796,0.53618,0.86665,0.80999,1.0846,0.002047,0.002047,0.002047 +484,249123,0.83519,0.81783,1.04115,0.73317,0.63835,0.69841,0.53636,0.86642,0.80979,1.08436,0.0020305,0.0020305,0.0020305 +485,249639,0.83064,0.81021,1.03823,0.73311,0.63891,0.69838,0.53654,0.86619,0.80921,1.08426,0.002014,0.002014,0.002014 +486,250155,0.83302,0.82045,1.04502,0.72919,0.64079,0.69845,0.53659,0.86608,0.80889,1.0843,0.0019975,0.0019975,0.0019975 +487,250669,0.83215,0.81553,1.03923,0.72589,0.64249,0.69832,0.5366,0.86593,0.80865,1.08406,0.001981,0.001981,0.001981 +488,251186,0.82921,0.81052,1.03963,0.72546,0.64291,0.69837,0.5368,0.86572,0.80833,1.08379,0.0019645,0.0019645,0.0019645 +489,251701,0.83209,0.81092,1.04038,0.72454,0.6438,0.69872,0.5371,0.86548,0.8081,1.08356,0.001948,0.001948,0.001948 +490,252215,0.83271,0.81379,1.04297,0.72467,0.64404,0.6989,0.53723,0.86521,0.80792,1.08333,0.0019315,0.0019315,0.0019315 +491,252730,0.82816,0.80799,1.03982,0.72387,0.6459,0.69885,0.53713,0.86494,0.80757,1.08316,0.001915,0.001915,0.001915 +492,253245,0.82699,0.80748,1.03789,0.7215,0.64721,0.69882,0.53715,0.86473,0.807,1.08286,0.0018985,0.0018985,0.0018985 +493,253760,0.82616,0.8015,1.03692,0.72081,0.64634,0.69908,0.53723,0.86448,0.80678,1.0826,0.001882,0.001882,0.001882 +494,254275,0.83004,0.80364,1.04117,0.72387,0.64594,0.69938,0.5373,0.86455,0.80663,1.08262,0.0018655,0.0018655,0.0018655 +495,254791,0.82708,0.80151,1.03863,0.72193,0.64768,0.69945,0.5375,0.86439,0.80638,1.08254,0.001849,0.001849,0.001849 +496,255308,0.82567,0.80143,1.03657,0.72191,0.64739,0.69961,0.53777,0.86439,0.80595,1.08238,0.0018325,0.0018325,0.0018325 +497,255823,0.82369,0.80151,1.03626,0.72133,0.64832,0.69973,0.53796,0.86413,0.80564,1.08188,0.001816,0.001816,0.001816 +498,256340,0.82423,0.79778,1.03576,0.72179,0.64795,0.6996,0.53787,0.86401,0.80531,1.08147,0.0017995,0.0017995,0.0017995 +499,256856,0.82747,0.80333,1.03862,0.72314,0.64848,0.70004,0.53798,0.86399,0.8052,1.08125,0.001783,0.001783,0.001783 +500,257372,0.82634,0.79708,1.03826,0.7231,0.64946,0.70016,0.53792,0.86412,0.80505,1.08122,0.0017665,0.0017665,0.0017665 +501,257890,0.82715,0.79928,1.03748,0.73182,0.64427,0.70028,0.53828,0.86419,0.80509,1.08103,0.00175,0.00175,0.00175 +502,258406,0.82391,0.79648,1.03613,0.73135,0.64554,0.70053,0.53828,0.86428,0.80498,1.08111,0.0017335,0.0017335,0.0017335 +503,258921,0.8179,0.79055,1.03415,0.72897,0.64766,0.70065,0.53846,0.86393,0.80434,1.08059,0.001717,0.001717,0.001717 +504,259435,0.82025,0.79244,1.03336,0.7335,0.64526,0.701,0.53889,0.86375,0.8037,1.08047,0.0017005,0.0017005,0.0017005 +505,259949,0.82248,0.78982,1.03457,0.72859,0.64857,0.70118,0.53896,0.86352,0.80331,1.08023,0.001684,0.001684,0.001684 +506,260463,0.81936,0.78664,1.03581,0.72496,0.65131,0.70151,0.53907,0.86327,0.80285,1.08011,0.0016675,0.0016675,0.0016675 +507,260977,0.81945,0.78734,1.03172,0.72515,0.65115,0.70151,0.53892,0.8629,0.80211,1.07974,0.001651,0.001651,0.001651 +508,261490,0.81654,0.78698,1.03197,0.72665,0.65013,0.70166,0.53916,0.86263,0.80202,1.07962,0.0016345,0.0016345,0.0016345 +509,262004,0.82102,0.78715,1.03386,0.72741,0.6496,0.70167,0.53925,0.86238,0.80218,1.07951,0.001618,0.001618,0.001618 +510,262518,0.81977,0.78558,1.03208,0.72592,0.65081,0.70188,0.53936,0.86246,0.80156,1.0795,0.0016015,0.0016015,0.0016015 +511,263032,0.81519,0.78233,1.03112,0.72619,0.65059,0.70207,0.53961,0.86221,0.80112,1.07918,0.001585,0.001585,0.001585 +512,263546,0.81716,0.78081,1.03135,0.72597,0.65143,0.70197,0.53963,0.8623,0.80097,1.07915,0.0015685,0.0015685,0.0015685 +513,264059,0.81863,0.782,1.03074,0.72661,0.65171,0.70212,0.53991,0.86212,0.80075,1.07895,0.001552,0.001552,0.001552 +514,264574,0.8183,0.78107,1.03279,0.72841,0.65102,0.70232,0.54006,0.86187,0.80019,1.07861,0.0015355,0.0015355,0.0015355 +515,265087,0.81456,0.78185,1.03061,0.72997,0.65057,0.70281,0.54023,0.86161,0.7999,1.07833,0.001519,0.001519,0.001519 +516,265601,0.81272,0.77257,1.0286,0.72975,0.6511,0.7027,0.54058,0.8614,0.79928,1.07802,0.0015025,0.0015025,0.0015025 +517,266114,0.81694,0.77913,1.03172,0.7313,0.65134,0.7031,0.5409,0.86107,0.79928,1.07762,0.001486,0.001486,0.001486 +518,266627,0.81034,0.77344,1.03045,0.73149,0.65145,0.70337,0.5411,0.86047,0.7991,1.07729,0.0014695,0.0014695,0.0014695 +519,267141,0.81496,0.77192,1.0305,0.73285,0.65111,0.70307,0.54133,0.86018,0.79864,1.07721,0.001453,0.001453,0.001453 +520,267655,0.80852,0.77068,1.02795,0.73119,0.65248,0.70327,0.54154,0.86004,0.79807,1.07698,0.0014365,0.0014365,0.0014365 +521,268170,0.81131,0.768,1.02708,0.73092,0.65246,0.70319,0.5417,0.85994,0.79759,1.07688,0.00142,0.00142,0.00142 +522,268684,0.80931,0.77002,1.02796,0.73138,0.65195,0.70352,0.54189,0.85981,0.7972,1.07682,0.0014035,0.0014035,0.0014035 +523,269198,0.80892,0.76858,1.02862,0.73146,0.65211,0.70365,0.54215,0.85939,0.79733,1.07649,0.001387,0.001387,0.001387 +524,269712,0.80932,0.76711,1.02804,0.73303,0.65062,0.70381,0.54233,0.85931,0.79692,1.07644,0.0013705,0.0013705,0.0013705 +525,270225,0.80506,0.7628,1.0263,0.73364,0.64881,0.70395,0.5425,0.85904,0.79641,1.07619,0.001354,0.001354,0.001354 +526,270740,0.81277,0.76282,1.0275,0.73561,0.64745,0.70439,0.54293,0.85887,0.79622,1.07599,0.0013375,0.0013375,0.0013375 +527,271254,0.81106,0.76605,1.02804,0.73596,0.64622,0.70467,0.5433,0.85854,0.79553,1.07588,0.001321,0.001321,0.001321 +528,271768,0.80538,0.76051,1.02354,0.73436,0.64821,0.70475,0.54336,0.8586,0.79543,1.07564,0.0013045,0.0013045,0.0013045 +529,272281,0.80485,0.75322,1.02344,0.73634,0.64681,0.70467,0.54322,0.85837,0.7945,1.07547,0.001288,0.001288,0.001288 +530,272795,0.8045,0.75297,1.02704,0.73829,0.64607,0.70495,0.5434,0.85821,0.7942,1.07535,0.0012715,0.0012715,0.0012715 +531,273309,0.80549,0.75415,1.02483,0.73911,0.64589,0.70491,0.54335,0.85795,0.79403,1.0753,0.001255,0.001255,0.001255 +532,273822,0.80463,0.75137,1.02604,0.73695,0.64794,0.705,0.54352,0.85801,0.79388,1.07533,0.0012385,0.0012385,0.0012385 +533,274337,0.80313,0.75033,1.02321,0.73711,0.64749,0.70532,0.54376,0.85782,0.79345,1.07511,0.001222,0.001222,0.001222 +534,274851,0.80298,0.74969,1.02424,0.74218,0.64557,0.70552,0.54416,0.85804,0.79273,1.07512,0.0012055,0.0012055,0.0012055 +535,275366,0.80052,0.74801,1.0228,0.74115,0.64664,0.70605,0.54437,0.85788,0.79243,1.07494,0.001189,0.001189,0.001189 +536,275878,0.80267,0.75011,1.02386,0.73764,0.64906,0.70624,0.54427,0.85757,0.79199,1.07483,0.0011725,0.0011725,0.0011725 +537,276393,0.80054,0.74474,1.02245,0.74021,0.6473,0.70659,0.54439,0.85772,0.79163,1.07493,0.001156,0.001156,0.001156 +538,276906,0.79994,0.74272,1.02139,0.73699,0.64994,0.70682,0.54441,0.85789,0.79116,1.07501,0.0011395,0.0011395,0.0011395 +539,277419,0.79602,0.73845,1.01954,0.73731,0.65161,0.70745,0.54497,0.85787,0.79074,1.07497,0.001123,0.001123,0.001123 +540,277934,0.79417,0.73699,1.02039,0.73425,0.65158,0.70751,0.54524,0.85789,0.79032,1.07502,0.0011065,0.0011065,0.0011065 +541,278448,0.79576,0.73589,1.01973,0.73211,0.65265,0.70741,0.54504,0.85784,0.79008,1.0749,0.00109,0.00109,0.00109 +542,278962,0.79384,0.73635,1.01885,0.73472,0.65123,0.70783,0.54541,0.85789,0.7896,1.07502,0.0010735,0.0010735,0.0010735 +543,279475,0.79356,0.73463,1.01682,0.73768,0.65024,0.70826,0.54567,0.85767,0.78904,1.07458,0.001057,0.001057,0.001057 +544,279989,0.79166,0.73266,1.01753,0.73885,0.6506,0.70848,0.54598,0.85768,0.78864,1.0744,0.0010405,0.0010405,0.0010405 +545,280502,0.79111,0.72926,1.01912,0.74063,0.65008,0.70864,0.54617,0.85783,0.78849,1.07467,0.001024,0.001024,0.001024 +546,281016,0.79226,0.73168,1.01873,0.74052,0.6495,0.70853,0.54634,0.85768,0.78817,1.07474,0.0010075,0.0010075,0.0010075 +547,281530,0.79054,0.72847,1.0185,0.74009,0.64915,0.70867,0.54658,0.85755,0.78774,1.07473,0.000991,0.000991,0.000991 +548,282043,0.78962,0.7272,1.01885,0.73858,0.65241,0.709,0.54683,0.85731,0.78784,1.07462,0.0009745,0.0009745,0.0009745 +549,282557,0.7898,0.72699,1.01694,0.73767,0.65271,0.70914,0.54698,0.85712,0.78756,1.07447,0.000958,0.000958,0.000958 +550,283070,0.78711,0.72389,1.01775,0.73631,0.65442,0.70896,0.547,0.85691,0.78721,1.07455,0.0009415,0.0009415,0.0009415 +551,283585,0.78377,0.72216,1.01311,0.73721,0.65582,0.70944,0.54733,0.85673,0.78715,1.07436,0.000925,0.000925,0.000925 +552,284099,0.78707,0.71838,1.01536,0.73747,0.65555,0.70966,0.54754,0.85661,0.78697,1.07434,0.0009085,0.0009085,0.0009085 +553,284613,0.78436,0.71652,1.01508,0.73737,0.65612,0.70962,0.54747,0.85644,0.78691,1.07417,0.000892,0.000892,0.000892 +554,285128,0.7846,0.71137,1.01181,0.73591,0.65709,0.70979,0.54757,0.85669,0.78636,1.0746,0.0008755,0.0008755,0.0008755 +555,285641,0.78296,0.71334,1.01399,0.73422,0.65855,0.71002,0.54789,0.85676,0.78617,1.07456,0.000859,0.000859,0.000859 +556,286154,0.7844,0.71225,1.01158,0.73492,0.65913,0.7102,0.54787,0.85643,0.78591,1.07424,0.0008425,0.0008425,0.0008425 +557,286667,0.77796,0.70653,1.00942,0.73501,0.65904,0.71027,0.54815,0.856,0.78588,1.07376,0.000826,0.000826,0.000826 +558,287180,0.78107,0.70939,1.01237,0.73686,0.65759,0.71012,0.54815,0.85552,0.78559,1.07331,0.0008095,0.0008095,0.0008095 +559,287692,0.77609,0.7022,1.01122,0.7379,0.65732,0.71027,0.5482,0.85548,0.78525,1.07328,0.000793,0.000793,0.000793 +560,288206,0.77715,0.70366,1.01033,0.73894,0.6568,0.71065,0.54836,0.85501,0.78502,1.07287,0.0007765,0.0007765,0.0007765 +561,288719,0.77557,0.69806,1.00827,0.73716,0.65738,0.71064,0.54865,0.8551,0.78463,1.07301,0.00076,0.00076,0.00076 +562,289232,0.7784,0.7016,1.00963,0.7378,0.65743,0.71079,0.54872,0.85495,0.78462,1.07296,0.0007435,0.0007435,0.0007435 +563,289747,0.77548,0.69866,1.00964,0.7378,0.65663,0.71094,0.54878,0.85463,0.78423,1.07281,0.000727,0.000727,0.000727 +564,290260,0.77404,0.69299,1.00614,0.7367,0.65662,0.71129,0.54924,0.8538,0.7844,1.07208,0.0007105,0.0007105,0.0007105 +565,290774,0.77273,0.69133,1.00465,0.73507,0.65729,0.7117,0.54937,0.85359,0.78411,1.07166,0.000694,0.000694,0.000694 +566,291287,0.77641,0.6943,1.0083,0.74612,0.6541,0.71226,0.54993,0.85343,0.78379,1.07156,0.0006775,0.0006775,0.0006775 +567,291801,0.77258,0.69352,1.01004,0.73545,0.65777,0.7121,0.54968,0.85305,0.78366,1.07157,0.000661,0.000661,0.000661 +568,292314,0.76617,0.68498,1.00536,0.7458,0.65445,0.71229,0.5498,0.85308,0.78337,1.07164,0.0006445,0.0006445,0.0006445 +569,292827,0.76687,0.68345,1.00509,0.7374,0.65675,0.71196,0.54965,0.8529,0.78333,1.07156,0.000628,0.000628,0.000628 +570,293339,0.76667,0.68366,1.00433,0.73834,0.65612,0.71231,0.54984,0.85285,0.7829,1.07151,0.0006115,0.0006115,0.0006115 +571,293851,0.76279,0.67796,1.00449,0.73836,0.65693,0.71281,0.55058,0.85274,0.78278,1.07143,0.000595,0.000595,0.000595 +572,294364,0.76359,0.67829,1.00408,0.73854,0.6572,0.71277,0.55043,0.85251,0.78262,1.07143,0.0005785,0.0005785,0.0005785 +573,294878,0.76379,0.66909,0.99984,0.73977,0.6568,0.71347,0.55074,0.85262,0.78203,1.07128,0.000562,0.000562,0.000562 +574,295392,0.76236,0.67324,1.00225,0.74083,0.65553,0.7139,0.55098,0.8523,0.78183,1.07124,0.0005455,0.0005455,0.0005455 +575,295906,0.76261,0.67191,1.00113,0.74319,0.6538,0.71403,0.55113,0.85221,0.78169,1.07123,0.000529,0.000529,0.000529 +576,296418,0.7626,0.67072,1.00324,0.74593,0.6559,0.7143,0.55146,0.8522,0.78145,1.07128,0.0005125,0.0005125,0.0005125 +577,296931,0.76152,0.66711,1.00007,0.74459,0.65713,0.71449,0.55174,0.85217,0.78143,1.07126,0.000496,0.000496,0.000496 +578,297444,0.75929,0.66197,1.00011,0.73544,0.65955,0.7145,0.55172,0.8521,0.78104,1.07128,0.0004795,0.0004795,0.0004795 +579,297958,0.75669,0.65836,0.99892,0.7354,0.65918,0.71459,0.55178,0.85208,0.78105,1.07127,0.000463,0.000463,0.000463 +580,298471,0.75416,0.66073,0.99838,0.73501,0.65899,0.71471,0.55197,0.85185,0.7812,1.07102,0.0004465,0.0004465,0.0004465 +581,298985,0.75691,0.65925,0.99721,0.7378,0.65834,0.71477,0.55209,0.85174,0.78111,1.07103,0.00043,0.00043,0.00043 +582,299498,0.75325,0.65433,0.99797,0.73849,0.65798,0.71473,0.55207,0.85175,0.781,1.07093,0.0004135,0.0004135,0.0004135 +583,300011,0.75424,0.65087,0.99551,0.73339,0.66269,0.7147,0.55235,0.85216,0.78095,1.07133,0.000397,0.000397,0.000397 +584,300525,0.75195,0.64892,0.9965,0.73314,0.66313,0.71473,0.55249,0.85213,0.78088,1.07127,0.0003805,0.0003805,0.0003805 +585,301038,0.74739,0.64138,0.99299,0.73264,0.66345,0.71487,0.55255,0.85202,0.78123,1.07112,0.000364,0.000364,0.000364 +586,301552,0.75043,0.64398,0.9948,0.73407,0.66168,0.71499,0.55275,0.85198,0.78099,1.07115,0.0003475,0.0003475,0.0003475 +587,302066,0.74486,0.64025,0.99439,0.73474,0.66164,0.71506,0.55288,0.85196,0.78083,1.07112,0.000331,0.000331,0.000331 +588,302580,0.74688,0.64118,0.99233,0.73677,0.65938,0.71489,0.55295,0.85188,0.78083,1.07096,0.0003145,0.0003145,0.0003145 +589,303092,0.74504,0.63476,0.99053,0.73644,0.65939,0.71488,0.55292,0.85176,0.78099,1.07068,0.000298,0.000298,0.000298 +590,303604,0.74171,0.63156,0.98779,0.73556,0.65772,0.71442,0.55241,0.8516,0.78119,1.0704,0.0002815,0.0002815,0.0002815 +591,304113,0.81659,0.68902,1.04882,0.73973,0.65633,0.71471,0.55251,0.85113,0.78126,1.06962,0.000265,0.000265,0.000265 +592,304621,0.81325,0.67394,1.0427,0.74212,0.65583,0.71542,0.55288,0.85059,0.78076,1.06881,0.0002485,0.0002485,0.0002485 +593,305128,0.80968,0.66528,1.04056,0.74627,0.65514,0.71655,0.55322,0.85027,0.7802,1.068,0.000232,0.000232,0.000232 +594,305634,0.80566,0.65818,1.04286,0.74593,0.65576,0.71696,0.55355,0.84962,0.78016,1.06709,0.0002155,0.0002155,0.0002155 +595,306143,0.8045,0.65022,1.04053,0.74688,0.65537,0.71723,0.55378,0.84943,0.78018,1.06656,0.000199,0.000199,0.000199 +596,306650,0.80301,0.64489,1.04285,0.74361,0.65829,0.71719,0.55367,0.84937,0.78033,1.06637,0.0001825,0.0001825,0.0001825 +597,307156,0.7949,0.63705,1.04111,0.74707,0.65653,0.71759,0.55396,0.84915,0.78076,1.06618,0.000166,0.000166,0.000166 +598,307662,0.79771,0.63083,1.03729,0.74992,0.65439,0.718,0.55404,0.84882,0.78118,1.0658,0.0001495,0.0001495,0.0001495 +599,308169,0.79046,0.62387,1.03344,0.74661,0.65706,0.71787,0.55404,0.84856,0.78193,1.06538,0.000133,0.000133,0.000133 +600,308675,0.78826,0.61676,1.03658,0.74969,0.65466,0.71766,0.55374,0.8484,0.78233,1.06531,0.0001165,0.0001165,0.0001165 diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000000000000000000000000000000000000..1ac53c83bf82b0beac772ea2a4879f050b0b602f --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,769 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Configuration file for building the Ultralytics YOLO documentation site using MkDocs. +# Provides settings to control site metadata, customize the appearance using the +# Material theme, define the navigation structure, and enable various plugins. + +# Site metadata +site_name: Ultralytics YOLO Docs +site_description: Explore Ultralytics YOLO, a cutting-edge real-time object detection and image segmentation model for various applications and hardware platforms. +site_url: https://docs.ultralytics.com +site_author: Ultralytics +repo_url: https://github.com/ultralytics/ultralytics +edit_uri: https://github.com/ultralytics/ultralytics/tree/main/docs/en/ +repo_name: ultralytics/ultralytics +remote_name: https://github.com/ultralytics/docs +docs_dir: "docs/en/" # where to find the markdown files +site_dir: "site/" # where to publish to +use_directory_urls: true # don't display 'index.html' in slugs + +# Theme customization +theme: + name: material + language: en + custom_dir: docs/overrides/ + logo: https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Reverse.svg + favicon: https://raw.githubusercontent.com/ultralytics/assets/refs/heads/main/logo/favicon-yolo.png + icon: + repo: fontawesome/brands/github + # font: # disabled for faster page load times + # text: Helvetica + # code: Roboto Mono + palette: + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to light mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + primary: black + accent: indigo + toggle: + icon: material/brightness-4 + name: Switch to system preference + - media: "(prefers-color-scheme: light)" + scheme: default + primary: indigo + accent: indigo + toggle: + icon: material/brightness-7 + name: Switch to dark mode + features: + - content.action.edit + - content.code.annotate + - content.code.copy + - content.tooltips + - search.highlight + - search.share + - search.suggest + - toc.follow + - navigation.top + - navigation.tabs + - navigation.tabs.sticky + - navigation.prune + - navigation.footer + - navigation.tracking + - navigation.instant + - navigation.instant.progress + - navigation.indexes + - navigation.sections # navigation.expand or navigation.sections + - content.tabs.link # all code tabs change simultaneously + +# Customization +copyright: © 2025 Ultralytics Inc. All rights reserved. +extra: # version: + homepage: https://www.ultralytics.com/ + # provider: mike # version drop-down menu + robots: robots.txt + analytics: + provider: google + property: G-2M5EHKC0BH + social: + - icon: fontawesome/brands/github + link: https://github.com/ultralytics + - icon: fontawesome/brands/linkedin + link: https://www.linkedin.com/company/ultralytics/ + - icon: fontawesome/brands/x-twitter + link: https://twitter.com/ultralytics + - icon: fontawesome/brands/youtube + link: https://youtube.com/ultralytics?sub_confirmation=1 + - icon: fontawesome/brands/docker + link: https://hub.docker.com/r/ultralytics/ultralytics/ + - icon: fontawesome/brands/python + link: https://pypi.org/project/ultralytics/ + - icon: fontawesome/brands/discord + link: https://discord.com/invite/ultralytics + - icon: fontawesome/brands/reddit + link: https://reddit.com/r/ultralytics + +extra_css: + - stylesheets/style.css + +extra_javascript: + - javascript/extra.js + - javascript/giscus.js + +markdown_extensions: + - admonition + - md_in_html + - tables + - attr_list + - def_list + - pymdownx.critic + - pymdownx.caret + - pymdownx.keys + - pymdownx.mark + - pymdownx.tilde + - pymdownx.details + - pymdownx.superfences + - pymdownx.inlinehilite + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.snippets: + base_path: ./ + - pymdownx.emoji: + emoji_index: !!python/name:material.extensions.emoji.twemoji + emoji_generator: !!python/name:material.extensions.emoji.to_svg + - pymdownx.tabbed: + alternate_style: true + +# Validation settings https://www.mkdocs.org/user-guide/configuration/#validation +validation: + nav: + omitted_files: info + not_found: warn + absolute_links: info + links: + absolute_links: relative_to_docs + anchors: warn + unrecognized_links: warn + +# Primary navigation --------------------------------------------------------------------------------------------------- +nav: + - Home: + - Home: index.md + - Quickstart: quickstart.md + - Modes: + - modes/index.md + - Train: modes/train.md + - Val: modes/val.md + - Predict: modes/predict.md + - Export: modes/export.md + - Track: modes/track.md + - Benchmark: modes/benchmark.md + - Tasks: + - tasks/index.md + - Detect: tasks/detect.md + - Segment: tasks/segment.md + - Classify: tasks/classify.md + - Pose: tasks/pose.md + - OBB: tasks/obb.md + - Models: + - models/index.md + - Datasets: + - datasets/index.md + - Solutions: + - solutions/index.md + - Guides: + - guides/index.md + - YOLO11 🚀 NEW: models/yolo11.md # for promotion of new pages + - Languages: + - 🇬🇧  English: https://ultralytics.com/docs/ + - 🇨🇳  简体中文: https://docs.ultralytics.com/zh/ + - 🇰🇷  한국어: https://docs.ultralytics.com/ko/ + - 🇯🇵  日本語: https://docs.ultralytics.com/ja/ + - 🇷🇺  Русский: https://docs.ultralytics.com/ru/ + - 🇩🇪  Deutsch: https://docs.ultralytics.com/de/ + - 🇫🇷  Français: https://docs.ultralytics.com/fr/ + - 🇪🇸  Español: https://docs.ultralytics.com/es/ + - 🇵🇹  Português: https://docs.ultralytics.com/pt/ + - 🇮🇹  Italiano: https://docs.ultralytics.com/it/ + - 🇹🇷  Türkçe: https://docs.ultralytics.com/tr/ + - 🇻🇳  Tiếng Việt: https://docs.ultralytics.com/vi/ + - 🇸🇦  العربية: https://docs.ultralytics.com/ar/ + - Quickstart: + - quickstart.md + - Usage: + - CLI: usage/cli.md + - Python: usage/python.md + - Callbacks: usage/callbacks.md + - Configuration: usage/cfg.md + - Simple Utilities: usage/simple-utilities.md + - Advanced Customization: usage/engine.md + - Modes: + - modes/index.md + - Train: modes/train.md + - Val: modes/val.md + - Predict: modes/predict.md + - Export: modes/export.md + - Track: modes/track.md + - Benchmark: modes/benchmark.md + - Tasks: + - tasks/index.md + - Detect: tasks/detect.md + - Segment: tasks/segment.md + - Classify: tasks/classify.md + - Pose: tasks/pose.md + - OBB: tasks/obb.md + - Models: + - models/index.md + - Datasets: + - datasets/index.md + - Solutions: + - solutions/index.md + - Guides: + - guides/index.md + - Modes: + - modes/index.md + - Train: modes/train.md + - Val: modes/val.md + - Predict: modes/predict.md + - Export: modes/export.md + - Track: modes/track.md + - Benchmark: modes/benchmark.md + - Tasks: + - tasks/index.md + - Detect: tasks/detect.md + - Segment: tasks/segment.md + - Classify: tasks/classify.md + - Pose: tasks/pose.md + - OBB: tasks/obb.md + - Tasks: + - tasks/index.md + - Detect: tasks/detect.md + - Segment: tasks/segment.md + - Classify: tasks/classify.md + - Pose: tasks/pose.md + - OBB: tasks/obb.md + - Modes: + - modes/index.md + - Train: modes/train.md + - Val: modes/val.md + - Predict: modes/predict.md + - Export: modes/export.md + - Track: modes/track.md + - Benchmark: modes/benchmark.md + - Models: + - models/index.md + - YOLOv3: models/yolov3.md + - YOLOv4: models/yolov4.md + - YOLOv5: models/yolov5.md + - YOLOv6: models/yolov6.md + - YOLOv7: models/yolov7.md + - YOLOv8: models/yolov8.md + - YOLOv9: models/yolov9.md + - YOLOv10: models/yolov10.md + - YOLO11 🚀 NEW: models/yolo11.md + - SAM (Segment Anything Model): models/sam.md + - SAM 2 (Segment Anything Model 2): models/sam-2.md + - MobileSAM (Mobile Segment Anything Model): models/mobile-sam.md + - FastSAM (Fast Segment Anything Model): models/fast-sam.md + - YOLO-NAS (Neural Architecture Search): models/yolo-nas.md + - RT-DETR (Realtime Detection Transformer): models/rtdetr.md + - YOLO-World (Real-Time Open-Vocabulary Object Detection): models/yolo-world.md + - Datasets: + - datasets/index.md + - Detection: + - datasets/detect/index.md + - Argoverse: datasets/detect/argoverse.md + - COCO: datasets/detect/coco.md + - LVIS: datasets/detect/lvis.md + - COCO8: datasets/detect/coco8.md + - GlobalWheat2020: datasets/detect/globalwheat2020.md + - Objects365: datasets/detect/objects365.md + - OpenImagesV7: datasets/detect/open-images-v7.md + - SKU-110K: datasets/detect/sku-110k.md + - VisDrone: datasets/detect/visdrone.md + - VOC: datasets/detect/voc.md + - xView: datasets/detect/xview.md + - RF100: datasets/detect/roboflow-100.md + - Brain-tumor: datasets/detect/brain-tumor.md + - African-wildlife: datasets/detect/african-wildlife.md + - Signature: datasets/detect/signature.md + - Medical-pills: datasets/detect/medical-pills.md + - Segmentation: + - datasets/segment/index.md + - COCO: datasets/segment/coco.md + - COCO8-seg: datasets/segment/coco8-seg.md + - Crack-seg: datasets/segment/crack-seg.md + - Carparts-seg: datasets/segment/carparts-seg.md + - Package-seg: datasets/segment/package-seg.md + - Pose: + - datasets/pose/index.md + - COCO: datasets/pose/coco.md + - COCO8-pose: datasets/pose/coco8-pose.md + - Tiger-pose: datasets/pose/tiger-pose.md + - Hand-keypoints: datasets/pose/hand-keypoints.md + - Dog-pose: datasets/pose/dog-pose.md + - Classification: + - datasets/classify/index.md + - Caltech 101: datasets/classify/caltech101.md + - Caltech 256: datasets/classify/caltech256.md + - CIFAR-10: datasets/classify/cifar10.md + - CIFAR-100: datasets/classify/cifar100.md + - Fashion-MNIST: datasets/classify/fashion-mnist.md + - ImageNet: datasets/classify/imagenet.md + - ImageNet-10: datasets/classify/imagenet10.md + - Imagenette: datasets/classify/imagenette.md + - Imagewoof: datasets/classify/imagewoof.md + - MNIST: datasets/classify/mnist.md + - Oriented Bounding Boxes (OBB): + - datasets/obb/index.md + - DOTAv2: datasets/obb/dota-v2.md + - DOTA8: datasets/obb/dota8.md + - Multi-Object Tracking: + - datasets/track/index.md + - Solutions 🚀 NEW: + - solutions/index.md + - Object Counting: guides/object-counting.md + - Object Cropping: guides/object-cropping.md + - Object Blurring: guides/object-blurring.md + - Workouts Monitoring: guides/workouts-monitoring.md + - Objects Counting in Regions: guides/region-counting.md + - Security Alarm System: guides/security-alarm-system.md + - Heatmaps: guides/heatmaps.md + - Instance Segmentation with Object Tracking: guides/instance-segmentation-and-tracking.md + - VisionEye Mapping: guides/vision-eye.md + - Speed Estimation: guides/speed-estimation.md + - Distance Calculation: guides/distance-calculation.md + - Queue Management: guides/queue-management.md + - Parking Management: guides/parking-management.md + - Analytics: guides/analytics.md + - Live Inference: guides/streamlit-live-inference.md + - Track Objects in Zone 🚀 NEW: guides/trackzone.md + - Guides: + - guides/index.md + - YOLO Common Issues: guides/yolo-common-issues.md + - YOLO Performance Metrics: guides/yolo-performance-metrics.md + - YOLO Thread-Safe Inference: guides/yolo-thread-safe-inference.md + - Model Deployment Options: guides/model-deployment-options.md + - K-Fold Cross Validation: guides/kfold-cross-validation.md + - Hyperparameter Tuning: guides/hyperparameter-tuning.md + - SAHI Tiled Inference: guides/sahi-tiled-inference.md + - AzureML Quickstart: guides/azureml-quickstart.md + - Conda Quickstart: guides/conda-quickstart.md + - Docker Quickstart: guides/docker-quickstart.md + - Raspberry Pi: guides/raspberry-pi.md + - NVIDIA Jetson: guides/nvidia-jetson.md + - DeepStream on NVIDIA Jetson: guides/deepstream-nvidia-jetson.md + - Triton Inference Server: guides/triton-inference-server.md + - Isolating Segmentation Objects: guides/isolating-segmentation-objects.md + - Edge TPU on Raspberry Pi: guides/coral-edge-tpu-on-raspberry-pi.md + - Viewing Inference Images in a Terminal: guides/view-results-in-terminal.md + - OpenVINO Latency vs Throughput modes: guides/optimizing-openvino-latency-vs-throughput-modes.md + - ROS Quickstart: guides/ros-quickstart.md + - Steps of a Computer Vision Project: guides/steps-of-a-cv-project.md + - Defining A Computer Vision Project's Goals: guides/defining-project-goals.md + - Data Collection and Annotation: guides/data-collection-and-annotation.md + - Preprocessing Annotated Data: guides/preprocessing_annotated_data.md + - Tips for Model Training: guides/model-training-tips.md + - Insights on Model Evaluation and Fine-Tuning: guides/model-evaluation-insights.md + - A Guide on Model Testing: guides/model-testing.md + - Best Practices for Model Deployment: guides/model-deployment-practices.md + - Maintaining Your Computer Vision Model: guides/model-monitoring-and-maintenance.md + - Explorer: + - datasets/explorer/index.md + - Explorer API: datasets/explorer/api.md + - Explorer Dashboard Demo: datasets/explorer/dashboard.md + - VOC Exploration Example: datasets/explorer/explorer.md + - YOLOv5: + - yolov5/index.md + - Quickstart: yolov5/quickstart_tutorial.md + - Environments: + - Amazon Web Services (AWS): yolov5/environments/aws_quickstart_tutorial.md + - Google Cloud (GCP): yolov5/environments/google_cloud_quickstart_tutorial.md + - AzureML: yolov5/environments/azureml_quickstart_tutorial.md + - Docker Image: yolov5/environments/docker_image_quickstart_tutorial.md + - Tutorials: + - Train Custom Data: yolov5/tutorials/train_custom_data.md + - Tips for Best Training Results: yolov5/tutorials/tips_for_best_training_results.md + - Multi-GPU Training: yolov5/tutorials/multi_gpu_training.md + - PyTorch Hub: yolov5/tutorials/pytorch_hub_model_loading.md + - TFLite, ONNX, CoreML, TensorRT Export: yolov5/tutorials/model_export.md + - Test-Time Augmentation (TTA): yolov5/tutorials/test_time_augmentation.md + - Model Ensembling: yolov5/tutorials/model_ensembling.md + - Pruning/Sparsity Tutorial: yolov5/tutorials/model_pruning_and_sparsity.md + - Hyperparameter evolution: yolov5/tutorials/hyperparameter_evolution.md + - Transfer learning with frozen layers: yolov5/tutorials/transfer_learning_with_frozen_layers.md + - Architecture Summary: yolov5/tutorials/architecture_description.md + - Roboflow Datasets: yolov5/tutorials/roboflow_datasets_integration.md + - Neural Magic's DeepSparse: yolov5/tutorials/neural_magic_pruning_quantization.md + - Comet Logging: yolov5/tutorials/comet_logging_integration.md + - Clearml Logging: yolov5/tutorials/clearml_logging_integration.md + - Integrations: + - integrations/index.md + - Amazon SageMaker: integrations/amazon-sagemaker.md + - ClearML: integrations/clearml.md + - Comet ML: integrations/comet.md + - CoreML: integrations/coreml.md + - DVC: integrations/dvc.md + - Google Colab: integrations/google-colab.md + - Gradio: integrations/gradio.md + - IBM Watsonx: integrations/ibm-watsonx.md + - JupyterLab: integrations/jupyterlab.md + - Kaggle: integrations/kaggle.md + - MLflow: integrations/mlflow.md + - Neural Magic: integrations/neural-magic.md + - ONNX: integrations/onnx.md + - OpenVINO: integrations/openvino.md + - PaddlePaddle: integrations/paddlepaddle.md + - MNN: integrations/mnn.md + - NCNN: integrations/ncnn.md + - Paperspace Gradient: integrations/paperspace.md + - Ray Tune: integrations/ray-tune.md + - Roboflow: integrations/roboflow.md + - TF GraphDef: integrations/tf-graphdef.md + - TF SavedModel: integrations/tf-savedmodel.md + - TF.js: integrations/tfjs.md + - TFLite: integrations/tflite.md + - TFLite Edge TPU: integrations/edge-tpu.md + - TensorBoard: integrations/tensorboard.md + - TensorRT: integrations/tensorrt.md + - TorchScript: integrations/torchscript.md + - VS Code: integrations/vscode.md + - Weights & Biases: integrations/weights-biases.md + - Albumentations: integrations/albumentations.md + - SONY IMX500: integrations/sony-imx500.md + - HUB: + - hub/index.md + - Web: + - hub/index.md + - Quickstart: hub/quickstart.md + - Datasets: hub/datasets.md + - Projects: hub/projects.md + - Models: hub/models.md + - Pro: hub/pro.md + - Cloud Training: hub/cloud-training.md + - Inference API: hub/inference-api.md + - Teams: hub/teams.md + - Integrations: hub/integrations.md + - App: + - hub/app/index.md + - iOS: hub/app/ios.md + - Android: hub/app/android.md + - Python SDK: + - hub/sdk/index.md + - Quickstart: hub/sdk/quickstart.md + - Model: hub/sdk/model.md + - Dataset: hub/sdk/dataset.md + - Project: hub/sdk/project.md + - Reference: + - base: + - api_client: hub/sdk/reference/base/api_client.md + - auth: hub/sdk/reference/base/auth.md + - crud_client: hub/sdk/reference/base/crud_client.md + - paginated_list: hub/sdk/reference/base/paginated_list.md + - server_clients: hub/sdk/reference/base/server_clients.md + - helpers: + - error_handler: hub/sdk/reference/helpers/error_handler.md + - exceptions: hub/sdk/reference/helpers/exceptions.md + - logger: hub/sdk/reference/helpers/logger.md + - utils: hub/sdk/reference/helpers/utils.md + - hub_client: hub/sdk/reference/hub_client.md + - modules: + - datasets: hub/sdk/reference/modules/datasets.md + - models: hub/sdk/reference/modules/models.md + - projects: hub/sdk/reference/modules/projects.md + - teams: hub/sdk/reference/modules/teams.md + - users: hub/sdk/reference/modules/users.md + - REST API: + - hub/api/index.md + + - Reference: + - cfg: + - __init__: reference/cfg/__init__.md + - data: + - annotator: reference/data/annotator.md + - augment: reference/data/augment.md + - base: reference/data/base.md + - build: reference/data/build.md + - converter: reference/data/converter.md + - dataset: reference/data/dataset.md + - loaders: reference/data/loaders.md + - split_dota: reference/data/split_dota.md + - utils: reference/data/utils.md + - engine: + - exporter: reference/engine/exporter.md + - model: reference/engine/model.md + - predictor: reference/engine/predictor.md + - results: reference/engine/results.md + - trainer: reference/engine/trainer.md + - tuner: reference/engine/tuner.md + - validator: reference/engine/validator.md + - hub: + - __init__: reference/hub/__init__.md + - auth: reference/hub/auth.md + - google: + - __init__: reference/hub/google/__init__.md + - session: reference/hub/session.md + - utils: reference/hub/utils.md + - models: + - fastsam: + - model: reference/models/fastsam/model.md + - predict: reference/models/fastsam/predict.md + - utils: reference/models/fastsam/utils.md + - val: reference/models/fastsam/val.md + - nas: + - model: reference/models/nas/model.md + - predict: reference/models/nas/predict.md + - val: reference/models/nas/val.md + - rtdetr: + - model: reference/models/rtdetr/model.md + - predict: reference/models/rtdetr/predict.md + - train: reference/models/rtdetr/train.md + - val: reference/models/rtdetr/val.md + - sam: + - amg: reference/models/sam/amg.md + - build: reference/models/sam/build.md + - model: reference/models/sam/model.md + - modules: + - blocks: reference/models/sam/modules/blocks.md + - decoders: reference/models/sam/modules/decoders.md + - encoders: reference/models/sam/modules/encoders.md + - memory_attention: reference/models/sam/modules/memory_attention.md + - sam: reference/models/sam/modules/sam.md + - tiny_encoder: reference/models/sam/modules/tiny_encoder.md + - transformer: reference/models/sam/modules/transformer.md + - utils: reference/models/sam/modules/utils.md + - predict: reference/models/sam/predict.md + - utils: + - loss: reference/models/utils/loss.md + - ops: reference/models/utils/ops.md + - yolo: + - classify: + - predict: reference/models/yolo/classify/predict.md + - train: reference/models/yolo/classify/train.md + - val: reference/models/yolo/classify/val.md + - detect: + - predict: reference/models/yolo/detect/predict.md + - train: reference/models/yolo/detect/train.md + - val: reference/models/yolo/detect/val.md + - model: reference/models/yolo/model.md + - obb: + - predict: reference/models/yolo/obb/predict.md + - train: reference/models/yolo/obb/train.md + - val: reference/models/yolo/obb/val.md + - pose: + - predict: reference/models/yolo/pose/predict.md + - train: reference/models/yolo/pose/train.md + - val: reference/models/yolo/pose/val.md + - segment: + - predict: reference/models/yolo/segment/predict.md + - train: reference/models/yolo/segment/train.md + - val: reference/models/yolo/segment/val.md + - world: + - train: reference/models/yolo/world/train.md + - train_world: reference/models/yolo/world/train_world.md + - nn: + - autobackend: reference/nn/autobackend.md + - modules: + - activation: reference/nn/modules/activation.md + - block: reference/nn/modules/block.md + - conv: reference/nn/modules/conv.md + - head: reference/nn/modules/head.md + - transformer: reference/nn/modules/transformer.md + - utils: reference/nn/modules/utils.md + - tasks: reference/nn/tasks.md + - solutions: + - ai_gym: reference/solutions/ai_gym.md + - analytics: reference/solutions/analytics.md + - distance_calculation: reference/solutions/distance_calculation.md + - heatmap: reference/solutions/heatmap.md + - object_counter: reference/solutions/object_counter.md + - parking_management: reference/solutions/parking_management.md + - queue_management: reference/solutions/queue_management.md + - region_counter: reference/solutions/region_counter.md + - security_alarm: reference/solutions/security_alarm.md + - solutions: reference/solutions/solutions.md + - speed_estimation: reference/solutions/speed_estimation.md + - streamlit_inference: reference/solutions/streamlit_inference.md + - trackzone: reference/solutions/trackzone.md + - trackers: + - basetrack: reference/trackers/basetrack.md + - bot_sort: reference/trackers/bot_sort.md + - byte_tracker: reference/trackers/byte_tracker.md + - track: reference/trackers/track.md + - utils: + - gmc: reference/trackers/utils/gmc.md + - kalman_filter: reference/trackers/utils/kalman_filter.md + - matching: reference/trackers/utils/matching.md + - utils: + - __init__: reference/utils/__init__.md + - autobatch: reference/utils/autobatch.md + - benchmarks: reference/utils/benchmarks.md + - callbacks: + - base: reference/utils/callbacks/base.md + - clearml: reference/utils/callbacks/clearml.md + - comet: reference/utils/callbacks/comet.md + - dvc: reference/utils/callbacks/dvc.md + - hub: reference/utils/callbacks/hub.md + - mlflow: reference/utils/callbacks/mlflow.md + - neptune: reference/utils/callbacks/neptune.md + - raytune: reference/utils/callbacks/raytune.md + - tensorboard: reference/utils/callbacks/tensorboard.md + - wb: reference/utils/callbacks/wb.md + - checks: reference/utils/checks.md + - dist: reference/utils/dist.md + - downloads: reference/utils/downloads.md + - errors: reference/utils/errors.md + - files: reference/utils/files.md + - instance: reference/utils/instance.md + - loss: reference/utils/loss.md + - metrics: reference/utils/metrics.md + - ops: reference/utils/ops.md + - patches: reference/utils/patches.md + - plotting: reference/utils/plotting.md + - tal: reference/utils/tal.md + - torch_utils: reference/utils/torch_utils.md + - triton: reference/utils/triton.md + - tuner: reference/utils/tuner.md + + - Help: + - Help: help/index.md + - Frequently Asked Questions (FAQ): help/FAQ.md + - Contributing Guide: help/contributing.md + - Continuous Integration (CI) Guide: help/CI.md + - Contributor License Agreement (CLA): help/CLA.md + - Minimum Reproducible Example (MRE) Guide: help/minimum-reproducible-example.md + - Code of Conduct: help/code-of-conduct.md + - Environmental, Health and Safety (EHS) Policy: help/environmental-health-safety.md + - Security Policy: help/security.md + - Privacy Policy: help/privacy.md + +# Plugins including 301 redirects navigation --------------------------------------------------------------------------- +plugins: + - macros + # - search: + # lang: en + - mkdocstrings: + enabled: true + default_handler: python + handlers: + python: + options: + docstring_options: + ignore_init_summary: true + merge_init_into_class: true + docstring_style: google + show_root_heading: true + show_source: true + separate_signature: true + line_length: 80 + show_signature_annotations: true + show_symbol_type_heading: true # insiders + show_symbol_type_toc: true # insiders + show_inheritance_diagram: true # insiders + - ultralytics: + add_desc: False + add_image: True + add_authors: True + add_json_ld: True + add_share_buttons: True + add_css: False + default_image: https://raw.githubusercontent.com/ultralytics/assets/main/yolov8/banner-yolov8.png + - redirects: + redirect_maps: + hi/index.md: index.md + nl/index.md: index.md + callbacks.md: usage/callbacks.md + cfg.md: usage/cfg.md + cli.md: usage/cli.md + config.md: usage/cfg.md + engine.md: usage/engine.md + environments/AWS-Quickstart.md: yolov5/environments/aws_quickstart_tutorial.md + environments/Docker-Quickstart.md: yolov5/environments/docker_image_quickstart_tutorial.md + environments/GCP-Quickstart.md: yolov5/environments/google_cloud_quickstart_tutorial.md + FAQ/augmentation.md: yolov5/tutorials/tips_for_best_training_results.md + package-framework.md: index.md + package-framework/mock_detector.md: index.md + predict.md: modes/predict.md + python.md: usage/python.md + quick-start.md: quickstart.md + app.md: hub/app/index.md + sdk.md: index.md + hub/inference_api.md: hub/inference-api.md + usage/hyperparameter_tuning.md: integrations/ray-tune.md + models/sam2.md: models/sam-2.md + reference/base_pred.md: reference/engine/predictor.md + reference/base_trainer.md: reference/engine/trainer.md + reference/exporter.md: reference/engine/exporter.md + reference/model.md: reference/engine/model.md + reference/nn.md: reference/nn/modules/head.md + reference/ops.md: reference/utils/ops.md + reference/results.md: reference/engine/results.md + reference/base_val.md: index.md + reference/index.md: reference/cfg/__init__.md + tasks/classification.md: tasks/classify.md + tasks/detection.md: tasks/detect.md + tasks/segmentation.md: tasks/segment.md + tasks/keypoints.md: tasks/pose.md + tasks/tracking.md: modes/track.md + SECURITY.md: help/security.md + help/minimum_reproducible_example.md: help/minimum-reproducible-example.md + help/code_of_conduct.md: help/code-of-conduct.md + tutorials/architecture-summary.md: yolov5/tutorials/architecture_description.md + tutorials/clearml-logging.md: yolov5/tutorials/clearml_logging_integration.md + tutorials/comet-logging.md: yolov5/tutorials/comet_logging_integration.md + tutorials/hyperparameter-evolution.md: yolov5/tutorials/hyperparameter_evolution.md + tutorials/model-ensembling.md: yolov5/tutorials/model_ensembling.md + tutorials/multi-gpu-training.md: yolov5/tutorials/multi_gpu_training.md + tutorials/nvidia-jetson.md: guides/nvidia-jetson.md + tutorials/pruning-sparsity.md: yolov5/tutorials/model_pruning_and_sparsity.md + tutorials/pytorch-hub.md: yolov5/tutorials/pytorch_hub_model_loading.md + tutorials/roboflow.md: yolov5/tutorials/roboflow_datasets_integration.md + tutorials/test-time-augmentation.md: yolov5/tutorials/test_time_augmentation.md + tutorials/torchscript-onnx-coreml-export.md: yolov5/tutorials/model_export.md + tutorials/train-custom-datasets.md: yolov5/tutorials/train_custom_data.md + tutorials/training-tips-best-results.md: yolov5/tutorials/tips_for_best_training_results.md + tutorials/transfer-learning-froze-layers.md: yolov5/tutorials/transfer_learning_with_frozen_layers.md + tutorials/weights-and-biasis-logging.md: yolov5/tutorials/comet_logging_integration.md + yolov5/pytorch_hub.md: yolov5/tutorials/pytorch_hub_model_loading.md + yolov5/hyp_evolution.md: yolov5/tutorials/hyperparameter_evolution.md + yolov5/pruning_sparsity.md: yolov5/tutorials/model_pruning_and_sparsity.md + yolov5/roboflow.md: yolov5/tutorials/roboflow_datasets_integration.md + yolov5/comet.md: yolov5/tutorials/comet_logging_integration.md + yolov5/clearml.md: yolov5/tutorials/clearml_logging_integration.md + yolov5/tta.md: yolov5/tutorials/test_time_augmentation.md + yolov5/multi_gpu_training.md: yolov5/tutorials/multi_gpu_training.md + yolov5/ensemble.md: yolov5/tutorials/model_ensembling.md + yolov5/jetson_nano.md: guides/nvidia-jetson.md + yolov5/transfer_learn_frozen.md: yolov5/tutorials/transfer_learning_with_frozen_layers.md + yolov5/neural_magic.md: yolov5/tutorials/neural_magic_pruning_quantization.md + yolov5/train_custom_data.md: yolov5/tutorials/train_custom_data.md + yolov5/architecture.md: yolov5/tutorials/architecture_description.md + yolov5/export.md: yolov5/tutorials/model_export.md + yolov5/yolov5_quickstart_tutorial.md: yolov5/quickstart_tutorial.md + yolov5/tips_for_best_training_results.md: yolov5/tutorials/tips_for_best_training_results.md + yolov5/tutorials/yolov5_neural_magic_tutorial.md: yolov5/tutorials/neural_magic_pruning_quantization.md + yolov5/tutorials/model_ensembling_tutorial.md: yolov5/tutorials/model_ensembling.md + yolov5/tutorials/pytorch_hub_tutorial.md: yolov5/tutorials/pytorch_hub_model_loading.md + yolov5/tutorials/yolov5_architecture_tutorial.md: yolov5/tutorials/architecture_description.md + yolov5/tutorials/multi_gpu_training_tutorial.md: yolov5/tutorials/multi_gpu_training.md + yolov5/tutorials/yolov5_pytorch_hub_tutorial.md: yolov5/tutorials/pytorch_hub_model_loading.md + yolov5/tutorials/model_export_tutorial.md: yolov5/tutorials/model_export.md + yolov5/tutorials/jetson_nano_tutorial.md: guides/nvidia-jetson.md + yolov5/tutorials/yolov5_model_ensembling_tutorial.md: yolov5/tutorials/model_ensembling.md + yolov5/tutorials/roboflow_integration.md: yolov5/tutorials/roboflow_datasets_integration.md + yolov5/tutorials/pruning_and_sparsity_tutorial.md: yolov5/tutorials/model_pruning_and_sparsity.md + yolov5/tutorials/yolov5_transfer_learning_with_frozen_layers_tutorial.md: yolov5/tutorials/transfer_learning_with_frozen_layers.md + yolov5/tutorials/transfer_learning_with_frozen_layers_tutorial.md: yolov5/tutorials/transfer_learning_with_frozen_layers.md + yolov5/tutorials/yolov5_model_export_tutorial.md: yolov5/tutorials/model_export.md + yolov5/tutorials/neural_magic_tutorial.md: yolov5/tutorials/neural_magic_pruning_quantization.md + yolov5/tutorials/yolov5_clearml_integration_tutorial.md: yolov5/tutorials/clearml_logging_integration.md + yolov5/tutorials/yolov5_train_custom_data.md: yolov5/tutorials/train_custom_data.md + yolov5/tutorials/comet_integration_tutorial.md: yolov5/tutorials/comet_logging_integration.md + yolov5/tutorials/yolov5_pruning_and_sparsity_tutorial.md: yolov5/tutorials/model_pruning_and_sparsity.md + yolov5/tutorials/yolov5_jetson_nano_tutorial.md: guides/nvidia-jetson.md + yolov5/tutorials/running_on_jetson_nano.md: guides/nvidia-jetson.md + yolov5/tutorials/yolov5_roboflow_integration.md: yolov5/tutorials/roboflow_datasets_integration.md + yolov5/tutorials/hyperparameter_evolution_tutorial.md: yolov5/tutorials/hyperparameter_evolution.md + yolov5/tutorials/yolov5_hyperparameter_evolution_tutorial.md: yolov5/tutorials/hyperparameter_evolution.md + yolov5/tutorials/clearml_integration_tutorial.md: yolov5/tutorials/clearml_logging_integration.md + yolov5/tutorials/test_time_augmentation_tutorial.md: yolov5/tutorials/test_time_augmentation.md + yolov5/tutorials/yolov5_test_time_augmentation_tutorial.md: yolov5/tutorials/test_time_augmentation.md + yolov5/environments/yolov5_amazon_web_services_quickstart_tutorial.md: yolov5/environments/aws_quickstart_tutorial.md + yolov5/environments/yolov5_google_cloud_platform_quickstart_tutorial.md: yolov5/environments/google_cloud_quickstart_tutorial.md + yolov5/environments/yolov5_docker_image_quickstart_tutorial.md: yolov5/environments/docker_image_quickstart_tutorial.md + reference/data/explorer/explorer.md: datasets/explorer/index.md + reference/data/explorer/gui/dash.md: datasets/explorer/index.md + reference/data/explorer/utils.md: datasets/explorer/index.md diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..705a2eb80c361b722f277bec6fe37cf01d262584 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,186 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Overview: +# This pyproject.toml file manages the build, packaging, and distribution of the Ultralytics library. +# It defines essential project metadata, dependencies, and settings used to develop and deploy the library. + +# Key Sections: +# - [build-system]: Specifies the build requirements and backend (e.g., setuptools, wheel). +# - [project]: Includes details like name, version, description, authors, dependencies and more. +# - [project.optional-dependencies]: Provides additional, optional packages for extended features. +# - [tool.*]: Configures settings for various tools (pytest, yapf, etc.) used in the project. + +# Installation: +# The Ultralytics library can be installed using the command: 'pip install ultralytics' +# For development purposes, you can install the package in editable mode with: 'pip install -e .' +# This approach allows for real-time code modifications without the need for re-installation. + +# Documentation: +# For comprehensive documentation and usage instructions, visit: https://docs.ultralytics.com + +[build-system] +requires = ["setuptools>=70.0.0", "wheel"] +build-backend = "setuptools.build_meta" + +# Project settings ----------------------------------------------------------------------------------------------------- +[project] +name = "ultralytics" +dynamic = ["version"] +description = "Ultralytics YOLO 🚀 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification." +readme = "README.md" +requires-python = ">=3.8" +license = { "text" = "AGPL-3.0" } +keywords = ["machine-learning", "deep-learning", "computer-vision", "ML", "DL", "AI", "YOLO", "YOLOv3", "YOLOv5", "YOLOv8", "YOLOv9", "YOLOv10", "YOLO11", "HUB", "Ultralytics"] +authors = [ + { name = "Glenn Jocher", email = "glenn.jocher@ultralytics.com" }, + { name = "Jing Qiu", email = "jing.qiu@ultralytics.com" }, +] +maintainers = [ + { name = "Ultralytics", email = "hello@ultralytics.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Software Development", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Image Recognition", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", +] + +# Required dependencies ------------------------------------------------------------------------------------------------ +dependencies = [ + "numpy>=1.23.0", + "numpy<2.0.0; sys_platform == 'darwin'", # macOS OpenVINO errors https://github.com/ultralytics/ultralytics/pull/17221 + "matplotlib>=3.3.0", + "opencv-python>=4.6.0", + "pillow>=7.1.2", + "pyyaml>=5.3.1", + "requests>=2.23.0", + "scipy>=1.4.1", + "torch>=1.8.0", + "torch>=1.8.0,!=2.4.0; sys_platform == 'win32'", # Windows CPU errors w/ 2.4.0 https://github.com/ultralytics/ultralytics/issues/15049 + "torchvision>=0.9.0", + "tqdm>=4.64.0", # progress bars + "psutil", # system utilization + "py-cpuinfo", # display CPU info + "pandas>=1.1.4", + "seaborn>=0.11.0", # plotting + "ultralytics-thop>=2.0.0", # FLOPs computation https://github.com/ultralytics/thop +] + +# Optional dependencies ------------------------------------------------------------------------------------------------ +[project.optional-dependencies] +dev = [ + "ipython", + "pytest", + "pytest-cov", + "coverage[toml]", + "mkdocs>=1.6.0", + "mkdocs-material>=9.5.9", + "mkdocstrings[python]", + "mkdocs-redirects", # 301 redirects + "mkdocs-ultralytics-plugin>=0.1.8", # for meta descriptions and images, dates and authors + "mkdocs-macros-plugin>=1.0.5" # duplicating content (i.e. export tables) in multiple places +] +export = [ + "onnx>=1.12.0", # ONNX export + "coremltools>=7.0; platform_system != 'Windows' and python_version <= '3.11'", # CoreML supported on macOS and Linux + "scikit-learn>=1.3.2; platform_system != 'Windows' and python_version <= '3.11'", # CoreML k-means quantization + "openvino>=2024.0.0", # OpenVINO export + "tensorflow>=2.0.0", # TF bug https://github.com/ultralytics/ultralytics/issues/5161 + "tensorflowjs>=3.9.0", # TF.js export, automatically installs tensorflow + "tensorstore>=0.1.63; platform_machine == 'aarch64' and python_version >= '3.9'", # for TF Raspberry Pi exports + "keras", # not installed automatically by tensorflow>=2.16 + "flatbuffers>=23.5.26,<100; platform_machine == 'aarch64'", # update old 'flatbuffers' included inside tensorflow package + "numpy==1.23.5; platform_machine == 'aarch64'", # fix error: `np.bool` was a deprecated alias for the builtin `bool` when using TensorRT models on NVIDIA Jetson + "h5py!=3.11.0; platform_machine == 'aarch64'", # fix h5py build issues due to missing aarch64 wheels in 3.11 release +] +solutions = [ + "shapely>=2.0.0", # shapely for point and polygon data matching + "streamlit", # for live inference on web browser i.e `yolo streamlit-predict` +] +logging = [ + "comet", # https://docs.ultralytics.com/integrations/comet/ + "tensorboard>=2.13.0", + "dvclive>=2.12.0", +] +extra = [ + "hub-sdk>=0.0.12", # Ultralytics HUB + "ipython", # interactive notebook + "albumentations>=1.4.6", # training augmentations + "pycocotools>=2.0.7", # COCO mAP +] + +[project.urls] +"Homepage" = "https://ultralytics.com" +"Source" = "https://github.com/ultralytics/ultralytics" +"Documentation" = "https://docs.ultralytics.com" +"Bug Reports" = "https://github.com/ultralytics/ultralytics/issues" +"Changelog" = "https://github.com/ultralytics/ultralytics/releases" + +[project.scripts] +yolo = "ultralytics.cfg:entrypoint" +ultralytics = "ultralytics.cfg:entrypoint" + +# Tools settings ------------------------------------------------------------------------------------------------------- +[tool.setuptools] # configuration specific to the `setuptools` build backend. +packages = { find = { where = ["."], include = ["ultralytics", "ultralytics.*"] } } +package-data = { "ultralytics" = ["**/*.yaml", "../tests/*.py"], "ultralytics.assets" = ["*.jpg"] } + +[tool.setuptools.dynamic] +version = { attr = "ultralytics.__version__" } + +[tool.pytest.ini_options] +addopts = "--doctest-modules --durations=30 --color=yes" +markers = [ + "slow: skip slow tests unless --slow is set", +] +norecursedirs = [".git", "dist", "build"] + +[tool.coverage.run] +source = ["ultralytics/"] +data_file = "tests/.coverage" +omit = ["ultralytics/utils/callbacks/*"] + +[tool.isort] +line_length = 120 +multi_line_output = 0 + +[tool.yapf] +based_on_style = "pep8" +spaces_before_comment = 2 +column_limit = 120 +coalesce_brackets = true +spaces_around_power_operator = true +space_between_ending_comma_and_closing_bracket = true +split_before_closing_bracket = false +split_before_first_argument = false + +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +docstring-code-format = true + +[tool.docformatter] +wrap-summaries = 120 +wrap-descriptions = 120 +pre-summary-newline = true +close-quotes-on-newline = true +in-place = true + +[tool.codespell] +ignore-words-list = "crate,nd,ned,strack,dota,ane,segway,fo,gool,winn,commend,bloc,nam,afterall" +skip = '*.pt,*.pth,*.torchscript,*.onnx,*.tflite,*.pb,*.bin,*.param,*.mlmodel,*.engine,*.npy,*.data*,*.csv,*pnnx*,*venv*,*translat*,__pycache__*,*.ico,*.jpg,*.png,*.mp4,*.mov,/runs,/.git,./docs/??/*.md,./docs/mkdocs_??.yml' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..38e813d20d8306e6dbbb2eb9e4722bd7d3e5a292 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +torch==2.2.2 +torchvision==0.17.2 +flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl +timm==1.0.14 +albumentations==2.0.4 +onnx==1.14.0 +onnxruntime==1.15.1 +pycocotools==2.0.7 +PyYAML==6.0.1 +scipy==1.13.0 +onnxslim==0.1.31 +onnxruntime-gpu==1.18.0 +gradio==4.44.1 +opencv-python==4.9.0.80 +psutil==5.9.8 +py-cpuinfo==9.0.0 +huggingface-hub==0.23.2 +safetensors==0.4.3 +numpy==1.26.4 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e86aa3c593d0def794a67a5eeadbbf46bba3097 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks + +# Constants used in tests +MODEL = WEIGHTS_DIR / "path with spaces" / "yolo11n.pt" # test spaces in path +CFG = "yolo11n.yaml" +SOURCE = ASSETS / "bus.jpg" +SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"] +TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files +CUDA_IS_AVAILABLE = checks.cuda_is_available() +CUDA_DEVICE_COUNT = checks.cuda_device_count() + +__all__ = ( + "MODEL", + "CFG", + "SOURCE", + "SOURCES_LIST", + "TMP", + "CUDA_IS_AVAILABLE", + "CUDA_DEVICE_COUNT", +) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..8703d81fce7a50a1a182fb89c5098d88d321f263 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,83 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import shutil +from pathlib import Path + +from tests import TMP + + +def pytest_addoption(parser): + """ + Add custom command-line options to pytest. + + Args: + parser (pytest.config.Parser): The pytest parser object for adding custom command-line options. + + Returns: + (None) + """ + parser.addoption("--slow", action="store_true", default=False, help="Run slow tests") + + +def pytest_collection_modifyitems(config, items): + """ + Modify the list of test items to exclude tests marked as slow if the --slow option is not specified. + + Args: + config (pytest.config.Config): The pytest configuration object that provides access to command-line options. + items (list): The list of collected pytest item objects to be modified based on the presence of --slow option. + + Returns: + (None) The function modifies the 'items' list in place, and does not return a value. + """ + if not config.getoption("--slow"): + # Remove the item entirely from the list of test items if it's marked as 'slow' + items[:] = [item for item in items if "slow" not in item.keywords] + + +def pytest_sessionstart(session): + """ + Initialize session configurations for pytest. + + This function is automatically called by pytest after the 'Session' object has been created but before performing + test collection. It sets the initial seeds and prepares the temporary directory for the test session. + + Args: + session (pytest.Session): The pytest session object. + + Returns: + (None) + """ + from ultralytics.utils.torch_utils import init_seeds + + init_seeds() + shutil.rmtree(TMP, ignore_errors=True) # delete any existing tests/tmp directory + TMP.mkdir(parents=True, exist_ok=True) # create a new empty directory + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + """ + Cleanup operations after pytest session. + + This function is automatically called by pytest at the end of the entire test session. It removes certain files + and directories used during testing. + + Args: + terminalreporter (pytest.terminal.TerminalReporter): The terminal reporter object used for terminal output. + exitstatus (int): The exit status of the test run. + config (pytest.config.Config): The pytest config object. + + Returns: + (None) + """ + from ultralytics.utils import WEIGHTS_DIR + + # Remove files + models = [path for x in ["*.onnx", "*.torchscript"] for path in WEIGHTS_DIR.rglob(x)] + for file in ["decelera_portrait_min.mov", "bus.jpg", "yolo11n.onnx", "yolo11n.torchscript"] + models: + Path(file).unlink(missing_ok=True) + + # Remove directories + models = [path for x in ["*.mlpackage", "*_openvino_model"] for path in WEIGHTS_DIR.rglob(x)] + for directory in [WEIGHTS_DIR / "path with spaces", TMP.parents[1] / ".pytest_cache", TMP] + models: + shutil.rmtree(directory, ignore_errors=True) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..aab6d8b4ac7d2dcd5a903e10a355fcb1c1449377 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,122 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import subprocess + +import pytest +from PIL import Image + +from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE +from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS +from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks +from ultralytics.utils.torch_utils import TORCH_1_9 + +# Constants +TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS] +MODELS = [WEIGHTS_DIR / TASK2MODEL[task] for task in TASKS] + + +def run(cmd): + """Execute a shell command using subprocess.""" + subprocess.run(cmd.split(), check=True) + + +def test_special_modes(): + """Test various special command-line modes for YOLO functionality.""" + run("yolo help") + run("yolo checks") + run("yolo version") + run("yolo settings reset") + run("yolo cfg") + + +@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA) +def test_train(task, model, data): + """Test YOLO training for different tasks, models, and datasets.""" + run(f"yolo train {task} model={model} data={data} imgsz=32 epochs=1 cache=disk") + + +@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA) +def test_val(task, model, data): + """Test YOLO validation process for specified task, model, and data using a shell command.""" + run(f"yolo val {task} model={model} data={data} imgsz=32 save_txt save_json") + + +@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA) +def test_predict(task, model, data): + """Test YOLO prediction on provided sample assets for specified task and model.""" + run(f"yolo predict model={model} source={ASSETS} imgsz=32 save save_crop save_txt") + + +@pytest.mark.parametrize("model", MODELS) +def test_export(model): + """Test exporting a YOLO model to TorchScript format.""" + run(f"yolo export model={model} format=torchscript imgsz=32") + + +def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"): + """Test the RTDETR functionality within Ultralytics for detection tasks using specified model and data.""" + # Warning: must use imgsz=640 (note also add coma, spaces, fraction=0.25 args to test single-image training) + run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk fraction=0.25") + run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt") + if TORCH_1_9: + weights = WEIGHTS_DIR / "rtdetr-l.pt" + run(f"yolo predict {task} model={weights} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt") + + +@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM with CLIP is not supported in Python 3.12") +def test_fastsam(task="segment", model=WEIGHTS_DIR / "FastSAM-s.pt", data="coco8-seg.yaml"): + """Test FastSAM model for segmenting objects in images using various prompts within Ultralytics.""" + source = ASSETS / "bus.jpg" + + run(f"yolo segment val {task} model={model} data={data} imgsz=32") + run(f"yolo segment predict model={model} source={source} imgsz=32 save save_crop save_txt") + + from ultralytics import FastSAM + from ultralytics.models.sam import Predictor + + # Create a FastSAM model + sam_model = FastSAM(model) # or FastSAM-x.pt + + # Run inference on an image + for s in (source, Image.open(source)): + everything_results = sam_model(s, device="cpu", retina_masks=True, imgsz=320, conf=0.4, iou=0.9) + + # Remove small regions + new_masks, _ = Predictor.remove_small_regions(everything_results[0].masks.data, min_area=20) + + # Run inference with bboxes and points and texts prompt at the same time + sam_model(source, bboxes=[439, 437, 524, 709], points=[[200, 200]], labels=[1], texts="a photo of a dog") + + +def test_mobilesam(): + """Test MobileSAM segmentation with point prompts using Ultralytics.""" + from ultralytics import SAM + + # Load the model + model = SAM(WEIGHTS_DIR / "mobile_sam.pt") + + # Source + source = ASSETS / "zidane.jpg" + + # Predict a segment based on a 1D point prompt and 1D labels. + model.predict(source, points=[900, 370], labels=[1]) + + # Predict a segment based on 3D points and 2D labels (multiple points per object). + model.predict(source, points=[[[900, 370], [1000, 100]]], labels=[[1, 1]]) + + # Predict a segment based on a box prompt + model.predict(source, bboxes=[439, 437, 524, 709], save=True) + + # Predict all + # model(source) + + +# Slow Tests ----------------------------------------------------------------------------------------------------------- +@pytest.mark.slow +@pytest.mark.parametrize("task,model,data", TASK_MODEL_DATA) +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +@pytest.mark.skipif(CUDA_DEVICE_COUNT < 2, reason="DDP is not available") +def test_train_gpu(task, model, data): + """Test YOLO training on GPU(s) for various tasks and models.""" + run(f"yolo train {task} model={model} data={data} imgsz=32 epochs=1 device=0") # single GPU + run(f"yolo train {task} model={model} data={data} imgsz=32 epochs=1 device=0,1") # multi GPU diff --git a/tests/test_cuda.py b/tests/test_cuda.py new file mode 100644 index 0000000000000000000000000000000000000000..d94f95bd06b7efd61189d310537f3298d384024f --- /dev/null +++ b/tests/test_cuda.py @@ -0,0 +1,155 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from itertools import product +from pathlib import Path + +import pytest +import torch + +from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE, MODEL, SOURCE +from ultralytics import YOLO +from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS +from ultralytics.utils import ASSETS, WEIGHTS_DIR +from ultralytics.utils.checks import check_amp + + +def test_checks(): + """Validate CUDA settings against torch CUDA functions.""" + assert torch.cuda.is_available() == CUDA_IS_AVAILABLE + assert torch.cuda.device_count() == CUDA_DEVICE_COUNT + + +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +def test_amp(): + """Test AMP training checks.""" + model = YOLO("yolo11n.pt").model.cuda() + assert check_amp(model) + + +@pytest.mark.slow +@pytest.mark.skipif(True, reason="CUDA export tests disabled pending additional Ultralytics GPU server availability") +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +@pytest.mark.parametrize( + "task, dynamic, int8, half, batch", + [ # generate all combinations but exclude those where both int8 and half are True + (task, dynamic, int8, half, batch) + # Note: tests reduced below pending compute availability expansion as GPU CI runner utilization is high + # for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2]) + for task, dynamic, int8, half, batch in product(TASKS, [True], [True], [False], [2]) + if not (int8 and half) # exclude cases where both int8 and half are True + ], +) +def test_export_engine_matrix(task, dynamic, int8, half, batch): + """Test YOLO model export to TensorRT format for various configurations and run inference.""" + file = YOLO(TASK2MODEL[task]).export( + format="engine", + imgsz=32, + dynamic=dynamic, + int8=int8, + half=half, + batch=batch, + data=TASK2DATA[task], + workspace=1, # reduce workspace GB for less resource utilization during testing + simplify=True, # use 'onnxslim' + ) + YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference + Path(file).unlink() # cleanup + Path(file).with_suffix(".cache").unlink() if int8 else None # cleanup INT8 cache + + +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +def test_train(): + """Test model training on a minimal dataset using available CUDA devices.""" + device = 0 if CUDA_DEVICE_COUNT == 1 else [0, 1] + YOLO(MODEL).train(data="coco8.yaml", imgsz=64, epochs=1, device=device) # requires imgsz>=64 + + +@pytest.mark.slow +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +def test_predict_multiple_devices(): + """Validate model prediction consistency across CPU and CUDA devices.""" + model = YOLO("yolo11n.pt") + model = model.cpu() + assert str(model.device) == "cpu" + _ = model(SOURCE) # CPU inference + assert str(model.device) == "cpu" + + model = model.to("cuda:0") + assert str(model.device) == "cuda:0" + _ = model(SOURCE) # CUDA inference + assert str(model.device) == "cuda:0" + + model = model.cpu() + assert str(model.device) == "cpu" + _ = model(SOURCE) # CPU inference + assert str(model.device) == "cpu" + + model = model.cuda() + assert str(model.device) == "cuda:0" + _ = model(SOURCE) # CUDA inference + assert str(model.device) == "cuda:0" + + +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +def test_autobatch(): + """Check optimal batch size for YOLO model training using autobatch utility.""" + from ultralytics.utils.autobatch import check_train_batch_size + + check_train_batch_size(YOLO(MODEL).model.cuda(), imgsz=128, amp=True) + + +@pytest.mark.slow +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +def test_utils_benchmarks(): + """Profile YOLO models for performance benchmarks.""" + from ultralytics.utils.benchmarks import ProfileModels + + # Pre-export a dynamic engine model to use dynamic inference + YOLO(MODEL).export(format="engine", imgsz=32, dynamic=True, batch=1) + ProfileModels([MODEL], imgsz=32, half=False, min_time=1, num_timed_runs=3, num_warmup_runs=1).profile() + + +@pytest.mark.skipif(not CUDA_IS_AVAILABLE, reason="CUDA is not available") +def test_predict_sam(): + """Test SAM model predictions using different prompts, including bounding boxes and point annotations.""" + from ultralytics import SAM + from ultralytics.models.sam import Predictor as SAMPredictor + + # Load a model + model = SAM(WEIGHTS_DIR / "sam2.1_b.pt") + + # Display model information (optional) + model.info() + + # Run inference + model(SOURCE, device=0) + + # Run inference with bboxes prompt + model(SOURCE, bboxes=[439, 437, 524, 709], device=0) + + # Run inference with no labels + model(ASSETS / "zidane.jpg", points=[900, 370], device=0) + + # Run inference with 1D points and 1D labels + model(ASSETS / "zidane.jpg", points=[900, 370], labels=[1], device=0) + + # Run inference with 2D points and 1D labels + model(ASSETS / "zidane.jpg", points=[[900, 370]], labels=[1], device=0) + + # Run inference with multiple 2D points and 1D labels + model(ASSETS / "zidane.jpg", points=[[400, 370], [900, 370]], labels=[1, 1], device=0) + + # Run inference with 3D points and 2D labels (multiple points per object) + model(ASSETS / "zidane.jpg", points=[[[900, 370], [1000, 100]]], labels=[[1, 1]], device=0) + + # Create SAMPredictor + overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024, model=WEIGHTS_DIR / "mobile_sam.pt") + predictor = SAMPredictor(overrides=overrides) + + # Set image + predictor.set_image(ASSETS / "zidane.jpg") # set with image file + # predictor(bboxes=[439, 437, 524, 709]) + # predictor(points=[900, 370], labels=[1]) + + # Reset image + predictor.reset_image() diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..fe95a5ca5dd557cb9fc82ff2e61755d7bcd6e81b --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,131 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import sys +from unittest import mock + +from tests import MODEL +from ultralytics import YOLO +from ultralytics.cfg import get_cfg +from ultralytics.engine.exporter import Exporter +from ultralytics.models.yolo import classify, detect, segment +from ultralytics.utils import ASSETS, DEFAULT_CFG, WEIGHTS_DIR + + +def test_func(*args): # noqa + """Test function callback for evaluating YOLO model performance metrics.""" + print("callback test passed") + + +def test_export(): + """Tests the model exporting function by adding a callback and asserting its execution.""" + exporter = Exporter() + exporter.add_callback("on_export_start", test_func) + assert test_func in exporter.callbacks["on_export_start"], "callback test failed" + f = exporter(model=YOLO("yolo11n.yaml").model) + YOLO(f)(ASSETS) # exported model inference + + +def test_detect(): + """Test YOLO object detection training, validation, and prediction functionality.""" + overrides = {"data": "coco8.yaml", "model": "yolo11n.yaml", "imgsz": 32, "epochs": 1, "save": False} + cfg = get_cfg(DEFAULT_CFG) + cfg.data = "coco8.yaml" + cfg.imgsz = 32 + + # Trainer + trainer = detect.DetectionTrainer(overrides=overrides) + trainer.add_callback("on_train_start", test_func) + assert test_func in trainer.callbacks["on_train_start"], "callback test failed" + trainer.train() + + # Validator + val = detect.DetectionValidator(args=cfg) + val.add_callback("on_val_start", test_func) + assert test_func in val.callbacks["on_val_start"], "callback test failed" + val(model=trainer.best) # validate best.pt + + # Predictor + pred = detect.DetectionPredictor(overrides={"imgsz": [64, 64]}) + pred.add_callback("on_predict_start", test_func) + assert test_func in pred.callbacks["on_predict_start"], "callback test failed" + # Confirm there is no issue with sys.argv being empty. + with mock.patch.object(sys, "argv", []): + result = pred(source=ASSETS, model=MODEL) + assert len(result), "predictor test failed" + + overrides["resume"] = trainer.last + trainer = detect.DetectionTrainer(overrides=overrides) + try: + trainer.train() + except Exception as e: + print(f"Expected exception caught: {e}") + return + + Exception("Resume test failed!") + + +def test_segment(): + """Tests image segmentation training, validation, and prediction pipelines using YOLO models.""" + overrides = {"data": "coco8-seg.yaml", "model": "yolo11n-seg.yaml", "imgsz": 32, "epochs": 1, "save": False} + cfg = get_cfg(DEFAULT_CFG) + cfg.data = "coco8-seg.yaml" + cfg.imgsz = 32 + # YOLO(CFG_SEG).train(**overrides) # works + + # Trainer + trainer = segment.SegmentationTrainer(overrides=overrides) + trainer.add_callback("on_train_start", test_func) + assert test_func in trainer.callbacks["on_train_start"], "callback test failed" + trainer.train() + + # Validator + val = segment.SegmentationValidator(args=cfg) + val.add_callback("on_val_start", test_func) + assert test_func in val.callbacks["on_val_start"], "callback test failed" + val(model=trainer.best) # validate best.pt + + # Predictor + pred = segment.SegmentationPredictor(overrides={"imgsz": [64, 64]}) + pred.add_callback("on_predict_start", test_func) + assert test_func in pred.callbacks["on_predict_start"], "callback test failed" + result = pred(source=ASSETS, model=WEIGHTS_DIR / "yolo11n-seg.pt") + assert len(result), "predictor test failed" + + # Test resume + overrides["resume"] = trainer.last + trainer = segment.SegmentationTrainer(overrides=overrides) + try: + trainer.train() + except Exception as e: + print(f"Expected exception caught: {e}") + return + + Exception("Resume test failed!") + + +def test_classify(): + """Test image classification including training, validation, and prediction phases.""" + overrides = {"data": "imagenet10", "model": "yolo11n-cls.yaml", "imgsz": 32, "epochs": 1, "save": False} + cfg = get_cfg(DEFAULT_CFG) + cfg.data = "imagenet10" + cfg.imgsz = 32 + # YOLO(CFG_SEG).train(**overrides) # works + + # Trainer + trainer = classify.ClassificationTrainer(overrides=overrides) + trainer.add_callback("on_train_start", test_func) + assert test_func in trainer.callbacks["on_train_start"], "callback test failed" + trainer.train() + + # Validator + val = classify.ClassificationValidator(args=cfg) + val.add_callback("on_val_start", test_func) + assert test_func in val.callbacks["on_val_start"], "callback test failed" + val(model=trainer.best) + + # Predictor + pred = classify.ClassificationPredictor(overrides={"imgsz": [64, 64]}) + pred.add_callback("on_predict_start", test_func) + assert test_func in pred.callbacks["on_predict_start"], "callback test failed" + result = pred(source=ASSETS, model=trainer.best) + assert len(result), "predictor test failed" diff --git a/tests/test_exports.py b/tests/test_exports.py new file mode 100644 index 0000000000000000000000000000000000000000..34dc32de1bc6e2ca220ec880bc5351fc07e02a23 --- /dev/null +++ b/tests/test_exports.py @@ -0,0 +1,216 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import shutil +import uuid +from itertools import product +from pathlib import Path + +import pytest + +from tests import MODEL, SOURCE +from ultralytics import YOLO +from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS +from ultralytics.utils import ( + IS_RASPBERRYPI, + LINUX, + MACOS, + WINDOWS, + checks, +) +from ultralytics.utils.torch_utils import TORCH_1_9, TORCH_1_13 + + +def test_export_torchscript(): + """Test YOLO model exporting to TorchScript format for compatibility and correctness.""" + file = YOLO(MODEL).export(format="torchscript", optimize=False, imgsz=32) + YOLO(file)(SOURCE, imgsz=32) # exported model inference + + +def test_export_onnx(): + """Test YOLO model export to ONNX format with dynamic axes.""" + file = YOLO(MODEL).export(format="onnx", dynamic=True, imgsz=32) + YOLO(file)(SOURCE, imgsz=32) # exported model inference + + +@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13") +def test_export_openvino(): + """Test YOLO exports to OpenVINO format for model inference compatibility.""" + file = YOLO(MODEL).export(format="openvino", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) # exported model inference + + +@pytest.mark.slow +@pytest.mark.skipif(not TORCH_1_13, reason="OpenVINO requires torch>=1.13") +@pytest.mark.parametrize( + "task, dynamic, int8, half, batch", + [ # generate all combinations but exclude those where both int8 and half are True + (task, dynamic, int8, half, batch) + for task, dynamic, int8, half, batch in product(TASKS, [True, False], [True, False], [True, False], [1, 2]) + if not (int8 and half) # exclude cases where both int8 and half are True + ], +) +def test_export_openvino_matrix(task, dynamic, int8, half, batch): + """Test YOLO model exports to OpenVINO under various configuration matrix conditions.""" + file = YOLO(TASK2MODEL[task]).export( + format="openvino", + imgsz=32, + dynamic=dynamic, + int8=int8, + half=half, + batch=batch, + data=TASK2DATA[task], + ) + if WINDOWS: + # Use unique filenames due to Windows file permissions bug possibly due to latent threaded use + # See https://github.com/ultralytics/ultralytics/actions/runs/8957949304/job/24601616830?pr=10423 + file = Path(file) + file = file.rename(file.with_stem(f"{file.stem}-{uuid.uuid4()}")) + YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference + shutil.rmtree(file, ignore_errors=True) # retry in case of potential lingering multi-threaded file usage errors + + +@pytest.mark.slow +@pytest.mark.parametrize( + "task, dynamic, int8, half, batch, simplify", product(TASKS, [True, False], [False], [False], [1, 2], [True, False]) +) +def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify): + """Test YOLO exports to ONNX format with various configurations and parameters.""" + file = YOLO(TASK2MODEL[task]).export( + format="onnx", + imgsz=32, + dynamic=dynamic, + int8=int8, + half=half, + batch=batch, + simplify=simplify, + ) + YOLO(file)([SOURCE] * batch, imgsz=64 if dynamic else 32) # exported model inference + Path(file).unlink() # cleanup + + +@pytest.mark.slow +@pytest.mark.parametrize("task, dynamic, int8, half, batch", product(TASKS, [False], [False], [False], [1, 2])) +def test_export_torchscript_matrix(task, dynamic, int8, half, batch): + """Tests YOLO model exports to TorchScript format under varied configurations.""" + file = YOLO(TASK2MODEL[task]).export( + format="torchscript", + imgsz=32, + dynamic=dynamic, + int8=int8, + half=half, + batch=batch, + ) + YOLO(file)([SOURCE] * 3, imgsz=64 if dynamic else 32) # exported model inference at batch=3 + Path(file).unlink() # cleanup + + +@pytest.mark.slow +@pytest.mark.skipif(not MACOS, reason="CoreML inference only supported on macOS") +@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8") +@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12") +@pytest.mark.parametrize( + "task, dynamic, int8, half, batch", + [ # generate all combinations but exclude those where both int8 and half are True + (task, dynamic, int8, half, batch) + for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1]) + if not (int8 and half) # exclude cases where both int8 and half are True + ], +) +def test_export_coreml_matrix(task, dynamic, int8, half, batch): + """Test YOLO exports to CoreML format with various parameter configurations.""" + file = YOLO(TASK2MODEL[task]).export( + format="coreml", + imgsz=32, + dynamic=dynamic, + int8=int8, + half=half, + batch=batch, + ) + YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3 + shutil.rmtree(file) # cleanup + + +@pytest.mark.slow +@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10") +@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS") +@pytest.mark.parametrize( + "task, dynamic, int8, half, batch", + [ # generate all combinations but exclude those where both int8 and half are True + (task, dynamic, int8, half, batch) + for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1]) + if not (int8 and half) # exclude cases where both int8 and half are True + ], +) +def test_export_tflite_matrix(task, dynamic, int8, half, batch): + """Test YOLO exports to TFLite format considering various export configurations.""" + file = YOLO(TASK2MODEL[task]).export( + format="tflite", + imgsz=32, + dynamic=dynamic, + int8=int8, + half=half, + batch=batch, + ) + YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3 + Path(file).unlink() # cleanup + + +@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8") +@pytest.mark.skipif(WINDOWS, reason="CoreML not supported on Windows") # RuntimeError: BlobWriter not loaded +@pytest.mark.skipif(IS_RASPBERRYPI, reason="CoreML not supported on Raspberry Pi") +@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12") +def test_export_coreml(): + """Test YOLO exports to CoreML format, optimized for macOS only.""" + if MACOS: + file = YOLO(MODEL).export(format="coreml", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) # model prediction only supported on macOS for nms=False models + else: + YOLO(MODEL).export(format="coreml", nms=True, imgsz=32) + + +@pytest.mark.skipif(not checks.IS_PYTHON_MINIMUM_3_10, reason="TFLite export requires Python>=3.10") +@pytest.mark.skipif(not LINUX, reason="Test disabled as TF suffers from install conflicts on Windows and macOS") +def test_export_tflite(): + """Test YOLO exports to TFLite format under specific OS and Python version conditions.""" + model = YOLO(MODEL) + file = model.export(format="tflite", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) + + +@pytest.mark.skipif(True, reason="Test disabled") +@pytest.mark.skipif(not LINUX, reason="TF suffers from install conflicts on Windows and macOS") +def test_export_pb(): + """Test YOLO exports to TensorFlow's Protobuf (*.pb) format.""" + model = YOLO(MODEL) + file = model.export(format="pb", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) + + +@pytest.mark.skipif(True, reason="Test disabled as Paddle protobuf and ONNX protobuf requirements conflict.") +def test_export_paddle(): + """Test YOLO exports to Paddle format, noting protobuf conflicts with ONNX.""" + YOLO(MODEL).export(format="paddle", imgsz=32) + + +@pytest.mark.slow +@pytest.mark.skipif(IS_RASPBERRYPI, reason="MNN not supported on Raspberry Pi") +def test_export_mnn(): + """Test YOLO exports to MNN format (WARNING: MNN test must precede NCNN test or CI error on Windows).""" + file = YOLO(MODEL).export(format="mnn", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) # exported model inference + + +@pytest.mark.slow +def test_export_ncnn(): + """Test YOLO exports to NCNN format.""" + file = YOLO(MODEL).export(format="ncnn", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) # exported model inference + + +@pytest.mark.skipif(True, reason="Test disabled as keras and tensorflow version conflicts with tflite export.") +@pytest.mark.skipif(not LINUX or MACOS, reason="Skipping test on Windows and Macos") +def test_export_imx(): + """Test YOLOv8n exports to IMX format.""" + model = YOLO("yolov8n.pt") + file = model.export(format="imx", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) diff --git a/tests/test_integrations.py b/tests/test_integrations.py new file mode 100644 index 0000000000000000000000000000000000000000..8067a1787f142f0f3d8f0834713ee6709d671115 --- /dev/null +++ b/tests/test_integrations.py @@ -0,0 +1,150 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import os +import subprocess +import time +from pathlib import Path + +import pytest + +from tests import MODEL, SOURCE, TMP +from ultralytics import YOLO, download +from ultralytics.utils import DATASETS_DIR, SETTINGS +from ultralytics.utils.checks import check_requirements + + +@pytest.mark.skipif(not check_requirements("ray", install=False), reason="ray[tune] not installed") +def test_model_ray_tune(): + """Tune YOLO model using Ray for hyperparameter optimization.""" + YOLO("yolo11n-cls.yaml").tune( + use_ray=True, data="imagenet10", grace_period=1, iterations=1, imgsz=32, epochs=1, plots=False, device="cpu" + ) + + +@pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed") +def test_mlflow(): + """Test training with MLflow tracking enabled (see https://mlflow.org/ for details).""" + SETTINGS["mlflow"] = True + YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=3, plots=False, device="cpu") + SETTINGS["mlflow"] = False + + +@pytest.mark.skipif(True, reason="Test failing in scheduled CI https://github.com/ultralytics/ultralytics/pull/8868") +@pytest.mark.skipif(not check_requirements("mlflow", install=False), reason="mlflow not installed") +def test_mlflow_keep_run_active(): + """Ensure MLflow run status matches MLFLOW_KEEP_RUN_ACTIVE environment variable settings.""" + import mlflow + + SETTINGS["mlflow"] = True + run_name = "Test Run" + os.environ["MLFLOW_RUN"] = run_name + + # Test with MLFLOW_KEEP_RUN_ACTIVE=True + os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "True" + YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu") + status = mlflow.active_run().info.status + assert status == "RUNNING", "MLflow run should be active when MLFLOW_KEEP_RUN_ACTIVE=True" + + run_id = mlflow.active_run().info.run_id + + # Test with MLFLOW_KEEP_RUN_ACTIVE=False + os.environ["MLFLOW_KEEP_RUN_ACTIVE"] = "False" + YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu") + status = mlflow.get_run(run_id=run_id).info.status + assert status == "FINISHED", "MLflow run should be ended when MLFLOW_KEEP_RUN_ACTIVE=False" + + # Test with MLFLOW_KEEP_RUN_ACTIVE not set + os.environ.pop("MLFLOW_KEEP_RUN_ACTIVE", None) + YOLO("yolo11n-cls.yaml").train(data="imagenet10", imgsz=32, epochs=1, plots=False, device="cpu") + status = mlflow.get_run(run_id=run_id).info.status + assert status == "FINISHED", "MLflow run should be ended by default when MLFLOW_KEEP_RUN_ACTIVE is not set" + SETTINGS["mlflow"] = False + + +@pytest.mark.skipif(not check_requirements("tritonclient", install=False), reason="tritonclient[all] not installed") +def test_triton(): + """ + Test NVIDIA Triton Server functionalities with YOLO model. + + See https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver. + """ + check_requirements("tritonclient[all]") + from tritonclient.http import InferenceServerClient # noqa + + # Create variables + model_name = "yolo" + triton_repo = TMP / "triton_repo" # Triton repo path + triton_model = triton_repo / model_name # Triton model path + + # Export model to ONNX + f = YOLO(MODEL).export(format="onnx", dynamic=True) + + # Prepare Triton repo + (triton_model / "1").mkdir(parents=True, exist_ok=True) + Path(f).rename(triton_model / "1" / "model.onnx") + (triton_model / "config.pbtxt").touch() + + # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver + tag = "nvcr.io/nvidia/tritonserver:23.09-py3" # 6.4 GB + + # Pull the image + subprocess.call(f"docker pull {tag}", shell=True) + + # Run the Triton server and capture the container ID + container_id = ( + subprocess.check_output( + f"docker run -d --rm -v {triton_repo}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models", + shell=True, + ) + .decode("utf-8") + .strip() + ) + + # Wait for the Triton server to start + triton_client = InferenceServerClient(url="localhost:8000", verbose=False, ssl=False) + + # Wait until model is ready + for _ in range(10): + with contextlib.suppress(Exception): + assert triton_client.is_model_ready(model_name) + break + time.sleep(1) + + # Check Triton inference + YOLO(f"http://localhost:8000/{model_name}", "detect")(SOURCE) # exported model inference + + # Kill and remove the container at the end of the test + subprocess.call(f"docker kill {container_id}", shell=True) + + +@pytest.mark.skipif(not check_requirements("pycocotools", install=False), reason="pycocotools not installed") +def test_pycocotools(): + """Validate YOLO model predictions on COCO dataset using pycocotools.""" + from ultralytics.models.yolo.detect import DetectionValidator + from ultralytics.models.yolo.pose import PoseValidator + from ultralytics.models.yolo.segment import SegmentationValidator + + # Download annotations after each dataset downloads first + url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/" + + args = {"model": "yolo11n.pt", "data": "coco8.yaml", "save_json": True, "imgsz": 64} + validator = DetectionValidator(args=args) + validator() + validator.is_coco = True + download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8/annotations") + _ = validator.eval_json(validator.stats) + + args = {"model": "yolo11n-seg.pt", "data": "coco8-seg.yaml", "save_json": True, "imgsz": 64} + validator = SegmentationValidator(args=args) + validator() + validator.is_coco = True + download(f"{url}instances_val2017.json", dir=DATASETS_DIR / "coco8-seg/annotations") + _ = validator.eval_json(validator.stats) + + args = {"model": "yolo11n-pose.pt", "data": "coco8-pose.yaml", "save_json": True, "imgsz": 64} + validator = PoseValidator(args=args) + validator() + validator.is_coco = True + download(f"{url}person_keypoints_val2017.json", dir=DATASETS_DIR / "coco8-pose/annotations") + _ = validator.eval_json(validator.stats) diff --git a/tests/test_python.py b/tests/test_python.py new file mode 100644 index 0000000000000000000000000000000000000000..644176fb4828668a83559e0c2ecdf915854ef85e --- /dev/null +++ b/tests/test_python.py @@ -0,0 +1,615 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import csv +import urllib +from copy import copy +from pathlib import Path + +import cv2 +import numpy as np +import pytest +import torch +import yaml +from PIL import Image + +from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP +from ultralytics import RTDETR, YOLO +from ultralytics.cfg import MODELS, TASK2DATA, TASKS +from ultralytics.data.build import load_inference_source +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_PATH, + LOGGER, + ONLINE, + ROOT, + WEIGHTS_DIR, + WINDOWS, + checks, + is_dir_writeable, + is_github_action_running, +) +from ultralytics.utils.downloads import download +from ultralytics.utils.torch_utils import TORCH_1_9 + +IS_TMP_WRITEABLE = is_dir_writeable(TMP) # WARNING: must be run once tests start as TMP does not exist on tests/init + + +def test_model_forward(): + """Test the forward pass of the YOLO model.""" + model = YOLO(CFG) + model(source=None, imgsz=32, augment=True) # also test no source and augment + + +def test_model_methods(): + """Test various methods and properties of the YOLO model to ensure correct functionality.""" + model = YOLO(MODEL) + + # Model methods + model.info(verbose=True, detailed=True) + model = model.reset_weights() + model = model.load(MODEL) + model.to("cpu") + model.fuse() + model.clear_callback("on_train_start") + model.reset_callbacks() + + # Model properties + _ = model.names + _ = model.device + _ = model.transforms + _ = model.task_map + + +def test_model_profile(): + """Test profiling of the YOLO model with `profile=True` to assess performance and resource usage.""" + from ultralytics.nn.tasks import DetectionModel + + model = DetectionModel() # build model + im = torch.randn(1, 3, 64, 64) # requires min imgsz=64 + _ = model.predict(im, profile=True) + + +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_predict_txt(): + """Tests YOLO predictions with file, directory, and pattern sources listed in a text file.""" + file = TMP / "sources_multi_row.txt" + with open(file, "w") as f: + for src in SOURCES_LIST: + f.write(f"{src}\n") + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images + + +@pytest.mark.skipif(True, reason="disabled for testing") +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_predict_csv_multi_row(): + """Tests YOLO predictions with sources listed in multiple rows of a CSV file.""" + file = TMP / "sources_multi_row.csv" + with open(file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["source"]) + writer.writerows([[src] for src in SOURCES_LIST]) + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images + + +@pytest.mark.skipif(True, reason="disabled for testing") +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_predict_csv_single_row(): + """Tests YOLO predictions with sources listed in a single row of a CSV file.""" + file = TMP / "sources_single_row.csv" + with open(file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(SOURCES_LIST) + results = YOLO(MODEL)(source=file, imgsz=32) + assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images + + +@pytest.mark.parametrize("model_name", MODELS) +def test_predict_img(model_name): + """Test YOLO model predictions on various image input types and sources, including online images.""" + model = YOLO(WEIGHTS_DIR / model_name) + im = cv2.imread(str(SOURCE)) # uint8 numpy array + assert len(model(source=Image.open(SOURCE), save=True, verbose=True, imgsz=32)) == 1 # PIL + assert len(model(source=im, save=True, save_txt=True, imgsz=32)) == 1 # ndarray + assert len(model(torch.rand((2, 3, 32, 32)), imgsz=32)) == 2 # batch-size 2 Tensor, FP32 0.0-1.0 RGB order + assert len(model(source=[im, im], save=True, save_txt=True, imgsz=32)) == 2 # batch + assert len(list(model(source=[im, im], save=True, stream=True, imgsz=32))) == 2 # stream + assert len(model(torch.zeros(320, 640, 3).numpy().astype(np.uint8), imgsz=32)) == 1 # tensor to numpy + batch = [ + str(SOURCE), # filename + Path(SOURCE), # Path + "https://github.com/ultralytics/assets/releases/download/v0.0.0/zidane.jpg" if ONLINE else SOURCE, # URI + cv2.imread(str(SOURCE)), # OpenCV + Image.open(SOURCE), # PIL + np.zeros((320, 640, 3), dtype=np.uint8), # numpy + ] + assert len(model(batch, imgsz=32)) == len(batch) # multiple sources in a batch + + +@pytest.mark.parametrize("model", MODELS) +def test_predict_visualize(model): + """Test model prediction methods with 'visualize=True' to generate and display prediction visualizations.""" + YOLO(WEIGHTS_DIR / model)(SOURCE, imgsz=32, visualize=True) + + +def test_predict_grey_and_4ch(): + """Test YOLO prediction on SOURCE converted to greyscale and 4-channel images with various filenames.""" + im = Image.open(SOURCE) + directory = TMP / "im4" + directory.mkdir(parents=True, exist_ok=True) + + source_greyscale = directory / "greyscale.jpg" + source_rgba = directory / "4ch.png" + source_non_utf = directory / "non_UTF_测试文件_tést_image.jpg" + source_spaces = directory / "image with spaces.jpg" + + im.convert("L").save(source_greyscale) # greyscale + im.convert("RGBA").save(source_rgba) # 4-ch PNG with alpha + im.save(source_non_utf) # non-UTF characters in filename + im.save(source_spaces) # spaces in filename + + # Inference + model = YOLO(MODEL) + for f in source_rgba, source_greyscale, source_non_utf, source_spaces: + for source in Image.open(f), cv2.imread(str(f)), f: + results = model(source, save=True, verbose=True, imgsz=32) + assert len(results) == 1 # verify that an image was run + f.unlink() # cleanup + + +@pytest.mark.slow +@pytest.mark.skipif(not ONLINE, reason="environment is offline") +@pytest.mark.skipif(is_github_action_running(), reason="No auth https://github.com/JuanBindez/pytubefix/issues/166") +def test_youtube(): + """Test YOLO model on a YouTube video stream, handling potential network-related errors.""" + model = YOLO(MODEL) + try: + model.predict("https://youtu.be/G17sBkb38XQ", imgsz=96, save=True) + # Handle internet connection errors and 'urllib.error.HTTPError: HTTP Error 429: Too Many Requests' + except (urllib.error.HTTPError, ConnectionError) as e: + LOGGER.warning(f"WARNING: YouTube Test Error: {e}") + + +@pytest.mark.skipif(not ONLINE, reason="environment is offline") +@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable") +def test_track_stream(): + """ + Tests streaming tracking on a short 10 frame video using ByteTrack tracker and different GMC methods. + + Note imgsz=160 required for tracking for higher confidence and better matches. + """ + video_url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/decelera_portrait_min.mov" + model = YOLO(MODEL) + model.track(video_url, imgsz=160, tracker="bytetrack.yaml") + model.track(video_url, imgsz=160, tracker="botsort.yaml", save_frames=True) # test frame saving also + + # Test Global Motion Compensation (GMC) methods + for gmc in "orb", "sift", "ecc": + with open(ROOT / "cfg/trackers/botsort.yaml", encoding="utf-8") as f: + data = yaml.safe_load(f) + tracker = TMP / f"botsort-{gmc}.yaml" + data["gmc_method"] = gmc + with open(tracker, "w", encoding="utf-8") as f: + yaml.safe_dump(data, f) + model.track(video_url, imgsz=160, tracker=tracker) + + +def test_val(): + """Test the validation mode of the YOLO model.""" + YOLO(MODEL).val(data="coco8.yaml", imgsz=32, save_hybrid=True) + + +def test_train_scratch(): + """Test training the YOLO model from scratch using the provided configuration.""" + model = YOLO(CFG) + model.train(data="coco8.yaml", epochs=2, imgsz=32, cache="disk", batch=-1, close_mosaic=1, name="model") + model(SOURCE) + + +def test_train_pretrained(): + """Test training of the YOLO model starting from a pre-trained checkpoint.""" + model = YOLO(WEIGHTS_DIR / "yolo11n-seg.pt") + model.train(data="coco8-seg.yaml", epochs=1, imgsz=32, cache="ram", copy_paste=0.5, mixup=0.5, name=0) + model(SOURCE) + + +def test_all_model_yamls(): + """Test YOLO model creation for all available YAML configurations in the `cfg/models` directory.""" + for m in (ROOT / "cfg" / "models").rglob("*.yaml"): + if "rtdetr" in m.name: + if TORCH_1_9: # torch<=1.8 issue - TypeError: __init__() got an unexpected keyword argument 'batch_first' + _ = RTDETR(m.name)(SOURCE, imgsz=640) # must be 640 + else: + YOLO(m.name) + + +@pytest.mark.skipif(WINDOWS, reason="Windows slow CI export bug https://github.com/ultralytics/ultralytics/pull/16003") +def test_workflow(): + """Test the complete workflow including training, validation, prediction, and exporting.""" + model = YOLO(MODEL) + model.train(data="coco8.yaml", epochs=1, imgsz=32, optimizer="SGD") + model.val(imgsz=32) + model.predict(SOURCE, imgsz=32) + model.export(format="torchscript") # WARNING: Windows slow CI export bug + + +def test_predict_callback_and_setup(): + """Test callback functionality during YOLO prediction setup and execution.""" + + def on_predict_batch_end(predictor): + """Callback function that handles operations at the end of a prediction batch.""" + path, im0s, _ = predictor.batch + im0s = im0s if isinstance(im0s, list) else [im0s] + bs = [predictor.dataset.bs for _ in range(len(path))] + predictor.results = zip(predictor.results, im0s, bs) # results is List[batch_size] + + model = YOLO(MODEL) + model.add_callback("on_predict_batch_end", on_predict_batch_end) + + dataset = load_inference_source(source=SOURCE) + bs = dataset.bs # noqa access predictor properties + results = model.predict(dataset, stream=True, imgsz=160) # source already setup + for r, im0, bs in results: + print("test_callback", im0.shape) + print("test_callback", bs) + boxes = r.boxes # Boxes object for bbox outputs + print(boxes) + + +@pytest.mark.parametrize("model", MODELS) +def test_results(model): + """Ensure YOLO model predictions can be processed and printed in various formats.""" + results = YOLO(WEIGHTS_DIR / model)([SOURCE, SOURCE], imgsz=160) + for r in results: + r = r.cpu().numpy() + print(r, len(r), r.path) # print numpy attributes + r = r.to(device="cpu", dtype=torch.float32) + r.save_txt(txt_file=TMP / "runs/tests/label.txt", save_conf=True) + r.save_crop(save_dir=TMP / "runs/tests/crops/") + r.to_json(normalize=True) + r.to_df(decimals=3) + r.to_csv() + r.to_xml() + r.plot(pil=True) + r.plot(conf=True, boxes=True) + print(r, len(r), r.path) # print after methods + + +def test_labels_and_crops(): + """Test output from prediction args for saving YOLO detection labels and crops; ensures accurate saving.""" + imgs = [SOURCE, ASSETS / "zidane.jpg"] + results = YOLO(WEIGHTS_DIR / "yolo11n.pt")(imgs, imgsz=160, save_txt=True, save_crop=True) + save_path = Path(results[0].save_dir) + for r in results: + im_name = Path(r.path).stem + cls_idxs = r.boxes.cls.int().tolist() + # Check correct detections + assert cls_idxs == ([0, 7, 0, 0] if r.path.endswith("bus.jpg") else [0, 0, 0]) # bus.jpg and zidane.jpg classes + # Check label path + labels = save_path / f"labels/{im_name}.txt" + assert labels.exists() + # Check detections match label count + assert len(r.boxes.data) == len([line for line in labels.read_text().splitlines() if line]) + # Check crops path and files + crop_dirs = list((save_path / "crops").iterdir()) + crop_files = [f for p in crop_dirs for f in p.glob("*")] + # Crop directories match detections + assert all(r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs) + # Same number of crops as detections + assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data) + + +@pytest.mark.skipif(not ONLINE, reason="environment is offline") +def test_data_utils(): + """Test utility functions in ultralytics/data/utils.py, including dataset stats and auto-splitting.""" + from ultralytics.data.utils import HUBDatasetStats, autosplit + from ultralytics.utils.downloads import zip_directory + + # from ultralytics.utils.files import WorkingDirectory + # with WorkingDirectory(ROOT.parent / 'tests'): + + for task in TASKS: + file = Path(TASK2DATA[task]).with_suffix(".zip") # i.e. coco8.zip + download(f"https://github.com/ultralytics/hub/raw/main/example_datasets/{file}", unzip=False, dir=TMP) + stats = HUBDatasetStats(TMP / file, task=task) + stats.get_json(save=True) + stats.process_images() + + autosplit(TMP / "coco8") + zip_directory(TMP / "coco8/images/val") # zip + + +@pytest.mark.skipif(not ONLINE, reason="environment is offline") +def test_data_converter(): + """Test dataset conversion functions from COCO to YOLO format and class mappings.""" + from ultralytics.data.converter import coco80_to_coco91_class, convert_coco + + file = "instances_val2017.json" + download(f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{file}", dir=TMP) + convert_coco(labels_dir=TMP, save_dir=TMP / "yolo_labels", use_segments=True, use_keypoints=False, cls91to80=True) + coco80_to_coco91_class() + + +def test_data_annotator(): + """Automatically annotate data using specified detection and segmentation models.""" + from ultralytics.data.annotator import auto_annotate + + auto_annotate( + ASSETS, + det_model=WEIGHTS_DIR / "yolo11n.pt", + sam_model=WEIGHTS_DIR / "mobile_sam.pt", + output_dir=TMP / "auto_annotate_labels", + ) + + +def test_events(): + """Test event sending functionality.""" + from ultralytics.hub.utils import Events + + events = Events() + events.enabled = True + cfg = copy(DEFAULT_CFG) # does not require deepcopy + cfg.mode = "test" + events(cfg) + + +def test_cfg_init(): + """Test configuration initialization utilities from the 'ultralytics.cfg' module.""" + from ultralytics.cfg import check_dict_alignment, copy_default_cfg, smart_value + + with contextlib.suppress(SyntaxError): + check_dict_alignment({"a": 1}, {"b": 2}) + copy_default_cfg() + (Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")).unlink(missing_ok=False) + [smart_value(x) for x in ["none", "true", "false"]] + + +def test_utils_init(): + """Test initialization utilities in the Ultralytics library.""" + from ultralytics.utils import get_git_branch, get_git_origin_url, get_ubuntu_version, is_github_action_running + + get_ubuntu_version() + is_github_action_running() + get_git_origin_url() + get_git_branch() + + +def test_utils_checks(): + """Test various utility checks for filenames, git status, requirements, image sizes, and versions.""" + checks.check_yolov5u_filename("yolov5n.pt") + checks.git_describe(ROOT) + checks.check_requirements() # check requirements.txt + checks.check_imgsz([600, 600], max_dim=1) + checks.check_imshow(warn=True) + checks.check_version("ultralytics", "8.0.0") + checks.print_args() + + +@pytest.mark.skipif(WINDOWS, reason="Windows profiling is extremely slow (cause unknown)") +def test_utils_benchmarks(): + """Benchmark model performance using 'ProfileModels' from 'ultralytics.utils.benchmarks'.""" + from ultralytics.utils.benchmarks import ProfileModels + + ProfileModels(["yolo11n.yaml"], imgsz=32, min_time=1, num_timed_runs=3, num_warmup_runs=1).profile() + + +def test_utils_torchutils(): + """Test Torch utility functions including profiling and FLOP calculations.""" + from ultralytics.nn.modules.conv import Conv + from ultralytics.utils.torch_utils import get_flops_with_torch_profiler, profile, time_sync + + x = torch.randn(1, 64, 20, 20) + m = Conv(64, 64, k=1, s=2) + + profile(x, [m], n=3) + get_flops_with_torch_profiler(m) + time_sync() + + +def test_utils_ops(): + """Test utility operations functions for coordinate transformation and normalization.""" + from ultralytics.utils.ops import ( + ltwh2xywh, + ltwh2xyxy, + make_divisible, + xywh2ltwh, + xywh2xyxy, + xywhn2xyxy, + xywhr2xyxyxyxy, + xyxy2ltwh, + xyxy2xywh, + xyxy2xywhn, + xyxyxyxy2xywhr, + ) + + make_divisible(17, torch.tensor([8])) + + boxes = torch.rand(10, 4) # xywh + torch.allclose(boxes, xyxy2xywh(xywh2xyxy(boxes))) + torch.allclose(boxes, xyxy2xywhn(xywhn2xyxy(boxes))) + torch.allclose(boxes, ltwh2xywh(xywh2ltwh(boxes))) + torch.allclose(boxes, xyxy2ltwh(ltwh2xyxy(boxes))) + + boxes = torch.rand(10, 5) # xywhr for OBB + boxes[:, 4] = torch.randn(10) * 30 + torch.allclose(boxes, xyxyxyxy2xywhr(xywhr2xyxyxyxy(boxes)), rtol=1e-3) + + +def test_utils_files(): + """Test file handling utilities including file age, date, and paths with spaces.""" + from ultralytics.utils.files import file_age, file_date, get_latest_run, spaces_in_path + + file_age(SOURCE) + file_date(SOURCE) + get_latest_run(ROOT / "runs") + + path = TMP / "path/with spaces" + path.mkdir(parents=True, exist_ok=True) + with spaces_in_path(path) as new_path: + print(new_path) + + +@pytest.mark.slow +def test_utils_patches_torch_save(): + """Test torch_save backoff when _torch_save raises RuntimeError to ensure robustness.""" + from unittest.mock import MagicMock, patch + + from ultralytics.utils.patches import torch_save + + mock = MagicMock(side_effect=RuntimeError) + + with patch("ultralytics.utils.patches._torch_save", new=mock): + with pytest.raises(RuntimeError): + torch_save(torch.zeros(1), TMP / "test.pt") + + assert mock.call_count == 4, "torch_save was not attempted the expected number of times" + + +def test_nn_modules_conv(): + """Test Convolutional Neural Network modules including CBAM, Conv2, and ConvTranspose.""" + from ultralytics.nn.modules.conv import CBAM, Conv2, ConvTranspose, DWConvTranspose2d, Focus + + c1, c2 = 8, 16 # input and output channels + x = torch.zeros(4, c1, 10, 10) # BCHW + + # Run all modules not otherwise covered in tests + DWConvTranspose2d(c1, c2)(x) + ConvTranspose(c1, c2)(x) + Focus(c1, c2)(x) + CBAM(c1)(x) + + # Fuse ops + m = Conv2(c1, c2) + m.fuse_convs() + m(x) + + +def test_nn_modules_block(): + """Test various blocks in neural network modules including C1, C3TR, BottleneckCSP, C3Ghost, and C3x.""" + from ultralytics.nn.modules.block import C1, C3TR, BottleneckCSP, C3Ghost, C3x + + c1, c2 = 8, 16 # input and output channels + x = torch.zeros(4, c1, 10, 10) # BCHW + + # Run all modules not otherwise covered in tests + C1(c1, c2)(x) + C3x(c1, c2)(x) + C3TR(c1, c2)(x) + C3Ghost(c1, c2)(x) + BottleneckCSP(c1, c2)(x) + + +@pytest.mark.skipif(not ONLINE, reason="environment is offline") +def test_hub(): + """Test Ultralytics HUB functionalities (e.g. export formats, logout).""" + from ultralytics.hub import export_fmts_hub, logout + from ultralytics.hub.utils import smart_request + + export_fmts_hub() + logout() + smart_request("GET", "https://github.com", progress=True) + + +@pytest.fixture +def image(): + """Load and return an image from a predefined source using OpenCV.""" + return cv2.imread(str(SOURCE)) + + +@pytest.mark.parametrize( + "auto_augment, erasing, force_color_jitter", + [ + (None, 0.0, False), + ("randaugment", 0.5, True), + ("augmix", 0.2, False), + ("autoaugment", 0.0, True), + ], +) +def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter): + """Tests classification transforms during training with various augmentations to ensure proper functionality.""" + from ultralytics.data.augment import classify_augmentations + + transform = classify_augmentations( + size=224, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + hflip=0.5, + vflip=0.5, + auto_augment=auto_augment, + hsv_h=0.015, + hsv_s=0.4, + hsv_v=0.4, + force_color_jitter=force_color_jitter, + erasing=erasing, + ) + + transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))) + + assert transformed_image.shape == (3, 224, 224) + assert torch.is_tensor(transformed_image) + assert transformed_image.dtype == torch.float32 + + +@pytest.mark.slow +@pytest.mark.skipif(not ONLINE, reason="environment is offline") +def test_model_tune(): + """Tune YOLO model for performance improvement.""" + YOLO("yolo11n-pose.pt").tune(data="coco8-pose.yaml", plots=False, imgsz=32, epochs=1, iterations=2, device="cpu") + YOLO("yolo11n-cls.pt").tune(data="imagenet10", plots=False, imgsz=32, epochs=1, iterations=2, device="cpu") + + +def test_model_embeddings(): + """Test YOLO model embeddings.""" + model_detect = YOLO(MODEL) + model_segment = YOLO(WEIGHTS_DIR / "yolo11n-seg.pt") + + for batch in [SOURCE], [SOURCE, SOURCE]: # test batch size 1 and 2 + assert len(model_detect.embed(source=batch, imgsz=32)) == len(batch) + assert len(model_segment.embed(source=batch, imgsz=32)) == len(batch) + + +@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="YOLOWorld with CLIP is not supported in Python 3.12") +def test_yolo_world(): + """Tests YOLO world models with CLIP support, including detection and training scenarios.""" + model = YOLO(WEIGHTS_DIR / "yolov8s-world.pt") # no YOLO11n-world model yet + model.set_classes(["tree", "window"]) + model(SOURCE, conf=0.01) + + model = YOLO(WEIGHTS_DIR / "yolov8s-worldv2.pt") # no YOLO11n-world model yet + # Training from a pretrained model. Eval is included at the final stage of training. + # Use dota8.yaml which has fewer categories to reduce the inference time of CLIP model + model.train( + data="dota8.yaml", + epochs=1, + imgsz=32, + cache="disk", + close_mosaic=1, + ) + + # test WorWorldTrainerFromScratch + from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch + + model = YOLO("yolov8s-worldv2.yaml") # no YOLO11n-world model yet + model.train( + data={"train": {"yolo_data": ["dota8.yaml"]}, "val": {"yolo_data": ["dota8.yaml"]}}, + epochs=1, + imgsz=32, + cache="disk", + close_mosaic=1, + trainer=WorldTrainerFromScratch, + ) + + +def test_yolov10(): + """Test YOLOv10 model training, validation, and prediction steps with minimal configurations.""" + model = YOLO("yolov10n.yaml") + # train/val/predict + model.train(data="coco8.yaml", epochs=1, imgsz=32, close_mosaic=1, cache="disk") + model.val(data="coco8.yaml", imgsz=32) + model.predict(imgsz=32, save_txt=True, save_crop=True, augment=True) + model(SOURCE) diff --git a/tests/test_solutions.py b/tests/test_solutions.py new file mode 100644 index 0000000000000000000000000000000000000000..056a056fbc1e0321b4ce506ac146e2a2108aec7e --- /dev/null +++ b/tests/test_solutions.py @@ -0,0 +1,94 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import cv2 +import pytest + +from tests import TMP +from ultralytics import YOLO, solutions +from ultralytics.utils import ASSETS_URL, WEIGHTS_DIR +from ultralytics.utils.downloads import safe_download + +DEMO_VIDEO = "solutions_ci_demo.mp4" +POSE_VIDEO = "solution_ci_pose_demo.mp4" + + +@pytest.mark.slow +def test_major_solutions(): + """Test the object counting, heatmap, speed estimation, trackzone and queue management solution.""" + safe_download(url=f"{ASSETS_URL}/{DEMO_VIDEO}", dir=TMP) + cap = cv2.VideoCapture(str(TMP / DEMO_VIDEO)) + assert cap.isOpened(), "Error reading video file" + region_points = [(20, 400), (1080, 400), (1080, 360), (20, 360)] + counter = solutions.ObjectCounter(region=region_points, model="yolo11n.pt", show=False) # Test object counter + heatmap = solutions.Heatmap(colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False) # Test heatmaps + heatmap_count = solutions.Heatmap( + colormap=cv2.COLORMAP_PARULA, model="yolo11n.pt", show=False, region=region_points + ) # Test heatmaps with object counting + speed = solutions.SpeedEstimator(region=region_points, model="yolo11n.pt", show=False) # Test queue manager + queue = solutions.QueueManager(region=region_points, model="yolo11n.pt", show=False) # Test speed estimation + line_analytics = solutions.Analytics(analytics_type="line", model="yolo11n.pt", show=False) # line analytics + pie_analytics = solutions.Analytics(analytics_type="pie", model="yolo11n.pt", show=False) # line analytics + bar_analytics = solutions.Analytics(analytics_type="bar", model="yolo11n.pt", show=False) # line analytics + area_analytics = solutions.Analytics(analytics_type="area", model="yolo11n.pt", show=False) # line analytics + trackzone = solutions.TrackZone(region=region_points, model="yolo11n.pt", show=False) # Test trackzone + frame_count = 0 # Required for analytics + while cap.isOpened(): + success, im0 = cap.read() + if not success: + break + frame_count += 1 + original_im0 = im0.copy() + _ = counter.count(original_im0.copy()) + _ = heatmap.generate_heatmap(original_im0.copy()) + _ = heatmap_count.generate_heatmap(original_im0.copy()) + _ = speed.estimate_speed(original_im0.copy()) + _ = queue.process_queue(original_im0.copy()) + _ = line_analytics.process_data(original_im0.copy(), frame_count) + _ = pie_analytics.process_data(original_im0.copy(), frame_count) + _ = bar_analytics.process_data(original_im0.copy(), frame_count) + _ = area_analytics.process_data(original_im0.copy(), frame_count) + _ = trackzone.trackzone(original_im0.copy()) + cap.release() + + # Test workouts monitoring + safe_download(url=f"{ASSETS_URL}/{POSE_VIDEO}", dir=TMP) + cap = cv2.VideoCapture(str(TMP / POSE_VIDEO)) + assert cap.isOpened(), "Error reading video file" + gym = solutions.AIGym(kpts=[5, 11, 13], show=False) + while cap.isOpened(): + success, im0 = cap.read() + if not success: + break + _ = gym.monitor(im0) + cap.release() + + +@pytest.mark.slow +def test_instance_segmentation(): + """Test the instance segmentation solution.""" + from ultralytics.utils.plotting import Annotator, colors + + model = YOLO(WEIGHTS_DIR / "yolo11n-seg.pt") + names = model.names + cap = cv2.VideoCapture(TMP / DEMO_VIDEO) + assert cap.isOpened(), "Error reading video file" + while cap.isOpened(): + success, im0 = cap.read() + if not success: + break + results = model.predict(im0) + annotator = Annotator(im0, line_width=2) + if results[0].masks is not None: + clss = results[0].boxes.cls.cpu().tolist() + masks = results[0].masks.xy + for mask, cls in zip(masks, clss): + color = colors(int(cls), True) + annotator.seg_bbox(mask=mask, mask_color=color, label=names[int(cls)]) + cap.release() + cv2.destroyAllWindows() + + +@pytest.mark.slow +def test_streamlit_predict(): + """Test streamlit predict live inference solution.""" + solutions.Inference().inference() diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d15296f1584c3773b59a050e31e3e32a0acea102 --- /dev/null +++ b/ultralytics/__init__.py @@ -0,0 +1,29 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +__version__ = "8.3.63" + +import os + +# Set ENV variables (place before imports) +if not os.environ.get("OMP_NUM_THREADS"): + os.environ["OMP_NUM_THREADS"] = "1" # default for reduced CPU utilization during training + +from ultralytics.models import NAS, RTDETR, SAM, YOLO, FastSAM, YOLOWorld +from ultralytics.utils import ASSETS, SETTINGS +from ultralytics.utils.checks import check_yolo as checks +from ultralytics.utils.downloads import download + +settings = SETTINGS +__all__ = ( + "__version__", + "ASSETS", + "YOLO", + "YOLOWorld", + "NAS", + "SAM", + "FastSAM", + "RTDETR", + "checks", + "download", + "settings", +) diff --git a/ultralytics/assets/bus.jpg b/ultralytics/assets/bus.jpg new file mode 100644 index 0000000000000000000000000000000000000000..41bda57f9b2b3fb2c5ba1163fd738e151f2b81d7 --- /dev/null +++ b/ultralytics/assets/bus.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c02019c4979c191eb739ddd944445ef408dad5679acab6fd520ef9d434bfbc63 +size 137419 diff --git a/ultralytics/assets/zidane.jpg b/ultralytics/assets/zidane.jpg new file mode 100644 index 0000000000000000000000000000000000000000..eeab1cdcb282b0e026a57c5bf85df36024b4e1f6 Binary files /dev/null and b/ultralytics/assets/zidane.jpg differ diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8625f7c99cf4995822c53e14930de5f9a59db3a2 --- /dev/null +++ b/ultralytics/cfg/__init__.py @@ -0,0 +1,1025 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import shutil +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, List, Union + +import cv2 + +from ultralytics.utils import ( + ASSETS, + DEFAULT_CFG, + DEFAULT_CFG_DICT, + DEFAULT_CFG_PATH, + DEFAULT_SOL_DICT, + IS_VSCODE, + LOGGER, + RANK, + ROOT, + RUNS_DIR, + SETTINGS, + SETTINGS_FILE, + TESTS_RUNNING, + IterableSimpleNamespace, + __version__, + checks, + colorstr, + deprecation_warn, + vscode_msg, + yaml_load, + yaml_print, +) + +# Define valid solutions +SOLUTION_MAP = { + "count": ("ObjectCounter", "count"), + "heatmap": ("Heatmap", "generate_heatmap"), + "queue": ("QueueManager", "process_queue"), + "speed": ("SpeedEstimator", "estimate_speed"), + "workout": ("AIGym", "monitor"), + "analytics": ("Analytics", "process_data"), + "trackzone": ("TrackZone", "trackzone"), + "inference": ("Inference", "inference"), + "help": None, +} + +# Define valid tasks and modes +MODES = {"train", "val", "predict", "export", "track", "benchmark"} +TASKS = {"detect", "segment", "classify", "pose", "obb"} +TASK2DATA = { + "detect": "coco8.yaml", + "segment": "coco8-seg.yaml", + "classify": "imagenet10", + "pose": "coco8-pose.yaml", + "obb": "dota8.yaml", +} +TASK2MODEL = { + "detect": "yolo11n.pt", + "segment": "yolo11n-seg.pt", + "classify": "yolo11n-cls.pt", + "pose": "yolo11n-pose.pt", + "obb": "yolo11n-obb.pt", +} +TASK2METRIC = { + "detect": "metrics/mAP50-95(B)", + "segment": "metrics/mAP50-95(M)", + "classify": "metrics/accuracy_top1", + "pose": "metrics/mAP50-95(P)", + "obb": "metrics/mAP50-95(B)", +} +MODELS = {TASK2MODEL[task] for task in TASKS} + +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +SOLUTIONS_HELP_MSG = f""" + Arguments received: {str(["yolo"] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview: + + yolo solutions SOLUTION ARGS + + Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())[:-1]} + ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults + at https://docs.ultralytics.com/usage/cfg + + 1. Call object counting solution + yolo solutions count source="path/to/video/file.mp4" region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] + + 2. Call heatmaps solution + yolo solutions heatmap colormap=cv2.COLORMAP_PARULA model=yolo11n.pt + + 3. Call queue management solution + yolo solutions queue region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] model=yolo11n.pt + + 4. Call workouts monitoring solution for push-ups + yolo solutions workout model=yolo11n-pose.pt kpts=[6, 8, 10] + + 5. Generate analytical graphs + yolo solutions analytics analytics_type="pie" + + 6. Track objects within specific zones + yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + + 7. Streamlit real-time webcam inference GUI + yolo streamlit-predict + """ +CLI_HELP_MSG = f""" + Arguments received: {str(["yolo"] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of {TASKS} + MODE (required) is one of {MODES} + ARGS (optional) are any number of custom 'arg=value' pairs like 'imgsz=320' that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' + + 1. Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + 2. Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + 3. Val a pretrained detection model at batch-size 1 and image size 640: + yolo val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + 4. Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + 5. Ultralytics solutions usage + yolo solutions count or in {list(SOLUTION_MAP.keys())[1:-1]} source="path/to/video/file.mp4" + + 6. Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + yolo solutions help + + Docs: https://docs.ultralytics.com + Solutions: https://docs.ultralytics.com/solutions/ + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Define keys for arg type checks +CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0 + "warmup_epochs", + "box", + "cls", + "dfl", + "degrees", + "shear", + "time", + "workspace", + "batch", +} +CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0 + "dropout", + "lr0", + "lrf", + "momentum", + "weight_decay", + "warmup_momentum", + "warmup_bias_lr", + "hsv_h", + "hsv_s", + "hsv_v", + "translate", + "scale", + "perspective", + "flipud", + "fliplr", + "bgr", + "mosaic", + "mixup", + "copy_paste", + "conf", + "iou", + "fraction", +} +CFG_INT_KEYS = { # integer-only arguments + "epochs", + "patience", + "workers", + "seed", + "close_mosaic", + "mask_ratio", + "max_det", + "vid_stride", + "line_width", + "nbs", + "save_period", +} +CFG_BOOL_KEYS = { # boolean-only arguments + "save", + "exist_ok", + "verbose", + "deterministic", + "single_cls", + "rect", + "cos_lr", + "overlap_mask", + "val", + "save_json", + "save_hybrid", + "half", + "dnn", + "plots", + "show", + "save_txt", + "save_conf", + "save_crop", + "save_frames", + "show_labels", + "show_conf", + "visualize", + "augment", + "agnostic_nms", + "retina_masks", + "show_boxes", + "keras", + "optimize", + "int8", + "dynamic", + "simplify", + "nms", + "profile", + "multi_scale", +} + + +def cfg2dict(cfg): + """ + Converts a configuration object to a dictionary. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path, + a string, a dictionary, or a SimpleNamespace object. + + Returns: + (Dict): Configuration object in dictionary format. + + Examples: + Convert a YAML file path to a dictionary: + >>> config_dict = cfg2dict("config.yaml") + + Convert a SimpleNamespace to a dictionary: + >>> from types import SimpleNamespace + >>> config_sn = SimpleNamespace(param1="value1", param2="value2") + >>> config_dict = cfg2dict(config_sn) + + Pass through an already existing dictionary: + >>> config_dict = cfg2dict({"param1": "value1", "param2": "value2"}) + + Notes: + - If cfg is a path or string, it's loaded as YAML and converted to a dictionary. + - If cfg is a SimpleNamespace object, it's converted to a dictionary using vars(). + - If cfg is already a dictionary, it's returned unchanged. + """ + if isinstance(cfg, (str, Path)): + cfg = yaml_load(cfg) # load dict + elif isinstance(cfg, SimpleNamespace): + cfg = vars(cfg) # convert to dict + return cfg + + +def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, overrides: Dict = None): + """ + Load and merge configuration data from a file or dictionary, with optional overrides. + + Args: + cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or + SimpleNamespace object. + overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration. + + Returns: + (SimpleNamespace): Namespace containing the merged configuration arguments. + + Examples: + >>> from ultralytics.cfg import get_cfg + >>> config = get_cfg() # Load default configuration + >>> config_with_overrides = get_cfg("path/to/config.yaml", overrides={"epochs": 50, "batch_size": 16}) + + Notes: + - If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence. + - Special handling ensures alignment and correctness of the configuration, such as converting numeric + `project` and `name` to strings and validating configuration keys and values. + - The function performs type and value checks on the configuration data. + """ + cfg = cfg2dict(cfg) + + # Merge overrides + if overrides: + overrides = cfg2dict(overrides) + if "save_dir" not in cfg: + overrides.pop("save_dir", None) # special override keys to ignore + check_dict_alignment(cfg, overrides) + cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) + + # Special handling for numeric project/name + for k in "project", "name": + if k in cfg and isinstance(cfg[k], (int, float)): + cfg[k] = str(cfg[k]) + if cfg.get("name") == "model": # assign model to 'name' arg + cfg["name"] = str(cfg.get("model", "")).split(".")[0] + LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") + + # Type and Value checks + check_cfg(cfg) + + # Return instance + return IterableSimpleNamespace(**cfg) + + +def check_cfg(cfg, hard=True): + """ + Checks configuration argument types and values for the Ultralytics library. + + This function validates the types and values of configuration arguments, ensuring correctness and converting + them if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS, + CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS. + + Args: + cfg (Dict): Configuration dictionary to validate. + hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them. + + Examples: + >>> config = { + ... "epochs": 50, # valid integer + ... "lr0": 0.01, # valid float + ... "momentum": 1.2, # invalid float (out of 0.0-1.0 range) + ... "save": "true", # invalid bool + ... } + >>> check_cfg(config, hard=False) + >>> print(config) + {'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key + + Notes: + - The function modifies the input dictionary in-place. + - None values are ignored as they may be from optional arguments. + - Fraction keys are checked to be within the range [0.0, 1.0]. + """ + for k, v in cfg.items(): + if v is not None: # None values may be from optional args + if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = v = float(v) + if not (0.0 <= v <= 1.0): + raise ValueError(f"'{k}={v}' is an invalid value. Valid '{k}' values are between 0.0 and 1.0.") + elif k in CFG_INT_KEYS and not isinstance(v, int): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. '{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) + elif k in CFG_BOOL_KEYS and not isinstance(v, bool): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) + + +def get_save_dir(args, name=None): + """ + Returns the directory path for saving outputs, derived from arguments or default settings. + + Args: + args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', + 'mode', and 'save_dir'. + name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' + or the 'args.mode'. + + Returns: + (Path): Directory path where outputs should be saved. + + Examples: + >>> from types import SimpleNamespace + >>> args = SimpleNamespace(project="my_project", task="detect", mode="train", exist_ok=True) + >>> save_dir = get_save_dir(args) + >>> print(save_dir) + my_project/detect/train + """ + if getattr(args, "save_dir", None): + save_dir = args.save_dir + else: + from ultralytics.utils.files import increment_path + + project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task + name = name or args.name or f"{args.mode}" + save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True) + + return Path(save_dir) + + +def _handle_deprecation(custom): + """ + Handles deprecated configuration keys by mapping them to current equivalents with deprecation warnings. + + Args: + custom (Dict): Configuration dictionary potentially containing deprecated keys. + + Examples: + >>> custom_config = {"boxes": True, "hide_labels": "False", "line_thickness": 2} + >>> _handle_deprecation(custom_config) + >>> print(custom_config) + {'show_boxes': True, 'show_labels': True, 'line_width': 2} + + Notes: + This function modifies the input dictionary in-place, replacing deprecated keys with their current + equivalents. It also handles value conversions where necessary, such as inverting boolean values for + 'hide_labels' and 'hide_conf'. + """ + for key in custom.copy().keys(): + if key == "boxes": + deprecation_warn(key, "show_boxes") + custom["show_boxes"] = custom.pop("boxes") + if key == "hide_labels": + deprecation_warn(key, "show_labels") + custom["show_labels"] = custom.pop("hide_labels") == "False" + if key == "hide_conf": + deprecation_warn(key, "show_conf") + custom["show_conf"] = custom.pop("hide_conf") == "False" + if key == "line_thickness": + deprecation_warn(key, "line_width") + custom["line_width"] = custom.pop("line_thickness") + if key == "label_smoothing": + deprecation_warn(key) + custom.pop("label_smoothing") + + return custom + + +def check_dict_alignment(base: Dict, custom: Dict, e=None): + """ + Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error + messages for mismatched keys. + + Args: + base (Dict): The base configuration dictionary containing valid keys. + custom (Dict): The custom configuration dictionary to be checked for alignment. + e (Exception | None): Optional error instance passed by the calling function. + + Raises: + SystemExit: If mismatched keys are found between the custom and base dictionaries. + + Examples: + >>> base_cfg = {"epochs": 50, "lr0": 0.01, "batch_size": 16} + >>> custom_cfg = {"epoch": 100, "lr": 0.02, "batch_size": 32} + >>> try: + ... check_dict_alignment(base_cfg, custom_cfg) + ... except SystemExit: + ... print("Mismatched keys found") + + Notes: + - Suggests corrections for mismatched keys based on similarity to valid keys. + - Automatically replaces deprecated keys in the custom configuration with updated equivalents. + - Prints detailed error messages for each mismatched key to help users correct their configurations. + """ + custom = _handle_deprecation(custom) + base_keys, custom_keys = (set(x.keys()) for x in (base, custom)) + if mismatched := [k for k in custom_keys if k not in base_keys]: + from difflib import get_close_matches + + string = "" + for x in mismatched: + matches = get_close_matches(x, base_keys) # key list + matches = [f"{k}={base[k]}" if base.get(k) is not None else k for k in matches] + match_str = f"Similar arguments are i.e. {matches}." if matches else "" + string += f"'{colorstr('red', 'bold', x)}' is not a valid YOLO argument. {match_str}\n" + raise SyntaxError(string + CLI_HELP_MSG) from e + + +def merge_equals_args(args: List[str]) -> List[str]: + """ + Merges arguments around isolated '=' in a list of strings and joins fragments with brackets. + + This function handles the following cases: + 1. ['arg', '=', 'val'] becomes ['arg=val'] + 2. ['arg=', 'val'] becomes ['arg=val'] + 3. ['arg', '=val'] becomes ['arg=val'] + 4. Joins fragments with brackets, e.g., ['imgsz=[3,', '640,', '640]'] becomes ['imgsz=[3,640,640]'] + + Args: + args (List[str]): A list of strings where each element represents an argument or fragment. + + Returns: + List[str]: A list of strings where the arguments around isolated '=' are merged and fragments with brackets are joined. + + Examples: + >>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3", "imgsz=[3,", "640,", "640]"] + >>> merge_and_join_args(args) + ['arg1=value', 'arg2=value2', 'arg3=value3', 'imgsz=[3,640,640]'] + """ + new_args = [] + current = "" + depth = 0 + + i = 0 + while i < len(args): + arg = args[i] + + # Handle equals sign merging + if arg == "=" and 0 < i < len(args) - 1: # merge ['arg', '=', 'val'] + new_args[-1] += f"={args[i + 1]}" + i += 2 + continue + elif arg.endswith("=") and i < len(args) - 1 and "=" not in args[i + 1]: # merge ['arg=', 'val'] + new_args.append(f"{arg}{args[i + 1]}") + i += 2 + continue + elif arg.startswith("=") and i > 0: # merge ['arg', '=val'] + new_args[-1] += arg + i += 1 + continue + + # Handle bracket joining + depth += arg.count("[") - arg.count("]") + current += arg + if depth == 0: + new_args.append(current) + current = "" + + i += 1 + + # Append any remaining current string + if current: + new_args.append(current) + + return new_args + + +def handle_yolo_hub(args: List[str]) -> None: + """ + Handles Ultralytics HUB command-line interface (CLI) commands for authentication. + + This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a + script with arguments related to HUB authentication. + + Args: + args (List[str]): A list of command line arguments. The first argument should be either 'login' + or 'logout'. For 'login', an optional second argument can be the API key. + + Examples: + ```bash + yolo login YOUR_API_KEY + ``` + + Notes: + - The function imports the 'hub' module from ultralytics to perform login and logout operations. + - For the 'login' command, if no API key is provided, an empty string is passed to the login function. + - The 'logout' command does not require any additional arguments. + """ + from ultralytics import hub + + if args[0] == "login": + key = args[1] if len(args) > 1 else "" + # Log in to Ultralytics HUB using the provided API key + hub.login(key) + elif args[0] == "logout": + # Log out from Ultralytics HUB + hub.logout() + + +def handle_yolo_settings(args: List[str]) -> None: + """ + Handles YOLO settings command-line interface (CLI) commands. + + This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be + called when executing a script with arguments related to YOLO settings management. + + Args: + args (List[str]): A list of command line arguments for YOLO settings management. + + Examples: + >>> handle_yolo_settings(["reset"]) # Reset YOLO settings + >>> handle_yolo_settings(["default_cfg_path=yolo11n.yaml"]) # Update a specific setting + + Notes: + - If no arguments are provided, the function will display the current settings. + - The 'reset' command will delete the existing settings file and create new default settings. + - Other arguments are treated as key-value pairs to update specific settings. + - The function will check for alignment between the provided settings and the existing ones. + - After processing, the updated settings will be displayed. + - For more information on handling YOLO settings, visit: + https://docs.ultralytics.com/quickstart/#ultralytics-settings + """ + url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL + try: + if any(args): + if args[0] == "reset": + SETTINGS_FILE.unlink() # delete the settings file + SETTINGS.reset() # create new settings + LOGGER.info("Settings reset successfully") # inform the user that settings have been reset + else: # save a new setting + new = dict(parse_key_value_pair(a) for a in args) + check_dict_alignment(SETTINGS, new) + SETTINGS.update(new) + + print(SETTINGS) # print the current settings + LOGGER.info(f"💡 Learn more about Ultralytics Settings at {url}") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.") + + +def handle_yolo_solutions(args: List[str]) -> None: + """ + Processes YOLO solutions arguments and runs the specified computer vision solutions pipeline. + + Args: + args (List[str]): Command-line arguments for configuring and running the Ultralytics YOLO + solutions: https://docs.ultralytics.com/solutions/, It can include solution name, source, + and other configuration parameters. + + Returns: + None: The function processes video frames and saves the output but doesn't return any value. + + Examples: + Run people counting solution with default settings: + >>> handle_yolo_solutions(["count"]) + + Run analytics with custom configuration: + >>> handle_yolo_solutions(["analytics", "conf=0.25", "source=path/to/video/file.mp4"]) + + Run inference with custom configuration, requires Streamlit version 1.29.0 or higher. + >>> handle_yolo_solutions(["inference", "model=yolo11n.pt"]) + + Notes: + - Default configurations are merged from DEFAULT_SOL_DICT and DEFAULT_CFG_DICT + - Arguments can be provided in the format 'key=value' or as boolean flags + - Available solutions are defined in SOLUTION_MAP with their respective classes and methods + - If an invalid solution is provided, defaults to 'count' solution + - Output videos are saved in 'runs/solution/{solution_name}' directory + - For 'analytics' solution, frame numbers are tracked for generating analytical graphs + - Video processing can be interrupted by pressing 'q' + - Processes video frames sequentially and saves output in .avi format + - If no source is specified, downloads and uses a default sample video\ + - The inference solution will be launched using the 'streamlit run' command. + - The Streamlit app file is located in the Ultralytics package directory. + """ + full_args_dict = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} # arguments dictionary + overrides = {} + + # check dictionary alignment + for arg in merge_equals_args(args): + arg = arg.lstrip("-").rstrip(",") + if "=" in arg: + try: + k, v = parse_key_value_pair(arg) + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {arg: ""}, e) + elif arg in full_args_dict and isinstance(full_args_dict.get(arg), bool): + overrides[arg] = True + check_dict_alignment(full_args_dict, overrides) # dict alignment + + # Get solution name + if args and args[0] in SOLUTION_MAP: + if args[0] != "help": + s_n = args.pop(0) # Extract the solution name directly + else: + LOGGER.info(SOLUTIONS_HELP_MSG) + else: + LOGGER.warning( + f"⚠️ No valid solution provided. Using default 'count'. Available: {', '.join(SOLUTION_MAP.keys())}" + ) + s_n = "count" # Default solution if none provided + + if args and args[0] == "help": # Add check for return if user call `yolo solutions help` + return + + if s_n == "inference": + checks.check_requirements("streamlit>=1.29.0") + LOGGER.info("💡 Loading Ultralytics live inference app...") + subprocess.run( + [ # Run subprocess with Streamlit custom argument + "streamlit", + "run", + str(ROOT / "solutions/streamlit_inference.py"), + "--server.headless", + "true", + overrides.pop("model", "yolo11n.pt"), + ] + ) + else: + cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source + + from ultralytics import solutions # import ultralytics solutions + + solution = getattr(solutions, cls)(IS_CLI=True, **overrides) # get solution class i.e ObjectCounter + process = getattr( + solution, method + ) # get specific function of class for processing i.e, count from ObjectCounter + + cap = cv2.VideoCapture(solution.CFG["source"]) # read the video file + + # extract width, height and fps of the video file, create save directory and initialize video writer + import os # for directory creation + from pathlib import Path + + from ultralytics.utils.files import increment_path # for output directory path update + + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + if s_n == "analytics": # analytical graphs follow fixed shape for output i.e w=1920, h=1080 + w, h = 1920, 1080 + save_dir = increment_path(Path("runs") / "solutions" / "exp", exist_ok=False) + save_dir.mkdir(parents=True, exist_ok=True) # create the output directory + vw = cv2.VideoWriter(os.path.join(save_dir, "solution.avi"), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + + try: # Process video frames + f_n = 0 # frame number, required for analytical graphs + while cap.isOpened(): + success, frame = cap.read() + if not success: + break + frame = process(frame, f_n := f_n + 1) if s_n == "analytics" else process(frame) + vw.write(frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + finally: + cap.release() + + +def parse_key_value_pair(pair: str = "key=value"): + """ + Parses a key-value pair string into separate key and value components. + + Args: + pair (str): A string containing a key-value pair in the format "key=value". + + Returns: + key (str): The parsed key. + value (str): The parsed value. + + Raises: + AssertionError: If the value is missing or empty. + + Examples: + >>> key, value = parse_key_value_pair("model=yolo11n.pt") + >>> print(f"Key: {key}, Value: {value}") + Key: model, Value: yolo11n.pt + + >>> key, value = parse_key_value_pair("epochs=100") + >>> print(f"Key: {key}, Value: {value}") + Key: epochs, Value: 100 + + Notes: + - The function splits the input string on the first '=' character. + - Leading and trailing whitespace is removed from both key and value. + - An assertion error is raised if the value is empty after stripping. + """ + k, v = pair.split("=", 1) # split on first '=' sign + k, v = k.strip(), v.strip() # remove spaces + assert v, f"missing '{k}' value" + return k, smart_value(v) + + +def smart_value(v): + """ + Converts a string representation of a value to its appropriate Python type. + + This function attempts to convert a given string into a Python object of the most appropriate type. It handles + conversions to None, bool, int, float, and other types that can be evaluated safely. + + Args: + v (str): The string representation of the value to be converted. + + Returns: + (Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion + is applicable. + + Examples: + >>> smart_value("42") + 42 + >>> smart_value("3.14") + 3.14 + >>> smart_value("True") + True + >>> smart_value("None") + None + >>> smart_value("some_string") + 'some_string' + + Notes: + - The function uses a case-insensitive comparison for boolean and None values. + - For other types, it attempts to use Python's eval() function, which can be unsafe if used on untrusted input. + - If no conversion is possible, the original string is returned. + """ + v_lower = v.lower() + if v_lower == "none": + return None + elif v_lower == "true": + return True + elif v_lower == "false": + return False + else: + try: + return eval(v) + except Exception: + return v + + +def entrypoint(debug=""): + """ + Ultralytics entrypoint function for parsing and executing command-line arguments. + + This function serves as the main entry point for the Ultralytics CLI, parsing command-line arguments and + executing the corresponding tasks such as training, validation, prediction, exporting models, and more. + + Args: + debug (str): Space-separated string of command-line arguments for debugging purposes. + + Examples: + Train a detection model for 10 epochs with an initial learning_rate of 0.01: + >>> entrypoint("train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01") + + Predict a YouTube video using a pretrained segmentation model at image size 320: + >>> entrypoint("predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320") + + Validate a pretrained detection model at batch-size 1 and image size 640: + >>> entrypoint("val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640") + + Notes: + - If no arguments are passed, the function will display the usage help message. + - For a list of all available commands and their arguments, see the provided help messages and the + Ultralytics documentation at https://docs.ultralytics.com. + """ + args = (debug.split(" ") if debug else ARGV)[1:] + if not args: # no arguments passed + LOGGER.info(CLI_HELP_MSG) + return + + special = { + "help": lambda: LOGGER.info(CLI_HELP_MSG), + "checks": checks.collect_system_info, + "version": lambda: LOGGER.info(__version__), + "settings": lambda: handle_yolo_settings(args[1:]), + "cfg": lambda: yaml_print(DEFAULT_CFG_PATH), + "hub": lambda: handle_yolo_hub(args[1:]), + "login": lambda: handle_yolo_hub(args), + "logout": lambda: handle_yolo_hub(args), + "copy-cfg": copy_default_cfg, + "solutions": lambda: handle_yolo_solutions(args[1:]), + } + full_args_dict = {**DEFAULT_CFG_DICT, **{k: None for k in TASKS}, **{k: None for k in MODES}, **special} + + # Define common misuses of special commands, i.e. -h, -help, --help + special.update({k[0]: v for k, v in special.items()}) # singular + special.update({k[:-1]: v for k, v in special.items() if len(k) > 1 and k.endswith("s")}) # singular + special = {**special, **{f"-{k}": v for k, v in special.items()}, **{f"--{k}": v for k, v in special.items()}} + + overrides = {} # basic overrides, i.e. imgsz=320 + for a in merge_equals_args(args): # merge spaces around '=' sign + if a.startswith("--"): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + a = a[2:] + if a.endswith(","): + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + a = a[:-1] + if "=" in a: + try: + k, v = parse_key_value_pair(a) + if k == "cfg" and v is not None: # custom.yaml passed + LOGGER.info(f"Overriding {DEFAULT_CFG_PATH} with {v}") + overrides = {k: val for k, val in yaml_load(checks.check_yaml(v)).items() if k != "cfg"} + else: + overrides[k] = v + except (NameError, SyntaxError, ValueError, AssertionError) as e: + check_dict_alignment(full_args_dict, {a: ""}, e) + + elif a in TASKS: + overrides["task"] = a + elif a in MODES: + overrides["mode"] = a + elif a.lower() in special: + special[a.lower()]() + return + elif a in DEFAULT_CFG_DICT and isinstance(DEFAULT_CFG_DICT[a], bool): + overrides[a] = True # auto-True for default bool args, i.e. 'yolo show' sets show=True + elif a in DEFAULT_CFG_DICT: + raise SyntaxError( + f"'{colorstr('red', 'bold', a)}' is a valid YOLO argument but is missing an '=' sign " + f"to set its value, i.e. try '{a}={DEFAULT_CFG_DICT[a]}'\n{CLI_HELP_MSG}" + ) + else: + check_dict_alignment(full_args_dict, {a: ""}) + + # Check keys + check_dict_alignment(full_args_dict, overrides) + + # Mode + mode = overrides.get("mode") + if mode is None: + mode = DEFAULT_CFG.mode or "predict" + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + elif mode not in MODES: + raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") + + # Task + task = overrides.pop("task", None) + if task: + if task == "classify" and mode == "track": + raise ValueError( + f"❌ Classification doesn't support 'mode=track'. Valid modes for classification are" + f" {MODES - {'track'}}.\n{CLI_HELP_MSG}" + ) + elif task not in TASKS: + if task == "track": + LOGGER.warning( + "WARNING ⚠️ invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}." + ) + task, mode = "detect", "track" + else: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if "model" not in overrides: + overrides["model"] = TASK2MODEL[task] + + # Model + model = overrides.pop("model", DEFAULT_CFG.model) + if model is None: + model = "yolo11n.pt" + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") + overrides["model"] = model + stem = Path(model).stem.lower() + if "rtdetr" in stem: # guess architecture + from ultralytics import RTDETR + + model = RTDETR(model) # no task argument + elif "fastsam" in stem: + from ultralytics import FastSAM + + model = FastSAM(model) + elif "sam_" in stem or "sam2_" in stem or "sam2.1_" in stem: + from ultralytics import SAM + + model = SAM(model) + else: + from ultralytics import YOLO + + model = YOLO(model, task=task) + if isinstance(overrides.get("pretrained"), str): + model.load(overrides["pretrained"]) + + # Task Update + if task != model.task: + if task: + LOGGER.warning( + f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model." + ) + task = model.task + + # Mode + if mode in {"predict", "track"} and "source" not in overrides: + overrides["source"] = ( + "https://ultralytics.com/images/boats.jpg" if task == "obb" else DEFAULT_CFG.source or ASSETS + ) + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") + elif mode in {"train", "val"}: + if "data" not in overrides and "resume" not in overrides: + overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") + elif mode == "export": + if "format" not in overrides: + overrides["format"] = DEFAULT_CFG.format or "torchscript" + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") + + # Run command in python + getattr(model, mode)(**overrides) # default args from model + + # Show help + LOGGER.info(f"💡 Learn more at https://docs.ultralytics.com/modes/{mode}") + + # Recommend VS Code extension + if IS_VSCODE and SETTINGS.get("vscode_msg", True): + LOGGER.info(vscode_msg()) + + +# Special modes -------------------------------------------------------------------------------------------------------- +def copy_default_cfg(): + """ + Copies the default configuration file and creates a new one with '_copy' appended to its name. + + This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it + with '_copy' appended to its name in the current working directory. It provides a convenient way + to create a custom configuration file based on the default settings. + + Examples: + >>> copy_default_cfg() + # Output: default.yaml copied to /path/to/current/directory/default_copy.yaml + # Example YOLO command with this new custom cfg: + # yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8 + + Notes: + - The new configuration file is created in the current working directory. + - After copying, the function prints a message with the new file's location and an example + YOLO command demonstrating how to use the new configuration file. + - This function is useful for users who want to modify the default configuration without + altering the original file. + """ + new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml") + shutil.copy2(DEFAULT_CFG_PATH, new_file) + LOGGER.info( + f"{DEFAULT_CFG_PATH} copied to {new_file}\n" + f"Example YOLO command with this new custom cfg:\n yolo cfg='{new_file}' imgsz=320 batch=8" + ) + + +if __name__ == "__main__": + # Example: entrypoint(debug='yolo predict model=yolo11n.pt') + entrypoint(debug="") diff --git a/ultralytics/cfg/datasets/Argoverse.yaml b/ultralytics/cfg/datasets/Argoverse.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e05023d7797ca6e7221c4226fcc3e552bd3edb3 --- /dev/null +++ b/ultralytics/cfg/datasets/Argoverse.yaml @@ -0,0 +1,75 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Argoverse-HD dataset (ring-front-center camera) https://www.cs.cmu.edu/~mengtial/proj/streaming/ by Argo AI +# Documentation: https://docs.ultralytics.com/datasets/detect/argoverse/ +# Example usage: yolo train data=Argoverse.yaml +# parent +# ├── ultralytics +# └── datasets +# └── Argoverse ← downloads here (31.5 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/Argoverse # dataset root dir +train: Argoverse-1.1/images/train/ # train images (relative to 'path') 39384 images +val: Argoverse-1.1/images/val/ # val images (relative to 'path') 15062 images +test: Argoverse-1.1/images/test/ # test images (optional) https://eval.ai/web/challenges/challenge-page/800/overview + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: bus + 5: truck + 6: traffic_light + 7: stop_sign + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import json + from tqdm import tqdm + from ultralytics.utils.downloads import download + from pathlib import Path + + def argoverse2yolo(set): + labels = {} + a = json.load(open(set, "rb")) + for annot in tqdm(a['annotations'], desc=f"Converting {set} to YOLOv5 format..."): + img_id = annot['image_id'] + img_name = a['images'][img_id]['name'] + img_label_name = f'{img_name[:-3]}txt' + + cls = annot['category_id'] # instance class id + x_center, y_center, width, height = annot['bbox'] + x_center = (x_center + width / 2) / 1920.0 # offset and scale + y_center = (y_center + height / 2) / 1200.0 # offset and scale + width /= 1920.0 # scale + height /= 1200.0 # scale + + img_dir = set.parents[2] / 'Argoverse-1.1' / 'labels' / a['seq_dirs'][a['images'][annot['image_id']]['sid']] + if not img_dir.exists(): + img_dir.mkdir(parents=True, exist_ok=True) + + k = str(img_dir / img_label_name) + if k not in labels: + labels[k] = [] + labels[k].append(f"{cls} {x_center} {y_center} {width} {height}\n") + + for k in labels: + with open(k, "w") as f: + f.writelines(labels[k]) + + + # Download 'https://argoverse-hd.s3.us-east-2.amazonaws.com/Argoverse-HD-Full.zip' (deprecated S3 link) + dir = Path(yaml['path']) # dataset root dir + urls = ['https://drive.google.com/file/d/1st9qW3BeIwQsnR0t8mRpvbsSWIo16ACi/view?usp=drive_link'] + print("\n\nWARNING: Argoverse dataset MUST be downloaded manually, autodownload will NOT work.") + print(f"WARNING: Manually download Argoverse dataset '{urls[0]}' to '{dir}' and re-run your command.\n\n") + # download(urls, dir=dir) + + # Convert + annotations_dir = 'Argoverse-HD/annotations/' + (dir / 'Argoverse-1.1' / 'tracking').rename(dir / 'Argoverse-1.1' / 'images') # rename 'tracking' to 'images' + for d in "train.json", "val.json": + argoverse2yolo(dir / annotations_dir / d) # convert Argoverse annotations to YOLO labels diff --git a/ultralytics/cfg/datasets/DOTAv1.5.yaml b/ultralytics/cfg/datasets/DOTAv1.5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..26c73808d7b253dea0b4555394bae246a815076d --- /dev/null +++ b/ultralytics/cfg/datasets/DOTAv1.5.yaml @@ -0,0 +1,37 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DOTA 1.5 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University +# Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/ +# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.5.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota1.5 ← downloads here (2GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/DOTAv1.5 # dataset root dir +train: images/train # train images (relative to 'path') 1411 images +val: images/val # val images (relative to 'path') 458 images +test: images/test # test images (optional) 937 images + +# Classes for DOTA 1.5 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + 15: container crane + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/DOTAv1.5.zip diff --git a/ultralytics/cfg/datasets/DOTAv1.yaml b/ultralytics/cfg/datasets/DOTAv1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5e71d2188d50d3efc80562ebb3fa7b65e07d2b5f --- /dev/null +++ b/ultralytics/cfg/datasets/DOTAv1.yaml @@ -0,0 +1,36 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DOTA 1.0 dataset https://captain-whu.github.io/DOTA/index.html for object detection in aerial images by Wuhan University +# Documentation: https://docs.ultralytics.com/datasets/obb/dota-v2/ +# Example usage: yolo train model=yolov8n-obb.pt data=DOTAv1.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota1 ← downloads here (2GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/DOTAv1 # dataset root dir +train: images/train # train images (relative to 'path') 1411 images +val: images/val # val images (relative to 'path') 458 images +test: images/test # test images (optional) 937 images + +# Classes for DOTA 1.0 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/DOTAv1.zip diff --git a/ultralytics/cfg/datasets/GlobalWheat2020.yaml b/ultralytics/cfg/datasets/GlobalWheat2020.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9dff73d7cd2c9be88e7d936208f649dc306cbdf9 --- /dev/null +++ b/ultralytics/cfg/datasets/GlobalWheat2020.yaml @@ -0,0 +1,54 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Global Wheat 2020 dataset https://www.global-wheat.com/ by University of Saskatchewan +# Documentation: https://docs.ultralytics.com/datasets/detect/globalwheat2020/ +# Example usage: yolo train data=GlobalWheat2020.yaml +# parent +# ├── ultralytics +# └── datasets +# └── GlobalWheat2020 ← downloads here (7.0 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/GlobalWheat2020 # dataset root dir +train: # train images (relative to 'path') 3422 images + - images/arvalis_1 + - images/arvalis_2 + - images/arvalis_3 + - images/ethz_1 + - images/rres_1 + - images/inrae_1 + - images/usask_1 +val: # val images (relative to 'path') 748 images (WARNING: train set contains ethz_1) + - images/ethz_1 +test: # test images (optional) 1276 images + - images/utokyo_1 + - images/utokyo_2 + - images/nau_1 + - images/uq_1 + +# Classes +names: + 0: wheat_head + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + from ultralytics.utils.downloads import download + from pathlib import Path + + # Download + dir = Path(yaml['path']) # dataset root dir + urls = ['https://zenodo.org/record/4298502/files/global-wheat-codalab-official.zip', + 'https://github.com/ultralytics/assets/releases/download/v0.0.0/GlobalWheat2020_labels.zip'] + download(urls, dir=dir) + + # Make Directories + for p in 'annotations', 'images', 'labels': + (dir / p).mkdir(parents=True, exist_ok=True) + + # Move + for p in 'arvalis_1', 'arvalis_2', 'arvalis_3', 'ethz_1', 'rres_1', 'inrae_1', 'usask_1', \ + 'utokyo_1', 'utokyo_2', 'nau_1', 'uq_1': + (dir / 'global-wheat-codalab-official' / p).rename(dir / 'images' / p) # move to /images + f = (dir / 'global-wheat-codalab-official' / p).with_suffix('.json') # json file + if f.exists(): + f.rename((dir / 'annotations' / p).with_suffix('.json')) # move to /annotations diff --git a/ultralytics/cfg/datasets/ImageNet.yaml b/ultralytics/cfg/datasets/ImageNet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..92e398a8fa846a8921e11078b277c748967f1ec6 --- /dev/null +++ b/ultralytics/cfg/datasets/ImageNet.yaml @@ -0,0 +1,2025 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# ImageNet-1k dataset https://www.image-net.org/index.php by Stanford University +# Simplified class names from https://github.com/anishathalye/imagenet-simple-labels +# Documentation: https://docs.ultralytics.com/datasets/classify/imagenet/ +# Example usage: yolo train task=classify data=imagenet +# parent +# ├── ultralytics +# └── datasets +# └── imagenet ← downloads here (144 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/imagenet # dataset root dir +train: train # train images (relative to 'path') 1281167 images +val: val # val images (relative to 'path') 50000 images +test: # test images (optional) + +# Classes +names: + 0: tench + 1: goldfish + 2: great white shark + 3: tiger shark + 4: hammerhead shark + 5: electric ray + 6: stingray + 7: cock + 8: hen + 9: ostrich + 10: brambling + 11: goldfinch + 12: house finch + 13: junco + 14: indigo bunting + 15: American robin + 16: bulbul + 17: jay + 18: magpie + 19: chickadee + 20: American dipper + 21: kite + 22: bald eagle + 23: vulture + 24: great grey owl + 25: fire salamander + 26: smooth newt + 27: newt + 28: spotted salamander + 29: axolotl + 30: American bullfrog + 31: tree frog + 32: tailed frog + 33: loggerhead sea turtle + 34: leatherback sea turtle + 35: mud turtle + 36: terrapin + 37: box turtle + 38: banded gecko + 39: green iguana + 40: Carolina anole + 41: desert grassland whiptail lizard + 42: agama + 43: frilled-necked lizard + 44: alligator lizard + 45: Gila monster + 46: European green lizard + 47: chameleon + 48: Komodo dragon + 49: Nile crocodile + 50: American alligator + 51: triceratops + 52: worm snake + 53: ring-necked snake + 54: eastern hog-nosed snake + 55: smooth green snake + 56: kingsnake + 57: garter snake + 58: water snake + 59: vine snake + 60: night snake + 61: boa constrictor + 62: African rock python + 63: Indian cobra + 64: green mamba + 65: sea snake + 66: Saharan horned viper + 67: eastern diamondback rattlesnake + 68: sidewinder + 69: trilobite + 70: harvestman + 71: scorpion + 72: yellow garden spider + 73: barn spider + 74: European garden spider + 75: southern black widow + 76: tarantula + 77: wolf spider + 78: tick + 79: centipede + 80: black grouse + 81: ptarmigan + 82: ruffed grouse + 83: prairie grouse + 84: peacock + 85: quail + 86: partridge + 87: grey parrot + 88: macaw + 89: sulphur-crested cockatoo + 90: lorikeet + 91: coucal + 92: bee eater + 93: hornbill + 94: hummingbird + 95: jacamar + 96: toucan + 97: duck + 98: red-breasted merganser + 99: goose + 100: black swan + 101: tusker + 102: echidna + 103: platypus + 104: wallaby + 105: koala + 106: wombat + 107: jellyfish + 108: sea anemone + 109: brain coral + 110: flatworm + 111: nematode + 112: conch + 113: snail + 114: slug + 115: sea slug + 116: chiton + 117: chambered nautilus + 118: Dungeness crab + 119: rock crab + 120: fiddler crab + 121: red king crab + 122: American lobster + 123: spiny lobster + 124: crayfish + 125: hermit crab + 126: isopod + 127: white stork + 128: black stork + 129: spoonbill + 130: flamingo + 131: little blue heron + 132: great egret + 133: bittern + 134: crane (bird) + 135: limpkin + 136: common gallinule + 137: American coot + 138: bustard + 139: ruddy turnstone + 140: dunlin + 141: common redshank + 142: dowitcher + 143: oystercatcher + 144: pelican + 145: king penguin + 146: albatross + 147: grey whale + 148: killer whale + 149: dugong + 150: sea lion + 151: Chihuahua + 152: Japanese Chin + 153: Maltese + 154: Pekingese + 155: Shih Tzu + 156: King Charles Spaniel + 157: Papillon + 158: toy terrier + 159: Rhodesian Ridgeback + 160: Afghan Hound + 161: Basset Hound + 162: Beagle + 163: Bloodhound + 164: Bluetick Coonhound + 165: Black and Tan Coonhound + 166: Treeing Walker Coonhound + 167: English foxhound + 168: Redbone Coonhound + 169: borzoi + 170: Irish Wolfhound + 171: Italian Greyhound + 172: Whippet + 173: Ibizan Hound + 174: Norwegian Elkhound + 175: Otterhound + 176: Saluki + 177: Scottish Deerhound + 178: Weimaraner + 179: Staffordshire Bull Terrier + 180: American Staffordshire Terrier + 181: Bedlington Terrier + 182: Border Terrier + 183: Kerry Blue Terrier + 184: Irish Terrier + 185: Norfolk Terrier + 186: Norwich Terrier + 187: Yorkshire Terrier + 188: Wire Fox Terrier + 189: Lakeland Terrier + 190: Sealyham Terrier + 191: Airedale Terrier + 192: Cairn Terrier + 193: Australian Terrier + 194: Dandie Dinmont Terrier + 195: Boston Terrier + 196: Miniature Schnauzer + 197: Giant Schnauzer + 198: Standard Schnauzer + 199: Scottish Terrier + 200: Tibetan Terrier + 201: Australian Silky Terrier + 202: Soft-coated Wheaten Terrier + 203: West Highland White Terrier + 204: Lhasa Apso + 205: Flat-Coated Retriever + 206: Curly-coated Retriever + 207: Golden Retriever + 208: Labrador Retriever + 209: Chesapeake Bay Retriever + 210: German Shorthaired Pointer + 211: Vizsla + 212: English Setter + 213: Irish Setter + 214: Gordon Setter + 215: Brittany + 216: Clumber Spaniel + 217: English Springer Spaniel + 218: Welsh Springer Spaniel + 219: Cocker Spaniels + 220: Sussex Spaniel + 221: Irish Water Spaniel + 222: Kuvasz + 223: Schipperke + 224: Groenendael + 225: Malinois + 226: Briard + 227: Australian Kelpie + 228: Komondor + 229: Old English Sheepdog + 230: Shetland Sheepdog + 231: collie + 232: Border Collie + 233: Bouvier des Flandres + 234: Rottweiler + 235: German Shepherd Dog + 236: Dobermann + 237: Miniature Pinscher + 238: Greater Swiss Mountain Dog + 239: Bernese Mountain Dog + 240: Appenzeller Sennenhund + 241: Entlebucher Sennenhund + 242: Boxer + 243: Bullmastiff + 244: Tibetan Mastiff + 245: French Bulldog + 246: Great Dane + 247: St. Bernard + 248: husky + 249: Alaskan Malamute + 250: Siberian Husky + 251: Dalmatian + 252: Affenpinscher + 253: Basenji + 254: pug + 255: Leonberger + 256: Newfoundland + 257: Pyrenean Mountain Dog + 258: Samoyed + 259: Pomeranian + 260: Chow Chow + 261: Keeshond + 262: Griffon Bruxellois + 263: Pembroke Welsh Corgi + 264: Cardigan Welsh Corgi + 265: Toy Poodle + 266: Miniature Poodle + 267: Standard Poodle + 268: Mexican hairless dog + 269: grey wolf + 270: Alaskan tundra wolf + 271: red wolf + 272: coyote + 273: dingo + 274: dhole + 275: African wild dog + 276: hyena + 277: red fox + 278: kit fox + 279: Arctic fox + 280: grey fox + 281: tabby cat + 282: tiger cat + 283: Persian cat + 284: Siamese cat + 285: Egyptian Mau + 286: cougar + 287: lynx + 288: leopard + 289: snow leopard + 290: jaguar + 291: lion + 292: tiger + 293: cheetah + 294: brown bear + 295: American black bear + 296: polar bear + 297: sloth bear + 298: mongoose + 299: meerkat + 300: tiger beetle + 301: ladybug + 302: ground beetle + 303: longhorn beetle + 304: leaf beetle + 305: dung beetle + 306: rhinoceros beetle + 307: weevil + 308: fly + 309: bee + 310: ant + 311: grasshopper + 312: cricket + 313: stick insect + 314: cockroach + 315: mantis + 316: cicada + 317: leafhopper + 318: lacewing + 319: dragonfly + 320: damselfly + 321: red admiral + 322: ringlet + 323: monarch butterfly + 324: small white + 325: sulphur butterfly + 326: gossamer-winged butterfly + 327: starfish + 328: sea urchin + 329: sea cucumber + 330: cottontail rabbit + 331: hare + 332: Angora rabbit + 333: hamster + 334: porcupine + 335: fox squirrel + 336: marmot + 337: beaver + 338: guinea pig + 339: common sorrel + 340: zebra + 341: pig + 342: wild boar + 343: warthog + 344: hippopotamus + 345: ox + 346: water buffalo + 347: bison + 348: ram + 349: bighorn sheep + 350: Alpine ibex + 351: hartebeest + 352: impala + 353: gazelle + 354: dromedary + 355: llama + 356: weasel + 357: mink + 358: European polecat + 359: black-footed ferret + 360: otter + 361: skunk + 362: badger + 363: armadillo + 364: three-toed sloth + 365: orangutan + 366: gorilla + 367: chimpanzee + 368: gibbon + 369: siamang + 370: guenon + 371: patas monkey + 372: baboon + 373: macaque + 374: langur + 375: black-and-white colobus + 376: proboscis monkey + 377: marmoset + 378: white-headed capuchin + 379: howler monkey + 380: titi + 381: Geoffroy's spider monkey + 382: common squirrel monkey + 383: ring-tailed lemur + 384: indri + 385: Asian elephant + 386: African bush elephant + 387: red panda + 388: giant panda + 389: snoek + 390: eel + 391: coho salmon + 392: rock beauty + 393: clownfish + 394: sturgeon + 395: garfish + 396: lionfish + 397: pufferfish + 398: abacus + 399: abaya + 400: academic gown + 401: accordion + 402: acoustic guitar + 403: aircraft carrier + 404: airliner + 405: airship + 406: altar + 407: ambulance + 408: amphibious vehicle + 409: analog clock + 410: apiary + 411: apron + 412: waste container + 413: assault rifle + 414: backpack + 415: bakery + 416: balance beam + 417: balloon + 418: ballpoint pen + 419: Band-Aid + 420: banjo + 421: baluster + 422: barbell + 423: barber chair + 424: barbershop + 425: barn + 426: barometer + 427: barrel + 428: wheelbarrow + 429: baseball + 430: basketball + 431: bassinet + 432: bassoon + 433: swimming cap + 434: bath towel + 435: bathtub + 436: station wagon + 437: lighthouse + 438: beaker + 439: military cap + 440: beer bottle + 441: beer glass + 442: bell-cot + 443: bib + 444: tandem bicycle + 445: bikini + 446: ring binder + 447: binoculars + 448: birdhouse + 449: boathouse + 450: bobsleigh + 451: bolo tie + 452: poke bonnet + 453: bookcase + 454: bookstore + 455: bottle cap + 456: bow + 457: bow tie + 458: brass + 459: bra + 460: breakwater + 461: breastplate + 462: broom + 463: bucket + 464: buckle + 465: bulletproof vest + 466: high-speed train + 467: butcher shop + 468: taxicab + 469: cauldron + 470: candle + 471: cannon + 472: canoe + 473: can opener + 474: cardigan + 475: car mirror + 476: carousel + 477: tool kit + 478: carton + 479: car wheel + 480: automated teller machine + 481: cassette + 482: cassette player + 483: castle + 484: catamaran + 485: CD player + 486: cello + 487: mobile phone + 488: chain + 489: chain-link fence + 490: chain mail + 491: chainsaw + 492: chest + 493: chiffonier + 494: chime + 495: china cabinet + 496: Christmas stocking + 497: church + 498: movie theater + 499: cleaver + 500: cliff dwelling + 501: cloak + 502: clogs + 503: cocktail shaker + 504: coffee mug + 505: coffeemaker + 506: coil + 507: combination lock + 508: computer keyboard + 509: confectionery store + 510: container ship + 511: convertible + 512: corkscrew + 513: cornet + 514: cowboy boot + 515: cowboy hat + 516: cradle + 517: crane (machine) + 518: crash helmet + 519: crate + 520: infant bed + 521: Crock Pot + 522: croquet ball + 523: crutch + 524: cuirass + 525: dam + 526: desk + 527: desktop computer + 528: rotary dial telephone + 529: diaper + 530: digital clock + 531: digital watch + 532: dining table + 533: dishcloth + 534: dishwasher + 535: disc brake + 536: dock + 537: dog sled + 538: dome + 539: doormat + 540: drilling rig + 541: drum + 542: drumstick + 543: dumbbell + 544: Dutch oven + 545: electric fan + 546: electric guitar + 547: electric locomotive + 548: entertainment center + 549: envelope + 550: espresso machine + 551: face powder + 552: feather boa + 553: filing cabinet + 554: fireboat + 555: fire engine + 556: fire screen sheet + 557: flagpole + 558: flute + 559: folding chair + 560: football helmet + 561: forklift + 562: fountain + 563: fountain pen + 564: four-poster bed + 565: freight car + 566: French horn + 567: frying pan + 568: fur coat + 569: garbage truck + 570: gas mask + 571: gas pump + 572: goblet + 573: go-kart + 574: golf ball + 575: golf cart + 576: gondola + 577: gong + 578: gown + 579: grand piano + 580: greenhouse + 581: grille + 582: grocery store + 583: guillotine + 584: barrette + 585: hair spray + 586: half-track + 587: hammer + 588: hamper + 589: hair dryer + 590: hand-held computer + 591: handkerchief + 592: hard disk drive + 593: harmonica + 594: harp + 595: harvester + 596: hatchet + 597: holster + 598: home theater + 599: honeycomb + 600: hook + 601: hoop skirt + 602: horizontal bar + 603: horse-drawn vehicle + 604: hourglass + 605: iPod + 606: clothes iron + 607: jack-o'-lantern + 608: jeans + 609: jeep + 610: T-shirt + 611: jigsaw puzzle + 612: pulled rickshaw + 613: joystick + 614: kimono + 615: knee pad + 616: knot + 617: lab coat + 618: ladle + 619: lampshade + 620: laptop computer + 621: lawn mower + 622: lens cap + 623: paper knife + 624: library + 625: lifeboat + 626: lighter + 627: limousine + 628: ocean liner + 629: lipstick + 630: slip-on shoe + 631: lotion + 632: speaker + 633: loupe + 634: sawmill + 635: magnetic compass + 636: mail bag + 637: mailbox + 638: tights + 639: tank suit + 640: manhole cover + 641: maraca + 642: marimba + 643: mask + 644: match + 645: maypole + 646: maze + 647: measuring cup + 648: medicine chest + 649: megalith + 650: microphone + 651: microwave oven + 652: military uniform + 653: milk can + 654: minibus + 655: miniskirt + 656: minivan + 657: missile + 658: mitten + 659: mixing bowl + 660: mobile home + 661: Model T + 662: modem + 663: monastery + 664: monitor + 665: moped + 666: mortar + 667: square academic cap + 668: mosque + 669: mosquito net + 670: scooter + 671: mountain bike + 672: tent + 673: computer mouse + 674: mousetrap + 675: moving van + 676: muzzle + 677: nail + 678: neck brace + 679: necklace + 680: nipple + 681: notebook computer + 682: obelisk + 683: oboe + 684: ocarina + 685: odometer + 686: oil filter + 687: organ + 688: oscilloscope + 689: overskirt + 690: bullock cart + 691: oxygen mask + 692: packet + 693: paddle + 694: paddle wheel + 695: padlock + 696: paintbrush + 697: pajamas + 698: palace + 699: pan flute + 700: paper towel + 701: parachute + 702: parallel bars + 703: park bench + 704: parking meter + 705: passenger car + 706: patio + 707: payphone + 708: pedestal + 709: pencil case + 710: pencil sharpener + 711: perfume + 712: Petri dish + 713: photocopier + 714: plectrum + 715: Pickelhaube + 716: picket fence + 717: pickup truck + 718: pier + 719: piggy bank + 720: pill bottle + 721: pillow + 722: ping-pong ball + 723: pinwheel + 724: pirate ship + 725: pitcher + 726: hand plane + 727: planetarium + 728: plastic bag + 729: plate rack + 730: plow + 731: plunger + 732: Polaroid camera + 733: pole + 734: police van + 735: poncho + 736: billiard table + 737: soda bottle + 738: pot + 739: potter's wheel + 740: power drill + 741: prayer rug + 742: printer + 743: prison + 744: projectile + 745: projector + 746: hockey puck + 747: punching bag + 748: purse + 749: quill + 750: quilt + 751: race car + 752: racket + 753: radiator + 754: radio + 755: radio telescope + 756: rain barrel + 757: recreational vehicle + 758: reel + 759: reflex camera + 760: refrigerator + 761: remote control + 762: restaurant + 763: revolver + 764: rifle + 765: rocking chair + 766: rotisserie + 767: eraser + 768: rugby ball + 769: ruler + 770: running shoe + 771: safe + 772: safety pin + 773: salt shaker + 774: sandal + 775: sarong + 776: saxophone + 777: scabbard + 778: weighing scale + 779: school bus + 780: schooner + 781: scoreboard + 782: CRT screen + 783: screw + 784: screwdriver + 785: seat belt + 786: sewing machine + 787: shield + 788: shoe store + 789: shoji + 790: shopping basket + 791: shopping cart + 792: shovel + 793: shower cap + 794: shower curtain + 795: ski + 796: ski mask + 797: sleeping bag + 798: slide rule + 799: sliding door + 800: slot machine + 801: snorkel + 802: snowmobile + 803: snowplow + 804: soap dispenser + 805: soccer ball + 806: sock + 807: solar thermal collector + 808: sombrero + 809: soup bowl + 810: space bar + 811: space heater + 812: space shuttle + 813: spatula + 814: motorboat + 815: spider web + 816: spindle + 817: sports car + 818: spotlight + 819: stage + 820: steam locomotive + 821: through arch bridge + 822: steel drum + 823: stethoscope + 824: scarf + 825: stone wall + 826: stopwatch + 827: stove + 828: strainer + 829: tram + 830: stretcher + 831: couch + 832: stupa + 833: submarine + 834: suit + 835: sundial + 836: sunglass + 837: sunglasses + 838: sunscreen + 839: suspension bridge + 840: mop + 841: sweatshirt + 842: swimsuit + 843: swing + 844: switch + 845: syringe + 846: table lamp + 847: tank + 848: tape player + 849: teapot + 850: teddy bear + 851: television + 852: tennis ball + 853: thatched roof + 854: front curtain + 855: thimble + 856: threshing machine + 857: throne + 858: tile roof + 859: toaster + 860: tobacco shop + 861: toilet seat + 862: torch + 863: totem pole + 864: tow truck + 865: toy store + 866: tractor + 867: semi-trailer truck + 868: tray + 869: trench coat + 870: tricycle + 871: trimaran + 872: tripod + 873: triumphal arch + 874: trolleybus + 875: trombone + 876: tub + 877: turnstile + 878: typewriter keyboard + 879: umbrella + 880: unicycle + 881: upright piano + 882: vacuum cleaner + 883: vase + 884: vault + 885: velvet + 886: vending machine + 887: vestment + 888: viaduct + 889: violin + 890: volleyball + 891: waffle iron + 892: wall clock + 893: wallet + 894: wardrobe + 895: military aircraft + 896: sink + 897: washing machine + 898: water bottle + 899: water jug + 900: water tower + 901: whiskey jug + 902: whistle + 903: wig + 904: window screen + 905: window shade + 906: Windsor tie + 907: wine bottle + 908: wing + 909: wok + 910: wooden spoon + 911: wool + 912: split-rail fence + 913: shipwreck + 914: yawl + 915: yurt + 916: website + 917: comic book + 918: crossword + 919: traffic sign + 920: traffic light + 921: dust jacket + 922: menu + 923: plate + 924: guacamole + 925: consomme + 926: hot pot + 927: trifle + 928: ice cream + 929: ice pop + 930: baguette + 931: bagel + 932: pretzel + 933: cheeseburger + 934: hot dog + 935: mashed potato + 936: cabbage + 937: broccoli + 938: cauliflower + 939: zucchini + 940: spaghetti squash + 941: acorn squash + 942: butternut squash + 943: cucumber + 944: artichoke + 945: bell pepper + 946: cardoon + 947: mushroom + 948: Granny Smith + 949: strawberry + 950: orange + 951: lemon + 952: fig + 953: pineapple + 954: banana + 955: jackfruit + 956: custard apple + 957: pomegranate + 958: hay + 959: carbonara + 960: chocolate syrup + 961: dough + 962: meatloaf + 963: pizza + 964: pot pie + 965: burrito + 966: red wine + 967: espresso + 968: cup + 969: eggnog + 970: alp + 971: bubble + 972: cliff + 973: coral reef + 974: geyser + 975: lakeshore + 976: promontory + 977: shoal + 978: seashore + 979: valley + 980: volcano + 981: baseball player + 982: bridegroom + 983: scuba diver + 984: rapeseed + 985: daisy + 986: yellow lady's slipper + 987: corn + 988: acorn + 989: rose hip + 990: horse chestnut seed + 991: coral fungus + 992: agaric + 993: gyromitra + 994: stinkhorn mushroom + 995: earth star + 996: hen-of-the-woods + 997: bolete + 998: ear + 999: toilet paper + +# Imagenet class codes to human-readable names +map: + n01440764: tench + n01443537: goldfish + n01484850: great_white_shark + n01491361: tiger_shark + n01494475: hammerhead + n01496331: electric_ray + n01498041: stingray + n01514668: cock + n01514859: hen + n01518878: ostrich + n01530575: brambling + n01531178: goldfinch + n01532829: house_finch + n01534433: junco + n01537544: indigo_bunting + n01558993: robin + n01560419: bulbul + n01580077: jay + n01582220: magpie + n01592084: chickadee + n01601694: water_ouzel + n01608432: kite + n01614925: bald_eagle + n01616318: vulture + n01622779: great_grey_owl + n01629819: European_fire_salamander + n01630670: common_newt + n01631663: eft + n01632458: spotted_salamander + n01632777: axolotl + n01641577: bullfrog + n01644373: tree_frog + n01644900: tailed_frog + n01664065: loggerhead + n01665541: leatherback_turtle + n01667114: mud_turtle + n01667778: terrapin + n01669191: box_turtle + n01675722: banded_gecko + n01677366: common_iguana + n01682714: American_chameleon + n01685808: whiptail + n01687978: agama + n01688243: frilled_lizard + n01689811: alligator_lizard + n01692333: Gila_monster + n01693334: green_lizard + n01694178: African_chameleon + n01695060: Komodo_dragon + n01697457: African_crocodile + n01698640: American_alligator + n01704323: triceratops + n01728572: thunder_snake + n01728920: ringneck_snake + n01729322: hognose_snake + n01729977: green_snake + n01734418: king_snake + n01735189: garter_snake + n01737021: water_snake + n01739381: vine_snake + n01740131: night_snake + n01742172: boa_constrictor + n01744401: rock_python + n01748264: Indian_cobra + n01749939: green_mamba + n01751748: sea_snake + n01753488: horned_viper + n01755581: diamondback + n01756291: sidewinder + n01768244: trilobite + n01770081: harvestman + n01770393: scorpion + n01773157: black_and_gold_garden_spider + n01773549: barn_spider + n01773797: garden_spider + n01774384: black_widow + n01774750: tarantula + n01775062: wolf_spider + n01776313: tick + n01784675: centipede + n01795545: black_grouse + n01796340: ptarmigan + n01797886: ruffed_grouse + n01798484: prairie_chicken + n01806143: peacock + n01806567: quail + n01807496: partridge + n01817953: African_grey + n01818515: macaw + n01819313: sulphur-crested_cockatoo + n01820546: lorikeet + n01824575: coucal + n01828970: bee_eater + n01829413: hornbill + n01833805: hummingbird + n01843065: jacamar + n01843383: toucan + n01847000: drake + n01855032: red-breasted_merganser + n01855672: goose + n01860187: black_swan + n01871265: tusker + n01872401: echidna + n01873310: platypus + n01877812: wallaby + n01882714: koala + n01883070: wombat + n01910747: jellyfish + n01914609: sea_anemone + n01917289: brain_coral + n01924916: flatworm + n01930112: nematode + n01943899: conch + n01944390: snail + n01945685: slug + n01950731: sea_slug + n01955084: chiton + n01968897: chambered_nautilus + n01978287: Dungeness_crab + n01978455: rock_crab + n01980166: fiddler_crab + n01981276: king_crab + n01983481: American_lobster + n01984695: spiny_lobster + n01985128: crayfish + n01986214: hermit_crab + n01990800: isopod + n02002556: white_stork + n02002724: black_stork + n02006656: spoonbill + n02007558: flamingo + n02009229: little_blue_heron + n02009912: American_egret + n02011460: bittern + n02012849: crane_(bird) + n02013706: limpkin + n02017213: European_gallinule + n02018207: American_coot + n02018795: bustard + n02025239: ruddy_turnstone + n02027492: red-backed_sandpiper + n02028035: redshank + n02033041: dowitcher + n02037110: oystercatcher + n02051845: pelican + n02056570: king_penguin + n02058221: albatross + n02066245: grey_whale + n02071294: killer_whale + n02074367: dugong + n02077923: sea_lion + n02085620: Chihuahua + n02085782: Japanese_spaniel + n02085936: Maltese_dog + n02086079: Pekinese + n02086240: Shih-Tzu + n02086646: Blenheim_spaniel + n02086910: papillon + n02087046: toy_terrier + n02087394: Rhodesian_ridgeback + n02088094: Afghan_hound + n02088238: basset + n02088364: beagle + n02088466: bloodhound + n02088632: bluetick + n02089078: black-and-tan_coonhound + n02089867: Walker_hound + n02089973: English_foxhound + n02090379: redbone + n02090622: borzoi + n02090721: Irish_wolfhound + n02091032: Italian_greyhound + n02091134: whippet + n02091244: Ibizan_hound + n02091467: Norwegian_elkhound + n02091635: otterhound + n02091831: Saluki + n02092002: Scottish_deerhound + n02092339: Weimaraner + n02093256: Staffordshire_bullterrier + n02093428: American_Staffordshire_terrier + n02093647: Bedlington_terrier + n02093754: Border_terrier + n02093859: Kerry_blue_terrier + n02093991: Irish_terrier + n02094114: Norfolk_terrier + n02094258: Norwich_terrier + n02094433: Yorkshire_terrier + n02095314: wire-haired_fox_terrier + n02095570: Lakeland_terrier + n02095889: Sealyham_terrier + n02096051: Airedale + n02096177: cairn + n02096294: Australian_terrier + n02096437: Dandie_Dinmont + n02096585: Boston_bull + n02097047: miniature_schnauzer + n02097130: giant_schnauzer + n02097209: standard_schnauzer + n02097298: Scotch_terrier + n02097474: Tibetan_terrier + n02097658: silky_terrier + n02098105: soft-coated_wheaten_terrier + n02098286: West_Highland_white_terrier + n02098413: Lhasa + n02099267: flat-coated_retriever + n02099429: curly-coated_retriever + n02099601: golden_retriever + n02099712: Labrador_retriever + n02099849: Chesapeake_Bay_retriever + n02100236: German_short-haired_pointer + n02100583: vizsla + n02100735: English_setter + n02100877: Irish_setter + n02101006: Gordon_setter + n02101388: Brittany_spaniel + n02101556: clumber + n02102040: English_springer + n02102177: Welsh_springer_spaniel + n02102318: cocker_spaniel + n02102480: Sussex_spaniel + n02102973: Irish_water_spaniel + n02104029: kuvasz + n02104365: schipperke + n02105056: groenendael + n02105162: malinois + n02105251: briard + n02105412: kelpie + n02105505: komondor + n02105641: Old_English_sheepdog + n02105855: Shetland_sheepdog + n02106030: collie + n02106166: Border_collie + n02106382: Bouvier_des_Flandres + n02106550: Rottweiler + n02106662: German_shepherd + n02107142: Doberman + n02107312: miniature_pinscher + n02107574: Greater_Swiss_Mountain_dog + n02107683: Bernese_mountain_dog + n02107908: Appenzeller + n02108000: EntleBucher + n02108089: boxer + n02108422: bull_mastiff + n02108551: Tibetan_mastiff + n02108915: French_bulldog + n02109047: Great_Dane + n02109525: Saint_Bernard + n02109961: Eskimo_dog + n02110063: malamute + n02110185: Siberian_husky + n02110341: dalmatian + n02110627: affenpinscher + n02110806: basenji + n02110958: pug + n02111129: Leonberg + n02111277: Newfoundland + n02111500: Great_Pyrenees + n02111889: Samoyed + n02112018: Pomeranian + n02112137: chow + n02112350: keeshond + n02112706: Brabancon_griffon + n02113023: Pembroke + n02113186: Cardigan + n02113624: toy_poodle + n02113712: miniature_poodle + n02113799: standard_poodle + n02113978: Mexican_hairless + n02114367: timber_wolf + n02114548: white_wolf + n02114712: red_wolf + n02114855: coyote + n02115641: dingo + n02115913: dhole + n02116738: African_hunting_dog + n02117135: hyena + n02119022: red_fox + n02119789: kit_fox + n02120079: Arctic_fox + n02120505: grey_fox + n02123045: tabby + n02123159: tiger_cat + n02123394: Persian_cat + n02123597: Siamese_cat + n02124075: Egyptian_cat + n02125311: cougar + n02127052: lynx + n02128385: leopard + n02128757: snow_leopard + n02128925: jaguar + n02129165: lion + n02129604: tiger + n02130308: cheetah + n02132136: brown_bear + n02133161: American_black_bear + n02134084: ice_bear + n02134418: sloth_bear + n02137549: mongoose + n02138441: meerkat + n02165105: tiger_beetle + n02165456: ladybug + n02167151: ground_beetle + n02168699: long-horned_beetle + n02169497: leaf_beetle + n02172182: dung_beetle + n02174001: rhinoceros_beetle + n02177972: weevil + n02190166: fly + n02206856: bee + n02219486: ant + n02226429: grasshopper + n02229544: cricket + n02231487: walking_stick + n02233338: cockroach + n02236044: mantis + n02256656: cicada + n02259212: leafhopper + n02264363: lacewing + n02268443: dragonfly + n02268853: damselfly + n02276258: admiral + n02277742: ringlet + n02279972: monarch + n02280649: cabbage_butterfly + n02281406: sulphur_butterfly + n02281787: lycaenid + n02317335: starfish + n02319095: sea_urchin + n02321529: sea_cucumber + n02325366: wood_rabbit + n02326432: hare + n02328150: Angora + n02342885: hamster + n02346627: porcupine + n02356798: fox_squirrel + n02361337: marmot + n02363005: beaver + n02364673: guinea_pig + n02389026: sorrel + n02391049: zebra + n02395406: hog + n02396427: wild_boar + n02397096: warthog + n02398521: hippopotamus + n02403003: ox + n02408429: water_buffalo + n02410509: bison + n02412080: ram + n02415577: bighorn + n02417914: ibex + n02422106: hartebeest + n02422699: impala + n02423022: gazelle + n02437312: Arabian_camel + n02437616: llama + n02441942: weasel + n02442845: mink + n02443114: polecat + n02443484: black-footed_ferret + n02444819: otter + n02445715: skunk + n02447366: badger + n02454379: armadillo + n02457408: three-toed_sloth + n02480495: orangutan + n02480855: gorilla + n02481823: chimpanzee + n02483362: gibbon + n02483708: siamang + n02484975: guenon + n02486261: patas + n02486410: baboon + n02487347: macaque + n02488291: langur + n02488702: colobus + n02489166: proboscis_monkey + n02490219: marmoset + n02492035: capuchin + n02492660: howler_monkey + n02493509: titi + n02493793: spider_monkey + n02494079: squirrel_monkey + n02497673: Madagascar_cat + n02500267: indri + n02504013: Indian_elephant + n02504458: African_elephant + n02509815: lesser_panda + n02510455: giant_panda + n02514041: barracouta + n02526121: eel + n02536864: coho + n02606052: rock_beauty + n02607072: anemone_fish + n02640242: sturgeon + n02641379: gar + n02643566: lionfish + n02655020: puffer + n02666196: abacus + n02667093: abaya + n02669723: academic_gown + n02672831: accordion + n02676566: acoustic_guitar + n02687172: aircraft_carrier + n02690373: airliner + n02692877: airship + n02699494: altar + n02701002: ambulance + n02704792: amphibian + n02708093: analog_clock + n02727426: apiary + n02730930: apron + n02747177: ashcan + n02749479: assault_rifle + n02769748: backpack + n02776631: bakery + n02777292: balance_beam + n02782093: balloon + n02783161: ballpoint + n02786058: Band_Aid + n02787622: banjo + n02788148: bannister + n02790996: barbell + n02791124: barber_chair + n02791270: barbershop + n02793495: barn + n02794156: barometer + n02795169: barrel + n02797295: barrow + n02799071: baseball + n02802426: basketball + n02804414: bassinet + n02804610: bassoon + n02807133: bathing_cap + n02808304: bath_towel + n02808440: bathtub + n02814533: beach_wagon + n02814860: beacon + n02815834: beaker + n02817516: bearskin + n02823428: beer_bottle + n02823750: beer_glass + n02825657: bell_cote + n02834397: bib + n02835271: bicycle-built-for-two + n02837789: bikini + n02840245: binder + n02841315: binoculars + n02843684: birdhouse + n02859443: boathouse + n02860847: bobsled + n02865351: bolo_tie + n02869837: bonnet + n02870880: bookcase + n02871525: bookshop + n02877765: bottlecap + n02879718: bow + n02883205: bow_tie + n02892201: brass + n02892767: brassiere + n02894605: breakwater + n02895154: breastplate + n02906734: broom + n02909870: bucket + n02910353: buckle + n02916936: bulletproof_vest + n02917067: bullet_train + n02927161: butcher_shop + n02930766: cab + n02939185: caldron + n02948072: candle + n02950826: cannon + n02951358: canoe + n02951585: can_opener + n02963159: cardigan + n02965783: car_mirror + n02966193: carousel + n02966687: carpenter's_kit + n02971356: carton + n02974003: car_wheel + n02977058: cash_machine + n02978881: cassette + n02979186: cassette_player + n02980441: castle + n02981792: catamaran + n02988304: CD_player + n02992211: cello + n02992529: cellular_telephone + n02999410: chain + n03000134: chainlink_fence + n03000247: chain_mail + n03000684: chain_saw + n03014705: chest + n03016953: chiffonier + n03017168: chime + n03018349: china_cabinet + n03026506: Christmas_stocking + n03028079: church + n03032252: cinema + n03041632: cleaver + n03042490: cliff_dwelling + n03045698: cloak + n03047690: clog + n03062245: cocktail_shaker + n03063599: coffee_mug + n03063689: coffeepot + n03065424: coil + n03075370: combination_lock + n03085013: computer_keyboard + n03089624: confectionery + n03095699: container_ship + n03100240: convertible + n03109150: corkscrew + n03110669: cornet + n03124043: cowboy_boot + n03124170: cowboy_hat + n03125729: cradle + n03126707: crane_(machine) + n03127747: crash_helmet + n03127925: crate + n03131574: crib + n03133878: Crock_Pot + n03134739: croquet_ball + n03141823: crutch + n03146219: cuirass + n03160309: dam + n03179701: desk + n03180011: desktop_computer + n03187595: dial_telephone + n03188531: diaper + n03196217: digital_clock + n03197337: digital_watch + n03201208: dining_table + n03207743: dishrag + n03207941: dishwasher + n03208938: disk_brake + n03216828: dock + n03218198: dogsled + n03220513: dome + n03223299: doormat + n03240683: drilling_platform + n03249569: drum + n03250847: drumstick + n03255030: dumbbell + n03259280: Dutch_oven + n03271574: electric_fan + n03272010: electric_guitar + n03272562: electric_locomotive + n03290653: entertainment_center + n03291819: envelope + n03297495: espresso_maker + n03314780: face_powder + n03325584: feather_boa + n03337140: file + n03344393: fireboat + n03345487: fire_engine + n03347037: fire_screen + n03355925: flagpole + n03372029: flute + n03376595: folding_chair + n03379051: football_helmet + n03384352: forklift + n03388043: fountain + n03388183: fountain_pen + n03388549: four-poster + n03393912: freight_car + n03394916: French_horn + n03400231: frying_pan + n03404251: fur_coat + n03417042: garbage_truck + n03424325: gasmask + n03425413: gas_pump + n03443371: goblet + n03444034: go-kart + n03445777: golf_ball + n03445924: golfcart + n03447447: gondola + n03447721: gong + n03450230: gown + n03452741: grand_piano + n03457902: greenhouse + n03459775: grille + n03461385: grocery_store + n03467068: guillotine + n03476684: hair_slide + n03476991: hair_spray + n03478589: half_track + n03481172: hammer + n03482405: hamper + n03483316: hand_blower + n03485407: hand-held_computer + n03485794: handkerchief + n03492542: hard_disc + n03494278: harmonica + n03495258: harp + n03496892: harvester + n03498962: hatchet + n03527444: holster + n03529860: home_theater + n03530642: honeycomb + n03532672: hook + n03534580: hoopskirt + n03535780: horizontal_bar + n03538406: horse_cart + n03544143: hourglass + n03584254: iPod + n03584829: iron + n03590841: jack-o'-lantern + n03594734: jean + n03594945: jeep + n03595614: jersey + n03598930: jigsaw_puzzle + n03599486: jinrikisha + n03602883: joystick + n03617480: kimono + n03623198: knee_pad + n03627232: knot + n03630383: lab_coat + n03633091: ladle + n03637318: lampshade + n03642806: laptop + n03649909: lawn_mower + n03657121: lens_cap + n03658185: letter_opener + n03661043: library + n03662601: lifeboat + n03666591: lighter + n03670208: limousine + n03673027: liner + n03676483: lipstick + n03680355: Loafer + n03690938: lotion + n03691459: loudspeaker + n03692522: loupe + n03697007: lumbermill + n03706229: magnetic_compass + n03709823: mailbag + n03710193: mailbox + n03710637: maillot_(tights) + n03710721: maillot_(tank_suit) + n03717622: manhole_cover + n03720891: maraca + n03721384: marimba + n03724870: mask + n03729826: matchstick + n03733131: maypole + n03733281: maze + n03733805: measuring_cup + n03742115: medicine_chest + n03743016: megalith + n03759954: microphone + n03761084: microwave + n03763968: military_uniform + n03764736: milk_can + n03769881: minibus + n03770439: miniskirt + n03770679: minivan + n03773504: missile + n03775071: mitten + n03775546: mixing_bowl + n03776460: mobile_home + n03777568: Model_T + n03777754: modem + n03781244: monastery + n03782006: monitor + n03785016: moped + n03786901: mortar + n03787032: mortarboard + n03788195: mosque + n03788365: mosquito_net + n03791053: motor_scooter + n03792782: mountain_bike + n03792972: mountain_tent + n03793489: mouse + n03794056: mousetrap + n03796401: moving_van + n03803284: muzzle + n03804744: nail + n03814639: neck_brace + n03814906: necklace + n03825788: nipple + n03832673: notebook + n03837869: obelisk + n03838899: oboe + n03840681: ocarina + n03841143: odometer + n03843555: oil_filter + n03854065: organ + n03857828: oscilloscope + n03866082: overskirt + n03868242: oxcart + n03868863: oxygen_mask + n03871628: packet + n03873416: paddle + n03874293: paddlewheel + n03874599: padlock + n03876231: paintbrush + n03877472: pajama + n03877845: palace + n03884397: panpipe + n03887697: paper_towel + n03888257: parachute + n03888605: parallel_bars + n03891251: park_bench + n03891332: parking_meter + n03895866: passenger_car + n03899768: patio + n03902125: pay-phone + n03903868: pedestal + n03908618: pencil_box + n03908714: pencil_sharpener + n03916031: perfume + n03920288: Petri_dish + n03924679: photocopier + n03929660: pick + n03929855: pickelhaube + n03930313: picket_fence + n03930630: pickup + n03933933: pier + n03935335: piggy_bank + n03937543: pill_bottle + n03938244: pillow + n03942813: ping-pong_ball + n03944341: pinwheel + n03947888: pirate + n03950228: pitcher + n03954731: plane + n03956157: planetarium + n03958227: plastic_bag + n03961711: plate_rack + n03967562: plow + n03970156: plunger + n03976467: Polaroid_camera + n03976657: pole + n03977966: police_van + n03980874: poncho + n03982430: pool_table + n03983396: pop_bottle + n03991062: pot + n03992509: potter's_wheel + n03995372: power_drill + n03998194: prayer_rug + n04004767: printer + n04005630: prison + n04008634: projectile + n04009552: projector + n04019541: puck + n04023962: punching_bag + n04026417: purse + n04033901: quill + n04033995: quilt + n04037443: racer + n04039381: racket + n04040759: radiator + n04041544: radio + n04044716: radio_telescope + n04049303: rain_barrel + n04065272: recreational_vehicle + n04067472: reel + n04069434: reflex_camera + n04070727: refrigerator + n04074963: remote_control + n04081281: restaurant + n04086273: revolver + n04090263: rifle + n04099969: rocking_chair + n04111531: rotisserie + n04116512: rubber_eraser + n04118538: rugby_ball + n04118776: rule + n04120489: running_shoe + n04125021: safe + n04127249: safety_pin + n04131690: saltshaker + n04133789: sandal + n04136333: sarong + n04141076: sax + n04141327: scabbard + n04141975: scale + n04146614: school_bus + n04147183: schooner + n04149813: scoreboard + n04152593: screen + n04153751: screw + n04154565: screwdriver + n04162706: seat_belt + n04179913: sewing_machine + n04192698: shield + n04200800: shoe_shop + n04201297: shoji + n04204238: shopping_basket + n04204347: shopping_cart + n04208210: shovel + n04209133: shower_cap + n04209239: shower_curtain + n04228054: ski + n04229816: ski_mask + n04235860: sleeping_bag + n04238763: slide_rule + n04239074: sliding_door + n04243546: slot + n04251144: snorkel + n04252077: snowmobile + n04252225: snowplow + n04254120: soap_dispenser + n04254680: soccer_ball + n04254777: sock + n04258138: solar_dish + n04259630: sombrero + n04263257: soup_bowl + n04264628: space_bar + n04265275: space_heater + n04266014: space_shuttle + n04270147: spatula + n04273569: speedboat + n04275548: spider_web + n04277352: spindle + n04285008: sports_car + n04286575: spotlight + n04296562: stage + n04310018: steam_locomotive + n04311004: steel_arch_bridge + n04311174: steel_drum + n04317175: stethoscope + n04325704: stole + n04326547: stone_wall + n04328186: stopwatch + n04330267: stove + n04332243: strainer + n04335435: streetcar + n04336792: stretcher + n04344873: studio_couch + n04346328: stupa + n04347754: submarine + n04350905: suit + n04355338: sundial + n04355933: sunglass + n04356056: sunglasses + n04357314: sunscreen + n04366367: suspension_bridge + n04367480: swab + n04370456: sweatshirt + n04371430: swimming_trunks + n04371774: swing + n04372370: switch + n04376876: syringe + n04380533: table_lamp + n04389033: tank + n04392985: tape_player + n04398044: teapot + n04399382: teddy + n04404412: television + n04409515: tennis_ball + n04417672: thatch + n04418357: theater_curtain + n04423845: thimble + n04428191: thresher + n04429376: throne + n04435653: tile_roof + n04442312: toaster + n04443257: tobacco_shop + n04447861: toilet_seat + n04456115: torch + n04458633: totem_pole + n04461696: tow_truck + n04462240: toyshop + n04465501: tractor + n04467665: trailer_truck + n04476259: tray + n04479046: trench_coat + n04482393: tricycle + n04483307: trimaran + n04485082: tripod + n04486054: triumphal_arch + n04487081: trolleybus + n04487394: trombone + n04493381: tub + n04501370: turnstile + n04505470: typewriter_keyboard + n04507155: umbrella + n04509417: unicycle + n04515003: upright + n04517823: vacuum + n04522168: vase + n04523525: vault + n04525038: velvet + n04525305: vending_machine + n04532106: vestment + n04532670: viaduct + n04536866: violin + n04540053: volleyball + n04542943: waffle_iron + n04548280: wall_clock + n04548362: wallet + n04550184: wardrobe + n04552348: warplane + n04553703: washbasin + n04554684: washer + n04557648: water_bottle + n04560804: water_jug + n04562935: water_tower + n04579145: whiskey_jug + n04579432: whistle + n04584207: wig + n04589890: window_screen + n04590129: window_shade + n04591157: Windsor_tie + n04591713: wine_bottle + n04592741: wing + n04596742: wok + n04597913: wooden_spoon + n04599235: wool + n04604644: worm_fence + n04606251: wreck + n04612504: yawl + n04613696: yurt + n06359193: web_site + n06596364: comic_book + n06785654: crossword_puzzle + n06794110: street_sign + n06874185: traffic_light + n07248320: book_jacket + n07565083: menu + n07579787: plate + n07583066: guacamole + n07584110: consomme + n07590611: hot_pot + n07613480: trifle + n07614500: ice_cream + n07615774: ice_lolly + n07684084: French_loaf + n07693725: bagel + n07695742: pretzel + n07697313: cheeseburger + n07697537: hotdog + n07711569: mashed_potato + n07714571: head_cabbage + n07714990: broccoli + n07715103: cauliflower + n07716358: zucchini + n07716906: spaghetti_squash + n07717410: acorn_squash + n07717556: butternut_squash + n07718472: cucumber + n07718747: artichoke + n07720875: bell_pepper + n07730033: cardoon + n07734744: mushroom + n07742313: Granny_Smith + n07745940: strawberry + n07747607: orange + n07749582: lemon + n07753113: fig + n07753275: pineapple + n07753592: banana + n07754684: jackfruit + n07760859: custard_apple + n07768694: pomegranate + n07802026: hay + n07831146: carbonara + n07836838: chocolate_sauce + n07860988: dough + n07871810: meat_loaf + n07873807: pizza + n07875152: potpie + n07880968: burrito + n07892512: red_wine + n07920052: espresso + n07930864: cup + n07932039: eggnog + n09193705: alp + n09229709: bubble + n09246464: cliff + n09256479: coral_reef + n09288635: geyser + n09332890: lakeside + n09399592: promontory + n09421951: sandbar + n09428293: seashore + n09468604: valley + n09472597: volcano + n09835506: ballplayer + n10148035: groom + n10565667: scuba_diver + n11879895: rapeseed + n11939491: daisy + n12057211: yellow_lady's_slipper + n12144580: corn + n12267677: acorn + n12620546: hip + n12768682: buckeye + n12985857: coral_fungus + n12998815: agaric + n13037406: gyromitra + n13040303: stinkhorn + n13044778: earthstar + n13052670: hen-of-the-woods + n13054560: bolete + n13133613: ear + n15075141: toilet_tissue + +# Download script/URL (optional) +download: yolo/data/scripts/get_imagenet.sh diff --git a/ultralytics/cfg/datasets/Objects365.yaml b/ultralytics/cfg/datasets/Objects365.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89921364a5252d9d4eb54029953d648021c0484a --- /dev/null +++ b/ultralytics/cfg/datasets/Objects365.yaml @@ -0,0 +1,443 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Objects365 dataset https://www.objects365.org/ by Megvii +# Documentation: https://docs.ultralytics.com/datasets/detect/objects365/ +# Example usage: yolo train data=Objects365.yaml +# parent +# ├── ultralytics +# └── datasets +# └── Objects365 ← downloads here (712 GB = 367G data + 345G zips) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/Objects365 # dataset root dir +train: images/train # train images (relative to 'path') 1742289 images +val: images/val # val images (relative to 'path') 80000 images +test: # test images (optional) + +# Classes +names: + 0: Person + 1: Sneakers + 2: Chair + 3: Other Shoes + 4: Hat + 5: Car + 6: Lamp + 7: Glasses + 8: Bottle + 9: Desk + 10: Cup + 11: Street Lights + 12: Cabinet/shelf + 13: Handbag/Satchel + 14: Bracelet + 15: Plate + 16: Picture/Frame + 17: Helmet + 18: Book + 19: Gloves + 20: Storage box + 21: Boat + 22: Leather Shoes + 23: Flower + 24: Bench + 25: Potted Plant + 26: Bowl/Basin + 27: Flag + 28: Pillow + 29: Boots + 30: Vase + 31: Microphone + 32: Necklace + 33: Ring + 34: SUV + 35: Wine Glass + 36: Belt + 37: Monitor/TV + 38: Backpack + 39: Umbrella + 40: Traffic Light + 41: Speaker + 42: Watch + 43: Tie + 44: Trash bin Can + 45: Slippers + 46: Bicycle + 47: Stool + 48: Barrel/bucket + 49: Van + 50: Couch + 51: Sandals + 52: Basket + 53: Drum + 54: Pen/Pencil + 55: Bus + 56: Wild Bird + 57: High Heels + 58: Motorcycle + 59: Guitar + 60: Carpet + 61: Cell Phone + 62: Bread + 63: Camera + 64: Canned + 65: Truck + 66: Traffic cone + 67: Cymbal + 68: Lifesaver + 69: Towel + 70: Stuffed Toy + 71: Candle + 72: Sailboat + 73: Laptop + 74: Awning + 75: Bed + 76: Faucet + 77: Tent + 78: Horse + 79: Mirror + 80: Power outlet + 81: Sink + 82: Apple + 83: Air Conditioner + 84: Knife + 85: Hockey Stick + 86: Paddle + 87: Pickup Truck + 88: Fork + 89: Traffic Sign + 90: Balloon + 91: Tripod + 92: Dog + 93: Spoon + 94: Clock + 95: Pot + 96: Cow + 97: Cake + 98: Dining Table + 99: Sheep + 100: Hanger + 101: Blackboard/Whiteboard + 102: Napkin + 103: Other Fish + 104: Orange/Tangerine + 105: Toiletry + 106: Keyboard + 107: Tomato + 108: Lantern + 109: Machinery Vehicle + 110: Fan + 111: Green Vegetables + 112: Banana + 113: Baseball Glove + 114: Airplane + 115: Mouse + 116: Train + 117: Pumpkin + 118: Soccer + 119: Skiboard + 120: Luggage + 121: Nightstand + 122: Tea pot + 123: Telephone + 124: Trolley + 125: Head Phone + 126: Sports Car + 127: Stop Sign + 128: Dessert + 129: Scooter + 130: Stroller + 131: Crane + 132: Remote + 133: Refrigerator + 134: Oven + 135: Lemon + 136: Duck + 137: Baseball Bat + 138: Surveillance Camera + 139: Cat + 140: Jug + 141: Broccoli + 142: Piano + 143: Pizza + 144: Elephant + 145: Skateboard + 146: Surfboard + 147: Gun + 148: Skating and Skiing shoes + 149: Gas stove + 150: Donut + 151: Bow Tie + 152: Carrot + 153: Toilet + 154: Kite + 155: Strawberry + 156: Other Balls + 157: Shovel + 158: Pepper + 159: Computer Box + 160: Toilet Paper + 161: Cleaning Products + 162: Chopsticks + 163: Microwave + 164: Pigeon + 165: Baseball + 166: Cutting/chopping Board + 167: Coffee Table + 168: Side Table + 169: Scissors + 170: Marker + 171: Pie + 172: Ladder + 173: Snowboard + 174: Cookies + 175: Radiator + 176: Fire Hydrant + 177: Basketball + 178: Zebra + 179: Grape + 180: Giraffe + 181: Potato + 182: Sausage + 183: Tricycle + 184: Violin + 185: Egg + 186: Fire Extinguisher + 187: Candy + 188: Fire Truck + 189: Billiards + 190: Converter + 191: Bathtub + 192: Wheelchair + 193: Golf Club + 194: Briefcase + 195: Cucumber + 196: Cigar/Cigarette + 197: Paint Brush + 198: Pear + 199: Heavy Truck + 200: Hamburger + 201: Extractor + 202: Extension Cord + 203: Tong + 204: Tennis Racket + 205: Folder + 206: American Football + 207: earphone + 208: Mask + 209: Kettle + 210: Tennis + 211: Ship + 212: Swing + 213: Coffee Machine + 214: Slide + 215: Carriage + 216: Onion + 217: Green beans + 218: Projector + 219: Frisbee + 220: Washing Machine/Drying Machine + 221: Chicken + 222: Printer + 223: Watermelon + 224: Saxophone + 225: Tissue + 226: Toothbrush + 227: Ice cream + 228: Hot-air balloon + 229: Cello + 230: French Fries + 231: Scale + 232: Trophy + 233: Cabbage + 234: Hot dog + 235: Blender + 236: Peach + 237: Rice + 238: Wallet/Purse + 239: Volleyball + 240: Deer + 241: Goose + 242: Tape + 243: Tablet + 244: Cosmetics + 245: Trumpet + 246: Pineapple + 247: Golf Ball + 248: Ambulance + 249: Parking meter + 250: Mango + 251: Key + 252: Hurdle + 253: Fishing Rod + 254: Medal + 255: Flute + 256: Brush + 257: Penguin + 258: Megaphone + 259: Corn + 260: Lettuce + 261: Garlic + 262: Swan + 263: Helicopter + 264: Green Onion + 265: Sandwich + 266: Nuts + 267: Speed Limit Sign + 268: Induction Cooker + 269: Broom + 270: Trombone + 271: Plum + 272: Rickshaw + 273: Goldfish + 274: Kiwi fruit + 275: Router/modem + 276: Poker Card + 277: Toaster + 278: Shrimp + 279: Sushi + 280: Cheese + 281: Notepaper + 282: Cherry + 283: Pliers + 284: CD + 285: Pasta + 286: Hammer + 287: Cue + 288: Avocado + 289: Hami melon + 290: Flask + 291: Mushroom + 292: Screwdriver + 293: Soap + 294: Recorder + 295: Bear + 296: Eggplant + 297: Board Eraser + 298: Coconut + 299: Tape Measure/Ruler + 300: Pig + 301: Showerhead + 302: Globe + 303: Chips + 304: Steak + 305: Crosswalk Sign + 306: Stapler + 307: Camel + 308: Formula 1 + 309: Pomegranate + 310: Dishwasher + 311: Crab + 312: Hoverboard + 313: Meatball + 314: Rice Cooker + 315: Tuba + 316: Calculator + 317: Papaya + 318: Antelope + 319: Parrot + 320: Seal + 321: Butterfly + 322: Dumbbell + 323: Donkey + 324: Lion + 325: Urinal + 326: Dolphin + 327: Electric Drill + 328: Hair Dryer + 329: Egg tart + 330: Jellyfish + 331: Treadmill + 332: Lighter + 333: Grapefruit + 334: Game board + 335: Mop + 336: Radish + 337: Baozi + 338: Target + 339: French + 340: Spring Rolls + 341: Monkey + 342: Rabbit + 343: Pencil Case + 344: Yak + 345: Red Cabbage + 346: Binoculars + 347: Asparagus + 348: Barbell + 349: Scallop + 350: Noddles + 351: Comb + 352: Dumpling + 353: Oyster + 354: Table Tennis paddle + 355: Cosmetics Brush/Eyeliner Pencil + 356: Chainsaw + 357: Eraser + 358: Lobster + 359: Durian + 360: Okra + 361: Lipstick + 362: Cosmetics Mirror + 363: Curling + 364: Table Tennis + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + from tqdm import tqdm + + from ultralytics.utils.checks import check_requirements + from ultralytics.utils.downloads import download + from ultralytics.utils.ops import xyxy2xywhn + + import numpy as np + from pathlib import Path + + check_requirements(('pycocotools>=2.0',)) + from pycocotools.coco import COCO + + # Make Directories + dir = Path(yaml['path']) # dataset root dir + for p in 'images', 'labels': + (dir / p).mkdir(parents=True, exist_ok=True) + for q in 'train', 'val': + (dir / p / q).mkdir(parents=True, exist_ok=True) + + # Train, Val Splits + for split, patches in [('train', 50 + 1), ('val', 43 + 1)]: + print(f"Processing {split} in {patches} patches ...") + images, labels = dir / 'images' / split, dir / 'labels' / split + + # Download + url = f"https://dorc.ks3-cn-beijing.ksyun.com/data-set/2020Objects365%E6%95%B0%E6%8D%AE%E9%9B%86/{split}/" + if split == 'train': + download([f'{url}zhiyuan_objv2_{split}.tar.gz'], dir=dir) # annotations json + download([f'{url}patch{i}.tar.gz' for i in range(patches)], dir=images, curl=True, threads=8) + elif split == 'val': + download([f'{url}zhiyuan_objv2_{split}.json'], dir=dir) # annotations json + download([f'{url}images/v1/patch{i}.tar.gz' for i in range(15 + 1)], dir=images, curl=True, threads=8) + download([f'{url}images/v2/patch{i}.tar.gz' for i in range(16, patches)], dir=images, curl=True, threads=8) + + # Move + for f in tqdm(images.rglob('*.jpg'), desc=f'Moving {split} images'): + f.rename(images / f.name) # move to /images/{split} + + # Labels + coco = COCO(dir / f'zhiyuan_objv2_{split}.json') + names = [x["name"] for x in coco.loadCats(coco.getCatIds())] + for cid, cat in enumerate(names): + catIds = coco.getCatIds(catNms=[cat]) + imgIds = coco.getImgIds(catIds=catIds) + for im in tqdm(coco.loadImgs(imgIds), desc=f'Class {cid + 1}/{len(names)} {cat}'): + width, height = im["width"], im["height"] + path = Path(im["file_name"]) # image filename + try: + with open(labels / path.with_suffix('.txt').name, 'a') as file: + annIds = coco.getAnnIds(imgIds=im["id"], catIds=catIds, iscrowd=None) + for a in coco.loadAnns(annIds): + x, y, w, h = a['bbox'] # bounding box in xywh (xy top-left corner) + xyxy = np.array([x, y, x + w, y + h])[None] # pixels(1,4) + x, y, w, h = xyxy2xywhn(xyxy, w=width, h=height, clip=True)[0] # normalized and clipped + file.write(f"{cid} {x:.5f} {y:.5f} {w:.5f} {h:.5f}\n") + except Exception as e: + print(e) diff --git a/ultralytics/cfg/datasets/SKU-110K.yaml b/ultralytics/cfg/datasets/SKU-110K.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2c94ced1bc48f6afec204f77e1c28d4fb8884ce --- /dev/null +++ b/ultralytics/cfg/datasets/SKU-110K.yaml @@ -0,0 +1,58 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# SKU-110K retail items dataset https://github.com/eg4000/SKU110K_CVPR19 by Trax Retail +# Documentation: https://docs.ultralytics.com/datasets/detect/sku-110k/ +# Example usage: yolo train data=SKU-110K.yaml +# parent +# ├── ultralytics +# └── datasets +# └── SKU-110K ← downloads here (13.6 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/SKU-110K # dataset root dir +train: train.txt # train images (relative to 'path') 8219 images +val: val.txt # val images (relative to 'path') 588 images +test: test.txt # test images (optional) 2936 images + +# Classes +names: + 0: object + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import shutil + from pathlib import Path + + import numpy as np + import pandas as pd + from tqdm import tqdm + + from ultralytics.utils.downloads import download + from ultralytics.utils.ops import xyxy2xywh + + # Download + dir = Path(yaml['path']) # dataset root dir + parent = Path(dir.parent) # download dir + urls = ['http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz'] + download(urls, dir=parent) + + # Rename directories + if dir.exists(): + shutil.rmtree(dir) + (parent / 'SKU110K_fixed').rename(dir) # rename dir + (dir / 'labels').mkdir(parents=True, exist_ok=True) # create labels dir + + # Convert labels + names = 'image', 'x1', 'y1', 'x2', 'y2', 'class', 'image_width', 'image_height' # column names + for d in 'annotations_train.csv', 'annotations_val.csv', 'annotations_test.csv': + x = pd.read_csv(dir / 'annotations' / d, names=names).values # annotations + images, unique_images = x[:, 0], np.unique(x[:, 0]) + with open((dir / d).with_suffix('.txt').__str__().replace('annotations_', ''), 'w') as f: + f.writelines(f'./images/{s}\n' for s in unique_images) + for im in tqdm(unique_images, desc=f'Converting {dir / d}'): + cls = 0 # single-class dataset + with open((dir / 'labels' / im).with_suffix('.txt'), 'a') as f: + for r in x[images == im]: + w, h = r[6], r[7] # image width, height + xywh = xyxy2xywh(np.array([[r[1] / w, r[2] / h, r[3] / w, r[4] / h]]))[0] # instance + f.write(f"{cls} {xywh[0]:.5f} {xywh[1]:.5f} {xywh[2]:.5f} {xywh[3]:.5f}\n") # write label diff --git a/ultralytics/cfg/datasets/VOC.yaml b/ultralytics/cfg/datasets/VOC.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2eb06ffdeb4e9bc9b8ea9abe698e481abe52d367 --- /dev/null +++ b/ultralytics/cfg/datasets/VOC.yaml @@ -0,0 +1,100 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# PASCAL VOC dataset http://host.robots.ox.ac.uk/pascal/VOC by University of Oxford +# Documentation: # Documentation: https://docs.ultralytics.com/datasets/detect/voc/ +# Example usage: yolo train data=VOC.yaml +# parent +# ├── ultralytics +# └── datasets +# └── VOC ← downloads here (2.8 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/VOC +train: # train images (relative to 'path') 16551 images + - images/train2012 + - images/train2007 + - images/val2012 + - images/val2007 +val: # val images (relative to 'path') 4952 images + - images/test2007 +test: # test images (optional) + - images/test2007 + +# Classes +names: + 0: aeroplane + 1: bicycle + 2: bird + 3: boat + 4: bottle + 5: bus + 6: car + 7: cat + 8: chair + 9: cow + 10: diningtable + 11: dog + 12: horse + 13: motorbike + 14: person + 15: pottedplant + 16: sheep + 17: sofa + 18: train + 19: tvmonitor + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import xml.etree.ElementTree as ET + + from tqdm import tqdm + from ultralytics.utils.downloads import download + from pathlib import Path + + def convert_label(path, lb_path, year, image_id): + def convert_box(size, box): + dw, dh = 1. / size[0], 1. / size[1] + x, y, w, h = (box[0] + box[1]) / 2.0 - 1, (box[2] + box[3]) / 2.0 - 1, box[1] - box[0], box[3] - box[2] + return x * dw, y * dh, w * dw, h * dh + + in_file = open(path / f'VOC{year}/Annotations/{image_id}.xml') + out_file = open(lb_path, 'w') + tree = ET.parse(in_file) + root = tree.getroot() + size = root.find('size') + w = int(size.find('width').text) + h = int(size.find('height').text) + + names = list(yaml['names'].values()) # names list + for obj in root.iter('object'): + cls = obj.find('name').text + if cls in names and int(obj.find('difficult').text) != 1: + xmlbox = obj.find('bndbox') + bb = convert_box((w, h), [float(xmlbox.find(x).text) for x in ('xmin', 'xmax', 'ymin', 'ymax')]) + cls_id = names.index(cls) # class id + out_file.write(" ".join(str(a) for a in (cls_id, *bb)) + '\n') + + + # Download + dir = Path(yaml['path']) # dataset root dir + url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' + urls = [f'{url}VOCtrainval_06-Nov-2007.zip', # 446MB, 5012 images + f'{url}VOCtest_06-Nov-2007.zip', # 438MB, 4953 images + f'{url}VOCtrainval_11-May-2012.zip'] # 1.95GB, 17126 images + download(urls, dir=dir / 'images', curl=True, threads=3, exist_ok=True) # download and unzip over existing paths (required) + + # Convert + path = dir / 'images/VOCdevkit' + for year, image_set in ('2012', 'train'), ('2012', 'val'), ('2007', 'train'), ('2007', 'val'), ('2007', 'test'): + imgs_path = dir / 'images' / f'{image_set}{year}' + lbs_path = dir / 'labels' / f'{image_set}{year}' + imgs_path.mkdir(exist_ok=True, parents=True) + lbs_path.mkdir(exist_ok=True, parents=True) + + with open(path / f'VOC{year}/ImageSets/Main/{image_set}.txt') as f: + image_ids = f.read().strip().split() + for id in tqdm(image_ids, desc=f'{image_set}{year}'): + f = path / f'VOC{year}/JPEGImages/{id}.jpg' # old img path + lb_path = (lbs_path / f.name).with_suffix('.txt') # new label path + f.rename(imgs_path / f.name) # move image + convert_label(path, lb_path, year, id) # convert labels to YOLO format diff --git a/ultralytics/cfg/datasets/VisDrone.yaml b/ultralytics/cfg/datasets/VisDrone.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9fc7b45e435c21ce7d33d59bfe5d69d0446f9201 --- /dev/null +++ b/ultralytics/cfg/datasets/VisDrone.yaml @@ -0,0 +1,73 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# VisDrone2019-DET dataset https://github.com/VisDrone/VisDrone-Dataset by Tianjin University +# Documentation: https://docs.ultralytics.com/datasets/detect/visdrone/ +# Example usage: yolo train data=VisDrone.yaml +# parent +# ├── ultralytics +# └── datasets +# └── VisDrone ← downloads here (2.3 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/VisDrone # dataset root dir +train: VisDrone2019-DET-train/images # train images (relative to 'path') 6471 images +val: VisDrone2019-DET-val/images # val images (relative to 'path') 548 images +test: VisDrone2019-DET-test-dev/images # test images (optional) 1610 images + +# Classes +names: + 0: pedestrian + 1: people + 2: bicycle + 3: car + 4: van + 5: truck + 6: tricycle + 7: awning-tricycle + 8: bus + 9: motor + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import os + from pathlib import Path + + from ultralytics.utils.downloads import download + + def visdrone2yolo(dir): + from PIL import Image + from tqdm import tqdm + + def convert_box(size, box): + # Convert VisDrone box to YOLO xywh box + dw = 1. / size[0] + dh = 1. / size[1] + return (box[0] + box[2] / 2) * dw, (box[1] + box[3] / 2) * dh, box[2] * dw, box[3] * dh + + (dir / 'labels').mkdir(parents=True, exist_ok=True) # make labels directory + pbar = tqdm((dir / 'annotations').glob('*.txt'), desc=f'Converting {dir}') + for f in pbar: + img_size = Image.open((dir / 'images' / f.name).with_suffix('.jpg')).size + lines = [] + with open(f, 'r') as file: # read annotation.txt + for row in [x.split(',') for x in file.read().strip().splitlines()]: + if row[4] == '0': # VisDrone 'ignored regions' class 0 + continue + cls = int(row[5]) - 1 + box = convert_box(img_size, tuple(map(int, row[:4]))) + lines.append(f"{cls} {' '.join(f'{x:.6f}' for x in box)}\n") + with open(str(f).replace(f'{os.sep}annotations{os.sep}', f'{os.sep}labels{os.sep}'), 'w') as fl: + fl.writelines(lines) # write label.txt + + + # Download + dir = Path(yaml['path']) # dataset root dir + urls = ['https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-train.zip', + 'https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-val.zip', + 'https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-test-dev.zip', + 'https://github.com/ultralytics/assets/releases/download/v0.0.0/VisDrone2019-DET-test-challenge.zip'] + download(urls, dir=dir, curl=True, threads=4) + + # Convert + for d in 'VisDrone2019-DET-train', 'VisDrone2019-DET-val', 'VisDrone2019-DET-test-dev': + visdrone2yolo(dir / d) # convert VisDrone annotations to YOLO labels diff --git a/ultralytics/cfg/datasets/african-wildlife.yaml b/ultralytics/cfg/datasets/african-wildlife.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b825f8f068b54a47e4f32cd105ca225f3f9f1f8a --- /dev/null +++ b/ultralytics/cfg/datasets/african-wildlife.yaml @@ -0,0 +1,25 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# African-wildlife dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/african-wildlife/ +# Example usage: yolo train data=african-wildlife.yaml +# parent +# ├── ultralytics +# └── datasets +# └── african-wildlife ← downloads here (100 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/african-wildlife # dataset root dir +train: train/images # train images (relative to 'path') 1052 images +val: valid/images # val images (relative to 'path') 225 images +test: test/images # test images (relative to 'path') 227 images + +# Classes +names: + 0: buffalo + 1: elephant + 2: rhino + 3: zebra + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/african-wildlife.zip diff --git a/ultralytics/cfg/datasets/brain-tumor.yaml b/ultralytics/cfg/datasets/brain-tumor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a448e84afc76832914cde3a2f0d6208b8c78e29 --- /dev/null +++ b/ultralytics/cfg/datasets/brain-tumor.yaml @@ -0,0 +1,23 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Brain-tumor dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/brain-tumor/ +# Example usage: yolo train data=brain-tumor.yaml +# parent +# ├── ultralytics +# └── datasets +# └── brain-tumor ← downloads here (4.05 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/brain-tumor # dataset root dir +train: train/images # train images (relative to 'path') 893 images +val: valid/images # val images (relative to 'path') 223 images +test: # test images (relative to 'path') + +# Classes +names: + 0: negative + 1: positive + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/brain-tumor.zip diff --git a/ultralytics/cfg/datasets/carparts-seg.yaml b/ultralytics/cfg/datasets/carparts-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9f15f9b06625e2a2575d24da262d317a9394e71a --- /dev/null +++ b/ultralytics/cfg/datasets/carparts-seg.yaml @@ -0,0 +1,44 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Carparts-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/carparts-seg/ +# Example usage: yolo train data=carparts-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── carparts-seg ← downloads here (132 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/carparts-seg # dataset root dir +train: train/images # train images (relative to 'path') 3516 images +val: valid/images # val images (relative to 'path') 276 images +test: test/images # test images (relative to 'path') 401 images + +# Classes +names: + 0: back_bumper + 1: back_door + 2: back_glass + 3: back_left_door + 4: back_left_light + 5: back_light + 6: back_right_door + 7: back_right_light + 8: front_bumper + 9: front_door + 10: front_glass + 11: front_left_door + 12: front_left_light + 13: front_light + 14: front_right_door + 15: front_right_light + 16: hood + 17: left_mirror + 18: object + 19: right_mirror + 20: tailgate + 21: trunk + 22: wheel + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/carparts-seg.zip diff --git a/ultralytics/cfg/datasets/coco-pose.yaml b/ultralytics/cfg/datasets/coco-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..353dcd721b0ed5d9a437775a121ee63cf6d916e7 --- /dev/null +++ b/ultralytics/cfg/datasets/coco-pose.yaml @@ -0,0 +1,39 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO 2017 Keypoints dataset https://cocodataset.org by Microsoft +# Documentation: https://docs.ultralytics.com/datasets/pose/coco/ +# Example usage: yolo train data=coco-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco-pose ← downloads here (20.1 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco-pose # dataset root dir +train: train2017.txt # train images (relative to 'path') 56599 images +val: val2017.txt # val images (relative to 'path') 2346 images +test: test-dev2017.txt # 20288 of 40670 images, submit to https://codalab.lisn.upsaclay.fr/competitions/7403 + +# Keypoints +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + +# Classes +names: + 0: person + +# Download script/URL (optional) +download: | + from ultralytics.utils.downloads import download + from pathlib import Path + + # Download labels + dir = Path(yaml['path']) # dataset root dir + url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' + urls = [url + 'coco2017labels-pose.zip'] # labels + download(urls, dir=dir.parent) + # Download data + urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images + 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images + 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) + download(urls, dir=dir / 'images', threads=3) diff --git a/ultralytics/cfg/datasets/coco.yaml b/ultralytics/cfg/datasets/coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..255b2af1a10999a2271bc43bff650077e6ab020a --- /dev/null +++ b/ultralytics/cfg/datasets/coco.yaml @@ -0,0 +1,115 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO 2017 dataset https://cocodataset.org by Microsoft +# Documentation: https://docs.ultralytics.com/datasets/detect/coco/ +# Example usage: yolo train data=coco.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco ← downloads here (20.1 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ..datasets/coco # dataset root dir +train: train2017.txt # train images (relative to 'path') 118287 images +val: val2017.txt # val images (relative to 'path') 5000 images +test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794 + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: | + from ultralytics.utils.downloads import download + from pathlib import Path + + # Download labels + segments = True # segment or box labels + dir = Path(yaml['path']) # dataset root dir + url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' + urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels + download(urls, dir=dir.parent) + # Download data + urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images + 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images + 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) + download(urls, dir=dir / 'images', threads=3) diff --git a/ultralytics/cfg/datasets/coco128-seg.yaml b/ultralytics/cfg/datasets/coco128-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b023c676300db3ca9909839b5cdf7ac709ba5949 --- /dev/null +++ b/ultralytics/cfg/datasets/coco128-seg.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO128-seg dataset https://www.kaggle.com/datasets/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/coco/ +# Example usage: yolo train data=coco128.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco128-seg ← downloads here (7 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco128-seg # dataset root dir +train: images/train2017 # train images (relative to 'path') 128 images +val: images/train2017 # val images (relative to 'path') 128 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco128-seg.zip diff --git a/ultralytics/cfg/datasets/coco128.yaml b/ultralytics/cfg/datasets/coco128.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12ff0511bcd0df62be6f05762743543e1eb21524 --- /dev/null +++ b/ultralytics/cfg/datasets/coco128.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO128 dataset https://www.kaggle.com/datasets/ultralytics/coco128 (first 128 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/coco/ +# Example usage: yolo train data=coco128.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco128 ← downloads here (7 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco128 # dataset root dir +train: images/train2017 # train images (relative to 'path') 128 images +val: images/train2017 # val images (relative to 'path') 128 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco128.zip diff --git a/ultralytics/cfg/datasets/coco8-pose.yaml b/ultralytics/cfg/datasets/coco8-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e8af1e344804571664f5be7de288ba5ff7c3822 --- /dev/null +++ b/ultralytics/cfg/datasets/coco8-pose.yaml @@ -0,0 +1,26 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO8-pose dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/coco8-pose/ +# Example usage: yolo train data=coco8-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco8-pose ← downloads here (1 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco8-pose # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) + +# Keypoints +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + +# Classes +names: + 0: person + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-pose.zip diff --git a/ultralytics/cfg/datasets/coco8-seg.yaml b/ultralytics/cfg/datasets/coco8-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ea6b31004cbf5fe33914387532d7e46b1967aa8 --- /dev/null +++ b/ultralytics/cfg/datasets/coco8-seg.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO8-seg dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/coco8-seg/ +# Example usage: yolo train data=coco8-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco8-seg ← downloads here (1 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco8-seg # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-seg.zip diff --git a/ultralytics/cfg/datasets/coco8.yaml b/ultralytics/cfg/datasets/coco8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8200738b46d0c5414afc6b2c68027bcfe40b739c --- /dev/null +++ b/ultralytics/cfg/datasets/coco8.yaml @@ -0,0 +1,101 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# COCO8 dataset (first 8 images from COCO train2017) by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/coco8/ +# Example usage: yolo train data=coco8.yaml +# parent +# ├── ultralytics +# └── datasets +# └── coco8 ← downloads here (1 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/coco8 # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images +test: # test images (optional) + +# Classes +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8.zip diff --git a/ultralytics/cfg/datasets/crack-seg.yaml b/ultralytics/cfg/datasets/crack-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..11bdd5f575fe7d8d9046636cb595ce707e8601a7 --- /dev/null +++ b/ultralytics/cfg/datasets/crack-seg.yaml @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Crack-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/crack-seg/ +# Example usage: yolo train data=crack-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── crack-seg ← downloads here (91.2 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/crack-seg # dataset root dir +train: train/images # train images (relative to 'path') 3717 images +val: valid/images # val images (relative to 'path') 112 images +test: test/images # test images (relative to 'path') 200 images + +# Classes +names: + 0: crack + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/crack-seg.zip diff --git a/ultralytics/cfg/datasets/dog-pose.yaml b/ultralytics/cfg/datasets/dog-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..447e542ce6c124533e25c7fa2a5caed88570d4ec --- /dev/null +++ b/ultralytics/cfg/datasets/dog-pose.yaml @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Dogs dataset http://vision.stanford.edu/aditya86/ImageNetDogs/ by Stanford +# Documentation: https://docs.ultralytics.com/datasets/pose/dog-pose/ +# Example usage: yolo train data=dog-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dog-pose ← downloads here (337 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/dog-pose # dataset root dir +train: train # train images (relative to 'path') 6773 images +val: val # val images (relative to 'path') 1703 images + +# Keypoints +kpt_shape: [24, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) + +# Classes +names: + 0: dog + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/dog-pose.zip diff --git a/ultralytics/cfg/datasets/dota8.yaml b/ultralytics/cfg/datasets/dota8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..486d9e2effbd0e2161ff982d2df2aea16b6d9fa0 --- /dev/null +++ b/ultralytics/cfg/datasets/dota8.yaml @@ -0,0 +1,35 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DOTA8 dataset 8 images from split DOTAv1 dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/obb/dota8/ +# Example usage: yolo train model=yolov8n-obb.pt data=dota8.yaml +# parent +# ├── ultralytics +# └── datasets +# └── dota8 ← downloads here (1MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/dota8 # dataset root dir +train: images/train # train images (relative to 'path') 4 images +val: images/val # val images (relative to 'path') 4 images + +# Classes for DOTA 1.0 +names: + 0: plane + 1: ship + 2: storage tank + 3: baseball diamond + 4: tennis court + 5: basketball court + 6: ground track field + 7: harbor + 8: bridge + 9: large vehicle + 10: small vehicle + 11: helicopter + 12: roundabout + 13: soccer ball field + 14: swimming pool + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/dota8.zip diff --git a/ultralytics/cfg/datasets/hand-keypoints.yaml b/ultralytics/cfg/datasets/hand-keypoints.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6d2f765c789c828efa2b993249c238e466985eb3 --- /dev/null +++ b/ultralytics/cfg/datasets/hand-keypoints.yaml @@ -0,0 +1,26 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Hand Keypoints dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/hand-keypoints/ +# Example usage: yolo train data=hand-keypoints.yaml +# parent +# ├── ultralytics +# └── datasets +# └── hand-keypoints ← downloads here (369 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/hand-keypoints # dataset root dir +train: train # train images (relative to 'path') 18776 images +val: val # val images (relative to 'path') 7992 images + +# Keypoints +kpt_shape: [21, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: + [0, 1, 2, 4, 3, 10, 11, 12, 13, 14, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19, 20] + +# Classes +names: + 0: hand + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/hand-keypoints.zip diff --git a/ultralytics/cfg/datasets/lvis.yaml b/ultralytics/cfg/datasets/lvis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22030ac90797d95686282fcaba2b2c12b4a0fc00 --- /dev/null +++ b/ultralytics/cfg/datasets/lvis.yaml @@ -0,0 +1,1236 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# LVIS dataset http://www.lvisdataset.org by Facebook AI Research. +# Documentation: https://docs.ultralytics.com/datasets/detect/lvis/ +# Example usage: yolo train data=lvis.yaml +# parent +# ├── ultralytics +# └── datasets +# └── lvis ← downloads here (20.1 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/lvis # dataset root dir +train: train.txt # train images (relative to 'path') 100170 images +val: val.txt # val images (relative to 'path') 19809 images +minival: minival.txt # minival images (relative to 'path') 5000 images + +names: + 0: aerosol can/spray can + 1: air conditioner + 2: airplane/aeroplane + 3: alarm clock + 4: alcohol/alcoholic beverage + 5: alligator/gator + 6: almond + 7: ambulance + 8: amplifier + 9: anklet/ankle bracelet + 10: antenna/aerial/transmitting aerial + 11: apple + 12: applesauce + 13: apricot + 14: apron + 15: aquarium/fish tank + 16: arctic/arctic type of shoe/galosh/golosh/rubber/rubber type of shoe/gumshoe + 17: armband + 18: armchair + 19: armoire + 20: armor/armour + 21: artichoke + 22: trash can/garbage can/wastebin/dustbin/trash barrel/trash bin + 23: ashtray + 24: asparagus + 25: atomizer/atomiser/spray/sprayer/nebulizer/nebuliser + 26: avocado + 27: award/accolade + 28: awning + 29: ax/axe + 30: baboon + 31: baby buggy/baby carriage/perambulator/pram/stroller + 32: basketball backboard + 33: backpack/knapsack/packsack/rucksack/haversack + 34: handbag/purse/pocketbook + 35: suitcase/baggage/luggage + 36: bagel/beigel + 37: bagpipe + 38: baguet/baguette + 39: bait/lure + 40: ball + 41: ballet skirt/tutu + 42: balloon + 43: bamboo + 44: banana + 45: Band Aid + 46: bandage + 47: bandanna/bandana + 48: banjo + 49: banner/streamer + 50: barbell + 51: barge + 52: barrel/cask + 53: barrette + 54: barrow/garden cart/lawn cart/wheelbarrow + 55: baseball base + 56: baseball + 57: baseball bat + 58: baseball cap/jockey cap/golf cap + 59: baseball glove/baseball mitt + 60: basket/handbasket + 61: basketball + 62: bass horn/sousaphone/tuba + 63: bat/bat animal + 64: bath mat + 65: bath towel + 66: bathrobe + 67: bathtub/bathing tub + 68: batter/batter food + 69: battery + 70: beachball + 71: bead + 72: bean curd/tofu + 73: beanbag + 74: beanie/beany + 75: bear + 76: bed + 77: bedpan + 78: bedspread/bedcover/bed covering/counterpane/spread + 79: cow + 80: beef/beef food/boeuf/boeuf food + 81: beeper/pager + 82: beer bottle + 83: beer can + 84: beetle + 85: bell + 86: bell pepper/capsicum + 87: belt + 88: belt buckle + 89: bench + 90: beret + 91: bib + 92: Bible + 93: bicycle/bike/bike bicycle + 94: visor/vizor + 95: billboard + 96: binder/ring-binder + 97: binoculars/field glasses/opera glasses + 98: bird + 99: birdfeeder + 100: birdbath + 101: birdcage + 102: birdhouse + 103: birthday cake + 104: birthday card + 105: pirate flag + 106: black sheep + 107: blackberry + 108: blackboard/chalkboard + 109: blanket + 110: blazer/sport jacket/sport coat/sports jacket/sports coat + 111: blender/liquidizer/liquidiser + 112: blimp + 113: blinker/flasher + 114: blouse + 115: blueberry + 116: gameboard + 117: boat/ship/ship boat + 118: bob/bobber/bobfloat + 119: bobbin/spool/reel + 120: bobby pin/hairgrip + 121: boiled egg/coddled egg + 122: bolo tie/bolo/bola tie/bola + 123: deadbolt + 124: bolt + 125: bonnet + 126: book + 127: bookcase + 128: booklet/brochure/leaflet/pamphlet + 129: bookmark/bookmarker + 130: boom microphone/microphone boom + 131: boot + 132: bottle + 133: bottle opener + 134: bouquet + 135: bow/bow weapon + 136: bow/bow decorative ribbons + 137: bow-tie/bowtie + 138: bowl + 139: pipe bowl + 140: bowler hat/bowler/derby hat/derby/plug hat + 141: bowling ball + 142: box + 143: boxing glove + 144: suspenders + 145: bracelet/bangle + 146: brass plaque + 147: brassiere/bra/bandeau + 148: bread-bin/breadbox + 149: bread + 150: breechcloth/breechclout/loincloth + 151: bridal gown/wedding gown/wedding dress + 152: briefcase + 153: broccoli + 154: broach + 155: broom + 156: brownie + 157: brussels sprouts + 158: bubble gum + 159: bucket/pail + 160: horse buggy + 161: horned cow + 162: bulldog + 163: bulldozer/dozer + 164: bullet train + 165: bulletin board/notice board + 166: bulletproof vest + 167: bullhorn/megaphone + 168: bun/roll + 169: bunk bed + 170: buoy + 171: burrito + 172: bus/bus vehicle/autobus/charabanc/double-decker/motorbus/motorcoach + 173: business card + 174: butter + 175: butterfly + 176: button + 177: cab/cab taxi/taxi/taxicab + 178: cabana + 179: cabin car/caboose + 180: cabinet + 181: locker/storage locker + 182: cake + 183: calculator + 184: calendar + 185: calf + 186: camcorder + 187: camel + 188: camera + 189: camera lens + 190: camper/camper vehicle/camping bus/motor home + 191: can/tin can + 192: can opener/tin opener + 193: candle/candlestick + 194: candle holder + 195: candy bar + 196: candy cane + 197: walking cane + 198: canister/canister + 199: canoe + 200: cantaloup/cantaloupe + 201: canteen + 202: cap/cap headwear + 203: bottle cap/cap/cap container lid + 204: cape + 205: cappuccino/coffee cappuccino + 206: car/car automobile/auto/auto automobile/automobile + 207: railcar/railcar part of a train/railway car/railway car part of a train/railroad car/railroad car part of a train + 208: elevator car + 209: car battery/automobile battery + 210: identity card + 211: card + 212: cardigan + 213: cargo ship/cargo vessel + 214: carnation + 215: horse carriage + 216: carrot + 217: tote bag + 218: cart + 219: carton + 220: cash register/register/register for cash transactions + 221: casserole + 222: cassette + 223: cast/plaster cast/plaster bandage + 224: cat + 225: cauliflower + 226: cayenne/cayenne spice/cayenne pepper/cayenne pepper spice/red pepper/red pepper spice + 227: CD player + 228: celery + 229: cellular telephone/cellular phone/cellphone/mobile phone/smart phone + 230: chain mail/ring mail/chain armor/chain armour/ring armor/ring armour + 231: chair + 232: chaise longue/chaise/daybed + 233: chalice + 234: chandelier + 235: chap + 236: checkbook/chequebook + 237: checkerboard + 238: cherry + 239: chessboard + 240: chicken/chicken animal + 241: chickpea/garbanzo + 242: chili/chili vegetable/chili pepper/chili pepper vegetable/chilli/chilli vegetable/chilly/chilly vegetable/chile/chile vegetable + 243: chime/gong + 244: chinaware + 245: crisp/crisp potato chip/potato chip + 246: poker chip + 247: chocolate bar + 248: chocolate cake + 249: chocolate milk + 250: chocolate mousse + 251: choker/collar/neckband + 252: chopping board/cutting board/chopping block + 253: chopstick + 254: Christmas tree + 255: slide + 256: cider/cyder + 257: cigar box + 258: cigarette + 259: cigarette case/cigarette pack + 260: cistern/water tank + 261: clarinet + 262: clasp + 263: cleansing agent/cleanser/cleaner + 264: cleat/cleat for securing rope + 265: clementine + 266: clip + 267: clipboard + 268: clippers/clippers for plants + 269: cloak + 270: clock/timepiece/timekeeper + 271: clock tower + 272: clothes hamper/laundry basket/clothes basket + 273: clothespin/clothes peg + 274: clutch bag + 275: coaster + 276: coat + 277: coat hanger/clothes hanger/dress hanger + 278: coatrack/hatrack + 279: cock/rooster + 280: cockroach + 281: cocoa/cocoa beverage/hot chocolate/hot chocolate beverage/drinking chocolate + 282: coconut/cocoanut + 283: coffee maker/coffee machine + 284: coffee table/cocktail table + 285: coffeepot + 286: coil + 287: coin + 288: colander/cullender + 289: coleslaw/slaw + 290: coloring material/colouring material + 291: combination lock + 292: pacifier/teething ring + 293: comic book + 294: compass + 295: computer keyboard/keyboard/keyboard computer + 296: condiment + 297: cone/traffic cone + 298: control/controller + 299: convertible/convertible automobile + 300: sofa bed + 301: cooker + 302: cookie/cooky/biscuit/biscuit cookie + 303: cooking utensil + 304: cooler/cooler for food/ice chest + 305: cork/cork bottle plug/bottle cork + 306: corkboard + 307: corkscrew/bottle screw + 308: edible corn/corn/maize + 309: cornbread + 310: cornet/horn/trumpet + 311: cornice/valance/valance board/pelmet + 312: cornmeal + 313: corset/girdle + 314: costume + 315: cougar/puma/catamount/mountain lion/panther + 316: coverall + 317: cowbell + 318: cowboy hat/ten-gallon hat + 319: crab/crab animal + 320: crabmeat + 321: cracker + 322: crape/crepe/French pancake + 323: crate + 324: crayon/wax crayon + 325: cream pitcher + 326: crescent roll/croissant + 327: crib/cot + 328: crock pot/earthenware jar + 329: crossbar + 330: crouton + 331: crow + 332: crowbar/wrecking bar/pry bar + 333: crown + 334: crucifix + 335: cruise ship/cruise liner + 336: police cruiser/patrol car/police car/squad car + 337: crumb + 338: crutch + 339: cub/cub animal + 340: cube/square block + 341: cucumber/cuke + 342: cufflink + 343: cup + 344: trophy cup + 345: cupboard/closet + 346: cupcake + 347: hair curler/hair roller/hair crimper + 348: curling iron + 349: curtain/drapery + 350: cushion + 351: cylinder + 352: cymbal + 353: dagger + 354: dalmatian + 355: dartboard + 356: date/date fruit + 357: deck chair/beach chair + 358: deer/cervid + 359: dental floss/floss + 360: desk + 361: detergent + 362: diaper + 363: diary/journal + 364: die/dice + 365: dinghy/dory/rowboat + 366: dining table + 367: tux/tuxedo + 368: dish + 369: dish antenna + 370: dishrag/dishcloth + 371: dishtowel/tea towel + 372: dishwasher/dishwashing machine + 373: dishwasher detergent/dishwashing detergent/dishwashing liquid/dishsoap + 374: dispenser + 375: diving board + 376: Dixie cup/paper cup + 377: dog + 378: dog collar + 379: doll + 380: dollar/dollar bill/one dollar bill + 381: dollhouse/doll's house + 382: dolphin + 383: domestic ass/donkey + 384: doorknob/doorhandle + 385: doormat/welcome mat + 386: doughnut/donut + 387: dove + 388: dragonfly + 389: drawer + 390: underdrawers/boxers/boxershorts + 391: dress/frock + 392: dress hat/high hat/opera hat/silk hat/top hat + 393: dress suit + 394: dresser + 395: drill + 396: drone + 397: dropper/eye dropper + 398: drum/drum musical instrument + 399: drumstick + 400: duck + 401: duckling + 402: duct tape + 403: duffel bag/duffle bag/duffel/duffle + 404: dumbbell + 405: dumpster + 406: dustpan + 407: eagle + 408: earphone/earpiece/headphone + 409: earplug + 410: earring + 411: easel + 412: eclair + 413: eel + 414: egg/eggs + 415: egg roll/spring roll + 416: egg yolk/yolk/yolk egg + 417: eggbeater/eggwhisk + 418: eggplant/aubergine + 419: electric chair + 420: refrigerator + 421: elephant + 422: elk/moose + 423: envelope + 424: eraser + 425: escargot + 426: eyepatch + 427: falcon + 428: fan + 429: faucet/spigot/tap + 430: fedora + 431: ferret + 432: Ferris wheel + 433: ferry/ferryboat + 434: fig/fig fruit + 435: fighter jet/fighter aircraft/attack aircraft + 436: figurine + 437: file cabinet/filing cabinet + 438: file/file tool + 439: fire alarm/smoke alarm + 440: fire engine/fire truck + 441: fire extinguisher/extinguisher + 442: fire hose + 443: fireplace + 444: fireplug/fire hydrant/hydrant + 445: first-aid kit + 446: fish + 447: fish/fish food + 448: fishbowl/goldfish bowl + 449: fishing rod/fishing pole + 450: flag + 451: flagpole/flagstaff + 452: flamingo + 453: flannel + 454: flap + 455: flash/flashbulb + 456: flashlight/torch + 457: fleece + 458: flip-flop/flip-flop sandal + 459: flipper/flipper footwear/fin/fin footwear + 460: flower arrangement/floral arrangement + 461: flute glass/champagne flute + 462: foal + 463: folding chair + 464: food processor + 465: football/football American + 466: football helmet + 467: footstool/footrest + 468: fork + 469: forklift + 470: freight car + 471: French toast + 472: freshener/air freshener + 473: frisbee + 474: frog/toad/toad frog + 475: fruit juice + 476: frying pan/frypan/skillet + 477: fudge + 478: funnel + 479: futon + 480: gag/muzzle + 481: garbage + 482: garbage truck + 483: garden hose + 484: gargle/mouthwash + 485: gargoyle + 486: garlic/ail + 487: gasmask/respirator/gas helmet + 488: gazelle + 489: gelatin/jelly + 490: gemstone + 491: generator + 492: giant panda/panda/panda bear + 493: gift wrap + 494: ginger/gingerroot + 495: giraffe + 496: cincture/sash/waistband/waistcloth + 497: glass/glass drink container/drinking glass + 498: globe + 499: glove + 500: goat + 501: goggles + 502: goldfish + 503: golf club/golf-club + 504: golfcart + 505: gondola/gondola boat + 506: goose + 507: gorilla + 508: gourd + 509: grape + 510: grater + 511: gravestone/headstone/tombstone + 512: gravy boat/gravy holder + 513: green bean + 514: green onion/spring onion/scallion + 515: griddle + 516: grill/grille/grillwork/radiator grille + 517: grits/hominy grits + 518: grizzly/grizzly bear + 519: grocery bag + 520: guitar + 521: gull/seagull + 522: gun + 523: hairbrush + 524: hairnet + 525: hairpin + 526: halter top + 527: ham/jambon/gammon + 528: hamburger/beefburger/burger + 529: hammer + 530: hammock + 531: hamper + 532: hamster + 533: hair dryer + 534: hand glass/hand mirror + 535: hand towel/face towel + 536: handcart/pushcart/hand truck + 537: handcuff + 538: handkerchief + 539: handle/grip/handgrip + 540: handsaw/carpenter's saw + 541: hardback book/hardcover book + 542: harmonium/organ/organ musical instrument/reed organ/reed organ musical instrument + 543: hat + 544: hatbox + 545: veil + 546: headband + 547: headboard + 548: headlight/headlamp + 549: headscarf + 550: headset + 551: headstall/headstall for horses/headpiece/headpiece for horses + 552: heart + 553: heater/warmer + 554: helicopter + 555: helmet + 556: heron + 557: highchair/feeding chair + 558: hinge + 559: hippopotamus + 560: hockey stick + 561: hog/pig + 562: home plate/home plate baseball/home base/home base baseball + 563: honey + 564: fume hood/exhaust hood + 565: hook + 566: hookah/narghile/nargileh/sheesha/shisha/water pipe + 567: hornet + 568: horse + 569: hose/hosepipe + 570: hot-air balloon + 571: hotplate + 572: hot sauce + 573: hourglass + 574: houseboat + 575: hummingbird + 576: hummus/humus/hommos/hoummos/humous + 577: polar bear + 578: icecream + 579: popsicle + 580: ice maker + 581: ice pack/ice bag + 582: ice skate + 583: igniter/ignitor/lighter + 584: inhaler/inhalator + 585: iPod + 586: iron/iron for clothing/smoothing iron/smoothing iron for clothing + 587: ironing board + 588: jacket + 589: jam + 590: jar + 591: jean/blue jean/denim + 592: jeep/landrover + 593: jelly bean/jelly egg + 594: jersey/T-shirt/tee shirt + 595: jet plane/jet-propelled plane + 596: jewel/gem/precious stone + 597: jewelry/jewellery + 598: joystick + 599: jumpsuit + 600: kayak + 601: keg + 602: kennel/doghouse + 603: kettle/boiler + 604: key + 605: keycard + 606: kilt + 607: kimono + 608: kitchen sink + 609: kitchen table + 610: kite + 611: kitten/kitty + 612: kiwi fruit + 613: knee pad + 614: knife + 615: knitting needle + 616: knob + 617: knocker/knocker on a door/doorknocker + 618: koala/koala bear + 619: lab coat/laboratory coat + 620: ladder + 621: ladle + 622: ladybug/ladybeetle/ladybird beetle + 623: lamb/lamb animal + 624: lamb-chop/lambchop + 625: lamp + 626: lamppost + 627: lampshade + 628: lantern + 629: lanyard/laniard + 630: laptop computer/notebook computer + 631: lasagna/lasagne + 632: latch + 633: lawn mower + 634: leather + 635: legging/legging clothing/leging/leging clothing/leg covering + 636: Lego/Lego set + 637: legume + 638: lemon + 639: lemonade + 640: lettuce + 641: license plate/numberplate + 642: life buoy/lifesaver/life belt/life ring + 643: life jacket/life vest + 644: lightbulb + 645: lightning rod/lightning conductor + 646: lime + 647: limousine + 648: lion + 649: lip balm + 650: liquor/spirits/hard liquor/liqueur/cordial + 651: lizard + 652: log + 653: lollipop + 654: speaker/speaker stereo equipment + 655: loveseat + 656: machine gun + 657: magazine + 658: magnet + 659: mail slot + 660: mailbox/mailbox at home/letter box/letter box at home + 661: mallard + 662: mallet + 663: mammoth + 664: manatee + 665: mandarin orange + 666: manager/through + 667: manhole + 668: map + 669: marker + 670: martini + 671: mascot + 672: mashed potato + 673: masher + 674: mask/facemask + 675: mast + 676: mat/mat gym equipment/gym mat + 677: matchbox + 678: mattress + 679: measuring cup + 680: measuring stick/ruler/ruler measuring stick/measuring rod + 681: meatball + 682: medicine + 683: melon + 684: microphone + 685: microscope + 686: microwave oven + 687: milestone/milepost + 688: milk + 689: milk can + 690: milkshake + 691: minivan + 692: mint candy + 693: mirror + 694: mitten + 695: mixer/mixer kitchen tool/stand mixer + 696: money + 697: monitor/monitor computer equipment + 698: monkey + 699: motor + 700: motor scooter/scooter + 701: motor vehicle/automotive vehicle + 702: motorcycle + 703: mound/mound baseball/pitcher's mound + 704: mouse/mouse computer equipment/computer mouse + 705: mousepad + 706: muffin + 707: mug + 708: mushroom + 709: music stool/piano stool + 710: musical instrument/instrument/instrument musical + 711: nailfile + 712: napkin/table napkin/serviette + 713: neckerchief + 714: necklace + 715: necktie/tie/tie necktie + 716: needle + 717: nest + 718: newspaper/paper/paper newspaper + 719: newsstand + 720: nightshirt/nightwear/sleepwear/nightclothes + 721: nosebag/nosebag for animals/feedbag + 722: noseband/noseband for animals/nosepiece/nosepiece for animals + 723: notebook + 724: notepad + 725: nut + 726: nutcracker + 727: oar + 728: octopus/octopus food + 729: octopus/octopus animal + 730: oil lamp/kerosene lamp/kerosine lamp + 731: olive oil + 732: omelet/omelette + 733: onion + 734: orange/orange fruit + 735: orange juice + 736: ostrich + 737: ottoman/pouf/pouffe/hassock + 738: oven + 739: overalls/overalls clothing + 740: owl + 741: packet + 742: inkpad/inking pad/stamp pad + 743: pad + 744: paddle/boat paddle + 745: padlock + 746: paintbrush + 747: painting + 748: pajamas/pyjamas + 749: palette/pallet + 750: pan/pan for cooking/cooking pan + 751: pan/pan metal container + 752: pancake + 753: pantyhose + 754: papaya + 755: paper plate + 756: paper towel + 757: paperback book/paper-back book/softback book/soft-cover book + 758: paperweight + 759: parachute + 760: parakeet/parrakeet/parroket/paraquet/paroquet/parroquet + 761: parasail/parasail sports + 762: parasol/sunshade + 763: parchment + 764: parka/anorak + 765: parking meter + 766: parrot + 767: passenger car/passenger car part of a train/coach/coach part of a train + 768: passenger ship + 769: passport + 770: pastry + 771: patty/patty food + 772: pea/pea food + 773: peach + 774: peanut butter + 775: pear + 776: peeler/peeler tool for fruit and vegetables + 777: wooden leg/pegleg + 778: pegboard + 779: pelican + 780: pen + 781: pencil + 782: pencil box/pencil case + 783: pencil sharpener + 784: pendulum + 785: penguin + 786: pennant + 787: penny/penny coin + 788: pepper/peppercorn + 789: pepper mill/pepper grinder + 790: perfume + 791: persimmon + 792: person/baby/child/boy/girl/man/woman/human + 793: pet + 794: pew/pew church bench/church bench + 795: phonebook/telephone book/telephone directory + 796: phonograph record/phonograph recording/record/record phonograph recording + 797: piano + 798: pickle + 799: pickup truck + 800: pie + 801: pigeon + 802: piggy bank/penny bank + 803: pillow + 804: pin/pin non jewelry + 805: pineapple + 806: pinecone + 807: ping-pong ball + 808: pinwheel + 809: tobacco pipe + 810: pipe/piping + 811: pistol/handgun + 812: pita/pita bread/pocket bread + 813: pitcher/pitcher vessel for liquid/ewer + 814: pitchfork + 815: pizza + 816: place mat + 817: plate + 818: platter + 819: playpen + 820: pliers/plyers + 821: plow/plow farm equipment/plough/plough farm equipment + 822: plume + 823: pocket watch + 824: pocketknife + 825: poker/poker fire stirring tool/stove poker/fire hook + 826: pole/post + 827: polo shirt/sport shirt + 828: poncho + 829: pony + 830: pool table/billiard table/snooker table + 831: pop/pop soda/soda/soda pop/tonic/soft drink + 832: postbox/postbox public/mailbox/mailbox public + 833: postcard/postal card/mailing-card + 834: poster/placard + 835: pot + 836: flowerpot + 837: potato + 838: potholder + 839: pottery/clayware + 840: pouch + 841: power shovel/excavator/digger + 842: prawn/shrimp + 843: pretzel + 844: printer/printing machine + 845: projectile/projectile weapon/missile + 846: projector + 847: propeller/propellor + 848: prune + 849: pudding + 850: puffer/puffer fish/pufferfish/blowfish/globefish + 851: puffin + 852: pug-dog + 853: pumpkin + 854: puncher + 855: puppet/marionette + 856: puppy + 857: quesadilla + 858: quiche + 859: quilt/comforter + 860: rabbit + 861: race car/racing car + 862: racket/racquet + 863: radar + 864: radiator + 865: radio receiver/radio set/radio/tuner/tuner radio + 866: radish/daikon + 867: raft + 868: rag doll + 869: raincoat/waterproof jacket + 870: ram/ram animal + 871: raspberry + 872: rat + 873: razorblade + 874: reamer/reamer juicer/juicer/juice reamer + 875: rearview mirror + 876: receipt + 877: recliner/reclining chair/lounger/lounger chair + 878: record player/phonograph/phonograph record player/turntable + 879: reflector + 880: remote control + 881: rhinoceros + 882: rib/rib food + 883: rifle + 884: ring + 885: river boat + 886: road map + 887: robe + 888: rocking chair + 889: rodent + 890: roller skate + 891: Rollerblade + 892: rolling pin + 893: root beer + 894: router/router computer equipment + 895: rubber band/elastic band + 896: runner/runner carpet + 897: plastic bag/paper bag + 898: saddle/saddle on an animal + 899: saddle blanket/saddlecloth/horse blanket + 900: saddlebag + 901: safety pin + 902: sail + 903: salad + 904: salad plate/salad bowl + 905: salami + 906: salmon/salmon fish + 907: salmon/salmon food + 908: salsa + 909: saltshaker + 910: sandal/sandal type of shoe + 911: sandwich + 912: satchel + 913: saucepan + 914: saucer + 915: sausage + 916: sawhorse/sawbuck + 917: saxophone + 918: scale/scale measuring instrument + 919: scarecrow/strawman + 920: scarf + 921: school bus + 922: scissors + 923: scoreboard + 924: scraper + 925: screwdriver + 926: scrubbing brush + 927: sculpture + 928: seabird/seafowl + 929: seahorse + 930: seaplane/hydroplane + 931: seashell + 932: sewing machine + 933: shaker + 934: shampoo + 935: shark + 936: sharpener + 937: Sharpie + 938: shaver/shaver electric/electric shaver/electric razor + 939: shaving cream/shaving soap + 940: shawl + 941: shears + 942: sheep + 943: shepherd dog/sheepdog + 944: sherbert/sherbet + 945: shield + 946: shirt + 947: shoe/sneaker/sneaker type of shoe/tennis shoe + 948: shopping bag + 949: shopping cart + 950: short pants/shorts/shorts clothing/trunks/trunks clothing + 951: shot glass + 952: shoulder bag + 953: shovel + 954: shower head + 955: shower cap + 956: shower curtain + 957: shredder/shredder for paper + 958: signboard + 959: silo + 960: sink + 961: skateboard + 962: skewer + 963: ski + 964: ski boot + 965: ski parka/ski jacket + 966: ski pole + 967: skirt + 968: skullcap + 969: sled/sledge/sleigh + 970: sleeping bag + 971: sling/sling bandage/triangular bandage + 972: slipper/slipper footwear/carpet slipper/carpet slipper footwear + 973: smoothie + 974: snake/serpent + 975: snowboard + 976: snowman + 977: snowmobile + 978: soap + 979: soccer ball + 980: sock + 981: sofa/couch/lounge + 982: softball + 983: solar array/solar battery/solar panel + 984: sombrero + 985: soup + 986: soup bowl + 987: soupspoon + 988: sour cream/soured cream + 989: soya milk/soybean milk/soymilk + 990: space shuttle + 991: sparkler/sparkler fireworks + 992: spatula + 993: spear/lance + 994: spectacles/specs/eyeglasses/glasses + 995: spice rack + 996: spider + 997: crawfish/crayfish + 998: sponge + 999: spoon + 1000: sportswear/athletic wear/activewear + 1001: spotlight + 1002: squid/squid food/calamari/calamary + 1003: squirrel + 1004: stagecoach + 1005: stapler/stapler stapling machine + 1006: starfish/sea star + 1007: statue/statue sculpture + 1008: steak/steak food + 1009: steak knife + 1010: steering wheel + 1011: stepladder + 1012: step stool + 1013: stereo/stereo sound system + 1014: stew + 1015: stirrer + 1016: stirrup + 1017: stool + 1018: stop sign + 1019: brake light + 1020: stove/kitchen stove/range/range kitchen appliance/kitchen range/cooking stove + 1021: strainer + 1022: strap + 1023: straw/straw for drinking/drinking straw + 1024: strawberry + 1025: street sign + 1026: streetlight/street lamp + 1027: string cheese + 1028: stylus + 1029: subwoofer + 1030: sugar bowl + 1031: sugarcane/sugarcane plant + 1032: suit/suit clothing + 1033: sunflower + 1034: sunglasses + 1035: sunhat + 1036: surfboard + 1037: sushi + 1038: mop + 1039: sweat pants + 1040: sweatband + 1041: sweater + 1042: sweatshirt + 1043: sweet potato + 1044: swimsuit/swimwear/bathing suit/swimming costume/bathing costume/swimming trunks/bathing trunks + 1045: sword + 1046: syringe + 1047: Tabasco sauce + 1048: table-tennis table/ping-pong table + 1049: table + 1050: table lamp + 1051: tablecloth + 1052: tachometer + 1053: taco + 1054: tag + 1055: taillight/rear light + 1056: tambourine + 1057: army tank/armored combat vehicle/armoured combat vehicle + 1058: tank/tank storage vessel/storage tank + 1059: tank top/tank top clothing + 1060: tape/tape sticky cloth or paper + 1061: tape measure/measuring tape + 1062: tapestry + 1063: tarp + 1064: tartan/plaid + 1065: tassel + 1066: tea bag + 1067: teacup + 1068: teakettle + 1069: teapot + 1070: teddy bear + 1071: telephone/phone/telephone set + 1072: telephone booth/phone booth/call box/telephone box/telephone kiosk + 1073: telephone pole/telegraph pole/telegraph post + 1074: telephoto lens/zoom lens + 1075: television camera/tv camera + 1076: television set/tv/tv set + 1077: tennis ball + 1078: tennis racket + 1079: tequila + 1080: thermometer + 1081: thermos bottle + 1082: thermostat + 1083: thimble + 1084: thread/yarn + 1085: thumbtack/drawing pin/pushpin + 1086: tiara + 1087: tiger + 1088: tights/tights clothing/leotards + 1089: timer/stopwatch + 1090: tinfoil + 1091: tinsel + 1092: tissue paper + 1093: toast/toast food + 1094: toaster + 1095: toaster oven + 1096: toilet + 1097: toilet tissue/toilet paper/bathroom tissue + 1098: tomato + 1099: tongs + 1100: toolbox + 1101: toothbrush + 1102: toothpaste + 1103: toothpick + 1104: cover + 1105: tortilla + 1106: tow truck + 1107: towel + 1108: towel rack/towel rail/towel bar + 1109: toy + 1110: tractor/tractor farm equipment + 1111: traffic light + 1112: dirt bike + 1113: trailer truck/tractor trailer/trucking rig/articulated lorry/semi truck + 1114: train/train railroad vehicle/railroad train + 1115: trampoline + 1116: tray + 1117: trench coat + 1118: triangle/triangle musical instrument + 1119: tricycle + 1120: tripod + 1121: trousers/pants/pants clothing + 1122: truck + 1123: truffle/truffle chocolate/chocolate truffle + 1124: trunk + 1125: vat + 1126: turban + 1127: turkey/turkey food + 1128: turnip + 1129: turtle + 1130: turtleneck/turtleneck clothing/polo-neck + 1131: typewriter + 1132: umbrella + 1133: underwear/underclothes/underclothing/underpants + 1134: unicycle + 1135: urinal + 1136: urn + 1137: vacuum cleaner + 1138: vase + 1139: vending machine + 1140: vent/blowhole/air vent + 1141: vest/waistcoat + 1142: videotape + 1143: vinegar + 1144: violin/fiddle + 1145: vodka + 1146: volleyball + 1147: vulture + 1148: waffle + 1149: waffle iron + 1150: wagon + 1151: wagon wheel + 1152: walking stick + 1153: wall clock + 1154: wall socket/wall plug/electric outlet/electrical outlet/outlet/electric receptacle + 1155: wallet/billfold + 1156: walrus + 1157: wardrobe + 1158: washbasin/basin/basin for washing/washbowl/washstand/handbasin + 1159: automatic washer/washing machine + 1160: watch/wristwatch + 1161: water bottle + 1162: water cooler + 1163: water faucet/water tap/tap/tap water faucet + 1164: water heater/hot-water heater + 1165: water jug + 1166: water gun/squirt gun + 1167: water scooter/sea scooter/jet ski + 1168: water ski + 1169: water tower + 1170: watering can + 1171: watermelon + 1172: weathervane/vane/vane weathervane/wind vane + 1173: webcam + 1174: wedding cake/bridecake + 1175: wedding ring/wedding band + 1176: wet suit + 1177: wheel + 1178: wheelchair + 1179: whipped cream + 1180: whistle + 1181: wig + 1182: wind chime + 1183: windmill + 1184: window box/window box for plants + 1185: windshield wiper/windscreen wiper/wiper/wiper for windshield or screen + 1186: windsock/air sock/air-sleeve/wind sleeve/wind cone + 1187: wine bottle + 1188: wine bucket/wine cooler + 1189: wineglass + 1190: blinder/blinder for horses + 1191: wok + 1192: wolf + 1193: wooden spoon + 1194: wreath + 1195: wrench/spanner + 1196: wristband + 1197: wristlet/wrist band + 1198: yacht + 1199: yogurt/yoghurt/yoghourt + 1200: yoke/yoke animal equipment + 1201: zebra + 1202: zucchini/courgette + +# Download script/URL (optional) +download: | + from ultralytics.utils.downloads import download + from pathlib import Path + + # Download labels + dir = Path(yaml['path']) # dataset root dir + url = 'https://github.com/ultralytics/assets/releases/download/v0.0.0/' + urls = [url + 'lvis-labels-segments.zip'] # labels + download(urls, dir=dir.parent) + # Download data + urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images + 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images + 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional) + download(urls, dir=dir / 'images', threads=3) diff --git a/ultralytics/cfg/datasets/medical-pills.yaml b/ultralytics/cfg/datasets/medical-pills.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25507c8b9bef023c866aa6d6d41f8a3a6bf0958a --- /dev/null +++ b/ultralytics/cfg/datasets/medical-pills.yaml @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Medical-pills dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/medical-pills/ +# Example usage: yolo train data=medical-pills.yaml +# parent +# ├── ultralytics +# └── datasets +# └── medical-pills ← downloads here (8.19 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/medical-pills # dataset root dir +train: train/images # train images (relative to 'path') 92 images +val: valid/images # val images (relative to 'path') 23 images +test: # test images (relative to 'path') + +# Classes +names: + 0: pill + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/medical-pills.zip diff --git a/ultralytics/cfg/datasets/open-images-v7.yaml b/ultralytics/cfg/datasets/open-images-v7.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6bd4e0bdcf588b0736dbfe33bb52edf1ea4ea3fc --- /dev/null +++ b/ultralytics/cfg/datasets/open-images-v7.yaml @@ -0,0 +1,661 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Open Images v7 dataset https://storage.googleapis.com/openimages/web/index.html by Google +# Documentation: https://docs.ultralytics.com/datasets/detect/open-images-v7/ +# Example usage: yolo train data=open-images-v7.yaml +# parent +# ├── ultralytics +# └── datasets +# └── open-images-v7 ← downloads here (561 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/open-images-v7 # dataset root dir +train: images/train # train images (relative to 'path') 1743042 images +val: images/val # val images (relative to 'path') 41620 images +test: # test images (optional) + +# Classes +names: + 0: Accordion + 1: Adhesive tape + 2: Aircraft + 3: Airplane + 4: Alarm clock + 5: Alpaca + 6: Ambulance + 7: Animal + 8: Ant + 9: Antelope + 10: Apple + 11: Armadillo + 12: Artichoke + 13: Auto part + 14: Axe + 15: Backpack + 16: Bagel + 17: Baked goods + 18: Balance beam + 19: Ball + 20: Balloon + 21: Banana + 22: Band-aid + 23: Banjo + 24: Barge + 25: Barrel + 26: Baseball bat + 27: Baseball glove + 28: Bat (Animal) + 29: Bathroom accessory + 30: Bathroom cabinet + 31: Bathtub + 32: Beaker + 33: Bear + 34: Bed + 35: Bee + 36: Beehive + 37: Beer + 38: Beetle + 39: Bell pepper + 40: Belt + 41: Bench + 42: Bicycle + 43: Bicycle helmet + 44: Bicycle wheel + 45: Bidet + 46: Billboard + 47: Billiard table + 48: Binoculars + 49: Bird + 50: Blender + 51: Blue jay + 52: Boat + 53: Bomb + 54: Book + 55: Bookcase + 56: Boot + 57: Bottle + 58: Bottle opener + 59: Bow and arrow + 60: Bowl + 61: Bowling equipment + 62: Box + 63: Boy + 64: Brassiere + 65: Bread + 66: Briefcase + 67: Broccoli + 68: Bronze sculpture + 69: Brown bear + 70: Building + 71: Bull + 72: Burrito + 73: Bus + 74: Bust + 75: Butterfly + 76: Cabbage + 77: Cabinetry + 78: Cake + 79: Cake stand + 80: Calculator + 81: Camel + 82: Camera + 83: Can opener + 84: Canary + 85: Candle + 86: Candy + 87: Cannon + 88: Canoe + 89: Cantaloupe + 90: Car + 91: Carnivore + 92: Carrot + 93: Cart + 94: Cassette deck + 95: Castle + 96: Cat + 97: Cat furniture + 98: Caterpillar + 99: Cattle + 100: Ceiling fan + 101: Cello + 102: Centipede + 103: Chainsaw + 104: Chair + 105: Cheese + 106: Cheetah + 107: Chest of drawers + 108: Chicken + 109: Chime + 110: Chisel + 111: Chopsticks + 112: Christmas tree + 113: Clock + 114: Closet + 115: Clothing + 116: Coat + 117: Cocktail + 118: Cocktail shaker + 119: Coconut + 120: Coffee + 121: Coffee cup + 122: Coffee table + 123: Coffeemaker + 124: Coin + 125: Common fig + 126: Common sunflower + 127: Computer keyboard + 128: Computer monitor + 129: Computer mouse + 130: Container + 131: Convenience store + 132: Cookie + 133: Cooking spray + 134: Corded phone + 135: Cosmetics + 136: Couch + 137: Countertop + 138: Cowboy hat + 139: Crab + 140: Cream + 141: Cricket ball + 142: Crocodile + 143: Croissant + 144: Crown + 145: Crutch + 146: Cucumber + 147: Cupboard + 148: Curtain + 149: Cutting board + 150: Dagger + 151: Dairy Product + 152: Deer + 153: Desk + 154: Dessert + 155: Diaper + 156: Dice + 157: Digital clock + 158: Dinosaur + 159: Dishwasher + 160: Dog + 161: Dog bed + 162: Doll + 163: Dolphin + 164: Door + 165: Door handle + 166: Doughnut + 167: Dragonfly + 168: Drawer + 169: Dress + 170: Drill (Tool) + 171: Drink + 172: Drinking straw + 173: Drum + 174: Duck + 175: Dumbbell + 176: Eagle + 177: Earrings + 178: Egg (Food) + 179: Elephant + 180: Envelope + 181: Eraser + 182: Face powder + 183: Facial tissue holder + 184: Falcon + 185: Fashion accessory + 186: Fast food + 187: Fax + 188: Fedora + 189: Filing cabinet + 190: Fire hydrant + 191: Fireplace + 192: Fish + 193: Flag + 194: Flashlight + 195: Flower + 196: Flowerpot + 197: Flute + 198: Flying disc + 199: Food + 200: Food processor + 201: Football + 202: Football helmet + 203: Footwear + 204: Fork + 205: Fountain + 206: Fox + 207: French fries + 208: French horn + 209: Frog + 210: Fruit + 211: Frying pan + 212: Furniture + 213: Garden Asparagus + 214: Gas stove + 215: Giraffe + 216: Girl + 217: Glasses + 218: Glove + 219: Goat + 220: Goggles + 221: Goldfish + 222: Golf ball + 223: Golf cart + 224: Gondola + 225: Goose + 226: Grape + 227: Grapefruit + 228: Grinder + 229: Guacamole + 230: Guitar + 231: Hair dryer + 232: Hair spray + 233: Hamburger + 234: Hammer + 235: Hamster + 236: Hand dryer + 237: Handbag + 238: Handgun + 239: Harbor seal + 240: Harmonica + 241: Harp + 242: Harpsichord + 243: Hat + 244: Headphones + 245: Heater + 246: Hedgehog + 247: Helicopter + 248: Helmet + 249: High heels + 250: Hiking equipment + 251: Hippopotamus + 252: Home appliance + 253: Honeycomb + 254: Horizontal bar + 255: Horse + 256: Hot dog + 257: House + 258: Houseplant + 259: Human arm + 260: Human beard + 261: Human body + 262: Human ear + 263: Human eye + 264: Human face + 265: Human foot + 266: Human hair + 267: Human hand + 268: Human head + 269: Human leg + 270: Human mouth + 271: Human nose + 272: Humidifier + 273: Ice cream + 274: Indoor rower + 275: Infant bed + 276: Insect + 277: Invertebrate + 278: Ipod + 279: Isopod + 280: Jacket + 281: Jacuzzi + 282: Jaguar (Animal) + 283: Jeans + 284: Jellyfish + 285: Jet ski + 286: Jug + 287: Juice + 288: Kangaroo + 289: Kettle + 290: Kitchen & dining room table + 291: Kitchen appliance + 292: Kitchen knife + 293: Kitchen utensil + 294: Kitchenware + 295: Kite + 296: Knife + 297: Koala + 298: Ladder + 299: Ladle + 300: Ladybug + 301: Lamp + 302: Land vehicle + 303: Lantern + 304: Laptop + 305: Lavender (Plant) + 306: Lemon + 307: Leopard + 308: Light bulb + 309: Light switch + 310: Lighthouse + 311: Lily + 312: Limousine + 313: Lion + 314: Lipstick + 315: Lizard + 316: Lobster + 317: Loveseat + 318: Luggage and bags + 319: Lynx + 320: Magpie + 321: Mammal + 322: Man + 323: Mango + 324: Maple + 325: Maracas + 326: Marine invertebrates + 327: Marine mammal + 328: Measuring cup + 329: Mechanical fan + 330: Medical equipment + 331: Microphone + 332: Microwave oven + 333: Milk + 334: Miniskirt + 335: Mirror + 336: Missile + 337: Mixer + 338: Mixing bowl + 339: Mobile phone + 340: Monkey + 341: Moths and butterflies + 342: Motorcycle + 343: Mouse + 344: Muffin + 345: Mug + 346: Mule + 347: Mushroom + 348: Musical instrument + 349: Musical keyboard + 350: Nail (Construction) + 351: Necklace + 352: Nightstand + 353: Oboe + 354: Office building + 355: Office supplies + 356: Orange + 357: Organ (Musical Instrument) + 358: Ostrich + 359: Otter + 360: Oven + 361: Owl + 362: Oyster + 363: Paddle + 364: Palm tree + 365: Pancake + 366: Panda + 367: Paper cutter + 368: Paper towel + 369: Parachute + 370: Parking meter + 371: Parrot + 372: Pasta + 373: Pastry + 374: Peach + 375: Pear + 376: Pen + 377: Pencil case + 378: Pencil sharpener + 379: Penguin + 380: Perfume + 381: Person + 382: Personal care + 383: Personal flotation device + 384: Piano + 385: Picnic basket + 386: Picture frame + 387: Pig + 388: Pillow + 389: Pineapple + 390: Pitcher (Container) + 391: Pizza + 392: Pizza cutter + 393: Plant + 394: Plastic bag + 395: Plate + 396: Platter + 397: Plumbing fixture + 398: Polar bear + 399: Pomegranate + 400: Popcorn + 401: Porch + 402: Porcupine + 403: Poster + 404: Potato + 405: Power plugs and sockets + 406: Pressure cooker + 407: Pretzel + 408: Printer + 409: Pumpkin + 410: Punching bag + 411: Rabbit + 412: Raccoon + 413: Racket + 414: Radish + 415: Ratchet (Device) + 416: Raven + 417: Rays and skates + 418: Red panda + 419: Refrigerator + 420: Remote control + 421: Reptile + 422: Rhinoceros + 423: Rifle + 424: Ring binder + 425: Rocket + 426: Roller skates + 427: Rose + 428: Rugby ball + 429: Ruler + 430: Salad + 431: Salt and pepper shakers + 432: Sandal + 433: Sandwich + 434: Saucer + 435: Saxophone + 436: Scale + 437: Scarf + 438: Scissors + 439: Scoreboard + 440: Scorpion + 441: Screwdriver + 442: Sculpture + 443: Sea lion + 444: Sea turtle + 445: Seafood + 446: Seahorse + 447: Seat belt + 448: Segway + 449: Serving tray + 450: Sewing machine + 451: Shark + 452: Sheep + 453: Shelf + 454: Shellfish + 455: Shirt + 456: Shorts + 457: Shotgun + 458: Shower + 459: Shrimp + 460: Sink + 461: Skateboard + 462: Ski + 463: Skirt + 464: Skull + 465: Skunk + 466: Skyscraper + 467: Slow cooker + 468: Snack + 469: Snail + 470: Snake + 471: Snowboard + 472: Snowman + 473: Snowmobile + 474: Snowplow + 475: Soap dispenser + 476: Sock + 477: Sofa bed + 478: Sombrero + 479: Sparrow + 480: Spatula + 481: Spice rack + 482: Spider + 483: Spoon + 484: Sports equipment + 485: Sports uniform + 486: Squash (Plant) + 487: Squid + 488: Squirrel + 489: Stairs + 490: Stapler + 491: Starfish + 492: Stationary bicycle + 493: Stethoscope + 494: Stool + 495: Stop sign + 496: Strawberry + 497: Street light + 498: Stretcher + 499: Studio couch + 500: Submarine + 501: Submarine sandwich + 502: Suit + 503: Suitcase + 504: Sun hat + 505: Sunglasses + 506: Surfboard + 507: Sushi + 508: Swan + 509: Swim cap + 510: Swimming pool + 511: Swimwear + 512: Sword + 513: Syringe + 514: Table + 515: Table tennis racket + 516: Tablet computer + 517: Tableware + 518: Taco + 519: Tank + 520: Tap + 521: Tart + 522: Taxi + 523: Tea + 524: Teapot + 525: Teddy bear + 526: Telephone + 527: Television + 528: Tennis ball + 529: Tennis racket + 530: Tent + 531: Tiara + 532: Tick + 533: Tie + 534: Tiger + 535: Tin can + 536: Tire + 537: Toaster + 538: Toilet + 539: Toilet paper + 540: Tomato + 541: Tool + 542: Toothbrush + 543: Torch + 544: Tortoise + 545: Towel + 546: Tower + 547: Toy + 548: Traffic light + 549: Traffic sign + 550: Train + 551: Training bench + 552: Treadmill + 553: Tree + 554: Tree house + 555: Tripod + 556: Trombone + 557: Trousers + 558: Truck + 559: Trumpet + 560: Turkey + 561: Turtle + 562: Umbrella + 563: Unicycle + 564: Van + 565: Vase + 566: Vegetable + 567: Vehicle + 568: Vehicle registration plate + 569: Violin + 570: Volleyball (Ball) + 571: Waffle + 572: Waffle iron + 573: Wall clock + 574: Wardrobe + 575: Washing machine + 576: Waste container + 577: Watch + 578: Watercraft + 579: Watermelon + 580: Weapon + 581: Whale + 582: Wheel + 583: Wheelchair + 584: Whisk + 585: Whiteboard + 586: Willow + 587: Window + 588: Window blind + 589: Wine + 590: Wine glass + 591: Wine rack + 592: Winter melon + 593: Wok + 594: Woman + 595: Wood-burning stove + 596: Woodpecker + 597: Worm + 598: Wrench + 599: Zebra + 600: Zucchini + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + from ultralytics.utils import LOGGER, SETTINGS, Path, is_ubuntu, get_ubuntu_version + from ultralytics.utils.checks import check_requirements, check_version + + check_requirements('fiftyone') + if is_ubuntu() and check_version(get_ubuntu_version(), '>=22.04'): + # Ubuntu>=22.04 patch https://github.com/voxel51/fiftyone/issues/2961#issuecomment-1666519347 + check_requirements('fiftyone-db-ubuntu2204') + + import fiftyone as fo + import fiftyone.zoo as foz + import warnings + + name = 'open-images-v7' + fraction = 1.0 # fraction of full dataset to use + LOGGER.warning('WARNING ⚠️ Open Images V7 dataset requires at least **561 GB of free space. Starting download...') + for split in 'train', 'validation': # 1743042 train, 41620 val images + train = split == 'train' + + # Load Open Images dataset + dataset = foz.load_zoo_dataset(name, + split=split, + label_types=['detections'], + dataset_dir=Path(SETTINGS['datasets_dir']) / 'fiftyone' / name, + max_samples=round((1743042 if train else 41620) * fraction)) + + # Define classes + if train: + classes = dataset.default_classes # all classes + # classes = dataset.distinct('ground_truth.detections.label') # only observed classes + + # Export to YOLO format + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="fiftyone.utils.yolo") + dataset.export(export_dir=str(Path(SETTINGS['datasets_dir']) / name), + dataset_type=fo.types.YOLOv5Dataset, + label_field='ground_truth', + split='val' if split == 'validation' else split, + classes=classes, + overwrite=train) diff --git a/ultralytics/cfg/datasets/package-seg.yaml b/ultralytics/cfg/datasets/package-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..433ca04c7fe58a6f7da6130acda50e335901fbb1 --- /dev/null +++ b/ultralytics/cfg/datasets/package-seg.yaml @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Package-seg dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/segment/package-seg/ +# Example usage: yolo train data=package-seg.yaml +# parent +# ├── ultralytics +# └── datasets +# └── package-seg ← downloads here (102 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/package-seg # dataset root dir +train: train/images # train images (relative to 'path') 1920 images +val: valid/images # val images (relative to 'path') 89 images +test: test/images # test images (relative to 'path') 188 images + +# Classes +names: + 0: package + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/package-seg.zip diff --git a/ultralytics/cfg/datasets/signature.yaml b/ultralytics/cfg/datasets/signature.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5c9d5c338e95b566cd8e6f56294e32dc6c5ab323 --- /dev/null +++ b/ultralytics/cfg/datasets/signature.yaml @@ -0,0 +1,21 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Signature dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/detect/signature/ +# Example usage: yolo train data=signature.yaml +# parent +# ├── ultralytics +# └── datasets +# └── signature ← downloads here (11.2 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/signature # dataset root dir +train: train/images # train images (relative to 'path') 143 images +val: valid/images # val images (relative to 'path') 35 images + +# Classes +names: + 0: signature + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/signature.zip diff --git a/ultralytics/cfg/datasets/tiger-pose.yaml b/ultralytics/cfg/datasets/tiger-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2b3f7b71761e475cff2aa3c97c8301c848e0844f --- /dev/null +++ b/ultralytics/cfg/datasets/tiger-pose.yaml @@ -0,0 +1,25 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Tiger Pose dataset by Ultralytics +# Documentation: https://docs.ultralytics.com/datasets/pose/tiger-pose/ +# Example usage: yolo train data=tiger-pose.yaml +# parent +# ├── ultralytics +# └── datasets +# └── tiger-pose ← downloads here (75.3 MB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/tiger-pose # dataset root dir +train: train # train images (relative to 'path') 210 images +val: val # val images (relative to 'path') 53 images + +# Keypoints +kpt_shape: [12, 2] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +flip_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] + +# Classes +names: + 0: tiger + +# Download script/URL (optional) +download: https://github.com/ultralytics/assets/releases/download/v0.0.0/tiger-pose.zip diff --git a/ultralytics/cfg/datasets/xView.yaml b/ultralytics/cfg/datasets/xView.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ccef985974b74a22f0614da384210e32b0af59d4 --- /dev/null +++ b/ultralytics/cfg/datasets/xView.yaml @@ -0,0 +1,153 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA) +# -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! -------- +# Documentation: https://docs.ultralytics.com/datasets/detect/xview/ +# Example usage: yolo train data=xView.yaml +# parent +# ├── ultralytics +# └── datasets +# └── xView ← downloads here (20.7 GB) + +# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] +path: ../datasets/xView # dataset root dir +train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images +val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images + +# Classes +names: + 0: Fixed-wing Aircraft + 1: Small Aircraft + 2: Cargo Plane + 3: Helicopter + 4: Passenger Vehicle + 5: Small Car + 6: Bus + 7: Pickup Truck + 8: Utility Truck + 9: Truck + 10: Cargo Truck + 11: Truck w/Box + 12: Truck Tractor + 13: Trailer + 14: Truck w/Flatbed + 15: Truck w/Liquid + 16: Crane Truck + 17: Railway Vehicle + 18: Passenger Car + 19: Cargo Car + 20: Flat Car + 21: Tank car + 22: Locomotive + 23: Maritime Vessel + 24: Motorboat + 25: Sailboat + 26: Tugboat + 27: Barge + 28: Fishing Vessel + 29: Ferry + 30: Yacht + 31: Container Ship + 32: Oil Tanker + 33: Engineering Vehicle + 34: Tower crane + 35: Container Crane + 36: Reach Stacker + 37: Straddle Carrier + 38: Mobile Crane + 39: Dump Truck + 40: Haul Truck + 41: Scraper/Tractor + 42: Front loader/Bulldozer + 43: Excavator + 44: Cement Mixer + 45: Ground Grader + 46: Hut/Tent + 47: Shed + 48: Building + 49: Aircraft Hangar + 50: Damaged Building + 51: Facility + 52: Construction Site + 53: Vehicle Lot + 54: Helipad + 55: Storage Tank + 56: Shipping container lot + 57: Shipping Container + 58: Pylon + 59: Tower + +# Download script/URL (optional) --------------------------------------------------------------------------------------- +download: | + import json + import os + from pathlib import Path + + import numpy as np + from PIL import Image + from tqdm import tqdm + + from ultralytics.data.utils import autosplit + from ultralytics.utils.ops import xyxy2xywhn + + + def convert_labels(fname=Path('xView/xView_train.geojson')): + # Convert xView geoJSON labels to YOLO format + path = fname.parent + with open(fname) as f: + print(f'Loading {fname}...') + data = json.load(f) + + # Make dirs + labels = Path(path / 'labels' / 'train') + os.system(f'rm -rf {labels}') + labels.mkdir(parents=True, exist_ok=True) + + # xView classes 11-94 to 0-59 + xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11, + 12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1, + 29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46, + 47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59] + + shapes = {} + for feature in tqdm(data['features'], desc=f'Converting {fname}'): + p = feature['properties'] + if p['bounds_imcoords']: + id = p['image_id'] + file = path / 'train_images' / id + if file.exists(): # 1395.tif missing + try: + box = np.array([int(num) for num in p['bounds_imcoords'].split(",")]) + assert box.shape[0] == 4, f'incorrect box shape {box.shape[0]}' + cls = p['type_id'] + cls = xview_class2index[int(cls)] # xView class to 0-60 + assert 59 >= cls >= 0, f'incorrect class index {cls}' + + # Write YOLO label + if id not in shapes: + shapes[id] = Image.open(file).size + box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True) + with open((labels / id).with_suffix('.txt'), 'a') as f: + f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt + except Exception as e: + print(f'WARNING: skipping one label for {file}: {e}') + + + # Download manually from https://challenge.xviewdataset.org + dir = Path(yaml['path']) # dataset root dir + # urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels + # 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images + # 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels) + # download(urls, dir=dir) + + # Convert labels + convert_labels(dir / 'xView_train.geojson') + + # Move images + images = Path(dir / 'images') + images.mkdir(parents=True, exist_ok=True) + Path(dir / 'train_images').rename(dir / 'images' / 'train') + Path(dir / 'val_images').rename(dir / 'images' / 'val') + + # Split + autosplit(dir / 'images' / 'train') diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7b67442bc68436d9e30df9c77c9219ed4d101a5b --- /dev/null +++ b/ultralytics/cfg/default.yaml @@ -0,0 +1,130 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Global configuration YAML with settings and hyperparameters for YOLO training, validation, prediction and export +# For documentation see https://docs.ultralytics.com/usage/cfg/ + +task: detect # (str) YOLO task, i.e. detect, segment, classify, pose, obb +mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark + +# Train settings ------------------------------------------------------------------------------------------------------- +model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml +data: # (str, optional) path to data file, i.e. coco8.yaml +epochs: 100 # (int) number of epochs to train for +time: # (float, optional) number of hours to train for, overrides epochs if supplied +patience: 100 # (int) epochs to wait for no observable improvement for early stopping of training +batch: 16 # (int) number of images per batch (-1 for AutoBatch) +imgsz: 640 # (int | list) input images size as int for train and val modes, or list[h,w] for predict and export modes +save: True # (bool) save train checkpoints and predict results +save_period: -1 # (int) Save checkpoint every x epochs (disabled if < 1) +cache: False # (bool) True/ram, disk or False. Use cache for data loading +device: # (int | str | list, optional) device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu +workers: 8 # (int) number of worker threads for data loading (per RANK if DDP) +project: # (str, optional) project name +name: # (str, optional) experiment name, results saved to 'project/name' directory +exist_ok: False # (bool) whether to overwrite existing experiment +pretrained: True # (bool | str) whether to use a pretrained model (bool) or a model to load weights from (str) +optimizer: auto # (str) optimizer to use, choices=[SGD, Adam, Adamax, AdamW, NAdam, RAdam, RMSProp, auto] +verbose: True # (bool) whether to print verbose output +seed: 0 # (int) random seed for reproducibility +deterministic: True # (bool) whether to enable deterministic mode +single_cls: False # (bool) train multi-class data as single-class +rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val' +cos_lr: False # (bool) use cosine learning rate scheduler +close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable) +resume: False # (bool) resume training from last checkpoint +amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check +fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set) +profile: False # (bool) profile ONNX and TensorRT speeds during training for loggers +freeze: None # (int | list, optional) freeze first n layers, or freeze list of layer indices during training +multi_scale: False # (bool) Whether to use multiscale during training +# Segmentation +overlap_mask: True # (bool) merge object masks into a single image mask during training (segment train only) +mask_ratio: 4 # (int) mask downsample ratio (segment train only) +# Classification +dropout: 0.0 # (float) use dropout regularization (classify train only) + +# Val/Test settings ---------------------------------------------------------------------------------------------------- +val: True # (bool) validate/test during training +split: val # (str) dataset split to use for validation, i.e. 'val', 'test' or 'train' +save_json: False # (bool) save results to JSON file +save_hybrid: False # (bool) save hybrid version of labels (labels + additional predictions) +conf: # (float, optional) object confidence threshold for detection (default 0.25 predict, 0.001 val) +iou: 0.7 # (float) intersection over union (IoU) threshold for NMS +max_det: 300 # (int) maximum number of detections per image +half: False # (bool) use half precision (FP16) +dnn: False # (bool) use OpenCV DNN for ONNX inference +plots: True # (bool) save plots and images during train/val + +# Predict settings ----------------------------------------------------------------------------------------------------- +source: # (str, optional) source directory for images or videos +vid_stride: 1 # (int) video frame-rate stride +stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False) +visualize: False # (bool) visualize model features +augment: False # (bool) apply image augmentation to prediction sources +agnostic_nms: False # (bool) class-agnostic NMS +classes: # (int | list[int], optional) filter results by class, i.e. classes=0, or classes=[0,2,3] +retina_masks: False # (bool) use high-resolution segmentation masks +embed: # (list[int], optional) return feature vectors/embeddings from given layers + +# Visualize settings --------------------------------------------------------------------------------------------------- +show: False # (bool) show predicted images and videos if environment allows +save_frames: False # (bool) save predicted individual video frames +save_txt: False # (bool) save results as .txt file +save_conf: False # (bool) save results with confidence scores +save_crop: False # (bool) save cropped images with results +show_labels: True # (bool) show prediction labels, i.e. 'person' +show_conf: True # (bool) show prediction confidence, i.e. '0.99' +show_boxes: True # (bool) show prediction boxes +line_width: # (int, optional) line width of the bounding boxes. Scaled to image size if None. + +# Export settings ------------------------------------------------------------------------------------------------------ +format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats +keras: False # (bool) use Kera=s +optimize: False # (bool) TorchScript: optimize for mobile +int8: False # (bool) CoreML/TF INT8 quantization +dynamic: False # (bool) ONNX/TF/TensorRT: dynamic axes +simplify: True # (bool) ONNX: simplify model using `onnxslim` +opset: # (int, optional) ONNX: opset version +workspace: None # (float, optional) TensorRT: workspace size (GiB), `None` will let TensorRT auto-allocate memory +nms: False # (bool) CoreML: add NMS + +# Hyperparameters ------------------------------------------------------------------------------------------------------ +lr0: 0.01 # (float) initial learning rate (i.e. SGD=1E-2, Adam=1E-3) +lrf: 0.01 # (float) final learning rate (lr0 * lrf) +momentum: 0.937 # (float) SGD momentum/Adam beta1 +weight_decay: 0.0005 # (float) optimizer weight decay 5e-4 +warmup_epochs: 3.0 # (float) warmup epochs (fractions ok) +warmup_momentum: 0.8 # (float) warmup initial momentum +warmup_bias_lr: 0.0 # 0.1 # (float) warmup initial bias lr +box: 7.5 # (float) box loss gain +cls: 0.5 # (float) cls loss gain (scale with pixels) +dfl: 1.5 # (float) dfl loss gain +pose: 12.0 # (float) pose loss gain +kobj: 1.0 # (float) keypoint obj loss gain +nbs: 64 # (int) nominal batch size +hsv_h: 0.015 # (float) image HSV-Hue augmentation (fraction) +hsv_s: 0.7 # (float) image HSV-Saturation augmentation (fraction) +hsv_v: 0.4 # (float) image HSV-Value augmentation (fraction) +degrees: 0.0 # (float) image rotation (+/- deg) +translate: 0.1 # (float) image translation (+/- fraction) +scale: 0.5 # (float) image scale (+/- gain) +shear: 0.0 # (float) image shear (+/- deg) +perspective: 0.0 # (float) image perspective (+/- fraction), range 0-0.001 +flipud: 0.0 # (float) image flip up-down (probability) +fliplr: 0.5 # (float) image flip left-right (probability) +bgr: 0.0 # (float) image channel BGR (probability) + +mosaic: 1.0 # (float) image mosaic (probability) +mixup: 0.0 # (float) image mixup (probability) +copy_paste: 0.1 # (float) segment copy-paste (probability) + +copy_paste_mode: "flip" # (str) the method to do copy_paste augmentation (flip, mixup) +auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix) +erasing: 0.4 # (float) probability of random erasing during classification training (0-0.9), 0 means no erasing, must be less than 1.0. +crop_fraction: 1.0 # (float) image crop fraction for classification (0.1-1), 1.0 means no crop, must be greater than 0. + +# Custom config.yaml --------------------------------------------------------------------------------------------------- +cfg: # (str, optional) for overriding defaults.yaml + +# Tracker settings ------------------------------------------------------------------------------------------------------ +tracker: botsort.yaml # (str) tracker type, choices=[botsort.yaml, bytetrack.yaml] diff --git a/ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml b/ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml new file mode 100644 index 0000000000000000000000000000000000000000..baedcb5dc5abcbf02bb58ff2f512e11cf5dbc2c8 --- /dev/null +++ b/ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-cls image classification model with ResNet18 backbone +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 10 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# ResNet18 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, TorchVision, [512, "resnet18", "DEFAULT", True, 2]] # truncate two layers from the end + +# YOLO11n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/11/yolo11-cls.yaml b/ultralytics/cfg/models/11/yolo11-cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a6457c6d6d9728878aa9a4b8a4d76708f39e0e0 --- /dev/null +++ b/ultralytics/cfg/models/11/yolo11-cls.yaml @@ -0,0 +1,33 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-cls image classification model +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n-cls.yaml' will call yolo11-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 151 layers, 1633584 parameters, 1633584 gradients, 3.3 GFLOPs + s: [0.50, 0.50, 1024] # summary: 151 layers, 5545488 parameters, 5545488 gradients, 12.2 GFLOPs + m: [0.50, 1.00, 512] # summary: 187 layers, 10455696 parameters, 10455696 gradients, 39.7 GFLOPs + l: [1.00, 1.00, 512] # summary: 309 layers, 12937104 parameters, 12937104 gradients, 49.9 GFLOPs + x: [1.00, 1.50, 512] # summary: 309 layers, 28458544 parameters, 28458544 gradients, 111.1 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 2, C2PSA, [1024]] # 9 + +# YOLO11n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/11/yolo11-obb.yaml b/ultralytics/cfg/models/11/yolo11-obb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8625c7cfdace7d120134711627fc86c7468be714 --- /dev/null +++ b/ultralytics/cfg/models/11/yolo11-obb.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/obb + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n-obb.yaml' will call yolo11-obb.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 344 layers, 2695747 parameters, 2695731 gradients, 6.9 GFLOPs + s: [0.50, 0.50, 1024] # summary: 344 layers, 9744931 parameters, 9744915 gradients, 22.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 434 layers, 20963523 parameters, 20963507 gradients, 72.2 GFLOPs + l: [1.00, 1.00, 512] # summary: 656 layers, 26220995 parameters, 26220979 gradients, 91.3 GFLOPs + x: [1.00, 1.50, 512] # summary: 656 layers, 58875331 parameters, 58875315 gradients, 204.3 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, OBB, [nc, 1]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/11/yolo11-pose.yaml b/ultralytics/cfg/models/11/yolo11-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7470edac2fa377ad5913911f2607c47e2cc1eba9 --- /dev/null +++ b/ultralytics/cfg/models/11/yolo11-pose.yaml @@ -0,0 +1,51 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-pose keypoints/pose estimation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 80 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolo11n-pose.yaml' will call yolo11.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 344 layers, 2908507 parameters, 2908491 gradients, 7.7 GFLOPs + s: [0.50, 0.50, 1024] # summary: 344 layers, 9948811 parameters, 9948795 gradients, 23.5 GFLOPs + m: [0.50, 1.00, 512] # summary: 434 layers, 20973273 parameters, 20973257 gradients, 72.3 GFLOPs + l: [1.00, 1.00, 512] # summary: 656 layers, 26230745 parameters, 26230729 gradients, 91.4 GFLOPs + x: [1.00, 1.50, 512] # summary: 656 layers, 58889881 parameters, 58889865 gradients, 204.3 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, Pose, [nc, kpt_shape]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/11/yolo11-seg.yaml b/ultralytics/cfg/models/11/yolo11-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a569f4af84dfdd698f90438c67a3b121f777c143 --- /dev/null +++ b/ultralytics/cfg/models/11/yolo11-seg.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/segment + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n-seg.yaml' will call yolo11-seg.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 355 layers, 2876848 parameters, 2876832 gradients, 10.5 GFLOPs + s: [0.50, 0.50, 1024] # summary: 355 layers, 10113248 parameters, 10113232 gradients, 35.8 GFLOPs + m: [0.50, 1.00, 512] # summary: 445 layers, 22420896 parameters, 22420880 gradients, 123.9 GFLOPs + l: [1.00, 1.00, 512] # summary: 667 layers, 27678368 parameters, 27678352 gradients, 143.0 GFLOPs + x: [1.00, 1.50, 512] # summary: 667 layers, 62142656 parameters, 62142640 gradients, 320.2 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, Segment, [nc, 32, 256]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/11/yolo11.yaml b/ultralytics/cfg/models/11/yolo11.yaml new file mode 100644 index 0000000000000000000000000000000000000000..409465a1bb78944a9866505e5f0c2602ef76c99b --- /dev/null +++ b/ultralytics/cfg/models/11/yolo11.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLO11 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo11 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs + s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs + l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs + x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs + +# YOLO11n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 2, C3k2, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 2, C3k2, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 2, C2PSA, [1024]] # 10 + +# YOLO11n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, C3k2, [512, False]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/README.md b/ultralytics/cfg/models/README.md new file mode 100644 index 0000000000000000000000000000000000000000..68a9238384ec4c8f08c1db1fd6bad95e824d96d4 --- /dev/null +++ b/ultralytics/cfg/models/README.md @@ -0,0 +1,48 @@ +## Models + +Welcome to the [Ultralytics](https://www.ultralytics.com/) Models directory! Here you will find a wide variety of pre-configured model configuration files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks. + +These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this directory provides a great starting point for your custom model development needs. + +To get started, simply browse through the models in this directory and find one that best suits your needs. Once you've selected a model, you can use the provided `*.yaml` file to train and deploy your custom YOLO model with ease. See full details at the Ultralytics [Docs](https://docs.ultralytics.com/models/), and if you need help or have any questions, feel free to reach out to the Ultralytics team for support. So, don't wait, start creating your custom YOLO model now! + +### Usage + +Model `*.yaml` files may be used directly in the [Command Line Interface (CLI)](https://docs.ultralytics.com/usage/cli/) with a `yolo` command: + +```bash +# Train a YOLO11n model using the coco8 dataset for 100 epochs +yolo task=detect mode=train model=yolo11n.yaml data=coco8.yaml epochs=100 +``` + +They may also be used directly in a Python environment, and accept the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above: + +```python +from ultralytics import YOLO + +# Initialize a YOLO11n model from a YAML configuration file +model = YOLO("model.yaml") + +# If a pre-trained model is available, use it instead +# model = YOLO("model.pt") + +# Display model information +model.info() + +# Train the model using the COCO8 dataset for 100 epochs +model.train(data="coco8.yaml", epochs=100) +``` + +## Pre-trained Model Architectures + +Ultralytics supports many model architectures. Visit [Ultralytics Models](https://docs.ultralytics.com/models/) to view detailed information and usage. Any of these models can be used by loading their configurations or pretrained checkpoints if available. + +## Contribute New Models + +Have you trained a new YOLO variant or achieved state-of-the-art performance with specific tuning? We'd love to showcase your work in our Models section! Contributions from the community in the form of new models, architectures, or optimizations are highly valued and can significantly enrich our repository. + +By contributing to this section, you're helping us offer a wider array of model choices and configurations to the community. It's a fantastic way to share your knowledge and expertise while making the Ultralytics YOLO ecosystem even more versatile. + +To get started, please consult our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) for step-by-step instructions on how to submit a Pull Request (PR) 🛠️. Your contributions are eagerly awaited! + +Let's join hands to extend the range and capabilities of the Ultralytics YOLO models 🙏! diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8d6b4f410be5cfdea18c1d3dae48501e443fec2 --- /dev/null +++ b/ultralytics/cfg/models/rt-detr/rtdetr-l.yaml @@ -0,0 +1,53 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-l hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, HGStem, [32, 48]] # 0-P2/4 + - [-1, 6, HGBlock, [48, 128, 3]] # stage 1 + + - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 + - [-1, 6, HGBlock, [96, 512, 3]] # stage 2 + + - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16 + - [-1, 6, HGBlock, [192, 1024, 5, True, False]] # cm, c2, k, light, shortcut + - [-1, 6, HGBlock, [192, 1024, 5, True, True]] + - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3 + + - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32 + - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0 + - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0 + - [[-1, 17], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1 + - [[-1, 12], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1 + + - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b13e94512bd5f995937a995002d66c93bff7803f --- /dev/null +++ b/ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-ResNet101 hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 23]] # 3 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 5 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 7 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 11 + - [-1, 1, Conv, [256, 1, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (16), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 17, downsample_convs.0 + - [[-1, 12], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (19), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.1 + - [[-1, 7], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (22), pan_blocks.1 + + - [[16, 19, 22], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8172ad4ed4c4fe9263b87d2595a61625d0644ad2 --- /dev/null +++ b/ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-ResNet50 hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 6]] # 3 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4 + +head: + - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 5 + - [-1, 1, AIFI, [1024, 8]] + - [-1, 1, Conv, [256, 1, 1]] # 7 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 9 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [256]] # 11 + - [-1, 1, Conv, [256, 1, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [2, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [256]] # X3 (16), fpn_blocks.1 + + - [-1, 1, Conv, [256, 3, 2]] # 17, downsample_convs.0 + - [[-1, 12], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [256]] # F4 (19), pan_blocks.0 + + - [-1, 1, Conv, [256, 3, 2]] # 20, downsample_convs.1 + - [[-1, 7], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [256]] # F5 (22), pan_blocks.1 + + - [[16, 19, 22], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml b/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9c4a19c8ab919d06a68ff655d6ee788ac5b23a0 --- /dev/null +++ b/ultralytics/cfg/models/rt-detr/rtdetr-x.yaml @@ -0,0 +1,57 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics RT-DETR-x hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + x: [1.00, 1.00, 2048] + +backbone: + # [from, repeats, module, args] + - [-1, 1, HGStem, [32, 64]] # 0-P2/4 + - [-1, 6, HGBlock, [64, 128, 3]] # stage 1 + + - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8 + - [-1, 6, HGBlock, [128, 512, 3]] + - [-1, 6, HGBlock, [128, 512, 3, False, True]] # 4-stage 2 + + - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 5-P3/16 + - [-1, 6, HGBlock, [256, 1024, 5, True, False]] # cm, c2, k, light, shortcut + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] + - [-1, 6, HGBlock, [256, 1024, 5, True, True]] # 10-stage 3 + + - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 11-P4/32 + - [-1, 6, HGBlock, [512, 2048, 5, True, False]] + - [-1, 6, HGBlock, [512, 2048, 5, True, True]] # 13-stage 4 + +head: + - [-1, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 14 input_proj.2 + - [-1, 1, AIFI, [2048, 8]] + - [-1, 1, Conv, [384, 1, 1]] # 16, Y5, lateral_convs.0 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [10, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 18 input_proj.1 + - [[-2, -1], 1, Concat, [1]] + - [-1, 3, RepC3, [384]] # 20, fpn_blocks.0 + - [-1, 1, Conv, [384, 1, 1]] # 21, Y4, lateral_convs.1 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [4, 1, Conv, [384, 1, 1, None, 1, 1, False]] # 23 input_proj.0 + - [[-2, -1], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, RepC3, [384]] # X3 (25), fpn_blocks.1 + + - [-1, 1, Conv, [384, 3, 2]] # 26, downsample_convs.0 + - [[-1, 21], 1, Concat, [1]] # cat Y4 + - [-1, 3, RepC3, [384]] # F4 (28), pan_blocks.0 + + - [-1, 1, Conv, [384, 3, 2]] # 29, downsample_convs.1 + - [[-1, 16], 1, Concat, [1]] # cat Y5 + - [-1, 3, RepC3, [384]] # F5 (31), pan_blocks.1 + + - [[25, 28, 31], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10b.yaml b/ultralytics/cfg/models/v10/yolov10b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..750379128cc77d22329cbe9315ce5a644b722baa --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10b.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10b object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + b: [0.67, 1.00, 512] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10l.yaml b/ultralytics/cfg/models/v10/yolov10l.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1dedd752e2372c186ffef0e10e8272227af1ce27 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10l.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10l object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + l: [1.00, 1.00, 512] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10m.yaml b/ultralytics/cfg/models/v10/yolov10m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6ba4020b3309322ca22dd84d107f3dd75802d5fa --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10m.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10m object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + m: [0.67, 0.75, 768] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10n.yaml b/ultralytics/cfg/models/v10/yolov10n.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a9aa7018950c2897c6efb3f38ad780016db8817b --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10n.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10n object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10s.yaml b/ultralytics/cfg/models/v10/yolov10s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dbb678b277d72bd2aabca2473974193495157559 --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10s.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10s object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + s: [0.33, 0.50, 1024] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v10/yolov10x.yaml b/ultralytics/cfg/models/v10/yolov10x.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57482133863ee7eee26b66f83d2d6567c9fa9baf --- /dev/null +++ b/ultralytics/cfg/models/v10/yolov10x.yaml @@ -0,0 +1,45 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv10x object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov10 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov10n.yaml' will call yolov10.yaml with scale 'n' + # [depth, width, max_channels] + x: [1.00, 1.25, 512] + +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2fCIB, [512, True]] + - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2fCIB, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + - [-1, 1, PSA, [1024]] # 10 + +# YOLOv10.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fCIB, [512, True]] # 13 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 16 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 13], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fCIB, [512, True]] # 19 (P4/16-medium) + + - [-1, 1, SCDown, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fCIB, [1024, True]] # 22 (P5/32-large) + + - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v12/yolov12.yaml b/ultralytics/cfg/models/v12/yolov12.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dca26329faf534768a8b1589bf0339d063caa532 --- /dev/null +++ b/ultralytics/cfg/models/v12/yolov12.yaml @@ -0,0 +1,45 @@ +# YOLOv12 🚀, AGPL-3.0 license +# YOLOv12 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov12n.yaml' will call yolov12.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.50, 0.25, 1024] # summary: 465 layers, 2,603,056 parameters, 2,603,040 gradients, 6.7 GFLOPs + s: [0.50, 0.50, 1024] # summary: 465 layers, 9,285,632 parameters, 9,285,616 gradients, 21.7 GFLOPs + m: [0.50, 1.00, 512] # summary: 501 layers, 20,201,216 parameters, 20,201,200 gradients, 68.1 GFLOPs + l: [1.00, 1.00, 512] # summary: 831 layers, 26,454,880 parameters, 26,454,864 gradients, 89.7 GFLOPs + x: [1.00, 1.50, 512] # summary: 831 layers, 59,216,928 parameters, 59,216,912 gradients, 200.3 GFLOPs + +# YOLO12n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 2, C3k2, [256, False, 0.25]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 2, C3k2, [512, False, 0.25]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 4, A2C2f, [512, True, 4]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 4, A2C2f, [1024, True, 1]] # 8 + +# YOLO12n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 2, A2C2f, [512, False, -1]] # 11 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 2, A2C2f, [256, False, -1]] # 14 + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P4 + - [-1, 2, A2C2f, [512, False, -1]] # 17 + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 8], 1, Concat, [1]] # cat head P5 + - [-1, 2, C3k2, [1024, True]] # 20 (P5/32-large) + + - [[14, 17, 20], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v3/yolov3-spp.yaml b/ultralytics/cfg/models/v3/yolov3-spp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6aef25ab748bc84ee6ad054057029de97e7e1b14 --- /dev/null +++ b/ultralytics/cfg/models/v3/yolov3-spp.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv3-SPP object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov3 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# darknet53 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [32, 3, 1]] # 0 + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Bottleneck, [64]] + - [-1, 1, Conv, [128, 3, 2]] # 3-P2/4 + - [-1, 2, Bottleneck, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 5-P3/8 + - [-1, 8, Bottleneck, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 7-P4/16 + - [-1, 8, Bottleneck, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P5/32 + - [-1, 4, Bottleneck, [1024]] # 10 + +# YOLOv3-SPP head +head: + - [-1, 1, Bottleneck, [1024, False]] + - [-1, 1, SPP, [512, [5, 9, 13]]] + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] # 15 (P5/32-large) + + - [-2, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 22 (P4/16-medium) + + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Bottleneck, [256, False]] + - [-1, 2, Bottleneck, [256, False]] # 27 (P3/8-small) + + - [[27, 22, 15], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v3/yolov3-tiny.yaml b/ultralytics/cfg/models/v3/yolov3-tiny.yaml new file mode 100644 index 0000000000000000000000000000000000000000..91a0bb03f7d8f8436cd4a09cb8e46d8e38484c5c --- /dev/null +++ b/ultralytics/cfg/models/v3/yolov3-tiny.yaml @@ -0,0 +1,40 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv3-tiiny object detection model with P4/16 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov3 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# YOLOv3-tiny backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [16, 3, 1]] # 0 + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 1-P1/2 + - [-1, 1, Conv, [32, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 3-P2/4 + - [-1, 1, Conv, [64, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 5-P3/8 + - [-1, 1, Conv, [128, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 7-P4/16 + - [-1, 1, Conv, [256, 3, 1]] + - [-1, 1, nn.MaxPool2d, [2, 2, 0]] # 9-P5/32 + - [-1, 1, Conv, [512, 3, 1]] + - [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]] # 11 + - [-1, 1, nn.MaxPool2d, [2, 1, 0]] # 12 + +# YOLOv3-tiny head +head: + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 15 (P5/32-large) + + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Conv, [256, 3, 1]] # 19 (P4/16-medium) + + - [[19, 15], 1, Detect, [nc]] # Detect(P4, P5) diff --git a/ultralytics/cfg/models/v3/yolov3.yaml b/ultralytics/cfg/models/v3/yolov3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..95c99de52be649df826589f73da346d4a75dd05f --- /dev/null +++ b/ultralytics/cfg/models/v3/yolov3.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv3 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov3 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# darknet53 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [32, 3, 1]] # 0 + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Bottleneck, [64]] + - [-1, 1, Conv, [128, 3, 2]] # 3-P2/4 + - [-1, 2, Bottleneck, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 5-P3/8 + - [-1, 8, Bottleneck, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 7-P4/16 + - [-1, 8, Bottleneck, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P5/32 + - [-1, 4, Bottleneck, [1024]] # 10 + +# YOLOv3 head +head: + - [-1, 1, Bottleneck, [1024, False]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, Conv, [1024, 3, 1]] # 15 (P5/32-large) + + - [-2, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Bottleneck, [512, False]] + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, Conv, [512, 3, 1]] # 22 (P4/16-medium) + + - [-2, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Bottleneck, [256, False]] + - [-1, 2, Bottleneck, [256, False]] # 27 (P3/8-small) + + - [[27, 22, 15], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v5/yolov5-p6.yaml b/ultralytics/cfg/models/v5/yolov5-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..376d1aba90c24c0df93c14af4a45f5b98b6dbf02 --- /dev/null +++ b/ultralytics/cfg/models/v5/yolov5-p6.yaml @@ -0,0 +1,62 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv5 object detection model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov5 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov5n-p6.yaml' will call yolov5-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.33, 1.25, 1024] + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [64, 6, 2, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 9, C3, [512]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C3, [768]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C3, [1024]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv5 v6.0 head +head: + - [-1, 1, Conv, [768, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C3, [768, False]] # 15 + + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3, [512, False]] # 19 + + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3, [256, False]] # 23 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 20], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3, [512, False]] # 26 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 16], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3, [768, False]] # 29 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P6 + - [-1, 3, C3, [1024, False]] # 32 (P6/64-xlarge) + + - [[23, 26, 29, 32], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v5/yolov5.yaml b/ultralytics/cfg/models/v5/yolov5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..76a4749ae4f102cdee4db2a14dae4912476736c4 --- /dev/null +++ b/ultralytics/cfg/models/v5/yolov5.yaml @@ -0,0 +1,51 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv5 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov5 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov5n.yaml' will call yolov5.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.33, 1.25, 1024] + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + - [-1, 1, Conv, [64, 6, 2, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3, [128]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3, [256]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 9, C3, [512]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3, [1024]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv5 v6.0 head +head: + - [-1, 1, Conv, [512, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3, [512, False]] # 13 + + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3, [256, False]] # 17 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3, [512, False]] # 20 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3, [1024, False]] # 23 (P5/32-large) + + - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v6/yolov6.yaml b/ultralytics/cfg/models/v6/yolov6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0812ac7e9b52bfaa2e9a54630744ffc408d71690 --- /dev/null +++ b/ultralytics/cfg/models/v6/yolov6.yaml @@ -0,0 +1,56 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Meituan YOLOv6 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov6 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +activation: nn.ReLU() # (optional) model default activation function +scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv6-3.0s backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 6, Conv, [128, 3, 1]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 12, Conv, [256, 3, 1]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 18, Conv, [512, 3, 1]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 6, Conv, [1024, 3, 1]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv6-3.0s head +head: + - [-1, 1, Conv, [256, 1, 1]] + - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, Conv, [256, 3, 1]] + - [-1, 9, Conv, [256, 3, 1]] # 14 + + - [-1, 1, Conv, [128, 1, 1]] + - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, Conv, [128, 3, 1]] + - [-1, 9, Conv, [128, 3, 1]] # 19 + + - [-1, 1, Conv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P4 + - [-1, 1, Conv, [256, 3, 1]] + - [-1, 9, Conv, [256, 3, 1]] # 23 + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 10], 1, Concat, [1]] # cat head P5 + - [-1, 1, Conv, [512, 3, 1]] + - [-1, 9, Conv, [512, 3, 1]] # 27 + + - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml b/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml new file mode 100644 index 0000000000000000000000000000000000000000..44cc00ebf2243977441dc3e492eb860f8c442267 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml @@ -0,0 +1,28 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-cls image classification model with ResNet101 backbone +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0-P1/2 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1-P2/4 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2-P3/8 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 23]] # 3-P4/16 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4-P5/32 + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml b/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1d05e0753fcd22195eaecd4e5dde22fb58750a14 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml @@ -0,0 +1,28 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-cls image classification model with ResNet50 backbone +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, ResNetLayer, [3, 64, 1, True, 1]] # 0-P1/2 + - [-1, 1, ResNetLayer, [64, 64, 1, False, 3]] # 1-P2/4 + - [-1, 1, ResNetLayer, [256, 128, 2, False, 4]] # 2-P3/8 + - [-1, 1, ResNetLayer, [512, 256, 2, False, 6]] # 3-P4/16 + - [-1, 1, ResNetLayer, [1024, 512, 2, False, 3]] # 4-P5/32 + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/v8/yolov8-cls.yaml b/ultralytics/cfg/models/v8/yolov8-cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e346e5e1b76164894a01f0640d30145d4161ce53 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-cls.yaml @@ -0,0 +1,32 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-cls image classification model with YOLO backbone +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/classify + +# Parameters +nc: 1000 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 1024] + l: [1.00, 1.00, 1024] + x: [1.00, 1.25, 1024] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + +# YOLOv8.0n head +head: + - [-1, 1, Classify, [nc]] # Classify diff --git a/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml b/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a98f23837bf0b82a6522303e731c4c32fc1f963f --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml @@ -0,0 +1,58 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P2/4 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect +# Employs Ghost convolutions and modules proposed in Huawei's GhostNet in https://arxiv.org/abs/1911.11907v2 + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost-p2 summary: 491 layers, 2033944 parameters, 2033928 gradients, 13.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost-p2 summary: 491 layers, 5562080 parameters, 5562064 gradients, 25.1 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost-p2 summary: 731 layers, 9031728 parameters, 9031712 gradients, 42.8 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost-p2 summary: 971 layers, 12214448 parameters, 12214432 gradients, 69.1 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost-p2 summary: 971 layers, 18664776 parameters, 18664760 gradients, 103.3 GFLOPs + +# YOLOv8.0-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0-ghost-p2 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 15 (P3/8-small) + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 2], 1, Concat, [1]] # cat backbone P2 + - [-1, 3, C3Ghost, [128]] # 18 (P2/4-xsmall) + + - [-1, 1, GhostConv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P3 + - [-1, 3, C3Ghost, [256]] # 21 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 24 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [1024]] # 27 (P5/32-large) + + - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml b/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..956c2f0ad668aa81ba6f2116f24325457caf330b --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml @@ -0,0 +1,60 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect +# Employs Ghost convolutions and modules proposed in Huawei's GhostNet in https://arxiv.org/abs/1911.11907v2 + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost-p6 summary: 529 layers, 2901100 parameters, 2901084 gradients, 5.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost-p6 summary: 529 layers, 9520008 parameters, 9519992 gradients, 16.4 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost-p6 summary: 789 layers, 18002904 parameters, 18002888 gradients, 34.4 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost-p6 summary: 1049 layers, 21227584 parameters, 21227568 gradients, 55.3 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost-p6 summary: 1049 layers, 33057852 parameters, 33057836 gradients, 85.7 GFLOPs + +# YOLOv8.0-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [768, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0-ghost-p6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C3Ghost, [768]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 20 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 23 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [768]] # 26 (P5/32-large) + + - [-1, 1, GhostConv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C3Ghost, [1024]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-ghost.yaml b/ultralytics/cfg/models/v8/yolov8-ghost.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5888fb39bd0528995b3be15a82b7626661cb7358 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-ghost.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect +# Employs Ghost convolutions and modules proposed in Huawei's GhostNet in https://arxiv.org/abs/1911.11907v2 + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-ghost summary: 403 layers, 1865316 parameters, 1865300 gradients, 5.8 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-ghost summary: 403 layers, 5960072 parameters, 5960056 gradients, 16.4 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-ghost summary: 603 layers, 10336312 parameters, 10336296 gradients, 32.7 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-ghost summary: 803 layers, 14277872 parameters, 14277856 gradients, 53.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-ghost summary: 803 layers, 22229308 parameters, 22229292 gradients, 83.3 GFLOPs + +# YOLOv8.0n-ghost backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, GhostConv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C3Ghost, [128, True]] + - [-1, 1, GhostConv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C3Ghost, [256, True]] + - [-1, 1, GhostConv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C3Ghost, [512, True]] + - [-1, 1, GhostConv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C3Ghost, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C3Ghost, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C3Ghost, [256]] # 15 (P3/8-small) + + - [-1, 1, GhostConv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C3Ghost, [512]] # 18 (P4/16-medium) + + - [-1, 1, GhostConv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C3Ghost, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-obb.yaml b/ultralytics/cfg/models/v8/yolov8-obb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..909324c5bec4190e86626599adb71c78b8957f04 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-obb.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-obb Oriented Bounding Boxes (OBB) model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/obb + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, OBB, [nc, 1]] # OBB(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-p2.yaml b/ultralytics/cfg/models/v8/yolov8-p2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..676bc8348c4b264e03140f9d4ec2da7a65b02a8b --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-p2.yaml @@ -0,0 +1,57 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P2/4 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0-p2 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 2], 1, Concat, [1]] # cat backbone P2 + - [-1, 3, C2f, [128]] # 18 (P2/4-xsmall) + + - [-1, 1, Conv, [128, 3, 2]] + - [[-1, 15], 1, Concat, [1]] # cat head P3 + - [-1, 3, C2f, [256]] # 21 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 24 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 27 (P5/32-large) + + - [[18, 21, 24, 27], 1, Detect, [nc]] # Detect(P2, P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-p6.yaml b/ultralytics/cfg/models/v8/yolov8-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fde34981f80238aa213c0e331f3488335af0a9c --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-p6.yaml @@ -0,0 +1,59 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n-p6 summary (fused): 220 layers, 4976656 parameters, 42560 gradients, 8.7 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s-p6 summary (fused): 220 layers, 17897168 parameters, 57920 gradients, 28.5 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m-p6 summary (fused): 285 layers, 44862352 parameters, 78400 gradients, 83.1 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l-p6 summary (fused): 350 layers, 62351440 parameters, 98880 gradients, 167.3 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x-p6 summary (fused): 350 layers, 97382352 parameters, 123456 gradients, 261.1 GFLOPs + +# YOLOv8.0x6 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [768, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0x6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Detect, [nc]] # Detect(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml b/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..447a21aab0703b0f3a6520ba85b112ff6bd4f809 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-pose-p6.yaml @@ -0,0 +1,60 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-pose keypoints/pose estimation model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 1 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolov8n-p6.yaml' will call yolov8-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0x6 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [768, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0x6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-pose.yaml b/ultralytics/cfg/models/v8/yolov8-pose.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c22bc435b578d58b1a432e1ee0f203562e73f076 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-pose.yaml @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-pose keypoints/pose estimation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/pose + +# Parameters +nc: 1 # number of classes +kpt_shape: [17, 3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) +scales: # model compound scaling constants, i.e. 'model=yolov8n-pose.yaml' will call yolov8-pose.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Pose, [nc, kpt_shape]] # Pose(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml b/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..50ec129ac1885898d1ca9b3bc94a196e92519efb --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-rtdetr.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-RTDETR hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/rtdetr +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml b/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c7ba9bf4dddaa23adc5521ff51ab149ae09495a --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-seg-p6.yaml @@ -0,0 +1,59 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-seg instance segmentation model with P3/8 - P6/64 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/segment + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-seg-p6.yaml' will call yolov8-seg-p6.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0x6 backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [768, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [768, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 9-P6/64 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 11 + +# YOLOv8.0x6 head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 8], 1, Concat, [1]] # cat backbone P5 + - [-1, 3, C2, [768, False]] # 14 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2, [512, False]] # 17 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2, [256, False]] # 20 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 17], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2, [512, False]] # 23 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 14], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2, [768, False]] # 26 (P5/32-large) + + - [-1, 1, Conv, [768, 3, 2]] + - [[-1, 11], 1, Concat, [1]] # cat head P6 + - [-1, 3, C2, [1024, False]] # 29 (P6/64-xlarge) + + - [[20, 23, 26, 29], 1, Segment, [nc, 32, 256]] # Pose(P3, P4, P5, P6) diff --git a/ultralytics/cfg/models/v8/yolov8-seg.yaml b/ultralytics/cfg/models/v8/yolov8-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52b1c7e9aedc4fbc04d5c4f48f4ad214ad6a51e2 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-seg.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/segment + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] + s: [0.33, 0.50, 1024] + m: [0.67, 0.75, 768] + l: [1.00, 1.00, 512] + x: [1.00, 1.25, 512] + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-world.yaml b/ultralytics/cfg/models/v8/yolov8-world.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3c92e824ce684207ae08d9f5f48310c51208aed5 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-world.yaml @@ -0,0 +1,51 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-World hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo-world +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) + + - [[15, 12, 9], 1, ImagePoolingAttn, [256]] # 16 (P3/8-small) + + - [15, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 19 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fAttn, [1024, 512, 16]] # 22 (P5/32-large) + + - [[15, 19, 22], 1, WorldDetect, [nc, 512, False]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8-worldv2.yaml b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c6aaa2773315acca026287894a07fd5ad0f2559e --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8-worldv2.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8-Worldv2 hybrid object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolo-world +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2fAttn, [256, 128, 4]] # 15 (P3/8-small) + + - [15, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2fAttn, [512, 256, 8]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2fAttn, [1024, 512, 16]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, WorldDetect, [nc, 512, True]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v8/yolov8.yaml b/ultralytics/cfg/models/v8/yolov8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c7b9938ec34d4e2529821a76d6de52f2accbac06 --- /dev/null +++ b/ultralytics/cfg/models/v8/yolov8.yaml @@ -0,0 +1,49 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Ultralytics YOLOv8 object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov8 +# Task docs: https://docs.ultralytics.com/tasks/detect + +# Parameters +nc: 80 # number of classes +scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n' + # [depth, width, max_channels] + n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs + s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs + m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs + l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs + x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs + +# YOLOv8.0n backbone +backbone: + # [from, repeats, module, args] + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 3, C2f, [128, True]] + - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8 + - [-1, 6, C2f, [256, True]] + - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 + - [-1, 6, C2f, [512, True]] + - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 + - [-1, 3, C2f, [1024, True]] + - [-1, 1, SPPF, [1024, 5]] # 9 + +# YOLOv8.0n head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 3, C2f, [512]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 3, C2f, [256]] # 15 (P3/8-small) + + - [-1, 1, Conv, [256, 3, 2]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 3, C2f, [512]] # 18 (P4/16-medium) + + - [-1, 1, Conv, [512, 3, 2]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 3, C2f, [1024]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9c-seg.yaml b/ultralytics/cfg/models/v9/yolov9c-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..14122cb83933b17f6b56a703cf8dd64bbd3c3fbc --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9c-seg.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9c-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/segment +# 654 layers, 27897120 parameters, 159.4 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2 + - [-1, 1, ADown, [256]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4 + - [-1, 1, ADown, [512]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6 + - [-1, 1, ADown, [512]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8 + - [-1, 1, SPPELAN, [512, 256]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9c.yaml b/ultralytics/cfg/models/v9/yolov9c.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fc1fcd13fdf945d5a91269065ae3b33d2f2a76f --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9c.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9c object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 618 layers, 25590912 parameters, 104.0 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]] # 2 + - [-1, 1, ADown, [256]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]] # 4 + - [-1, 1, ADown, [512]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 6 + - [-1, 1, ADown, [512]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 8 + - [-1, 1, SPPELAN, [512, 256]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]] # 15 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 18 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9e-seg.yaml b/ultralytics/cfg/models/v9/yolov9e-seg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4361daac29358984827fea4c1eaa3478324a033a --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9e-seg.yaml @@ -0,0 +1,64 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9e-seg instance segmentation model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/segment +# 1261 layers, 60512800 parameters, 248.4 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, nn.Identity, []] + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 3 + - [-1, 1, ADown, [256]] # 4-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 5 + - [-1, 1, ADown, [512]] # 6-P4/16 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7 + - [-1, 1, ADown, [1024]] # 8-P5/32 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9 + + - [1, 1, CBLinear, [[64]]] # 10 + - [3, 1, CBLinear, [[64, 128]]] # 11 + - [5, 1, CBLinear, [[64, 128, 256]]] # 12 + - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13 + - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14 + + - [0, 1, Conv, [64, 3, 2]] # 15-P1/2 + - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16 + - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4 + - [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]] # 18 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 19 + - [-1, 1, ADown, [256]] # 20-P3/8 + - [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]] # 21 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 22 + - [-1, 1, ADown, [512]] # 23-P4/16 + - [[13, 14, -1], 1, CBFuse, [[3, 3]]] # 24 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 25 + - [-1, 1, ADown, [1024]] # 26-P5/32 + - [[14, -1], 1, CBFuse, [[4]]] # 27 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 28 + - [-1, 1, SPPELAN, [512, 256]] # 29 + +# GELAN head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 25], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 32 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 22], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]] # 35 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 32], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 38 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 29], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large) + + - [[35, 38, 41], 1, Segment, [nc, 32, 256]] # Segment (P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9e.yaml b/ultralytics/cfg/models/v9/yolov9e.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bba5597d0cf9d848925b30afe1e07c760cf8a62e --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9e.yaml @@ -0,0 +1,64 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9e object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 1225 layers, 58206592 parameters, 193.0 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, nn.Identity, []] + - [-1, 1, Conv, [64, 3, 2]] # 1-P1/2 + - [-1, 1, Conv, [128, 3, 2]] # 2-P2/4 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 3 + - [-1, 1, ADown, [256]] # 4-P3/8 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 5 + - [-1, 1, ADown, [512]] # 6-P4/16 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 7 + - [-1, 1, ADown, [1024]] # 8-P5/32 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 9 + + - [1, 1, CBLinear, [[64]]] # 10 + - [3, 1, CBLinear, [[64, 128]]] # 11 + - [5, 1, CBLinear, [[64, 128, 256]]] # 12 + - [7, 1, CBLinear, [[64, 128, 256, 512]]] # 13 + - [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]] # 14 + + - [0, 1, Conv, [64, 3, 2]] # 15-P1/2 + - [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]] # 16 + - [-1, 1, Conv, [128, 3, 2]] # 17-P2/4 + - [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]] # 18 + - [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]] # 19 + - [-1, 1, ADown, [256]] # 20-P3/8 + - [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]] # 21 + - [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]] # 22 + - [-1, 1, ADown, [512]] # 23-P4/16 + - [[13, 14, -1], 1, CBFuse, [[3, 3]]] # 24 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 25 + - [-1, 1, ADown, [1024]] # 26-P5/32 + - [[14, -1], 1, CBFuse, [[4]]] # 27 + - [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]] # 28 + - [-1, 1, SPPELAN, [512, 256]] # 29 + +# GELAN head +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 25], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 32 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 22], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]] # 35 (P3/8-small) + + - [-1, 1, ADown, [256]] + - [[-1, 32], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]] # 38 (P4/16-medium) + + - [-1, 1, ADown, [512]] + - [[-1, 29], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]] # 41 (P5/32-large) + + - [[35, 38, 41], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9m.yaml b/ultralytics/cfg/models/v9/yolov9m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..89bed65bebbed3d42eb804d19bcf8eefca7791f2 --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9m.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9m object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 603 layers, 20216160 parameters, 77.9 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [32, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [64, 3, 2]] # 1-P2/4 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 1]] # 2 + - [-1, 1, AConv, [240]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]] # 4 + - [-1, 1, AConv, [360]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]] # 6 + - [-1, 1, AConv, [480]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]] # 8 + - [-1, 1, SPPELAN, [480, 240]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]] # 15 + + - [-1, 1, AConv, [180]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]] # 18 (P4/16-medium) + + - [-1, 1, AConv, [240]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/models/v9/yolov9s.yaml b/ultralytics/cfg/models/v9/yolov9s.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28891f4cebcdc3682bec403c539c79eb5bfe7834 --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9s.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9s object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 917 layers, 7318368 parameters, 27.6 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [32, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [64, 3, 2]] # 1-P2/4 + - [-1, 1, ELAN1, [64, 64, 32]] # 2 + - [-1, 1, AConv, [128]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 4 + - [-1, 1, AConv, [192]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]] # 6 + - [-1, 1, AConv, [256]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]] # 8 + - [-1, 1, SPPELAN, [256, 128]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 15 + + - [-1, 1, AConv, [96]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]] # 18 (P4/16-medium) + + - [-1, 1, AConv, [128]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4 P5) diff --git a/ultralytics/cfg/models/v9/yolov9t.yaml b/ultralytics/cfg/models/v9/yolov9t.yaml new file mode 100644 index 0000000000000000000000000000000000000000..21a5bad86b93a15adc4fa00e749a4b5a0699abd3 --- /dev/null +++ b/ultralytics/cfg/models/v9/yolov9t.yaml @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# YOLOv9t object detection model with P3/8 - P5/32 outputs +# Model docs: https://docs.ultralytics.com/models/yolov9 +# Task docs: https://docs.ultralytics.com/tasks/detect +# 917 layers, 2128720 parameters, 8.5 GFLOPs + +# Parameters +nc: 80 # number of classes + +# GELAN backbone +backbone: + - [-1, 1, Conv, [16, 3, 2]] # 0-P1/2 + - [-1, 1, Conv, [32, 3, 2]] # 1-P2/4 + - [-1, 1, ELAN1, [32, 32, 16]] # 2 + - [-1, 1, AConv, [64]] # 3-P3/8 + - [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]] # 4 + - [-1, 1, AConv, [96]] # 5-P4/16 + - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 6 + - [-1, 1, AConv, [128]] # 7-P5/32 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 8 + - [-1, 1, SPPELAN, [128, 64]] # 9 + +head: + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 6], 1, Concat, [1]] # cat backbone P4 + - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 12 + + - [-1, 1, nn.Upsample, [None, 2, "nearest"]] + - [[-1, 4], 1, Concat, [1]] # cat backbone P3 + - [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]] # 15 + + - [-1, 1, AConv, [48]] + - [[-1, 12], 1, Concat, [1]] # cat head P4 + - [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]] # 18 (P4/16-medium) + + - [-1, 1, AConv, [64]] + - [[-1, 9], 1, Concat, [1]] # cat head P5 + - [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]] # 21 (P5/32-large) + + - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5) diff --git a/ultralytics/cfg/solutions/default.yaml b/ultralytics/cfg/solutions/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4afb49b324894c7caa6f6a1382cdc8acae10db9 --- /dev/null +++ b/ultralytics/cfg/solutions/default.yaml @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Global configuration YAML with settings and arguments for Ultralytics Solutions +# For documentation see https://docs.ultralytics.com/solutions/ + +# Object counting settings -------------------------------------------------------------------------------------------- +region: # list[tuple[int, int]] object counting, queue or speed estimation region points. +show_in: True # (bool) flag to display objects moving *into* the defined region +show_out: True # (bool) flag to display objects moving *out of* the defined region + +# Heatmaps settings ---------------------------------------------------------------------------------------------------- +colormap: # (int | str) colormap for heatmap, Only OPENCV supported colormaps can be used. + +# Workouts monitoring settings ----------------------------------------------------------------------------------------- +up_angle: 145.0 # (float) Workouts up_angle for counts, 145.0 is default value. +down_angle: 90 # (float) Workouts down_angle for counts, 90 is default value. Y +kpts: [6, 8, 10] # (list[int]) keypoints for workouts monitoring, i.e. for push-ups kpts have values of [6, 8, 10]. + +# Analytics settings --------------------------------------------------------------------------------------------------- +analytics_type: "line" # (str) analytics type i.e "line", "pie", "bar" or "area" charts. +json_file: # (str) parking system regions file path. + +# Security alarm system settings --------------------------------------------------------------------------------------- +records: 5 # (int) Total detections count to send an email about security diff --git a/ultralytics/cfg/trackers/botsort.yaml b/ultralytics/cfg/trackers/botsort.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aedcee4860fe00e969aca4c51ff0975d931ac283 --- /dev/null +++ b/ultralytics/cfg/trackers/botsort.yaml @@ -0,0 +1,21 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Default Ultralytics settings for BoT-SORT tracker when using mode="track" +# For documentation and examples see https://docs.ultralytics.com/modes/track/ +# For BoT-SORT source code see https://github.com/NirAharon/BoT-SORT + +tracker_type: botsort # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.25 # threshold for the first association +track_low_thresh: 0.1 # threshold for the second association +new_track_thresh: 0.25 # threshold for init new track if the detection does not match any tracks +track_buffer: 30 # buffer to calculate the time when to remove tracks +match_thresh: 0.8 # threshold for matching tracks +fuse_score: True # Whether to fuse confidence scores with the iou distances before matching +# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) + +# BoT-SORT settings +gmc_method: sparseOptFlow # method of global motion compensation +# ReID model related thresh (not supported yet) +proximity_thresh: 0.5 +appearance_thresh: 0.25 +with_reid: False diff --git a/ultralytics/cfg/trackers/bytetrack.yaml b/ultralytics/cfg/trackers/bytetrack.yaml new file mode 100644 index 0000000000000000000000000000000000000000..62071a3022da1a38fc8aaf74a538d3829a489502 --- /dev/null +++ b/ultralytics/cfg/trackers/bytetrack.yaml @@ -0,0 +1,14 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Default Ultralytics settings for ByteTrack tracker when using mode="track" +# For documentation and examples see https://docs.ultralytics.com/modes/track/ +# For ByteTrack source code see https://github.com/ifzhang/ByteTrack + +tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack'] +track_high_thresh: 0.25 # threshold for the first association +track_low_thresh: 0.1 # threshold for the second association +new_track_thresh: 0.25 # threshold for init new track if the detection does not match any tracks +track_buffer: 30 # buffer to calculate the time when to remove tracks +match_thresh: 0.8 # threshold for matching tracks +fuse_score: True # Whether to fuse confidence scores with the iou distances before matching +# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now) diff --git a/ultralytics/data/__init__.py b/ultralytics/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d258d5df5f417903d2a42c9dce4174610a9804 --- /dev/null +++ b/ultralytics/data/__init__.py @@ -0,0 +1,26 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .base import BaseDataset +from .build import build_dataloader, build_grounding, build_yolo_dataset, load_inference_source +from .dataset import ( + ClassificationDataset, + GroundingDataset, + SemanticDataset, + YOLOConcatDataset, + YOLODataset, + YOLOMultiModalDataset, +) + +__all__ = ( + "BaseDataset", + "ClassificationDataset", + "SemanticDataset", + "YOLODataset", + "YOLOMultiModalDataset", + "YOLOConcatDataset", + "GroundingDataset", + "build_yolo_dataset", + "build_grounding", + "build_dataloader", + "load_inference_source", +) diff --git a/ultralytics/data/annotator.py b/ultralytics/data/annotator.py new file mode 100644 index 0000000000000000000000000000000000000000..982e5de5ab1ed6e401bfea97d21548c47623ef09 --- /dev/null +++ b/ultralytics/data/annotator.py @@ -0,0 +1,72 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +from ultralytics import SAM, YOLO + + +def auto_annotate( + data, + det_model="yolo11x.pt", + sam_model="sam_b.pt", + device="", + conf=0.25, + iou=0.45, + imgsz=640, + max_det=300, + classes=None, + output_dir=None, +): + """ + Automatically annotates images using a YOLO object detection model and a SAM segmentation model. + + This function processes images in a specified directory, detects objects using a YOLO model, and then generates + segmentation masks using a SAM model. The resulting annotations are saved as text files. + + Args: + data (str): Path to a folder containing images to be annotated. + det_model (str): Path or name of the pre-trained YOLO detection model. + sam_model (str): Path or name of the pre-trained SAM segmentation model. + device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0'). + conf (float): Confidence threshold for detection model; default is 0.25. + iou (float): IoU threshold for filtering overlapping boxes in detection results; default is 0.45. + imgsz (int): Input image resize dimension; default is 640. + max_det (int): Limits detections per image to control outputs in dense scenes. + classes (list): Filters predictions to specified class IDs, returning only relevant detections. + output_dir (str | None): Directory to save the annotated results. If None, a default directory is created. + + Examples: + >>> from ultralytics.data.annotator import auto_annotate + >>> auto_annotate(data="ultralytics/assets", det_model="yolo11n.pt", sam_model="mobile_sam.pt") + + Notes: + - The function creates a new directory for output if not specified. + - Annotation results are saved as text files with the same names as the input images. + - Each line in the output text file represents a detected object with its class ID and segmentation points. + """ + det_model = YOLO(det_model) + sam_model = SAM(sam_model) + + data = Path(data) + if not output_dir: + output_dir = data.parent / f"{data.stem}_auto_annotate_labels" + Path(output_dir).mkdir(exist_ok=True, parents=True) + + det_results = det_model( + data, stream=True, device=device, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det, classes=classes + ) + + for result in det_results: + class_ids = result.boxes.cls.int().tolist() # noqa + if len(class_ids): + boxes = result.boxes.xyxy # Boxes object for bbox outputs + sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device) + segments = sam_results[0].masks.xyn # noqa + + with open(f"{Path(output_dir) / Path(result.path).stem}.txt", "w") as f: + for i in range(len(segments)): + s = segments[i] + if len(s) == 0: + continue + segment = map(str, segments[i].reshape(-1).tolist()) + f.write(f"{class_ids[i]} " + " ".join(segment) + "\n") diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..85b90148c2a1934008da812023ef4b360a7572e6 --- /dev/null +++ b/ultralytics/data/augment.py @@ -0,0 +1,2744 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +import random +from copy import deepcopy +from typing import Tuple, Union + +import cv2 +import numpy as np +import torch +from PIL import Image + +from ultralytics.data.utils import polygons2masks, polygons2masks_overlap +from ultralytics.utils import LOGGER, colorstr +from ultralytics.utils.checks import check_version +from ultralytics.utils.instance import Instances +from ultralytics.utils.metrics import bbox_ioa +from ultralytics.utils.ops import segment2box, xyxyxyxy2xywhr +from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13 + +DEFAULT_MEAN = (0.0, 0.0, 0.0) +DEFAULT_STD = (1.0, 1.0, 1.0) +DEFAULT_CROP_FRACTION = 1.0 + + +class BaseTransform: + """ + Base class for image transformations in the Ultralytics library. + + This class serves as a foundation for implementing various image processing operations, designed to be + compatible with both classification and semantic segmentation tasks. + + Methods: + apply_image: Applies image transformations to labels. + apply_instances: Applies transformations to object instances in labels. + apply_semantic: Applies semantic segmentation to an image. + __call__: Applies all label transformations to an image, instances, and semantic masks. + + Examples: + >>> transform = BaseTransform() + >>> labels = {"image": np.array(...), "instances": [...], "semantic": np.array(...)} + >>> transformed_labels = transform(labels) + """ + + def __init__(self) -> None: + """ + Initializes the BaseTransform object. + + This constructor sets up the base transformation object, which can be extended for specific image + processing tasks. It is designed to be compatible with both classification and semantic segmentation. + + Examples: + >>> transform = BaseTransform() + """ + pass + + def apply_image(self, labels): + """ + Applies image transformations to labels. + + This method is intended to be overridden by subclasses to implement specific image transformation + logic. In its base form, it returns the input labels unchanged. + + Args: + labels (Any): The input labels to be transformed. The exact type and structure of labels may + vary depending on the specific implementation. + + Returns: + (Any): The transformed labels. In the base implementation, this is identical to the input. + + Examples: + >>> transform = BaseTransform() + >>> original_labels = [1, 2, 3] + >>> transformed_labels = transform.apply_image(original_labels) + >>> print(transformed_labels) + [1, 2, 3] + """ + pass + + def apply_instances(self, labels): + """ + Applies transformations to object instances in labels. + + This method is responsible for applying various transformations to object instances within the given + labels. It is designed to be overridden by subclasses to implement specific instance transformation + logic. + + Args: + labels (Dict): A dictionary containing label information, including object instances. + + Returns: + (Dict): The modified labels dictionary with transformed object instances. + + Examples: + >>> transform = BaseTransform() + >>> labels = {"instances": Instances(xyxy=torch.rand(5, 4), cls=torch.randint(0, 80, (5,)))} + >>> transformed_labels = transform.apply_instances(labels) + """ + pass + + def apply_semantic(self, labels): + """ + Applies semantic segmentation transformations to an image. + + This method is intended to be overridden by subclasses to implement specific semantic segmentation + transformations. In its base form, it does not perform any operations. + + Args: + labels (Any): The input labels or semantic segmentation mask to be transformed. + + Returns: + (Any): The transformed semantic segmentation mask or labels. + + Examples: + >>> transform = BaseTransform() + >>> semantic_mask = np.zeros((100, 100), dtype=np.uint8) + >>> transformed_mask = transform.apply_semantic(semantic_mask) + """ + pass + + def __call__(self, labels): + """ + Applies all label transformations to an image, instances, and semantic masks. + + This method orchestrates the application of various transformations defined in the BaseTransform class + to the input labels. It sequentially calls the apply_image and apply_instances methods to process the + image and object instances, respectively. + + Args: + labels (Dict): A dictionary containing image data and annotations. Expected keys include 'img' for + the image data, and 'instances' for object instances. + + Returns: + (Dict): The input labels dictionary with transformed image and instances. + + Examples: + >>> transform = BaseTransform() + >>> labels = {"img": np.random.rand(640, 640, 3), "instances": []} + >>> transformed_labels = transform(labels) + """ + self.apply_image(labels) + self.apply_instances(labels) + self.apply_semantic(labels) + + +class Compose: + """ + A class for composing multiple image transformations. + + Attributes: + transforms (List[Callable]): A list of transformation functions to be applied sequentially. + + Methods: + __call__: Applies a series of transformations to input data. + append: Appends a new transform to the existing list of transforms. + insert: Inserts a new transform at a specified index in the list of transforms. + __getitem__: Retrieves a specific transform or a set of transforms using indexing. + __setitem__: Sets a specific transform or a set of transforms using indexing. + tolist: Converts the list of transforms to a standard Python list. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(30)] + >>> compose = Compose(transforms) + >>> transformed_data = compose(data) + >>> compose.append(CenterCrop((224, 224))) + >>> compose.insert(0, RandomFlip()) + """ + + def __init__(self, transforms): + """ + Initializes the Compose object with a list of transforms. + + Args: + transforms (List[Callable]): A list of callable transform objects to be applied sequentially. + + Examples: + >>> from ultralytics.data.augment import Compose, RandomHSV, RandomFlip + >>> transforms = [RandomHSV(), RandomFlip()] + >>> compose = Compose(transforms) + """ + self.transforms = transforms if isinstance(transforms, list) else [transforms] + + def __call__(self, data): + """ + Applies a series of transformations to input data. This method sequentially applies each transformation in the + Compose object's list of transforms to the input data. + + Args: + data (Any): The input data to be transformed. This can be of any type, depending on the + transformations in the list. + + Returns: + (Any): The transformed data after applying all transformations in sequence. + + Examples: + >>> transforms = [Transform1(), Transform2(), Transform3()] + >>> compose = Compose(transforms) + >>> transformed_data = compose(input_data) + """ + for t in self.transforms: + data = t(data) + return data + + def append(self, transform): + """ + Appends a new transform to the existing list of transforms. + + Args: + transform (BaseTransform): The transformation to be added to the composition. + + Examples: + >>> compose = Compose([RandomFlip(), RandomPerspective()]) + >>> compose.append(RandomHSV()) + """ + self.transforms.append(transform) + + def insert(self, index, transform): + """ + Inserts a new transform at a specified index in the existing list of transforms. + + Args: + index (int): The index at which to insert the new transform. + transform (BaseTransform): The transform object to be inserted. + + Examples: + >>> compose = Compose([Transform1(), Transform2()]) + >>> compose.insert(1, Transform3()) + >>> len(compose.transforms) + 3 + """ + self.transforms.insert(index, transform) + + def __getitem__(self, index: Union[list, int]) -> "Compose": + """ + Retrieves a specific transform or a set of transforms using indexing. + + Args: + index (int | List[int]): Index or list of indices of the transforms to retrieve. + + Returns: + (Compose): A new Compose object containing the selected transform(s). + + Raises: + AssertionError: If the index is not of type int or list. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(10), RandomHSV(0.5, 0.5, 0.5)] + >>> compose = Compose(transforms) + >>> single_transform = compose[1] # Returns a Compose object with only RandomPerspective + >>> multiple_transforms = compose[0:2] # Returns a Compose object with RandomFlip and RandomPerspective + """ + assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" + index = [index] if isinstance(index, int) else index + return Compose([self.transforms[i] for i in index]) + + def __setitem__(self, index: Union[list, int], value: Union[list, int]) -> None: + """ + Sets one or more transforms in the composition using indexing. + + Args: + index (int | List[int]): Index or list of indices to set transforms at. + value (Any | List[Any]): Transform or list of transforms to set at the specified index(es). + + Raises: + AssertionError: If index type is invalid, value type doesn't match index type, or index is out of range. + + Examples: + >>> compose = Compose([Transform1(), Transform2(), Transform3()]) + >>> compose[1] = NewTransform() # Replace second transform + >>> compose[0:2] = [NewTransform1(), NewTransform2()] # Replace first two transforms + """ + assert isinstance(index, (int, list)), f"The indices should be either list or int type but got {type(index)}" + if isinstance(index, list): + assert isinstance(value, list), ( + f"The indices should be the same type as values, but got {type(index)} and {type(value)}" + ) + if isinstance(index, int): + index, value = [index], [value] + for i, v in zip(index, value): + assert i < len(self.transforms), f"list index {i} out of range {len(self.transforms)}." + self.transforms[i] = v + + def tolist(self): + """ + Converts the list of transforms to a standard Python list. + + Returns: + (List): A list containing all the transform objects in the Compose instance. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(10), CenterCrop()] + >>> compose = Compose(transforms) + >>> transform_list = compose.tolist() + >>> print(len(transform_list)) + 3 + """ + return self.transforms + + def __repr__(self): + """ + Returns a string representation of the Compose object. + + Returns: + (str): A string representation of the Compose object, including the list of transforms. + + Examples: + >>> transforms = [RandomFlip(), RandomPerspective(degrees=10, translate=0.1, scale=0.1)] + >>> compose = Compose(transforms) + >>> print(compose) + Compose([ + RandomFlip(), + RandomPerspective(degrees=10, translate=0.1, scale=0.1) + ]) + """ + return f"{self.__class__.__name__}({', '.join([f'{t}' for t in self.transforms])})" + + +class BaseMixTransform: + """ + Base class for mix transformations like MixUp and Mosaic. + + This class provides a foundation for implementing mix transformations on datasets. It handles the + probability-based application of transforms and manages the mixing of multiple images and labels. + + Attributes: + dataset (Any): The dataset object containing images and labels. + pre_transform (Callable | None): Optional transform to apply before mixing. + p (float): Probability of applying the mix transformation. + + Methods: + __call__: Applies the mix transformation to the input labels. + _mix_transform: Abstract method to be implemented by subclasses for specific mix operations. + get_indexes: Abstract method to get indexes of images to be mixed. + _update_label_text: Updates label text for mixed images. + + Examples: + >>> class CustomMixTransform(BaseMixTransform): + ... def _mix_transform(self, labels): + ... # Implement custom mix logic here + ... return labels + ... + ... def get_indexes(self): + ... return [random.randint(0, len(self.dataset) - 1) for _ in range(3)] + >>> dataset = YourDataset() + >>> transform = CustomMixTransform(dataset, p=0.5) + >>> mixed_labels = transform(original_labels) + """ + + def __init__(self, dataset, pre_transform=None, p=0.0) -> None: + """ + Initializes the BaseMixTransform object for mix transformations like MixUp and Mosaic. + + This class serves as a base for implementing mix transformations in image processing pipelines. + + Args: + dataset (Any): The dataset object containing images and labels for mixing. + pre_transform (Callable | None): Optional transform to apply before mixing. + p (float): Probability of applying the mix transformation. Should be in the range [0.0, 1.0]. + + Examples: + >>> dataset = YOLODataset("path/to/data") + >>> pre_transform = Compose([RandomFlip(), RandomPerspective()]) + >>> mix_transform = BaseMixTransform(dataset, pre_transform, p=0.5) + """ + self.dataset = dataset + self.pre_transform = pre_transform + self.p = p + + def __call__(self, labels): + """ + Applies pre-processing transforms and mixup/mosaic transforms to labels data. + + This method determines whether to apply the mix transform based on a probability factor. If applied, it + selects additional images, applies pre-transforms if specified, and then performs the mix transform. + + Args: + labels (Dict): A dictionary containing label data for an image. + + Returns: + (Dict): The transformed labels dictionary, which may include mixed data from other images. + + Examples: + >>> transform = BaseMixTransform(dataset, pre_transform=None, p=0.5) + >>> result = transform({"image": img, "bboxes": boxes, "cls": classes}) + """ + if random.uniform(0, 1) > self.p: + return labels + + # Get index of one or three other images + indexes = self.get_indexes() + if isinstance(indexes, int): + indexes = [indexes] + + # Get images information will be used for Mosaic or MixUp + mix_labels = [self.dataset.get_image_and_label(i) for i in indexes] + + if self.pre_transform is not None: + for i, data in enumerate(mix_labels): + mix_labels[i] = self.pre_transform(data) + labels["mix_labels"] = mix_labels + + # Update cls and texts + labels = self._update_label_text(labels) + # Mosaic or MixUp + labels = self._mix_transform(labels) + labels.pop("mix_labels", None) + return labels + + def _mix_transform(self, labels): + """ + Applies MixUp or Mosaic augmentation to the label dictionary. + + This method should be implemented by subclasses to perform specific mix transformations like MixUp or + Mosaic. It modifies the input label dictionary in-place with the augmented data. + + Args: + labels (Dict): A dictionary containing image and label data. Expected to have a 'mix_labels' key + with a list of additional image and label data for mixing. + + Returns: + (Dict): The modified labels dictionary with augmented data after applying the mix transform. + + Examples: + >>> transform = BaseMixTransform(dataset) + >>> labels = {"image": img, "bboxes": boxes, "mix_labels": [{"image": img2, "bboxes": boxes2}]} + >>> augmented_labels = transform._mix_transform(labels) + """ + raise NotImplementedError + + def get_indexes(self): + """ + Gets a list of shuffled indexes for mosaic augmentation. + + Returns: + (List[int]): A list of shuffled indexes from the dataset. + + Examples: + >>> transform = BaseMixTransform(dataset) + >>> indexes = transform.get_indexes() + >>> print(indexes) # [3, 18, 7, 2] + """ + raise NotImplementedError + + @staticmethod + def _update_label_text(labels): + """ + Updates label text and class IDs for mixed labels in image augmentation. + + This method processes the 'texts' and 'cls' fields of the input labels dictionary and any mixed labels, + creating a unified set of text labels and updating class IDs accordingly. + + Args: + labels (Dict): A dictionary containing label information, including 'texts' and 'cls' fields, + and optionally a 'mix_labels' field with additional label dictionaries. + + Returns: + (Dict): The updated labels dictionary with unified text labels and updated class IDs. + + Examples: + >>> labels = { + ... "texts": [["cat"], ["dog"]], + ... "cls": torch.tensor([[0], [1]]), + ... "mix_labels": [{"texts": [["bird"], ["fish"]], "cls": torch.tensor([[0], [1]])}], + ... } + >>> updated_labels = self._update_label_text(labels) + >>> print(updated_labels["texts"]) + [['cat'], ['dog'], ['bird'], ['fish']] + >>> print(updated_labels["cls"]) + tensor([[0], + [1]]) + >>> print(updated_labels["mix_labels"][0]["cls"]) + tensor([[2], + [3]]) + """ + if "texts" not in labels: + return labels + + mix_texts = sum([labels["texts"]] + [x["texts"] for x in labels["mix_labels"]], []) + mix_texts = list({tuple(x) for x in mix_texts}) + text2id = {text: i for i, text in enumerate(mix_texts)} + + for label in [labels] + labels["mix_labels"]: + for i, cls in enumerate(label["cls"].squeeze(-1).tolist()): + text = label["texts"][int(cls)] + label["cls"][i] = text2id[tuple(text)] + label["texts"] = mix_texts + return labels + + +class Mosaic(BaseMixTransform): + """ + Mosaic augmentation for image datasets. + + This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. + The augmentation is applied to a dataset with a given probability. + + Attributes: + dataset: The dataset on which the mosaic augmentation is applied. + imgsz (int): Image size (height and width) after mosaic pipeline of a single image. + p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1. + n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3). + border (Tuple[int, int]): Border size for width and height. + + Methods: + get_indexes: Returns a list of random indexes from the dataset. + _mix_transform: Applies mixup transformation to the input image and labels. + _mosaic3: Creates a 1x3 image mosaic. + _mosaic4: Creates a 2x2 image mosaic. + _mosaic9: Creates a 3x3 image mosaic. + _update_labels: Updates labels with padding. + _cat_labels: Concatenates labels and clips mosaic border instances. + + Examples: + >>> from ultralytics.data.augment import Mosaic + >>> dataset = YourDataset(...) # Your image dataset + >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4) + >>> augmented_labels = mosaic_aug(original_labels) + """ + + def __init__(self, dataset, imgsz=640, p=1.0, n=4): + """ + Initializes the Mosaic augmentation object. + + This class performs mosaic augmentation by combining multiple (4 or 9) images into a single mosaic image. + The augmentation is applied to a dataset with a given probability. + + Args: + dataset (Any): The dataset on which the mosaic augmentation is applied. + imgsz (int): Image size (height and width) after mosaic pipeline of a single image. + p (float): Probability of applying the mosaic augmentation. Must be in the range 0-1. + n (int): The grid size, either 4 (for 2x2) or 9 (for 3x3). + + Examples: + >>> from ultralytics.data.augment import Mosaic + >>> dataset = YourDataset(...) + >>> mosaic_aug = Mosaic(dataset, imgsz=640, p=0.5, n=4) + """ + assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." + assert n in {4, 9}, "grid must be equal to 4 or 9." + super().__init__(dataset=dataset, p=p) + self.imgsz = imgsz + self.border = (-imgsz // 2, -imgsz // 2) # width, height + self.n = n + + def get_indexes(self, buffer=True): + """ + Returns a list of random indexes from the dataset for mosaic augmentation. + + This method selects random image indexes either from a buffer or from the entire dataset, depending on + the 'buffer' parameter. It is used to choose images for creating mosaic augmentations. + + Args: + buffer (bool): If True, selects images from the dataset buffer. If False, selects from the entire + dataset. + + Returns: + (List[int]): A list of random image indexes. The length of the list is n-1, where n is the number + of images used in the mosaic (either 3 or 8, depending on whether n is 4 or 9). + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4) + >>> indexes = mosaic.get_indexes() + >>> print(len(indexes)) # Output: 3 + """ + if buffer: # select images from buffer + return random.choices(list(self.dataset.buffer), k=self.n - 1) + else: # select any images + return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)] + + def _mix_transform(self, labels): + """ + Applies mosaic augmentation to the input image and labels. + + This method combines multiple images (3, 4, or 9) into a single mosaic image based on the 'n' attribute. + It ensures that rectangular annotations are not present and that there are other images available for + mosaic augmentation. + + Args: + labels (Dict): A dictionary containing image data and annotations. Expected keys include: + - 'rect_shape': Should be None as rect and mosaic are mutually exclusive. + - 'mix_labels': A list of dictionaries containing data for other images to be used in the mosaic. + + Returns: + (Dict): A dictionary containing the mosaic-augmented image and updated annotations. + + Raises: + AssertionError: If 'rect_shape' is not None or if 'mix_labels' is empty. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4) + >>> augmented_data = mosaic._mix_transform(labels) + """ + assert labels.get("rect_shape", None) is None, "rect and mosaic are mutually exclusive." + assert len(labels.get("mix_labels", [])), "There are no other images for mosaic augment." + return ( + self._mosaic3(labels) if self.n == 3 else self._mosaic4(labels) if self.n == 4 else self._mosaic9(labels) + ) # This code is modified for mosaic3 method. + + def _mosaic3(self, labels): + """ + Creates a 1x3 image mosaic by combining three images. + + This method arranges three images in a horizontal layout, with the main image in the center and two + additional images on either side. It's part of the Mosaic augmentation technique used in object detection. + + Args: + labels (Dict): A dictionary containing image and label information for the main (center) image. + Must include 'img' key with the image array, and 'mix_labels' key with a list of two + dictionaries containing information for the side images. + + Returns: + (Dict): A dictionary with the mosaic image and updated labels. Keys include: + - 'img' (np.ndarray): The mosaic image array with shape (H, W, C). + - Other keys from the input labels, updated to reflect the new image dimensions. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=3) + >>> labels = { + ... "img": np.random.rand(480, 640, 3), + ... "mix_labels": [{"img": np.random.rand(480, 640, 3)} for _ in range(2)], + ... } + >>> result = mosaic._mosaic3(labels) + >>> print(result["img"].shape) + (640, 640, 3) + """ + mosaic_labels = [] + s = self.imgsz + for i in range(3): + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] + # Load image + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") + + # Place img in img3 + if i == 0: # center + img3 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 3 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 2: # left + c = s - w, s + h0 - h, s, s + h0 + + padw, padh = c[:2] + x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coordinates + + img3[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img3[ymin:ymax, xmin:xmax] + # hp, wp = h, w # height, width previous for next iteration + + # Labels assuming imgsz*2 mosaic size + labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1]) + mosaic_labels.append(labels_patch) + final_labels = self._cat_labels(mosaic_labels) + + final_labels["img"] = img3[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] + return final_labels + + def _mosaic4(self, labels): + """ + Creates a 2x2 image mosaic from four input images. + + This method combines four images into a single mosaic image by placing them in a 2x2 grid. It also + updates the corresponding labels for each image in the mosaic. + + Args: + labels (Dict): A dictionary containing image data and labels for the base image (index 0) and three + additional images (indices 1-3) in the 'mix_labels' key. + + Returns: + (Dict): A dictionary containing the mosaic image and updated labels. The 'img' key contains the mosaic + image as a numpy array, and other keys contain the combined and adjusted labels for all four images. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=4) + >>> labels = { + ... "img": np.random.rand(480, 640, 3), + ... "mix_labels": [{"img": np.random.rand(480, 640, 3)} for _ in range(3)], + ... } + >>> result = mosaic._mosaic4(labels) + >>> assert result["img"].shape == (1280, 1280, 3) + """ + mosaic_labels = [] + s = self.imgsz + yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.border) # mosaic center x, y + for i in range(4): + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] + # Load image + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") + + # Place img in img4 + if i == 0: # top left + img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image) + x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image) + elif i == 1: # top right + x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc + x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h + elif i == 2: # bottom left + x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h) + x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h) + elif i == 3: # bottom right + x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h) + x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h) + + img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax] + padw = x1a - x1b + padh = y1a - y1b + + labels_patch = self._update_labels(labels_patch, padw, padh) + mosaic_labels.append(labels_patch) + final_labels = self._cat_labels(mosaic_labels) + final_labels["img"] = img4 + return final_labels + + def _mosaic9(self, labels): + """ + Creates a 3x3 image mosaic from the input image and eight additional images. + + This method combines nine images into a single mosaic image. The input image is placed at the center, + and eight additional images from the dataset are placed around it in a 3x3 grid pattern. + + Args: + labels (Dict): A dictionary containing the input image and its associated labels. It should have + the following keys: + - 'img' (numpy.ndarray): The input image. + - 'resized_shape' (Tuple[int, int]): The shape of the resized image (height, width). + - 'mix_labels' (List[Dict]): A list of dictionaries containing information for the additional + eight images, each with the same structure as the input labels. + + Returns: + (Dict): A dictionary containing the mosaic image and updated labels. It includes the following keys: + - 'img' (numpy.ndarray): The final mosaic image. + - Other keys from the input labels, updated to reflect the new mosaic arrangement. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640, p=1.0, n=9) + >>> input_labels = dataset[0] + >>> mosaic_result = mosaic._mosaic9(input_labels) + >>> mosaic_image = mosaic_result["img"] + """ + mosaic_labels = [] + s = self.imgsz + hp, wp = -1, -1 # height, width previous + for i in range(9): + labels_patch = labels if i == 0 else labels["mix_labels"][i - 1] + # Load image + img = labels_patch["img"] + h, w = labels_patch.pop("resized_shape") + + # Place img in img9 + if i == 0: # center + img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # top + c = s, s - h, s + w, s + elif i == 2: # top right + c = s + wp, s - h, s + wp + w, s + elif i == 3: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 4: # bottom right + c = s + w0, s + hp, s + w0 + w, s + hp + h + elif i == 5: # bottom + c = s + w0 - w, s + h0, s + w0, s + h0 + h + elif i == 6: # bottom left + c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h + elif i == 7: # left + c = s - w, s + h0 - h, s, s + h0 + elif i == 8: # top left + c = s - w, s + h0 - hp - h, s, s + h0 - hp + + padw, padh = c[:2] + x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coordinates + + # Image + img9[y1:y2, x1:x2] = img[y1 - padh :, x1 - padw :] # img9[ymin:ymax, xmin:xmax] + hp, wp = h, w # height, width previous for next iteration + + # Labels assuming imgsz*2 mosaic size + labels_patch = self._update_labels(labels_patch, padw + self.border[0], padh + self.border[1]) + mosaic_labels.append(labels_patch) + final_labels = self._cat_labels(mosaic_labels) + + final_labels["img"] = img9[-self.border[0] : self.border[0], -self.border[1] : self.border[1]] + return final_labels + + @staticmethod + def _update_labels(labels, padw, padh): + """ + Updates label coordinates with padding values. + + This method adjusts the bounding box coordinates of object instances in the labels by adding padding + values. It also denormalizes the coordinates if they were previously normalized. + + Args: + labels (Dict): A dictionary containing image and instance information. + padw (int): Padding width to be added to the x-coordinates. + padh (int): Padding height to be added to the y-coordinates. + + Returns: + (Dict): Updated labels dictionary with adjusted instance coordinates. + + Examples: + >>> labels = {"img": np.zeros((100, 100, 3)), "instances": Instances(...)} + >>> padw, padh = 50, 50 + >>> updated_labels = Mosaic._update_labels(labels, padw, padh) + """ + nh, nw = labels["img"].shape[:2] + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(nw, nh) + labels["instances"].add_padding(padw, padh) + return labels + + def _cat_labels(self, mosaic_labels): + """ + Concatenates and processes labels for mosaic augmentation. + + This method combines labels from multiple images used in mosaic augmentation, clips instances to the + mosaic border, and removes zero-area boxes. + + Args: + mosaic_labels (List[Dict]): A list of label dictionaries for each image in the mosaic. + + Returns: + (Dict): A dictionary containing concatenated and processed labels for the mosaic image, including: + - im_file (str): File path of the first image in the mosaic. + - ori_shape (Tuple[int, int]): Original shape of the first image. + - resized_shape (Tuple[int, int]): Shape of the mosaic image (imgsz * 2, imgsz * 2). + - cls (np.ndarray): Concatenated class labels. + - instances (Instances): Concatenated instance annotations. + - mosaic_border (Tuple[int, int]): Mosaic border size. + - texts (List[str], optional): Text labels if present in the original labels. + + Examples: + >>> mosaic = Mosaic(dataset, imgsz=640) + >>> mosaic_labels = [{"cls": np.array([0, 1]), "instances": Instances(...)} for _ in range(4)] + >>> result = mosaic._cat_labels(mosaic_labels) + >>> print(result.keys()) + dict_keys(['im_file', 'ori_shape', 'resized_shape', 'cls', 'instances', 'mosaic_border']) + """ + if len(mosaic_labels) == 0: + return {} + cls = [] + instances = [] + imgsz = self.imgsz * 2 # mosaic imgsz + for labels in mosaic_labels: + cls.append(labels["cls"]) + instances.append(labels["instances"]) + # Final labels + final_labels = { + "im_file": mosaic_labels[0]["im_file"], + "ori_shape": mosaic_labels[0]["ori_shape"], + "resized_shape": (imgsz, imgsz), + "cls": np.concatenate(cls, 0), + "instances": Instances.concatenate(instances, axis=0), + "mosaic_border": self.border, + } + final_labels["instances"].clip(imgsz, imgsz) + good = final_labels["instances"].remove_zero_area_boxes() + final_labels["cls"] = final_labels["cls"][good] + if "texts" in mosaic_labels[0]: + final_labels["texts"] = mosaic_labels[0]["texts"] + return final_labels + + +class MixUp(BaseMixTransform): + """ + Applies MixUp augmentation to image datasets. + + This class implements the MixUp augmentation technique as described in the paper "mixup: Beyond Empirical Risk + Minimization" (https://arxiv.org/abs/1710.09412). MixUp combines two images and their labels using a random weight. + + Attributes: + dataset (Any): The dataset to which MixUp augmentation will be applied. + pre_transform (Callable | None): Optional transform to apply before MixUp. + p (float): Probability of applying MixUp augmentation. + + Methods: + get_indexes: Returns a random index from the dataset. + _mix_transform: Applies MixUp augmentation to the input labels. + + Examples: + >>> from ultralytics.data.augment import MixUp + >>> dataset = YourDataset(...) # Your image dataset + >>> mixup = MixUp(dataset, p=0.5) + >>> augmented_labels = mixup(original_labels) + """ + + def __init__(self, dataset, pre_transform=None, p=0.0) -> None: + """ + Initializes the MixUp augmentation object. + + MixUp is an image augmentation technique that combines two images by taking a weighted sum of their pixel + values and labels. This implementation is designed for use with the Ultralytics YOLO framework. + + Args: + dataset (Any): The dataset to which MixUp augmentation will be applied. + pre_transform (Callable | None): Optional transform to apply to images before MixUp. + p (float): Probability of applying MixUp augmentation to an image. Must be in the range [0, 1]. + + Examples: + >>> from ultralytics.data.dataset import YOLODataset + >>> dataset = YOLODataset("path/to/data.yaml") + >>> mixup = MixUp(dataset, pre_transform=None, p=0.5) + """ + super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) + + def get_indexes(self): + """ + Get a random index from the dataset. + + This method returns a single random index from the dataset, which is used to select an image for MixUp + augmentation. + + Returns: + (int): A random integer index within the range of the dataset length. + + Examples: + >>> mixup = MixUp(dataset) + >>> index = mixup.get_indexes() + >>> print(index) + 42 + """ + return random.randint(0, len(self.dataset) - 1) + + def _mix_transform(self, labels): + """ + Applies MixUp augmentation to the input labels. + + This method implements the MixUp augmentation technique as described in the paper + "mixup: Beyond Empirical Risk Minimization" (https://arxiv.org/abs/1710.09412). + + Args: + labels (Dict): A dictionary containing the original image and label information. + + Returns: + (Dict): A dictionary containing the mixed-up image and combined label information. + + Examples: + >>> mixer = MixUp(dataset) + >>> mixed_labels = mixer._mix_transform(labels) + """ + r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 + labels2 = labels["mix_labels"][0] + labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8) + labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) + labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0) + return labels + + +class RandomPerspective: + """ + Implements random perspective and affine transformations on images and corresponding annotations. + + This class applies random rotations, translations, scaling, shearing, and perspective transformations + to images and their associated bounding boxes, segments, and keypoints. It can be used as part of an + augmentation pipeline for object detection and instance segmentation tasks. + + Attributes: + degrees (float): Maximum absolute degree range for random rotations. + translate (float): Maximum translation as a fraction of the image size. + scale (float): Scaling factor range, e.g., scale=0.1 means 0.9-1.1. + shear (float): Maximum shear angle in degrees. + perspective (float): Perspective distortion factor. + border (Tuple[int, int]): Mosaic border size as (x, y). + pre_transform (Callable | None): Optional transform to apply before the random perspective. + + Methods: + affine_transform: Applies affine transformations to the input image. + apply_bboxes: Transforms bounding boxes using the affine matrix. + apply_segments: Transforms segments and generates new bounding boxes. + apply_keypoints: Transforms keypoints using the affine matrix. + __call__: Applies the random perspective transformation to images and annotations. + box_candidates: Filters transformed bounding boxes based on size and aspect ratio. + + Examples: + >>> transform = RandomPerspective(degrees=10, translate=0.1, scale=0.1, shear=10) + >>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> labels = {"img": image, "cls": np.array([0, 1]), "instances": Instances(...)} + >>> result = transform(labels) + >>> transformed_image = result["img"] + >>> transformed_instances = result["instances"] + """ + + def __init__( + self, degrees=0.0, translate=0.1, scale=0.5, shear=0.0, perspective=0.0, border=(0, 0), pre_transform=None + ): + """ + Initializes RandomPerspective object with transformation parameters. + + This class implements random perspective and affine transformations on images and corresponding bounding boxes, + segments, and keypoints. Transformations include rotation, translation, scaling, and shearing. + + Args: + degrees (float): Degree range for random rotations. + translate (float): Fraction of total width and height for random translation. + scale (float): Scaling factor interval, e.g., a scale factor of 0.5 allows a resize between 50%-150%. + shear (float): Shear intensity (angle in degrees). + perspective (float): Perspective distortion factor. + border (Tuple[int, int]): Tuple specifying mosaic border (top/bottom, left/right). + pre_transform (Callable | None): Function/transform to apply to the image before starting the random + transformation. + + Examples: + >>> transform = RandomPerspective(degrees=10.0, translate=0.1, scale=0.5, shear=5.0) + >>> result = transform(labels) # Apply random perspective to labels + """ + self.degrees = degrees + self.translate = translate + self.scale = scale + self.shear = shear + self.perspective = perspective + self.border = border # mosaic border + self.pre_transform = pre_transform + + def affine_transform(self, img, border): + """ + Applies a sequence of affine transformations centered around the image center. + + This function performs a series of geometric transformations on the input image, including + translation, perspective change, rotation, scaling, and shearing. The transformations are + applied in a specific order to maintain consistency. + + Args: + img (np.ndarray): Input image to be transformed. + border (Tuple[int, int]): Border dimensions for the transformed image. + + Returns: + (Tuple[np.ndarray, np.ndarray, float]): A tuple containing: + - np.ndarray: Transformed image. + - np.ndarray: 3x3 transformation matrix. + - float: Scale factor applied during the transformation. + + Examples: + >>> import numpy as np + >>> img = np.random.rand(100, 100, 3) + >>> border = (10, 10) + >>> transformed_img, matrix, scale = affine_transform(img, border) + """ + # Center + C = np.eye(3, dtype=np.float32) + + C[0, 2] = -img.shape[1] / 2 # x translation (pixels) + C[1, 2] = -img.shape[0] / 2 # y translation (pixels) + + # Perspective + P = np.eye(3, dtype=np.float32) + P[2, 0] = random.uniform(-self.perspective, self.perspective) # x perspective (about y) + P[2, 1] = random.uniform(-self.perspective, self.perspective) # y perspective (about x) + + # Rotation and Scale + R = np.eye(3, dtype=np.float32) + a = random.uniform(-self.degrees, self.degrees) + # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations + s = random.uniform(1 - self.scale, 1 + self.scale) + # s = 2 ** random.uniform(-scale, scale) + R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + S = np.eye(3, dtype=np.float32) + S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # x shear (deg) + S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180) # y shear (deg) + + # Translation + T = np.eye(3, dtype=np.float32) + T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[0] # x translation (pixels) + T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * self.size[1] # y translation (pixels) + + # Combined rotation matrix + M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT + # Affine image + if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed + if self.perspective: + img = cv2.warpPerspective(img, M, dsize=self.size, borderValue=(114, 114, 114)) + else: # affine + img = cv2.warpAffine(img, M[:2], dsize=self.size, borderValue=(114, 114, 114)) + return img, M, s + + def apply_bboxes(self, bboxes, M): + """ + Apply affine transformation to bounding boxes. + + This function applies an affine transformation to a set of bounding boxes using the provided + transformation matrix. + + Args: + bboxes (torch.Tensor): Bounding boxes in xyxy format with shape (N, 4), where N is the number + of bounding boxes. + M (torch.Tensor): Affine transformation matrix with shape (3, 3). + + Returns: + (torch.Tensor): Transformed bounding boxes in xyxy format with shape (N, 4). + + Examples: + >>> bboxes = torch.tensor([[10, 10, 20, 20], [30, 30, 40, 40]]) + >>> M = torch.eye(3) + >>> transformed_bboxes = apply_bboxes(bboxes, M) + """ + n = len(bboxes) + if n == 0: + return bboxes + + xy = np.ones((n * 4, 3), dtype=bboxes.dtype) + xy[:, :2] = bboxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = xy @ M.T # transform + xy = (xy[:, :2] / xy[:, 2:3] if self.perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine + + # Create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T + + def apply_segments(self, segments, M): + """ + Apply affine transformations to segments and generate new bounding boxes. + + This function applies affine transformations to input segments and generates new bounding boxes based on + the transformed segments. It clips the transformed segments to fit within the new bounding boxes. + + Args: + segments (np.ndarray): Input segments with shape (N, M, 2), where N is the number of segments and M is the + number of points in each segment. + M (np.ndarray): Affine transformation matrix with shape (3, 3). + + Returns: + (Tuple[np.ndarray, np.ndarray]): A tuple containing: + - New bounding boxes with shape (N, 4) in xyxy format. + - Transformed and clipped segments with shape (N, M, 2). + + Examples: + >>> segments = np.random.rand(10, 500, 2) # 10 segments with 500 points each + >>> M = np.eye(3) # Identity transformation matrix + >>> new_bboxes, new_segments = apply_segments(segments, M) + """ + n, num = segments.shape[:2] + if n == 0: + return [], segments + + xy = np.ones((n * num, 3), dtype=segments.dtype) + segments = segments.reshape(-1, 2) + xy[:, :2] = segments + xy = xy @ M.T # transform + xy = xy[:, :2] / xy[:, 2:3] + segments = xy.reshape(n, -1, 2) + bboxes = np.stack([segment2box(xy, self.size[0], self.size[1]) for xy in segments], 0) + segments[..., 0] = segments[..., 0].clip(bboxes[:, 0:1], bboxes[:, 2:3]) + segments[..., 1] = segments[..., 1].clip(bboxes[:, 1:2], bboxes[:, 3:4]) + return bboxes, segments + + def apply_keypoints(self, keypoints, M): + """ + Applies affine transformation to keypoints. + + This method transforms the input keypoints using the provided affine transformation matrix. It handles + perspective rescaling if necessary and updates the visibility of keypoints that fall outside the image + boundaries after transformation. + + Args: + keypoints (np.ndarray): Array of keypoints with shape (N, 17, 3), where N is the number of instances, + 17 is the number of keypoints per instance, and 3 represents (x, y, visibility). + M (np.ndarray): 3x3 affine transformation matrix. + + Returns: + (np.ndarray): Transformed keypoints array with the same shape as input (N, 17, 3). + + Examples: + >>> random_perspective = RandomPerspective() + >>> keypoints = np.random.rand(5, 17, 3) # 5 instances, 17 keypoints each + >>> M = np.eye(3) # Identity transformation + >>> transformed_keypoints = random_perspective.apply_keypoints(keypoints, M) + """ + n, nkpt = keypoints.shape[:2] + if n == 0: + return keypoints + xy = np.ones((n * nkpt, 3), dtype=keypoints.dtype) + visible = keypoints[..., 2].reshape(n * nkpt, 1) + xy[:, :2] = keypoints[..., :2].reshape(n * nkpt, 2) + xy = xy @ M.T # transform + xy = xy[:, :2] / xy[:, 2:3] # perspective rescale or affine + out_mask = (xy[:, 0] < 0) | (xy[:, 1] < 0) | (xy[:, 0] > self.size[0]) | (xy[:, 1] > self.size[1]) + visible[out_mask] = 0 + return np.concatenate([xy, visible], axis=-1).reshape(n, nkpt, 3) + + def __call__(self, labels): + """ + Applies random perspective and affine transformations to an image and its associated labels. + + This method performs a series of transformations including rotation, translation, scaling, shearing, + and perspective distortion on the input image and adjusts the corresponding bounding boxes, segments, + and keypoints accordingly. + + Args: + labels (Dict): A dictionary containing image data and annotations. + Must include: + 'img' (ndarray): The input image. + 'cls' (ndarray): Class labels. + 'instances' (Instances): Object instances with bounding boxes, segments, and keypoints. + May include: + 'mosaic_border' (Tuple[int, int]): Border size for mosaic augmentation. + + Returns: + (Dict): Transformed labels dictionary containing: + - 'img' (np.ndarray): The transformed image. + - 'cls' (np.ndarray): Updated class labels. + - 'instances' (Instances): Updated object instances. + - 'resized_shape' (Tuple[int, int]): New image shape after transformation. + + Examples: + >>> transform = RandomPerspective() + >>> image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> labels = { + ... "img": image, + ... "cls": np.array([0, 1, 2]), + ... "instances": Instances(bboxes=np.array([[10, 10, 50, 50], [100, 100, 150, 150]])), + ... } + >>> result = transform(labels) + >>> assert result["img"].shape[:2] == result["resized_shape"] + """ + if self.pre_transform and "mosaic_border" not in labels: + labels = self.pre_transform(labels) + labels.pop("ratio_pad", None) # do not need ratio pad + + img = labels["img"] + cls = labels["cls"] + instances = labels.pop("instances") + # Make sure the coord formats are right + instances.convert_bbox(format="xyxy") + instances.denormalize(*img.shape[:2][::-1]) + + border = labels.pop("mosaic_border", self.border) + self.size = img.shape[1] + border[1] * 2, img.shape[0] + border[0] * 2 # w, h + # M is affine matrix + # Scale for func:`box_candidates` + img, M, scale = self.affine_transform(img, border) + + bboxes = self.apply_bboxes(instances.bboxes, M) + + segments = instances.segments + keypoints = instances.keypoints + # Update bboxes if there are segments. + if len(segments): + bboxes, segments = self.apply_segments(segments, M) + + if keypoints is not None: + keypoints = self.apply_keypoints(keypoints, M) + new_instances = Instances(bboxes, segments, keypoints, bbox_format="xyxy", normalized=False) + # Clip + new_instances.clip(*self.size) + + # Filter instances + instances.scale(scale_w=scale, scale_h=scale, bbox_only=True) + # Make the bboxes have the same scale with new_bboxes + i = self.box_candidates( + box1=instances.bboxes.T, box2=new_instances.bboxes.T, area_thr=0.01 if len(segments) else 0.10 + ) + labels["instances"] = new_instances[i] + labels["cls"] = cls[i] + labels["img"] = img + labels["resized_shape"] = img.shape[:2] + return labels + + @staticmethod + def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): + """ + Compute candidate boxes for further processing based on size and aspect ratio criteria. + + This method compares boxes before and after augmentation to determine if they meet specified + thresholds for width, height, aspect ratio, and area. It's used to filter out boxes that have + been overly distorted or reduced by the augmentation process. + + Args: + box1 (numpy.ndarray): Original boxes before augmentation, shape (4, N) where n is the + number of boxes. Format is [x1, y1, x2, y2] in absolute coordinates. + box2 (numpy.ndarray): Augmented boxes after transformation, shape (4, N). Format is + [x1, y1, x2, y2] in absolute coordinates. + wh_thr (float): Width and height threshold in pixels. Boxes smaller than this in either + dimension are rejected. + ar_thr (float): Aspect ratio threshold. Boxes with an aspect ratio greater than this + value are rejected. + area_thr (float): Area ratio threshold. Boxes with an area ratio (new/old) less than + this value are rejected. + eps (float): Small epsilon value to prevent division by zero. + + Returns: + (numpy.ndarray): Boolean array of shape (n) indicating which boxes are candidates. + True values correspond to boxes that meet all criteria. + + Examples: + >>> random_perspective = RandomPerspective() + >>> box1 = np.array([[0, 0, 100, 100], [0, 0, 50, 50]]).T + >>> box2 = np.array([[10, 10, 90, 90], [5, 5, 45, 45]]).T + >>> candidates = random_perspective.box_candidates(box1, box2) + >>> print(candidates) + [True True] + """ + w1, h1 = box1[2] - box1[0], box1[3] - box1[1] + w2, h2 = box2[2] - box2[0], box2[3] - box2[1] + ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio + return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates + + +class RandomHSV: + """ + Randomly adjusts the Hue, Saturation, and Value (HSV) channels of an image. + + This class applies random HSV augmentation to images within predefined limits set by hgain, sgain, and vgain. + + Attributes: + hgain (float): Maximum variation for hue. Range is typically [0, 1]. + sgain (float): Maximum variation for saturation. Range is typically [0, 1]. + vgain (float): Maximum variation for value. Range is typically [0, 1]. + + Methods: + __call__: Applies random HSV augmentation to an image. + + Examples: + >>> import numpy as np + >>> from ultralytics.data.augment import RandomHSV + >>> augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5) + >>> image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + >>> labels = {"img": image} + >>> augmenter(labels) + >>> augmented_image = augmented_labels["img"] + """ + + def __init__(self, hgain=0.5, sgain=0.5, vgain=0.5) -> None: + """ + Initializes the RandomHSV object for random HSV (Hue, Saturation, Value) augmentation. + + This class applies random adjustments to the HSV channels of an image within specified limits. + + Args: + hgain (float): Maximum variation for hue. Should be in the range [0, 1]. + sgain (float): Maximum variation for saturation. Should be in the range [0, 1]. + vgain (float): Maximum variation for value. Should be in the range [0, 1]. + + Examples: + >>> hsv_aug = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5) + >>> hsv_aug(image) + """ + self.hgain = hgain + self.sgain = sgain + self.vgain = vgain + + def __call__(self, labels): + """ + Applies random HSV augmentation to an image within predefined limits. + + This method modifies the input image by randomly adjusting its Hue, Saturation, and Value (HSV) channels. + The adjustments are made within the limits set by hgain, sgain, and vgain during initialization. + + Args: + labels (Dict): A dictionary containing image data and metadata. Must include an 'img' key with + the image as a numpy array. + + Returns: + (None): The function modifies the input 'labels' dictionary in-place, updating the 'img' key + with the HSV-augmented image. + + Examples: + >>> hsv_augmenter = RandomHSV(hgain=0.5, sgain=0.5, vgain=0.5) + >>> labels = {"img": np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)} + >>> hsv_augmenter(labels) + >>> augmented_img = labels["img"] + """ + img = labels["img"] + if self.hgain or self.sgain or self.vgain: + r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains + hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV)) + dtype = img.dtype # uint8 + + x = np.arange(0, 256, dtype=r.dtype) + lut_hue = ((x * r[0]) % 180).astype(dtype) + lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) + lut_val = np.clip(x * r[2], 0, 255).astype(dtype) + + im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) + cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed + return labels + + +class RandomFlip: + """ + Applies a random horizontal or vertical flip to an image with a given probability. + + This class performs random image flipping and updates corresponding instance annotations such as + bounding boxes and keypoints. + + Attributes: + p (float): Probability of applying the flip. Must be between 0 and 1. + direction (str): Direction of flip, either 'horizontal' or 'vertical'. + flip_idx (array-like): Index mapping for flipping keypoints, if applicable. + + Methods: + __call__: Applies the random flip transformation to an image and its annotations. + + Examples: + >>> transform = RandomFlip(p=0.5, direction="horizontal") + >>> result = transform({"img": image, "instances": instances}) + >>> flipped_image = result["img"] + >>> flipped_instances = result["instances"] + """ + + def __init__(self, p=0.5, direction="horizontal", flip_idx=None) -> None: + """ + Initializes the RandomFlip class with probability and direction. + + This class applies a random horizontal or vertical flip to an image with a given probability. + It also updates any instances (bounding boxes, keypoints, etc.) accordingly. + + Args: + p (float): The probability of applying the flip. Must be between 0 and 1. + direction (str): The direction to apply the flip. Must be 'horizontal' or 'vertical'. + flip_idx (List[int] | None): Index mapping for flipping keypoints, if any. + + Raises: + AssertionError: If direction is not 'horizontal' or 'vertical', or if p is not between 0 and 1. + + Examples: + >>> flip = RandomFlip(p=0.5, direction="horizontal") + >>> flip_with_idx = RandomFlip(p=0.7, direction="vertical", flip_idx=[1, 0, 3, 2, 5, 4]) + """ + assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}" + assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}." + + self.p = p + self.direction = direction + self.flip_idx = flip_idx + + def __call__(self, labels): + """ + Applies random flip to an image and updates any instances like bounding boxes or keypoints accordingly. + + This method randomly flips the input image either horizontally or vertically based on the initialized + probability and direction. It also updates the corresponding instances (bounding boxes, keypoints) to + match the flipped image. + + Args: + labels (Dict): A dictionary containing the following keys: + 'img' (numpy.ndarray): The image to be flipped. + 'instances' (ultralytics.utils.instance.Instances): An object containing bounding boxes and + optionally keypoints. + + Returns: + (Dict): The same dictionary with the flipped image and updated instances: + 'img' (numpy.ndarray): The flipped image. + 'instances' (ultralytics.utils.instance.Instances): Updated instances matching the flipped image. + + Examples: + >>> labels = {"img": np.random.rand(640, 640, 3), "instances": Instances(...)} + >>> random_flip = RandomFlip(p=0.5, direction="horizontal") + >>> flipped_labels = random_flip(labels) + """ + img = labels["img"] + instances = labels.pop("instances") + instances.convert_bbox(format="xywh") + h, w = img.shape[:2] + h = 1 if instances.normalized else h + w = 1 if instances.normalized else w + + # Flip up-down + if self.direction == "vertical" and random.random() < self.p: + img = np.flipud(img) + instances.flipud(h) + if self.direction == "horizontal" and random.random() < self.p: + img = np.fliplr(img) + instances.fliplr(w) + # For keypoints + if self.flip_idx is not None and instances.keypoints is not None: + instances.keypoints = np.ascontiguousarray(instances.keypoints[:, self.flip_idx, :]) + labels["img"] = np.ascontiguousarray(img) + labels["instances"] = instances + return labels + + +class LetterBox: + """ + Resize image and padding for detection, instance segmentation, pose. + + This class resizes and pads images to a specified shape while preserving aspect ratio. It also updates + corresponding labels and bounding boxes. + + Attributes: + new_shape (tuple): Target shape (height, width) for resizing. + auto (bool): Whether to use minimum rectangle. + scaleFill (bool): Whether to stretch the image to new_shape. + scaleup (bool): Whether to allow scaling up. If False, only scale down. + stride (int): Stride for rounding padding. + center (bool): Whether to center the image or align to top-left. + + Methods: + __call__: Resize and pad image, update labels and bounding boxes. + + Examples: + >>> transform = LetterBox(new_shape=(640, 640)) + >>> result = transform(labels) + >>> resized_img = result["img"] + >>> updated_instances = result["instances"] + """ + + def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, center=True, stride=32): + """ + Initialize LetterBox object for resizing and padding images. + + This class is designed to resize and pad images for object detection, instance segmentation, and pose estimation + tasks. It supports various resizing modes including auto-sizing, scale-fill, and letterboxing. + + Args: + new_shape (Tuple[int, int]): Target size (height, width) for the resized image. + auto (bool): If True, use minimum rectangle to resize. If False, use new_shape directly. + scaleFill (bool): If True, stretch the image to new_shape without padding. + scaleup (bool): If True, allow scaling up. If False, only scale down. + center (bool): If True, center the placed image. If False, place image in top-left corner. + stride (int): Stride of the model (e.g., 32 for YOLOv5). + + Attributes: + new_shape (Tuple[int, int]): Target size for the resized image. + auto (bool): Flag for using minimum rectangle resizing. + scaleFill (bool): Flag for stretching image without padding. + scaleup (bool): Flag for allowing upscaling. + stride (int): Stride value for ensuring image size is divisible by stride. + + Examples: + >>> letterbox = LetterBox(new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32) + >>> resized_img = letterbox(original_img) + """ + self.new_shape = new_shape + self.auto = auto + self.scaleFill = scaleFill + self.scaleup = scaleup + self.stride = stride + self.center = center # Put the image in the middle or top-left + + def __call__(self, labels=None, image=None): + """ + Resizes and pads an image for object detection, instance segmentation, or pose estimation tasks. + + This method applies letterboxing to the input image, which involves resizing the image while maintaining its + aspect ratio and adding padding to fit the new shape. It also updates any associated labels accordingly. + + Args: + labels (Dict | None): A dictionary containing image data and associated labels, or empty dict if None. + image (np.ndarray | None): The input image as a numpy array. If None, the image is taken from 'labels'. + + Returns: + (Dict | Tuple): If 'labels' is provided, returns an updated dictionary with the resized and padded image, + updated labels, and additional metadata. If 'labels' is empty, returns a tuple containing the resized + and padded image, and a tuple of (ratio, (left_pad, top_pad)). + + Examples: + >>> letterbox = LetterBox(new_shape=(640, 640)) + >>> result = letterbox(labels={"img": np.zeros((480, 640, 3)), "instances": Instances(...)}) + >>> resized_img = result["img"] + >>> updated_instances = result["instances"] + """ + if labels is None: + labels = {} + img = labels.get("img") if image is None else image + shape = img.shape[:2] # current shape [height, width] + new_shape = labels.pop("rect_shape", self.new_shape) + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not self.scaleup: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if self.auto: # minimum rectangle + dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding + elif self.scaleFill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + if self.center: + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1)) + img = cv2.copyMakeBorder( + img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) + ) # add border + if labels.get("ratio_pad"): + labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation + + if len(labels): + labels = self._update_labels(labels, ratio, left, top) + labels["img"] = img + labels["resized_shape"] = new_shape + return labels + else: + return img + + @staticmethod + def _update_labels(labels, ratio, padw, padh): + """ + Updates labels after applying letterboxing to an image. + + This method modifies the bounding box coordinates of instances in the labels + to account for resizing and padding applied during letterboxing. + + Args: + labels (Dict): A dictionary containing image labels and instances. + ratio (Tuple[float, float]): Scaling ratios (width, height) applied to the image. + padw (float): Padding width added to the image. + padh (float): Padding height added to the image. + + Returns: + (Dict): Updated labels dictionary with modified instance coordinates. + + Examples: + >>> letterbox = LetterBox(new_shape=(640, 640)) + >>> labels = {"instances": Instances(...)} + >>> ratio = (0.5, 0.5) + >>> padw, padh = 10, 20 + >>> updated_labels = letterbox._update_labels(labels, ratio, padw, padh) + """ + labels["instances"].convert_bbox(format="xyxy") + labels["instances"].denormalize(*labels["img"].shape[:2][::-1]) + labels["instances"].scale(*ratio) + labels["instances"].add_padding(padw, padh) + return labels + + +class CopyPaste(BaseMixTransform): + """ + CopyPaste class for applying Copy-Paste augmentation to image datasets. + + This class implements the Copy-Paste augmentation technique as described in the paper "Simple Copy-Paste is a Strong + Data Augmentation Method for Instance Segmentation" (https://arxiv.org/abs/2012.07177). It combines objects from + different images to create new training samples. + + Attributes: + dataset (Any): The dataset to which Copy-Paste augmentation will be applied. + pre_transform (Callable | None): Optional transform to apply before Copy-Paste. + p (float): Probability of applying Copy-Paste augmentation. + + Methods: + get_indexes: Returns a random index from the dataset. + _mix_transform: Applies Copy-Paste augmentation to the input labels. + __call__: Applies the Copy-Paste transformation to images and annotations. + + Examples: + >>> from ultralytics.data.augment import CopyPaste + >>> dataset = YourDataset(...) # Your image dataset + >>> copypaste = CopyPaste(dataset, p=0.5) + >>> augmented_labels = copypaste(original_labels) + """ + + def __init__(self, dataset=None, pre_transform=None, p=0.5, mode="flip") -> None: + """Initializes CopyPaste object with dataset, pre_transform, and probability of applying MixUp.""" + super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) + assert mode in {"flip", "mixup"}, f"Expected `mode` to be `flip` or `mixup`, but got {mode}." + self.mode = mode + + def get_indexes(self): + """Returns a list of random indexes from the dataset for CopyPaste augmentation.""" + return random.randint(0, len(self.dataset) - 1) + + def _mix_transform(self, labels): + """Applies Copy-Paste augmentation to combine objects from another image into the current image.""" + labels2 = labels["mix_labels"][0] + return self._transform(labels, labels2) + + def __call__(self, labels): + """Applies Copy-Paste augmentation to an image and its labels.""" + if len(labels["instances"].segments) == 0 or self.p == 0: + return labels + if self.mode == "flip": + return self._transform(labels) + + # Get index of one or three other images + indexes = self.get_indexes() + if isinstance(indexes, int): + indexes = [indexes] + + # Get images information will be used for Mosaic or MixUp + mix_labels = [self.dataset.get_image_and_label(i) for i in indexes] + + if self.pre_transform is not None: + for i, data in enumerate(mix_labels): + mix_labels[i] = self.pre_transform(data) + labels["mix_labels"] = mix_labels + + # Update cls and texts + labels = self._update_label_text(labels) + # Mosaic or MixUp + labels = self._mix_transform(labels) + labels.pop("mix_labels", None) + return labels + + def _transform(self, labels1, labels2={}): + """Applies Copy-Paste augmentation to combine objects from another image into the current image.""" + im = labels1["img"] + cls = labels1["cls"] + h, w = im.shape[:2] + instances = labels1.pop("instances") + instances.convert_bbox(format="xyxy") + instances.denormalize(w, h) + + im_new = np.zeros(im.shape, np.uint8) + instances2 = labels2.pop("instances", None) + if instances2 is None: + instances2 = deepcopy(instances) + instances2.fliplr(w) + ioa = bbox_ioa(instances2.bboxes, instances.bboxes) # intersection over area, (N, M) + indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, ) + n = len(indexes) + sorted_idx = np.argsort(ioa.max(1)[indexes]) + indexes = indexes[sorted_idx] + for j in indexes[: round(self.p * n)]: + cls = np.concatenate((cls, labels2.get("cls", cls)[[j]]), axis=0) + instances = Instances.concatenate((instances, instances2[[j]]), axis=0) + cv2.drawContours(im_new, instances2.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED) + + result = labels2.get("img", cv2.flip(im, 1)) # augment segments + i = im_new.astype(bool) + im[i] = result[i] + + labels1["img"] = im + labels1["cls"] = cls + labels1["instances"] = instances + return labels1 + + +class Albumentations: + """ + Albumentations transformations for image augmentation. + + This class applies various image transformations using the Albumentations library. It includes operations such as + Blur, Median Blur, conversion to grayscale, Contrast Limited Adaptive Histogram Equalization (CLAHE), random changes + in brightness and contrast, RandomGamma, and image quality reduction through compression. + + Attributes: + p (float): Probability of applying the transformations. + transform (albumentations.Compose): Composed Albumentations transforms. + contains_spatial (bool): Indicates if the transforms include spatial operations. + + Methods: + __call__: Applies the Albumentations transformations to the input labels. + + Examples: + >>> transform = Albumentations(p=0.5) + >>> augmented_labels = transform(labels) + + Notes: + - The Albumentations package must be installed to use this class. + - If the package is not installed or an error occurs during initialization, the transform will be set to None. + - Spatial transforms are handled differently and require special processing for bounding boxes. + """ + + def __init__(self, p=1.0): + """ + Initialize the Albumentations transform object for YOLO bbox formatted parameters. + + This class applies various image augmentations using the Albumentations library, including Blur, Median Blur, + conversion to grayscale, Contrast Limited Adaptive Histogram Equalization, random changes of brightness and + contrast, RandomGamma, and image quality reduction through compression. + + Args: + p (float): Probability of applying the augmentations. Must be between 0 and 1. + + Attributes: + p (float): Probability of applying the augmentations. + transform (albumentations.Compose): Composed Albumentations transforms. + contains_spatial (bool): Indicates if the transforms include spatial transformations. + + Raises: + ImportError: If the Albumentations package is not installed. + Exception: For any other errors during initialization. + + Examples: + >>> transform = Albumentations(p=0.5) + >>> augmented = transform(image=image, bboxes=bboxes, class_labels=classes) + >>> augmented_image = augmented["image"] + >>> augmented_bboxes = augmented["bboxes"] + + Notes: + - Requires Albumentations version 1.0.3 or higher. + - Spatial transforms are handled differently to ensure bbox compatibility. + - Some transforms are applied with very low probability (0.01) by default. + """ + self.p = p + self.transform = None + prefix = colorstr("albumentations: ") + + try: + import albumentations as A + + check_version(A.__version__, "1.0.3", hard=True) # version requirement + + # List of possible spatial transforms + spatial_transforms = { + "Affine", + "BBoxSafeRandomCrop", + "CenterCrop", + "CoarseDropout", + "Crop", + "CropAndPad", + "CropNonEmptyMaskIfExists", + "D4", + "ElasticTransform", + "Flip", + "GridDistortion", + "GridDropout", + "HorizontalFlip", + "Lambda", + "LongestMaxSize", + "MaskDropout", + "MixUp", + "Morphological", + "NoOp", + "OpticalDistortion", + "PadIfNeeded", + "Perspective", + "PiecewiseAffine", + "PixelDropout", + "RandomCrop", + "RandomCropFromBorders", + "RandomGridShuffle", + "RandomResizedCrop", + "RandomRotate90", + "RandomScale", + "RandomSizedBBoxSafeCrop", + "RandomSizedCrop", + "Resize", + "Rotate", + "SafeRotate", + "ShiftScaleRotate", + "SmallestMaxSize", + "Transpose", + "VerticalFlip", + "XYMasking", + } # from https://albumentations.ai/docs/getting_started/transforms_and_targets/#spatial-level-transforms + + # Transforms + T = [ + A.Blur(p=0.01), + A.MedianBlur(p=0.01), + A.ToGray(p=0.01), + A.CLAHE(p=0.01), + A.RandomBrightnessContrast(p=0.0), + A.RandomGamma(p=0.0), + A.ImageCompression(quality_lower=75, p=0.0), + ] + + # Compose transforms + self.contains_spatial = any(transform.__class__.__name__ in spatial_transforms for transform in T) + self.transform = ( + A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"])) + if self.contains_spatial + else A.Compose(T) + ) + if hasattr(self.transform, "set_random_seed"): + # Required for deterministic transforms in albumentations>=1.4.21 + self.transform.set_random_seed(torch.initial_seed()) + LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) + except ImportError: # package not installed, skip + pass + except Exception as e: + LOGGER.info(f"{prefix}{e}") + + def __call__(self, labels): + """ + Applies Albumentations transformations to input labels. + + This method applies a series of image augmentations using the Albumentations library. It can perform both + spatial and non-spatial transformations on the input image and its corresponding labels. + + Args: + labels (Dict): A dictionary containing image data and annotations. Expected keys are: + - 'img': numpy.ndarray representing the image + - 'cls': numpy.ndarray of class labels + - 'instances': object containing bounding boxes and other instance information + + Returns: + (Dict): The input dictionary with augmented image and updated annotations. + + Examples: + >>> transform = Albumentations(p=0.5) + >>> labels = { + ... "img": np.random.rand(640, 640, 3), + ... "cls": np.array([0, 1]), + ... "instances": Instances(bboxes=np.array([[0, 0, 1, 1], [0.5, 0.5, 0.8, 0.8]])), + ... } + >>> augmented = transform(labels) + >>> assert augmented["img"].shape == (640, 640, 3) + + Notes: + - The method applies transformations with probability self.p. + - Spatial transforms update bounding boxes, while non-spatial transforms only modify the image. + - Requires the Albumentations library to be installed. + """ + if self.transform is None or random.random() > self.p: + return labels + + if self.contains_spatial: + cls = labels["cls"] + if len(cls): + im = labels["img"] + labels["instances"].convert_bbox("xywh") + labels["instances"].normalize(*im.shape[:2][::-1]) + bboxes = labels["instances"].bboxes + # TODO: add supports of segments and keypoints + new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed + if len(new["class_labels"]) > 0: # skip update if no bbox in new im + labels["img"] = new["image"] + labels["cls"] = np.array(new["class_labels"]) + bboxes = np.array(new["bboxes"], dtype=np.float32) + labels["instances"].update(bboxes=bboxes) + else: + labels["img"] = self.transform(image=labels["img"])["image"] # transformed + + return labels + + +class Format: + """ + A class for formatting image annotations for object detection, instance segmentation, and pose estimation tasks. + + This class standardizes image and instance annotations to be used by the `collate_fn` in PyTorch DataLoader. + + Attributes: + bbox_format (str): Format for bounding boxes. Options are 'xywh' or 'xyxy'. + normalize (bool): Whether to normalize bounding boxes. + return_mask (bool): Whether to return instance masks for segmentation. + return_keypoint (bool): Whether to return keypoints for pose estimation. + return_obb (bool): Whether to return oriented bounding boxes. + mask_ratio (int): Downsample ratio for masks. + mask_overlap (bool): Whether to overlap masks. + batch_idx (bool): Whether to keep batch indexes. + bgr (float): The probability to return BGR images. + + Methods: + __call__: Formats labels dictionary with image, classes, bounding boxes, and optionally masks and keypoints. + _format_img: Converts image from Numpy array to PyTorch tensor. + _format_segments: Converts polygon points to bitmap masks. + + Examples: + >>> formatter = Format(bbox_format="xywh", normalize=True, return_mask=True) + >>> formatted_labels = formatter(labels) + >>> img = formatted_labels["img"] + >>> bboxes = formatted_labels["bboxes"] + >>> masks = formatted_labels["masks"] + """ + + def __init__( + self, + bbox_format="xywh", + normalize=True, + return_mask=False, + return_keypoint=False, + return_obb=False, + mask_ratio=4, + mask_overlap=True, + batch_idx=True, + bgr=0.0, + ): + """ + Initializes the Format class with given parameters for image and instance annotation formatting. + + This class standardizes image and instance annotations for object detection, instance segmentation, and pose + estimation tasks, preparing them for use in PyTorch DataLoader's `collate_fn`. + + Args: + bbox_format (str): Format for bounding boxes. Options are 'xywh', 'xyxy', etc. + normalize (bool): Whether to normalize bounding boxes to [0,1]. + return_mask (bool): If True, returns instance masks for segmentation tasks. + return_keypoint (bool): If True, returns keypoints for pose estimation tasks. + return_obb (bool): If True, returns oriented bounding boxes. + mask_ratio (int): Downsample ratio for masks. + mask_overlap (bool): If True, allows mask overlap. + batch_idx (bool): If True, keeps batch indexes. + bgr (float): Probability of returning BGR images instead of RGB. + + Attributes: + bbox_format (str): Format for bounding boxes. + normalize (bool): Whether bounding boxes are normalized. + return_mask (bool): Whether to return instance masks. + return_keypoint (bool): Whether to return keypoints. + return_obb (bool): Whether to return oriented bounding boxes. + mask_ratio (int): Downsample ratio for masks. + mask_overlap (bool): Whether masks can overlap. + batch_idx (bool): Whether to keep batch indexes. + bgr (float): The probability to return BGR images. + + Examples: + >>> format = Format(bbox_format="xyxy", return_mask=True, return_keypoint=False) + >>> print(format.bbox_format) + xyxy + """ + self.bbox_format = bbox_format + self.normalize = normalize + self.return_mask = return_mask # set False when training detection only + self.return_keypoint = return_keypoint + self.return_obb = return_obb + self.mask_ratio = mask_ratio + self.mask_overlap = mask_overlap + self.batch_idx = batch_idx # keep the batch indexes + self.bgr = bgr + + def __call__(self, labels): + """ + Formats image annotations for object detection, instance segmentation, and pose estimation tasks. + + This method standardizes the image and instance annotations to be used by the `collate_fn` in PyTorch + DataLoader. It processes the input labels dictionary, converting annotations to the specified format and + applying normalization if required. + + Args: + labels (Dict): A dictionary containing image and annotation data with the following keys: + - 'img': The input image as a numpy array. + - 'cls': Class labels for instances. + - 'instances': An Instances object containing bounding boxes, segments, and keypoints. + + Returns: + (Dict): A dictionary with formatted data, including: + - 'img': Formatted image tensor. + - 'cls': Class label's tensor. + - 'bboxes': Bounding boxes tensor in the specified format. + - 'masks': Instance masks tensor (if return_mask is True). + - 'keypoints': Keypoints tensor (if return_keypoint is True). + - 'batch_idx': Batch index tensor (if batch_idx is True). + + Examples: + >>> formatter = Format(bbox_format="xywh", normalize=True, return_mask=True) + >>> labels = {"img": np.random.rand(640, 640, 3), "cls": np.array([0, 1]), "instances": Instances(...)} + >>> formatted_labels = formatter(labels) + >>> print(formatted_labels.keys()) + """ + img = labels.pop("img") + h, w = img.shape[:2] + cls = labels.pop("cls") + instances = labels.pop("instances") + instances.convert_bbox(format=self.bbox_format) + instances.denormalize(w, h) + nl = len(instances) + + if self.return_mask: + if nl: + masks, instances, cls = self._format_segments(instances, cls, w, h) + masks = torch.from_numpy(masks) + else: + masks = torch.zeros( + 1 if self.mask_overlap else nl, img.shape[0] // self.mask_ratio, img.shape[1] // self.mask_ratio + ) + labels["masks"] = masks + labels["img"] = self._format_img(img) + labels["cls"] = torch.from_numpy(cls) if nl else torch.zeros(nl) + labels["bboxes"] = torch.from_numpy(instances.bboxes) if nl else torch.zeros((nl, 4)) + if self.return_keypoint: + labels["keypoints"] = torch.from_numpy(instances.keypoints) + if self.normalize: + labels["keypoints"][..., 0] /= w + labels["keypoints"][..., 1] /= h + if self.return_obb: + labels["bboxes"] = ( + xyxyxyxy2xywhr(torch.from_numpy(instances.segments)) if len(instances.segments) else torch.zeros((0, 5)) + ) + # NOTE: need to normalize obb in xywhr format for width-height consistency + if self.normalize: + labels["bboxes"][:, [0, 2]] /= w + labels["bboxes"][:, [1, 3]] /= h + # Then we can use collate_fn + if self.batch_idx: + labels["batch_idx"] = torch.zeros(nl) + return labels + + def _format_img(self, img): + """ + Formats an image for YOLO from a Numpy array to a PyTorch tensor. + + This function performs the following operations: + 1. Ensures the image has 3 dimensions (adds a channel dimension if needed). + 2. Transposes the image from HWC to CHW format. + 3. Optionally flips the color channels from RGB to BGR. + 4. Converts the image to a contiguous array. + 5. Converts the Numpy array to a PyTorch tensor. + + Args: + img (np.ndarray): Input image as a Numpy array with shape (H, W, C) or (H, W). + + Returns: + (torch.Tensor): Formatted image as a PyTorch tensor with shape (C, H, W). + + Examples: + >>> import numpy as np + >>> img = np.random.rand(100, 100, 3) + >>> formatted_img = self._format_img(img) + >>> print(formatted_img.shape) + torch.Size([3, 100, 100]) + """ + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = img.transpose(2, 0, 1) + img = np.ascontiguousarray(img[::-1] if random.uniform(0, 1) > self.bgr else img) + img = torch.from_numpy(img) + return img + + def _format_segments(self, instances, cls, w, h): + """ + Converts polygon segments to bitmap masks. + + Args: + instances (Instances): Object containing segment information. + cls (numpy.ndarray): Class labels for each instance. + w (int): Width of the image. + h (int): Height of the image. + + Returns: + masks (numpy.ndarray): Bitmap masks with shape (N, H, W) or (1, H, W) if mask_overlap is True. + instances (Instances): Updated instances object with sorted segments if mask_overlap is True. + cls (numpy.ndarray): Updated class labels, sorted if mask_overlap is True. + + Notes: + - If self.mask_overlap is True, masks are overlapped and sorted by area. + - If self.mask_overlap is False, each mask is represented separately. + - Masks are downsampled according to self.mask_ratio. + """ + segments = instances.segments + if self.mask_overlap: + masks, sorted_idx = polygons2masks_overlap((h, w), segments, downsample_ratio=self.mask_ratio) + masks = masks[None] # (640, 640) -> (1, 640, 640) + instances = instances[sorted_idx] + cls = cls[sorted_idx] + else: + masks = polygons2masks((h, w), segments, color=1, downsample_ratio=self.mask_ratio) + + return masks, instances, cls + + +class RandomLoadText: + """ + Randomly samples positive and negative texts and updates class indices accordingly. + + This class is responsible for sampling texts from a given set of class texts, including both positive + (present in the image) and negative (not present in the image) samples. It updates the class indices + to reflect the sampled texts and can optionally pad the text list to a fixed length. + + Attributes: + prompt_format (str): Format string for text prompts. + neg_samples (Tuple[int, int]): Range for randomly sampling negative texts. + max_samples (int): Maximum number of different text samples in one image. + padding (bool): Whether to pad texts to max_samples. + padding_value (str): The text used for padding when padding is True. + + Methods: + __call__: Processes the input labels and returns updated classes and texts. + + Examples: + >>> loader = RandomLoadText(prompt_format="Object: {}", neg_samples=(5, 10), max_samples=20) + >>> labels = {"cls": [0, 1, 2], "texts": [["cat"], ["dog"], ["bird"]], "instances": [...]} + >>> updated_labels = loader(labels) + >>> print(updated_labels["texts"]) + ['Object: cat', 'Object: dog', 'Object: bird', 'Object: elephant', 'Object: car'] + """ + + def __init__( + self, + prompt_format: str = "{}", + neg_samples: Tuple[int, int] = (80, 80), + max_samples: int = 80, + padding: bool = False, + padding_value: str = "", + ) -> None: + """ + Initializes the RandomLoadText class for randomly sampling positive and negative texts. + + This class is designed to randomly sample positive texts and negative texts, and update the class + indices accordingly to the number of samples. It can be used for text-based object detection tasks. + + Args: + prompt_format (str): Format string for the prompt. Default is '{}'. The format string should + contain a single pair of curly braces {} where the text will be inserted. + neg_samples (Tuple[int, int]): A range to randomly sample negative texts. The first integer + specifies the minimum number of negative samples, and the second integer specifies the + maximum. Default is (80, 80). + max_samples (int): The maximum number of different text samples in one image. Default is 80. + padding (bool): Whether to pad texts to max_samples. If True, the number of texts will always + be equal to max_samples. Default is False. + padding_value (str): The padding text to use when padding is True. Default is an empty string. + + Attributes: + prompt_format (str): The format string for the prompt. + neg_samples (Tuple[int, int]): The range for sampling negative texts. + max_samples (int): The maximum number of text samples. + padding (bool): Whether padding is enabled. + padding_value (str): The value used for padding. + + Examples: + >>> random_load_text = RandomLoadText(prompt_format="Object: {}", neg_samples=(50, 100), max_samples=120) + >>> random_load_text.prompt_format + 'Object: {}' + >>> random_load_text.neg_samples + (50, 100) + >>> random_load_text.max_samples + 120 + """ + self.prompt_format = prompt_format + self.neg_samples = neg_samples + self.max_samples = max_samples + self.padding = padding + self.padding_value = padding_value + + def __call__(self, labels: dict) -> dict: + """ + Randomly samples positive and negative texts and updates class indices accordingly. + + This method samples positive texts based on the existing class labels in the image, and randomly + selects negative texts from the remaining classes. It then updates the class indices to match the + new sampled text order. + + Args: + labels (Dict): A dictionary containing image labels and metadata. Must include 'texts' and 'cls' keys. + + Returns: + (Dict): Updated labels dictionary with new 'cls' and 'texts' entries. + + Examples: + >>> loader = RandomLoadText(prompt_format="A photo of {}", neg_samples=(5, 10), max_samples=20) + >>> labels = {"cls": np.array([[0], [1], [2]]), "texts": [["dog"], ["cat"], ["bird"]]} + >>> updated_labels = loader(labels) + """ + assert "texts" in labels, "No texts found in labels." + class_texts = labels["texts"] + num_classes = len(class_texts) + cls = np.asarray(labels.pop("cls"), dtype=int) + pos_labels = np.unique(cls).tolist() + + if len(pos_labels) > self.max_samples: + pos_labels = random.sample(pos_labels, k=self.max_samples) + + neg_samples = min(min(num_classes, self.max_samples) - len(pos_labels), random.randint(*self.neg_samples)) + neg_labels = [i for i in range(num_classes) if i not in pos_labels] + neg_labels = random.sample(neg_labels, k=neg_samples) + + sampled_labels = pos_labels + neg_labels + random.shuffle(sampled_labels) + + label2ids = {label: i for i, label in enumerate(sampled_labels)} + valid_idx = np.zeros(len(labels["instances"]), dtype=bool) + new_cls = [] + for i, label in enumerate(cls.squeeze(-1).tolist()): + if label not in label2ids: + continue + valid_idx[i] = True + new_cls.append([label2ids[label]]) + labels["instances"] = labels["instances"][valid_idx] + labels["cls"] = np.array(new_cls) + + # Randomly select one prompt when there's more than one prompts + texts = [] + for label in sampled_labels: + prompts = class_texts[label] + assert len(prompts) > 0 + prompt = self.prompt_format.format(prompts[random.randrange(len(prompts))]) + texts.append(prompt) + + if self.padding: + valid_labels = len(pos_labels) + len(neg_labels) + num_padding = self.max_samples - valid_labels + if num_padding > 0: + texts += [self.padding_value] * num_padding + + labels["texts"] = texts + return labels + + +def v8_transforms(dataset, imgsz, hyp, stretch=False): + """ + Applies a series of image transformations for training. + + This function creates a composition of image augmentation techniques to prepare images for YOLO training. + It includes operations such as mosaic, copy-paste, random perspective, mixup, and various color adjustments. + + Args: + dataset (Dataset): The dataset object containing image data and annotations. + imgsz (int): The target image size for resizing. + hyp (Namespace): A dictionary of hyperparameters controlling various aspects of the transformations. + stretch (bool): If True, applies stretching to the image. If False, uses LetterBox resizing. + + Returns: + (Compose): A composition of image transformations to be applied to the dataset. + + Examples: + >>> from ultralytics.data.dataset import YOLODataset + >>> from ultralytics.utils import IterableSimpleNamespace + >>> dataset = YOLODataset(img_path="path/to/images", imgsz=640) + >>> hyp = IterableSimpleNamespace(mosaic=1.0, copy_paste=0.5, degrees=10.0, translate=0.2, scale=0.9) + >>> transforms = v8_transforms(dataset, imgsz=640, hyp=hyp) + >>> augmented_data = transforms(dataset[0]) + """ + mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic) + affine = RandomPerspective( + degrees=hyp.degrees, + translate=hyp.translate, + scale=hyp.scale, + shear=hyp.shear, + perspective=hyp.perspective, + pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), + ) + + pre_transform = Compose([mosaic, affine]) + if hyp.copy_paste_mode == "flip": + pre_transform.insert(1, CopyPaste(p=hyp.copy_paste, mode=hyp.copy_paste_mode)) + else: + pre_transform.append( + CopyPaste( + dataset, + pre_transform=Compose([Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), affine]), + p=hyp.copy_paste, + mode=hyp.copy_paste_mode, + ) + ) + flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation + if dataset.use_keypoints: + kpt_shape = dataset.data.get("kpt_shape", None) + if len(flip_idx) == 0 and hyp.fliplr > 0.0: + hyp.fliplr = 0.0 + LOGGER.warning("WARNING ⚠️ No 'flip_idx' array defined in data.yaml, setting augmentation 'fliplr=0.0'") + elif flip_idx and (len(flip_idx) != kpt_shape[0]): + raise ValueError(f"data.yaml flip_idx={flip_idx} length must be equal to kpt_shape[0]={kpt_shape[0]}") + + return Compose( + [ + pre_transform, + MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), + Albumentations(p=1.0), + RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), + RandomFlip(direction="vertical", p=hyp.flipud), + RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx), + ] + ) # transforms + + +# Classification augmentations ----------------------------------------------------------------------------------------- +def classify_transforms( + size=224, + mean=DEFAULT_MEAN, + std=DEFAULT_STD, + interpolation="BILINEAR", + crop_fraction: float = DEFAULT_CROP_FRACTION, +): + """ + Creates a composition of image transforms for classification tasks. + + This function generates a sequence of torchvision transforms suitable for preprocessing images + for classification models during evaluation or inference. The transforms include resizing, + center cropping, conversion to tensor, and normalization. + + Args: + size (int | tuple): The target size for the transformed image. If an int, it defines the shortest edge. If a + tuple, it defines (height, width). + mean (tuple): Mean values for each RGB channel used in normalization. + std (tuple): Standard deviation values for each RGB channel used in normalization. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. + crop_fraction (float): Fraction of the image to be cropped. + + Returns: + (torchvision.transforms.Compose): A composition of torchvision transforms. + + Examples: + >>> transforms = classify_transforms(size=224) + >>> img = Image.open("path/to/image.jpg") + >>> transformed_img = transforms(img) + """ + import torchvision.transforms as T # scope for faster 'import ultralytics' + + if isinstance(size, (tuple, list)): + assert len(size) == 2, f"'size' tuples must be length 2, not length {len(size)}" + scale_size = tuple(math.floor(x / crop_fraction) for x in size) + else: + scale_size = math.floor(size / crop_fraction) + scale_size = (scale_size, scale_size) + + # Aspect ratio is preserved, crops center within image, no borders are added, image is lost + if scale_size[0] == scale_size[1]: + # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg) + tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))] + else: + # Resize the shortest edge to matching target dim for non-square target + tfl = [T.Resize(scale_size)] + tfl.extend( + [ + T.CenterCrop(size), + T.ToTensor(), + T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), + ] + ) + return T.Compose(tfl) + + +# Classification training augmentations -------------------------------------------------------------------------------- +def classify_augmentations( + size=224, + mean=DEFAULT_MEAN, + std=DEFAULT_STD, + scale=None, + ratio=None, + hflip=0.5, + vflip=0.0, + auto_augment=None, + hsv_h=0.015, # image HSV-Hue augmentation (fraction) + hsv_s=0.4, # image HSV-Saturation augmentation (fraction) + hsv_v=0.4, # image HSV-Value augmentation (fraction) + force_color_jitter=False, + erasing=0.0, + interpolation="BILINEAR", +): + """ + Creates a composition of image augmentation transforms for classification tasks. + + This function generates a set of image transformations suitable for training classification models. It includes + options for resizing, flipping, color jittering, auto augmentation, and random erasing. + + Args: + size (int): Target size for the image after transformations. + mean (tuple): Mean values for normalization, one per channel. + std (tuple): Standard deviation values for normalization, one per channel. + scale (tuple | None): Range of size of the origin size cropped. + ratio (tuple | None): Range of aspect ratio of the origin aspect ratio cropped. + hflip (float): Probability of horizontal flip. + vflip (float): Probability of vertical flip. + auto_augment (str | None): Auto augmentation policy. Can be 'randaugment', 'augmix', 'autoaugment' or None. + hsv_h (float): Image HSV-Hue augmentation factor. + hsv_s (float): Image HSV-Saturation augmentation factor. + hsv_v (float): Image HSV-Value augmentation factor. + force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled. + erasing (float): Probability of random erasing. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. + + Returns: + (torchvision.transforms.Compose): A composition of image augmentation transforms. + + Examples: + >>> transforms = classify_augmentations(size=224, auto_augment="randaugment") + >>> augmented_image = transforms(original_image) + """ + # Transforms to apply if Albumentations not installed + import torchvision.transforms as T # scope for faster 'import ultralytics' + + if not isinstance(size, int): + raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") + scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range + ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range + interpolation = getattr(T.InterpolationMode, interpolation) + primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] + if hflip > 0.0: + primary_tfl.append(T.RandomHorizontalFlip(p=hflip)) + if vflip > 0.0: + primary_tfl.append(T.RandomVerticalFlip(p=vflip)) + + secondary_tfl = [] + disable_color_jitter = False + if auto_augment: + assert isinstance(auto_augment, str), f"Provided argument should be string, but got type {type(auto_augment)}" + # color jitter is typically disabled if AA/RA on, + # this allows override without breaking old hparm cfgs + disable_color_jitter = not force_color_jitter + + if auto_augment == "randaugment": + if TORCHVISION_0_11: + secondary_tfl.append(T.RandAugment(interpolation=interpolation)) + else: + LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.') + + elif auto_augment == "augmix": + if TORCHVISION_0_13: + secondary_tfl.append(T.AugMix(interpolation=interpolation)) + else: + LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.') + + elif auto_augment == "autoaugment": + if TORCHVISION_0_10: + secondary_tfl.append(T.AutoAugment(interpolation=interpolation)) + else: + LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.') + + else: + raise ValueError( + f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", ' + f'"augmix", "autoaugment" or None' + ) + + if not disable_color_jitter: + secondary_tfl.append(T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)) + + final_tfl = [ + T.ToTensor(), + T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), + T.RandomErasing(p=erasing, inplace=True), + ] + + return T.Compose(primary_tfl + secondary_tfl + final_tfl) + + +# NOTE: keep this class for backward compatibility +class ClassifyLetterBox: + """ + A class for resizing and padding images for classification tasks. + + This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). + It resizes and pads images to a specified size while maintaining the original aspect ratio. + + Attributes: + h (int): Target height of the image. + w (int): Target width of the image. + auto (bool): If True, automatically calculates the short side using stride. + stride (int): The stride value, used when 'auto' is True. + + Methods: + __call__: Applies the letterbox transformation to an input image. + + Examples: + >>> transform = ClassifyLetterBox(size=(640, 640), auto=False, stride=32) + >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> result = transform(img) + >>> print(result.shape) + (640, 640, 3) + """ + + def __init__(self, size=(640, 640), auto=False, stride=32): + """ + Initializes the ClassifyLetterBox object for image preprocessing. + + This class is designed to be part of a transformation pipeline for image classification tasks. It resizes and + pads images to a specified size while maintaining the original aspect ratio. + + Args: + size (int | Tuple[int, int]): Target size for the letterboxed image. If an int, a square image of + (size, size) is created. If a tuple, it should be (height, width). + auto (bool): If True, automatically calculates the short side based on stride. Default is False. + stride (int): The stride value, used when 'auto' is True. Default is 32. + + Attributes: + h (int): Target height of the letterboxed image. + w (int): Target width of the letterboxed image. + auto (bool): Flag indicating whether to automatically calculate short side. + stride (int): Stride value for automatic short side calculation. + + Examples: + >>> transform = ClassifyLetterBox(size=224) + >>> img = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> result = transform(img) + >>> print(result.shape) + (224, 224, 3) + """ + super().__init__() + self.h, self.w = (size, size) if isinstance(size, int) else size + self.auto = auto # pass max size integer, automatically solve for short side using stride + self.stride = stride # used with auto + + def __call__(self, im): + """ + Resizes and pads an image using the letterbox method. + + This method resizes the input image to fit within the specified dimensions while maintaining its aspect ratio, + then pads the resized image to match the target size. + + Args: + im (numpy.ndarray): Input image as a numpy array with shape (H, W, C). + + Returns: + (numpy.ndarray): Resized and padded image as a numpy array with shape (hs, ws, 3), where hs and ws are + the target height and width respectively. + + Examples: + >>> letterbox = ClassifyLetterBox(size=(640, 640)) + >>> image = np.random.randint(0, 255, (720, 1280, 3), dtype=np.uint8) + >>> resized_image = letterbox(image) + >>> print(resized_image.shape) + (640, 640, 3) + """ + imh, imw = im.shape[:2] + r = min(self.h / imh, self.w / imw) # ratio of new/old dimensions + h, w = round(imh * r), round(imw * r) # resized image dimensions + + # Calculate padding dimensions + hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else (self.h, self.w) + top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1) + + # Create padded image + im_out = np.full((hs, ws, 3), 114, dtype=im.dtype) + im_out[top : top + h, left : left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + return im_out + + +# NOTE: keep this class for backward compatibility +class CenterCrop: + """ + Applies center cropping to images for classification tasks. + + This class performs center cropping on input images, resizing them to a specified size while maintaining the aspect + ratio. It is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]). + + Attributes: + h (int): Target height of the cropped image. + w (int): Target width of the cropped image. + + Methods: + __call__: Applies the center crop transformation to an input image. + + Examples: + >>> transform = CenterCrop(640) + >>> image = np.random.randint(0, 255, (1080, 1920, 3), dtype=np.uint8) + >>> cropped_image = transform(image) + >>> print(cropped_image.shape) + (640, 640, 3) + """ + + def __init__(self, size=640): + """ + Initializes the CenterCrop object for image preprocessing. + + This class is designed to be part of a transformation pipeline, e.g., T.Compose([CenterCrop(size), ToTensor()]). + It performs a center crop on input images to a specified size. + + Args: + size (int | Tuple[int, int]): The desired output size of the crop. If size is an int, a square crop + (size, size) is made. If size is a sequence like (h, w), it is used as the output size. + + Returns: + (None): This method initializes the object and does not return anything. + + Examples: + >>> transform = CenterCrop(224) + >>> img = np.random.rand(300, 300, 3) + >>> cropped_img = transform(img) + >>> print(cropped_img.shape) + (224, 224, 3) + """ + super().__init__() + self.h, self.w = (size, size) if isinstance(size, int) else size + + def __call__(self, im): + """ + Applies center cropping to an input image. + + This method resizes and crops the center of the image using a letterbox method. It maintains the aspect + ratio of the original image while fitting it into the specified dimensions. + + Args: + im (numpy.ndarray | PIL.Image.Image): The input image as a numpy array of shape (H, W, C) or a + PIL Image object. + + Returns: + (numpy.ndarray): The center-cropped and resized image as a numpy array of shape (self.h, self.w, C). + + Examples: + >>> transform = CenterCrop(size=224) + >>> image = np.random.randint(0, 255, (640, 480, 3), dtype=np.uint8) + >>> cropped_image = transform(image) + >>> assert cropped_image.shape == (224, 224, 3) + """ + if isinstance(im, Image.Image): # convert from PIL to numpy array if required + im = np.asarray(im) + imh, imw = im.shape[:2] + m = min(imh, imw) # min dimension + top, left = (imh - m) // 2, (imw - m) // 2 + return cv2.resize(im[top : top + m, left : left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR) + + +# NOTE: keep this class for backward compatibility +class ToTensor: + """ + Converts an image from a numpy array to a PyTorch tensor. + + This class is designed to be part of a transformation pipeline, e.g., T.Compose([LetterBox(size), ToTensor()]). + + Attributes: + half (bool): If True, converts the image to half precision (float16). + + Methods: + __call__: Applies the tensor conversion to an input image. + + Examples: + >>> transform = ToTensor(half=True) + >>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> tensor_img = transform(img) + >>> print(tensor_img.shape, tensor_img.dtype) + torch.Size([3, 640, 640]) torch.float16 + + Notes: + The input image is expected to be in BGR format with shape (H, W, C). + The output tensor will be in RGB format with shape (C, H, W), normalized to [0, 1]. + """ + + def __init__(self, half=False): + """ + Initializes the ToTensor object for converting images to PyTorch tensors. + + This class is designed to be used as part of a transformation pipeline for image preprocessing in the + Ultralytics YOLO framework. It converts numpy arrays or PIL Images to PyTorch tensors, with an option + for half-precision (float16) conversion. + + Args: + half (bool): If True, converts the tensor to half precision (float16). Default is False. + + Examples: + >>> transform = ToTensor(half=True) + >>> img = np.random.rand(640, 640, 3) + >>> tensor_img = transform(img) + >>> print(tensor_img.dtype) + torch.float16 + """ + super().__init__() + self.half = half + + def __call__(self, im): + """ + Transforms an image from a numpy array to a PyTorch tensor. + + This method converts the input image from a numpy array to a PyTorch tensor, applying optional + half-precision conversion and normalization. The image is transposed from HWC to CHW format and + the color channels are reversed from BGR to RGB. + + Args: + im (numpy.ndarray): Input image as a numpy array with shape (H, W, C) in BGR order. + + Returns: + (torch.Tensor): The transformed image as a PyTorch tensor in float32 or float16, normalized + to [0, 1] with shape (C, H, W) in RGB order. + + Examples: + >>> transform = ToTensor(half=True) + >>> img = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) + >>> tensor_img = transform(img) + >>> print(tensor_img.shape, tensor_img.dtype) + torch.Size([3, 640, 640]) torch.float16 + """ + im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous + im = torch.from_numpy(im) # to torch + im = im.half() if self.half else im.float() # uint8 to fp16/32 + im /= 255.0 # 0-255 to 0.0-1.0 + return im diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..446b4ecf88b4e1193882ccc9681bc18c54559717 --- /dev/null +++ b/ultralytics/data/base.py @@ -0,0 +1,346 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import glob +import math +import os +import random +from copy import deepcopy +from multiprocessing.pool import ThreadPool +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import psutil +from torch.utils.data import Dataset + +from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS +from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM + + +class BaseDataset(Dataset): + """ + Base dataset class for loading and processing image data. + + Args: + img_path (str): Path to the folder containing images. + imgsz (int, optional): Image size. Defaults to 640. + cache (bool, optional): Cache images to RAM or disk during training. Defaults to False. + augment (bool, optional): If True, data augmentation is applied. Defaults to True. + hyp (dict, optional): Hyperparameters to apply data augmentation. Defaults to None. + prefix (str, optional): Prefix to print in log messages. Defaults to ''. + rect (bool, optional): If True, rectangular training is used. Defaults to False. + batch_size (int, optional): Size of batches. Defaults to None. + stride (int, optional): Stride. Defaults to 32. + pad (float, optional): Padding. Defaults to 0.0. + single_cls (bool, optional): If True, single class training is used. Defaults to False. + classes (list): List of included classes. Default is None. + fraction (float): Fraction of dataset to utilize. Default is 1.0 (use all data). + + Attributes: + im_files (list): List of image file paths. + labels (list): List of label data dictionaries. + ni (int): Number of images in the dataset. + ims (list): List of loaded images. + npy_files (list): List of numpy file paths. + transforms (callable): Image transformation function. + """ + + def __init__( + self, + img_path, + imgsz=640, + cache=False, + augment=True, + hyp=DEFAULT_CFG, + prefix="", + rect=False, + batch_size=16, + stride=32, + pad=0.5, + single_cls=False, + classes=None, + fraction=1.0, + ): + """Initialize BaseDataset with given configuration and options.""" + super().__init__() + self.img_path = img_path + self.imgsz = imgsz + self.augment = augment + self.single_cls = single_cls + self.prefix = prefix + self.fraction = fraction + self.im_files = self.get_img_files(self.img_path) + self.labels = self.get_labels() + self.update_labels(include_class=classes) # single_cls and include_class + self.ni = len(self.labels) # number of images + self.rect = rect + self.batch_size = batch_size + self.stride = stride + self.pad = pad + if self.rect: + assert self.batch_size is not None + self.set_rectangle() + + # Buffer thread for mosaic images + self.buffer = [] # buffer size = batch size + self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0 + + # Cache images (options are cache = True, False, None, "ram", "disk") + self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni + self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files] + self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None + if self.cache == "ram" and self.check_cache_ram(): + if hyp.deterministic: + LOGGER.warning( + "WARNING ⚠️ cache='ram' may produce non-deterministic training results. " + "Consider cache='disk' as a deterministic alternative if your disk space allows." + ) + self.cache_images() + elif self.cache == "disk" and self.check_cache_disk(): + self.cache_images() + + # Transforms + self.transforms = self.build_transforms(hyp=hyp) + + def get_img_files(self, img_path): + """Read image files.""" + try: + f = [] # image files + for p in img_path if isinstance(img_path, list) else [img_path]: + p = Path(p) # os-agnostic + if p.is_dir(): # dir + f += glob.glob(str(p / "**" / "*.*"), recursive=True) + # F = list(p.rglob('*.*')) # pathlib + elif p.is_file(): # file + with open(p) as t: + t = t.read().strip().splitlines() + parent = str(p.parent) + os.sep + f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path + # F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib) + else: + raise FileNotFoundError(f"{self.prefix}{p} does not exist") + im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS) + # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib + assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}" + except Exception as e: + raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e + if self.fraction < 1: + im_files = im_files[: round(len(im_files) * self.fraction)] # retain a fraction of the dataset + return im_files + + def update_labels(self, include_class: Optional[list]): + """Update labels to include only these classes (optional).""" + include_class_array = np.array(include_class).reshape(1, -1) + for i in range(len(self.labels)): + if include_class is not None: + cls = self.labels[i]["cls"] + bboxes = self.labels[i]["bboxes"] + segments = self.labels[i]["segments"] + keypoints = self.labels[i]["keypoints"] + j = (cls == include_class_array).any(1) + self.labels[i]["cls"] = cls[j] + self.labels[i]["bboxes"] = bboxes[j] + if segments: + self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx] + if keypoints is not None: + self.labels[i]["keypoints"] = keypoints[j] + if self.single_cls: + self.labels[i]["cls"][:, 0] = 0 + + def load_image(self, i, rect_mode=True): + """Loads 1 image from dataset index 'i', returns (im, resized hw).""" + im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] + if im is None: # not cached in RAM + if fn.exists(): # load npy + try: + im = np.load(fn) + except Exception as e: + LOGGER.warning(f"{self.prefix}WARNING ⚠️ Removing corrupt *.npy image file {fn} due to: {e}") + Path(fn).unlink(missing_ok=True) + im = cv2.imread(f) # BGR + else: # read image + im = cv2.imread(f) # BGR + if im is None: + raise FileNotFoundError(f"Image Not Found {f}") + + h0, w0 = im.shape[:2] # orig hw + if rect_mode: # resize long side to imgsz while maintaining aspect ratio + r = self.imgsz / max(h0, w0) # ratio + if r != 1: # if sizes are not equal + w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz)) + im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR) + elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz + im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR) + + # Add to buffer if training with augmentations + if self.augment: + self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized + self.buffer.append(i) + if 1 < len(self.buffer) >= self.max_buffer_length: # prevent empty buffer + j = self.buffer.pop(0) + if self.cache != "ram": + self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None + + return im, (h0, w0), im.shape[:2] + + return self.ims[i], self.im_hw0[i], self.im_hw[i] + + def cache_images(self): + """Cache images to memory or disk.""" + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM") + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap(fcn, range(self.ni)) + pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0) + for i, x in pbar: + if self.cache == "disk": + b += self.npy_files[i].stat().st_size + else: # 'ram' + self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i) + b += self.ims[i].nbytes + pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})" + pbar.close() + + def cache_images_to_disk(self, i): + """Saves an image as an *.npy file for faster loading.""" + f = self.npy_files[i] + if not f.exists(): + np.save(f.as_posix(), cv2.imread(self.im_files[i]), allow_pickle=False) + + def check_cache_disk(self, safety_margin=0.5): + """Check image caching requirements vs available disk space.""" + import shutil + + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + n = min(self.ni, 30) # extrapolate from 30 random images + for _ in range(n): + im_file = random.choice(self.im_files) + im = cv2.imread(im_file) + if im is None: + continue + b += im.nbytes + if not os.access(Path(im_file).parent, os.W_OK): + self.cache = None + LOGGER.info(f"{self.prefix}Skipping caching images to disk, directory not writeable ⚠️") + return False + disk_required = b * self.ni / n * (1 + safety_margin) # bytes required to cache dataset to disk + total, used, free = shutil.disk_usage(Path(self.im_files[0]).parent) + if disk_required > free: + self.cache = None + LOGGER.info( + f"{self.prefix}{disk_required / gb:.1f}GB disk space required, " + f"with {int(safety_margin * 100)}% safety margin but only " + f"{free / gb:.1f}/{total / gb:.1f}GB free, not caching images to disk ⚠️" + ) + return False + return True + + def check_cache_ram(self, safety_margin=0.5): + """Check image caching requirements vs available memory.""" + b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes + n = min(self.ni, 30) # extrapolate from 30 random images + for _ in range(n): + im = cv2.imread(random.choice(self.im_files)) # sample image + if im is None: + continue + ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio + b += im.nbytes * ratio**2 + mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM + mem = psutil.virtual_memory() + if mem_required > mem.available: + self.cache = None + LOGGER.info( + f"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images " + f"with {int(safety_margin * 100)}% safety margin but only " + f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images ⚠️" + ) + return False + return True + + def set_rectangle(self): + """Sets the shape of bounding boxes for YOLO detections as rectangles.""" + bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index + nb = bi[-1] + 1 # number of batches + + s = np.array([x.pop("shape") for x in self.labels]) # hw + ar = s[:, 0] / s[:, 1] # aspect ratio + irect = ar.argsort() + self.im_files = [self.im_files[i] for i in irect] + self.labels = [self.labels[i] for i in irect] + ar = ar[irect] + + # Set training image shapes + shapes = [[1, 1]] * nb + for i in range(nb): + ari = ar[bi == i] + mini, maxi = ari.min(), ari.max() + if maxi < 1: + shapes[i] = [maxi, 1] + elif mini > 1: + shapes[i] = [1, 1 / mini] + + self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride + self.batch = bi # batch index of image + + def __getitem__(self, index): + """Returns transformed label information for given index.""" + return self.transforms(self.get_image_and_label(index)) + + def get_image_and_label(self, index): + """Get and return label information from the dataset.""" + label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 + label.pop("shape", None) # shape is for rect, remove it + label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index) + label["ratio_pad"] = ( + label["resized_shape"][0] / label["ori_shape"][0], + label["resized_shape"][1] / label["ori_shape"][1], + ) # for evaluation + if self.rect: + label["rect_shape"] = self.batch_shapes[self.batch[index]] + return self.update_labels_info(label) + + def __len__(self): + """Returns the length of the labels list for the dataset.""" + return len(self.labels) + + def update_labels_info(self, label): + """Custom your label format here.""" + return label + + def build_transforms(self, hyp=None): + """ + Users can customize augmentations here. + + Example: + ```python + if self.augment: + # Training transforms + return Compose([]) + else: + # Val transforms + return Compose([]) + ``` + """ + raise NotImplementedError + + def get_labels(self): + """ + Users can customize their own format here. + + Note: + Ensure output is a dictionary with the following keys: + ```python + dict( + im_file=im_file, + shape=shape, # format: (height, width) + cls=cls, + bboxes=bboxes, # xywh + segments=segments, # xy + keypoints=keypoints, # xy + normalized=True, # or False + bbox_format="xyxy", # or xywh, ltwh + ) + ``` + """ + raise NotImplementedError diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..33b31ba4fb32ffed6c8ae88c6b92777c998ab765 --- /dev/null +++ b/ultralytics/data/build.py @@ -0,0 +1,215 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +import random +from pathlib import Path + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import dataloader, distributed + +from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset +from ultralytics.data.loaders import ( + LOADERS, + LoadImagesAndVideos, + LoadPilAndNumpy, + LoadScreenshots, + LoadStreams, + LoadTensor, + SourceTypes, + autocast_list, +) +from ultralytics.data.utils import IMG_FORMATS, PIN_MEMORY, VID_FORMATS +from ultralytics.utils import RANK, colorstr +from ultralytics.utils.checks import check_file + + +class InfiniteDataLoader(dataloader.DataLoader): + """ + Dataloader that reuses workers. + + Uses same syntax as vanilla DataLoader. + """ + + def __init__(self, *args, **kwargs): + """Dataloader that infinitely recycles workers, inherits from DataLoader.""" + super().__init__(*args, **kwargs) + object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler)) + self.iterator = super().__iter__() + + def __len__(self): + """Returns the length of the batch sampler's sampler.""" + return len(self.batch_sampler.sampler) + + def __iter__(self): + """Creates a sampler that repeats indefinitely.""" + for _ in range(len(self)): + yield next(self.iterator) + + def __del__(self): + """Ensure that workers are terminated.""" + if hasattr(self.iterator, "_workers"): + for w in self.iterator._workers: # force terminate + if w.is_alive(): + w.terminate() + self.iterator._shutdown_workers() # cleanup + + def reset(self): + """ + Reset iterator. + + This is useful when we want to modify settings of dataset while training. + """ + self.iterator = self._get_iterator() + + +class _RepeatSampler: + """ + Sampler that repeats forever. + + Args: + sampler (Dataset.sampler): The sampler to repeat. + """ + + def __init__(self, sampler): + """Initializes an object that repeats a given sampler indefinitely.""" + self.sampler = sampler + + def __iter__(self): + """Iterates over the 'sampler' and yields its contents.""" + while True: + yield from iter(self.sampler) + + +def seed_worker(worker_id): # noqa + """Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader.""" + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def build_yolo_dataset(cfg, img_path, batch, data, mode="train", rect=False, stride=32, multi_modal=False): + """Build YOLO Dataset.""" + dataset = YOLOMultiModalDataset if multi_modal else YOLODataset + return dataset( + img_path=img_path, + imgsz=cfg.imgsz, + batch_size=batch, + augment=mode == "train", # augmentation + hyp=cfg, # TODO: probably add a get_hyps_from_cfg function + rect=cfg.rect or rect, # rectangular batches + cache=cfg.cache or None, + single_cls=cfg.single_cls or False, + stride=int(stride), + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), + task=cfg.task, + classes=cfg.classes, + data=data, + fraction=cfg.fraction if mode == "train" else 1.0, + ) + + +def build_grounding(cfg, img_path, json_file, batch, mode="train", rect=False, stride=32): + """Build YOLO Dataset.""" + return GroundingDataset( + img_path=img_path, + json_file=json_file, + imgsz=cfg.imgsz, + batch_size=batch, + augment=mode == "train", # augmentation + hyp=cfg, # TODO: probably add a get_hyps_from_cfg function + rect=cfg.rect or rect, # rectangular batches + cache=cfg.cache or None, + single_cls=cfg.single_cls or False, + stride=int(stride), + pad=0.0 if mode == "train" else 0.5, + prefix=colorstr(f"{mode}: "), + task=cfg.task, + classes=cfg.classes, + fraction=cfg.fraction if mode == "train" else 1.0, + ) + + +def build_dataloader(dataset, batch, workers, shuffle=True, rank=-1): + """Return an InfiniteDataLoader or DataLoader for training or validation set.""" + batch = min(batch, len(dataset)) + nd = torch.cuda.device_count() # number of CUDA devices + nw = min(os.cpu_count() // max(nd, 1), workers) # number of workers + sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle) + generator = torch.Generator() + generator.manual_seed(6148914691236517205 + RANK) + return InfiniteDataLoader( + dataset=dataset, + batch_size=batch, + shuffle=shuffle and sampler is None, + num_workers=nw, + sampler=sampler, + pin_memory=PIN_MEMORY, + collate_fn=getattr(dataset, "collate_fn", None), + worker_init_fn=seed_worker, + generator=generator, + ) + + +def check_source(source): + """Check source type and return corresponding flag values.""" + webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False + if isinstance(source, (str, int, Path)): # int for local usb camera + source = str(source) + is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS) + is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")) + webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file) + screenshot = source.lower() == "screen" + if is_url and is_file: + source = check_file(source) # download + elif isinstance(source, LOADERS): + in_memory = True + elif isinstance(source, (list, tuple)): + source = autocast_list(source) # convert all list elements to PIL or np arrays + from_img = True + elif isinstance(source, (Image.Image, np.ndarray)): + from_img = True + elif isinstance(source, torch.Tensor): + tensor = True + else: + raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict") + + return source, webcam, screenshot, from_img, in_memory, tensor + + +def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False): + """ + Loads an inference source for object detection and applies necessary transformations. + + Args: + source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference. + batch (int, optional): Batch size for dataloaders. Default is 1. + vid_stride (int, optional): The frame interval for video sources. Default is 1. + buffer (bool, optional): Determined whether stream frames will be buffered. Default is False. + + Returns: + dataset (Dataset): A dataset object for the specified input source. + """ + source, stream, screenshot, from_img, in_memory, tensor = check_source(source) + source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor) + + # Dataloader + if tensor: + dataset = LoadTensor(source) + elif in_memory: + dataset = source + elif stream: + dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer) + elif screenshot: + dataset = LoadScreenshots(source) + elif from_img: + dataset = LoadPilAndNumpy(source) + else: + dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride) + + # Attach source types to the dataset + setattr(dataset, "source_type", source_type) + + return dataset diff --git a/ultralytics/data/converter.py b/ultralytics/data/converter.py new file mode 100644 index 0000000000000000000000000000000000000000..05a316b4858a20e62c131e488e1b2f70cde26ed7 --- /dev/null +++ b/ultralytics/data/converter.py @@ -0,0 +1,702 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json +import random +import shutil +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +import cv2 +import numpy as np +from PIL import Image + +from ultralytics.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM +from ultralytics.utils.downloads import download +from ultralytics.utils.files import increment_path + + +def coco91_to_coco80_class(): + """ + Converts 91-index COCO class IDs to 80-index COCO class IDs. + + Returns: + (list): A list of 91 class IDs where the index represents the 80-index class ID and the value is the + corresponding 91-index class ID. + """ + return [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + None, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + None, + 24, + 25, + None, + None, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + None, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + None, + 60, + None, + None, + 61, + None, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + None, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + None, + ] + + +def coco80_to_coco91_class(): + r""" + Converts 80-index (val2014) to 91-index (paper). + For details see https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/. + + Example: + ```python + import numpy as np + + a = np.loadtxt("data/coco.names", dtype="str", delimiter="\n") + b = np.loadtxt("data/coco_paper.names", dtype="str", delimiter="\n") + x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco + x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet + ``` + """ + return [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 27, + 28, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 67, + 70, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + ] + + +def convert_coco( + labels_dir="../coco/annotations/", + save_dir="coco_converted/", + use_segments=False, + use_keypoints=False, + cls91to80=True, + lvis=False, +): + """ + Converts COCO dataset annotations to a YOLO annotation format suitable for training YOLO models. + + Args: + labels_dir (str, optional): Path to directory containing COCO dataset annotation files. + save_dir (str, optional): Path to directory to save results to. + use_segments (bool, optional): Whether to include segmentation masks in the output. + use_keypoints (bool, optional): Whether to include keypoint annotations in the output. + cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs. + lvis (bool, optional): Whether to convert data in lvis dataset way. + + Example: + ```python + from ultralytics.data.converter import convert_coco + + convert_coco("../datasets/coco/annotations/", use_segments=True, use_keypoints=False, cls91to80=False) + convert_coco( + "../datasets/lvis/annotations/", use_segments=True, use_keypoints=False, cls91to80=False, lvis=True + ) + ``` + + Output: + Generates output files in the specified output directory. + """ + # Create dataset directory + save_dir = increment_path(save_dir) # increment if save directory already exists + for p in save_dir / "labels", save_dir / "images": + p.mkdir(parents=True, exist_ok=True) # make dir + + # Convert classes + coco80 = coco91_to_coco80_class() + + # Import json + for json_file in sorted(Path(labels_dir).resolve().glob("*.json")): + lname = "" if lvis else json_file.stem.replace("instances_", "") + fn = Path(save_dir) / "labels" / lname # folder name + fn.mkdir(parents=True, exist_ok=True) + if lvis: + # NOTE: create folders for both train and val in advance, + # since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split. + (fn / "train2017").mkdir(parents=True, exist_ok=True) + (fn / "val2017").mkdir(parents=True, exist_ok=True) + with open(json_file, encoding="utf-8") as f: + data = json.load(f) + + # Create image dict + images = {f"{x['id']:d}": x for x in data["images"]} + # Create image-annotations dict + imgToAnns = defaultdict(list) + for ann in data["annotations"]: + imgToAnns[ann["image_id"]].append(ann) + + image_txt = [] + # Write labels file + for img_id, anns in TQDM(imgToAnns.items(), desc=f"Annotations {json_file}"): + img = images[f"{img_id:d}"] + h, w = img["height"], img["width"] + f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"] + if lvis: + image_txt.append(str(Path("./images") / f)) + + bboxes = [] + segments = [] + keypoints = [] + for ann in anns: + if ann.get("iscrowd", False): + continue + # The COCO box format is [top left x, top left y, width, height] + box = np.array(ann["bbox"], dtype=np.float64) + box[:2] += box[2:] / 2 # xy top-left corner to center + box[[0, 2]] /= w # normalize x + box[[1, 3]] /= h # normalize y + if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0 + continue + + cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class + box = [cls] + box.tolist() + if box not in bboxes: + bboxes.append(box) + if use_segments and ann.get("segmentation") is not None: + if len(ann["segmentation"]) == 0: + segments.append([]) + continue + elif len(ann["segmentation"]) > 1: + s = merge_multi_segment(ann["segmentation"]) + s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist() + else: + s = [j for i in ann["segmentation"] for j in i] # all segments concatenated + s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist() + s = [cls] + s + segments.append(s) + if use_keypoints and ann.get("keypoints") is not None: + keypoints.append( + box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist() + ) + + # Write + with open((fn / f).with_suffix(".txt"), "a") as file: + for i in range(len(bboxes)): + if use_keypoints: + line = (*(keypoints[i]),) # cls, box, keypoints + else: + line = ( + *(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]), + ) # cls, box or segments + file.write(("%g " * len(line)).rstrip() % line + "\n") + + if lvis: + with open((Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")), "a") as f: + f.writelines(f"{line}\n" for line in image_txt) + + LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}") + + +def convert_segment_masks_to_yolo_seg(masks_dir, output_dir, classes): + """ + Converts a dataset of segmentation mask images to the YOLO segmentation format. + + This function takes the directory containing the binary format mask images and converts them into YOLO segmentation format. + The converted masks are saved in the specified output directory. + + Args: + masks_dir (str): The path to the directory where all mask images (png, jpg) are stored. + output_dir (str): The path to the directory where the converted YOLO segmentation masks will be stored. + classes (int): Total classes in the dataset i.e. for COCO classes=80 + + Example: + ```python + from ultralytics.data.converter import convert_segment_masks_to_yolo_seg + + # The classes here is the total classes in the dataset, for COCO dataset we have 80 classes + convert_segment_masks_to_yolo_seg("path/to/masks_directory", "path/to/output/directory", classes=80) + ``` + + Notes: + The expected directory structure for the masks is: + + - masks + ├─ mask_image_01.png or mask_image_01.jpg + ├─ mask_image_02.png or mask_image_02.jpg + ├─ mask_image_03.png or mask_image_03.jpg + └─ mask_image_04.png or mask_image_04.jpg + + After execution, the labels will be organized in the following structure: + + - output_dir + ├─ mask_yolo_01.txt + ├─ mask_yolo_02.txt + ├─ mask_yolo_03.txt + └─ mask_yolo_04.txt + """ + pixel_to_class_mapping = {i + 1: i for i in range(classes)} + for mask_path in Path(masks_dir).iterdir(): + if mask_path.suffix in {".png", ".jpg"}: + mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) # Read the mask image in grayscale + img_height, img_width = mask.shape # Get image dimensions + LOGGER.info(f"Processing {mask_path} imgsz = {img_height} x {img_width}") + + unique_values = np.unique(mask) # Get unique pixel values representing different classes + yolo_format_data = [] + + for value in unique_values: + if value == 0: + continue # Skip background + class_index = pixel_to_class_mapping.get(value, -1) + if class_index == -1: + LOGGER.warning(f"Unknown class for pixel value {value} in file {mask_path}, skipping.") + continue + + # Create a binary mask for the current class and find contours + contours, _ = cv2.findContours( + (mask == value).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) # Find contours + + for contour in contours: + if len(contour) >= 3: # YOLO requires at least 3 points for a valid segmentation + contour = contour.squeeze() # Remove single-dimensional entries + yolo_format = [class_index] + for point in contour: + # Normalize the coordinates + yolo_format.append(round(point[0] / img_width, 6)) # Rounding to 6 decimal places + yolo_format.append(round(point[1] / img_height, 6)) + yolo_format_data.append(yolo_format) + # Save Ultralytics YOLO format data to file + output_path = Path(output_dir) / f"{mask_path.stem}.txt" + with open(output_path, "w") as file: + for item in yolo_format_data: + line = " ".join(map(str, item)) + file.write(line + "\n") + LOGGER.info(f"Processed and stored at {output_path} imgsz = {img_height} x {img_width}") + + +def convert_dota_to_yolo_obb(dota_root_path: str): + """ + Converts DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format. + + The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the + associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory. + + Args: + dota_root_path (str): The root directory path of the DOTA dataset. + + Example: + ```python + from ultralytics.data.converter import convert_dota_to_yolo_obb + + convert_dota_to_yolo_obb("path/to/DOTA") + ``` + + Notes: + The directory structure assumed for the DOTA dataset: + + - DOTA + ├─ images + │ ├─ train + │ └─ val + └─ labels + ├─ train_original + └─ val_original + + After execution, the function will organize the labels into: + + - DOTA + └─ labels + ├─ train + └─ val + """ + dota_root_path = Path(dota_root_path) + + # Class names to indices mapping + class_mapping = { + "plane": 0, + "ship": 1, + "storage-tank": 2, + "baseball-diamond": 3, + "tennis-court": 4, + "basketball-court": 5, + "ground-track-field": 6, + "harbor": 7, + "bridge": 8, + "large-vehicle": 9, + "small-vehicle": 10, + "helicopter": 11, + "roundabout": 12, + "soccer-ball-field": 13, + "swimming-pool": 14, + "container-crane": 15, + "airport": 16, + "helipad": 17, + } + + def convert_label(image_name, image_width, image_height, orig_label_dir, save_dir): + """Converts a single image's DOTA annotation to YOLO OBB format and saves it to a specified directory.""" + orig_label_path = orig_label_dir / f"{image_name}.txt" + save_path = save_dir / f"{image_name}.txt" + + with orig_label_path.open("r") as f, save_path.open("w") as g: + lines = f.readlines() + for line in lines: + parts = line.strip().split() + if len(parts) < 9: + continue + class_name = parts[8] + class_idx = class_mapping[class_name] + coords = [float(p) for p in parts[:8]] + normalized_coords = [ + coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8) + ] + formatted_coords = [f"{coord:.6g}" for coord in normalized_coords] + g.write(f"{class_idx} {' '.join(formatted_coords)}\n") + + for phase in ["train", "val"]: + image_dir = dota_root_path / "images" / phase + orig_label_dir = dota_root_path / "labels" / f"{phase}_original" + save_dir = dota_root_path / "labels" / phase + + save_dir.mkdir(parents=True, exist_ok=True) + + image_paths = list(image_dir.iterdir()) + for image_path in TQDM(image_paths, desc=f"Processing {phase} images"): + if image_path.suffix != ".png": + continue + image_name_without_ext = image_path.stem + img = cv2.imread(str(image_path)) + h, w = img.shape[:2] + convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir) + + +def min_index(arr1, arr2): + """ + Find a pair of indexes with the shortest distance between two arrays of 2D points. + + Args: + arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points. + arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points. + + Returns: + (tuple): A tuple containing the indexes of the points with the shortest distance in arr1 and arr2 respectively. + """ + dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1) + return np.unravel_index(np.argmin(dis, axis=None), dis.shape) + + +def merge_multi_segment(segments): + """ + Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment. + This function connects these coordinates with a thin line to merge all segments into one. + + Args: + segments (List[List]): Original segmentations in COCO's JSON file. + Each element is a list of coordinates, like [segmentation1, segmentation2,...]. + + Returns: + s (List[np.ndarray]): A list of connected segments represented as NumPy arrays. + """ + s = [] + segments = [np.array(i).reshape(-1, 2) for i in segments] + idx_list = [[] for _ in range(len(segments))] + + # Record the indexes with min distance between each segment + for i in range(1, len(segments)): + idx1, idx2 = min_index(segments[i - 1], segments[i]) + idx_list[i - 1].append(idx1) + idx_list[i].append(idx2) + + # Use two round to connect all the segments + for k in range(2): + # Forward connection + if k == 0: + for i, idx in enumerate(idx_list): + # Middle segments have two indexes, reverse the index of middle segments + if len(idx) == 2 and idx[0] > idx[1]: + idx = idx[::-1] + segments[i] = segments[i][::-1, :] + + segments[i] = np.roll(segments[i], -idx[0], axis=0) + segments[i] = np.concatenate([segments[i], segments[i][:1]]) + # Deal with the first segment and the last one + if i in {0, len(idx_list) - 1}: + s.append(segments[i]) + else: + idx = [0, idx[1] - idx[0]] + s.append(segments[i][idx[0] : idx[1] + 1]) + + else: + for i in range(len(idx_list) - 1, -1, -1): + if i not in {0, len(idx_list) - 1}: + idx = idx_list[i] + nidx = abs(idx[1] - idx[0]) + s.append(segments[i][nidx:]) + return s + + +def yolo_bbox2segment(im_dir, save_dir=None, sam_model="sam_b.pt", device=None): + """ + Converts existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) + in YOLO format. Generates segmentation data using SAM auto-annotator as needed. + + Args: + im_dir (str | Path): Path to image directory to convert. + save_dir (str | Path): Path to save the generated labels, labels will be saved + into `labels-segment` in the same directory level of `im_dir` if save_dir is None. Default: None. + sam_model (str): Segmentation model to use for intermediate segmentation data; optional. + device (int | str): The specific device to run SAM models. Default: None. + + Notes: + The input directory structure assumed for dataset: + + - im_dir + ├─ 001.jpg + ├─ ... + └─ NNN.jpg + - labels + ├─ 001.txt + ├─ ... + └─ NNN.txt + """ + from ultralytics import SAM + from ultralytics.data import YOLODataset + from ultralytics.utils import LOGGER + from ultralytics.utils.ops import xywh2xyxy + + # NOTE: add placeholder to pass class index check + dataset = YOLODataset(im_dir, data=dict(names=list(range(1000)))) + if len(dataset.labels[0]["segments"]) > 0: # if it's segment data + LOGGER.info("Segmentation labels detected, no need to generate new ones!") + return + + LOGGER.info("Detection labels detected, generating segment labels by SAM model!") + sam_model = SAM(sam_model) + for label in TQDM(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"): + h, w = label["shape"] + boxes = label["bboxes"] + if len(boxes) == 0: # skip empty labels + continue + boxes[:, [0, 2]] *= w + boxes[:, [1, 3]] *= h + im = cv2.imread(label["im_file"]) + sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False, device=device) + label["segments"] = sam_results[0].masks.xyn + + save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment" + save_dir.mkdir(parents=True, exist_ok=True) + for label in dataset.labels: + texts = [] + lb_name = Path(label["im_file"]).with_suffix(".txt").name + txt_file = save_dir / lb_name + cls = label["cls"] + for i, s in enumerate(label["segments"]): + if len(s) == 0: + continue + line = (int(cls[i]), *s.reshape(-1)) + texts.append(("%g " * len(line)).rstrip() % line) + with open(txt_file, "a") as f: + f.writelines(text + "\n" for text in texts) + LOGGER.info(f"Generated segment labels saved in {save_dir}") + + +def create_synthetic_coco_dataset(): + """ + Creates a synthetic COCO dataset with random images based on filenames from label lists. + + This function downloads COCO labels, reads image filenames from label list files, + creates synthetic images for train2017 and val2017 subsets, and organizes + them in the COCO dataset structure. It uses multithreading to generate images efficiently. + + Examples: + >>> from ultralytics.data.converter import create_synthetic_coco_dataset + >>> create_synthetic_coco_dataset() + + Notes: + - Requires internet connection to download label files. + - Generates random RGB images of varying sizes (480x480 to 640x640 pixels). + - Existing test2017 directory is removed as it's not needed. + - Reads image filenames from train2017.txt and val2017.txt files. + """ + + def create_synthetic_image(image_file): + """Generates synthetic images with random sizes and colors for dataset augmentation or testing purposes.""" + if not image_file.exists(): + size = (random.randint(480, 640), random.randint(480, 640)) + Image.new( + "RGB", + size=size, + color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), + ).save(image_file) + + # Download labels + dir = DATASETS_DIR / "coco" + url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/" + label_zip = "coco2017labels-segments.zip" + download([url + label_zip], dir=dir.parent) + + # Create synthetic images + shutil.rmtree(dir / "labels" / "test2017", ignore_errors=True) # Remove test2017 directory as not needed + with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + for subset in ["train2017", "val2017"]: + subset_dir = dir / "images" / subset + subset_dir.mkdir(parents=True, exist_ok=True) + + # Read image filenames from label list file + label_list_file = dir / f"{subset}.txt" + if label_list_file.exists(): + with open(label_list_file) as f: + image_files = [dir / line.strip() for line in f] + + # Submit all tasks + futures = [executor.submit(create_synthetic_image, image_file) for image_file in image_files] + for _ in TQDM(as_completed(futures), total=len(futures), desc=f"Generating images for {subset}"): + pass # The actual work is done in the background + else: + print(f"Warning: Labels file {label_list_file} does not exist. Skipping image creation for {subset}.") + + print("Synthetic COCO dataset created successfully.") diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1817c35c50b9af954b79707c922e0ecd22d50130 --- /dev/null +++ b/ultralytics/data/dataset.py @@ -0,0 +1,521 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json +from collections import defaultdict +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import ConcatDataset + +from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr +from ultralytics.utils.ops import resample_segments +from ultralytics.utils.torch_utils import TORCHVISION_0_18 + +from .augment import ( + Compose, + Format, + Instances, + LetterBox, + RandomLoadText, + classify_augmentations, + classify_transforms, + v8_transforms, +) +from .base import BaseDataset +from .utils import ( + HELP_URL, + LOGGER, + get_hash, + img2label_paths, + load_dataset_cache_file, + save_dataset_cache_file, + verify_image, + verify_image_label, +) + +# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 +DATASET_CACHE_VERSION = "1.0.3" + + +class YOLODataset(BaseDataset): + """ + Dataset class for loading object detection and/or segmentation labels in YOLO format. + + Args: + data (dict, optional): A dataset YAML dictionary. Defaults to None. + task (str): An explicit arg to point current task, Defaults to 'detect'. + + Returns: + (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. + """ + + def __init__(self, *args, data=None, task="detect", **kwargs): + """Initializes the YOLODataset with optional configurations for segments and keypoints.""" + self.use_segments = task == "segment" + self.use_keypoints = task == "pose" + self.use_obb = task == "obb" + self.data = data + assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints." + super().__init__(*args, **kwargs) + + def cache_labels(self, path=Path("./labels.cache")): + """ + Cache dataset labels, check images and read shapes. + + Args: + path (Path): Path where to save the cache file. Default is Path("./labels.cache"). + + Returns: + (dict): labels. + """ + x = {"labels": []} + nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages + desc = f"{self.prefix}Scanning {path.parent / path.stem}..." + total = len(self.im_files) + nkpt, ndim = self.data.get("kpt_shape", (0, 0)) + if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}): + raise ValueError( + "'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of " + "keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'" + ) + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap( + func=verify_image_label, + iterable=zip( + self.im_files, + self.label_files, + repeat(self.prefix), + repeat(self.use_keypoints), + repeat(len(self.data["names"])), + repeat(nkpt), + repeat(ndim), + ), + ) + pbar = TQDM(results, desc=desc, total=total) + for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar: + nm += nm_f + nf += nf_f + ne += ne_f + nc += nc_f + if im_file: + x["labels"].append( + { + "im_file": im_file, + "shape": shape, + "cls": lb[:, 0:1], # n, 1 + "bboxes": lb[:, 1:], # n, 4 + "segments": segments, + "keypoints": keypoint, + "normalized": True, + "bbox_format": "xywh", + } + ) + if msg: + msgs.append(msg) + pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt" + pbar.close() + + if msgs: + LOGGER.info("\n".join(msgs)) + if nf == 0: + LOGGER.warning(f"{self.prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}") + x["hash"] = get_hash(self.label_files + self.im_files) + x["results"] = nf, nm, ne, nc, len(self.im_files) + x["msgs"] = msgs # warnings + save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) + return x + + def get_labels(self): + """Returns dictionary of labels for YOLO training.""" + self.label_files = img2label_paths(self.im_files) + cache_path = Path(self.label_files[0]).parent.with_suffix(".cache") + try: + cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash + except (FileNotFoundError, AssertionError, AttributeError): + cache, exists = self.cache_labels(cache_path), False # run cache ops + + # Display cache + nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total + if exists and LOCAL_RANK in {-1, 0}: + d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt" + TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings + + # Read cache + [cache.pop(k) for k in ("hash", "version", "msgs")] # remove items + labels = cache["labels"] + if not labels: + LOGGER.warning(f"WARNING ⚠️ No images found in {cache_path}, training may not work correctly. {HELP_URL}") + self.im_files = [lb["im_file"] for lb in labels] # update im_files + + # Check if the dataset is all boxes or all segments + lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels) + len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths)) + if len_segments and len_boxes != len_segments: + LOGGER.warning( + f"WARNING ⚠️ Box and segment counts should be equal, but got len(segments) = {len_segments}, " + f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. " + "To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset." + ) + for lb in labels: + lb["segments"] = [] + if len_cls == 0: + LOGGER.warning(f"WARNING ⚠️ No labels found in {cache_path}, training may not work correctly. {HELP_URL}") + return labels + + def build_transforms(self, hyp=None): + """Builds and appends transforms to the list.""" + if self.augment: + hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 + hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 + transforms = v8_transforms(self, self.imgsz, hyp) + else: + transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)]) + transforms.append( + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + return_obb=self.use_obb, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + bgr=hyp.bgr if self.augment else 0.0, # only affect training. + ) + ) + return transforms + + def close_mosaic(self, hyp): + """Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations.""" + hyp.mosaic = 0.0 # set mosaic ratio=0.0 + hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic + hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic + self.transforms = self.build_transforms(hyp) + + def update_labels_info(self, label): + """ + Custom your label format here. + + Note: + cls is not with bboxes now, classification and semantic segmentation need an independent cls label + Can also support classification and semantic segmentation by adding or removing dict keys there. + """ + bboxes = label.pop("bboxes") + segments = label.pop("segments", []) + keypoints = label.pop("keypoints", None) + bbox_format = label.pop("bbox_format") + normalized = label.pop("normalized") + + # NOTE: do NOT resample oriented boxes + segment_resamples = 100 if self.use_obb else 1000 + if len(segments) > 0: + # make sure segments interpolate correctly if original length is greater than segment_resamples + max_len = max(len(s) for s in segments) + segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples + # list[np.array(segment_resamples, 2)] * num_samples + segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0) + else: + segments = np.zeros((0, segment_resamples, 2), dtype=np.float32) + label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized) + return label + + @staticmethod + def collate_fn(batch): + """Collates data samples into batches.""" + new_batch = {} + keys = batch[0].keys() + values = list(zip(*[list(b.values()) for b in batch])) + for i, k in enumerate(keys): + value = values[i] + if k == "img": + value = torch.stack(value, 0) + if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}: + value = torch.cat(value, 0) + new_batch[k] = value + new_batch["batch_idx"] = list(new_batch["batch_idx"]) + for i in range(len(new_batch["batch_idx"])): + new_batch["batch_idx"][i] += i # add target image index for build_targets() + new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0) + return new_batch + + +class YOLOMultiModalDataset(YOLODataset): + """ + Dataset class for loading object detection and/or segmentation labels in YOLO format. + + Args: + data (dict, optional): A dataset YAML dictionary. Defaults to None. + task (str): An explicit arg to point current task, Defaults to 'detect'. + + Returns: + (torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. + """ + + def __init__(self, *args, data=None, task="detect", **kwargs): + """Initializes a dataset object for object detection tasks with optional specifications.""" + super().__init__(*args, data=data, task=task, **kwargs) + + def update_labels_info(self, label): + """Add texts information for multi-modal model training.""" + labels = super().update_labels_info(label) + # NOTE: some categories are concatenated with its synonyms by `/`. + labels["texts"] = [v.split("/") for _, v in self.data["names"].items()] + return labels + + def build_transforms(self, hyp=None): + """Enhances data transformations with optional text augmentation for multi-modal training.""" + transforms = super().build_transforms(hyp) + if self.augment: + # NOTE: hard-coded the args for now. + transforms.insert(-1, RandomLoadText(max_samples=min(self.data["nc"], 80), padding=True)) + return transforms + + +class GroundingDataset(YOLODataset): + """Handles object detection tasks by loading annotations from a specified JSON file, supporting YOLO format.""" + + def __init__(self, *args, task="detect", json_file, **kwargs): + """Initializes a GroundingDataset for object detection, loading annotations from a specified JSON file.""" + assert task == "detect", "`GroundingDataset` only support `detect` task for now!" + self.json_file = json_file + super().__init__(*args, task=task, data={}, **kwargs) + + def get_img_files(self, img_path): + """The image files would be read in `get_labels` function, return empty list here.""" + return [] + + def get_labels(self): + """Loads annotations from a JSON file, filters, and normalizes bounding boxes for each image.""" + labels = [] + LOGGER.info("Loading annotation file...") + with open(self.json_file) as f: + annotations = json.load(f) + images = {f"{x['id']:d}": x for x in annotations["images"]} + img_to_anns = defaultdict(list) + for ann in annotations["annotations"]: + img_to_anns[ann["image_id"]].append(ann) + for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"): + img = images[f"{img_id:d}"] + h, w, f = img["height"], img["width"], img["file_name"] + im_file = Path(self.img_path) / f + if not im_file.exists(): + continue + self.im_files.append(str(im_file)) + bboxes = [] + cat2id = {} + texts = [] + for ann in anns: + if ann["iscrowd"]: + continue + box = np.array(ann["bbox"], dtype=np.float32) + box[:2] += box[2:] / 2 + box[[0, 2]] /= float(w) + box[[1, 3]] /= float(h) + if box[2] <= 0 or box[3] <= 0: + continue + + caption = img["caption"] + cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]) + if cat_name not in cat2id: + cat2id[cat_name] = len(cat2id) + texts.append([cat_name]) + cls = cat2id[cat_name] # class + box = [cls] + box.tolist() + if box not in bboxes: + bboxes.append(box) + lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32) + labels.append( + { + "im_file": im_file, + "shape": (h, w), + "cls": lb[:, 0:1], # n, 1 + "bboxes": lb[:, 1:], # n, 4 + "normalized": True, + "bbox_format": "xywh", + "texts": texts, + } + ) + return labels + + def build_transforms(self, hyp=None): + """Configures augmentations for training with optional text loading; `hyp` adjusts augmentation intensity.""" + transforms = super().build_transforms(hyp) + if self.augment: + # NOTE: hard-coded the args for now. + transforms.insert(-1, RandomLoadText(max_samples=80, padding=True)) + return transforms + + +class YOLOConcatDataset(ConcatDataset): + """ + Dataset as a concatenation of multiple datasets. + + This class is useful to assemble different existing datasets. + """ + + @staticmethod + def collate_fn(batch): + """Collates data samples into batches.""" + return YOLODataset.collate_fn(batch) + + +# TODO: support semantic segmentation +class SemanticDataset(BaseDataset): + """ + Semantic Segmentation Dataset. + + This class is responsible for handling datasets used for semantic segmentation tasks. It inherits functionalities + from the BaseDataset class. + + Note: + This class is currently a placeholder and needs to be populated with methods and attributes for supporting + semantic segmentation tasks. + """ + + def __init__(self): + """Initialize a SemanticDataset object.""" + super().__init__() + + +class ClassificationDataset: + """ + Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image + augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep + learning models, with optional image transformations and caching mechanisms to speed up training. + + This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images + in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process + to ensure data integrity and consistency. + + Attributes: + cache_ram (bool): Indicates if caching in RAM is enabled. + cache_disk (bool): Indicates if caching on disk is enabled. + samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache + file (if caching on disk), and optionally the loaded image array (if caching in RAM). + torch_transforms (callable): PyTorch transforms to be applied to the images. + """ + + def __init__(self, root, args, augment=False, prefix=""): + """ + Initialize YOLO object with root, image size, augmentations, and cache settings. + + Args: + root (str): Path to the dataset directory where images are stored in a class-specific folder structure. + args (Namespace): Configuration containing dataset-related settings such as image size, augmentation + parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction + of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training), + `auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`. + augment (bool, optional): Whether to apply augmentations to the dataset. Default is False. + prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and + debugging. Default is an empty string. + """ + import torchvision # scope for faster 'import ultralytics' + + # Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import + if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18 + self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True) + else: + self.base = torchvision.datasets.ImageFolder(root=root) + self.samples = self.base.samples + self.root = self.base.root + + # Initialize attributes + if augment and args.fraction < 1.0: # reduce training fraction + self.samples = self.samples[: round(len(self.samples) * args.fraction)] + self.prefix = colorstr(f"{prefix}: ") if prefix else "" + self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM + if self.cache_ram: + LOGGER.warning( + "WARNING ⚠️ Classification `cache_ram` training has known memory leak in " + "https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`." + ) + self.cache_ram = False + self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files + self.samples = self.verify_images() # filter out bad images + self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im + scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) + self.torch_transforms = ( + classify_augmentations( + size=args.imgsz, + scale=scale, + hflip=args.fliplr, + vflip=args.flipud, + erasing=args.erasing, + auto_augment=args.auto_augment, + hsv_h=args.hsv_h, + hsv_s=args.hsv_s, + hsv_v=args.hsv_v, + ) + if augment + else classify_transforms(size=args.imgsz, crop_fraction=args.crop_fraction) + ) + + def __getitem__(self, i): + """Returns subset of data and targets corresponding to given indices.""" + f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image + if self.cache_ram: + if im is None: # Warning: two separate if statements required here, do not combine this with previous line + im = self.samples[i][3] = cv2.imread(f) + elif self.cache_disk: + if not fn.exists(): # load npy + np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False) + im = np.load(fn) + else: # read image + im = cv2.imread(f) # BGR + # Convert NumPy array to PIL image + im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) + sample = self.torch_transforms(im) + return {"img": sample, "cls": j} + + def __len__(self) -> int: + """Return the total number of samples in the dataset.""" + return len(self.samples) + + def verify_images(self): + """Verify all images in dataset.""" + desc = f"{self.prefix}Scanning {self.root}..." + path = Path(self.root).with_suffix(".cache") # *.cache file path + + try: + cache = load_dataset_cache_file(path) # attempt to load a *.cache file + assert cache["version"] == DATASET_CACHE_VERSION # matches current version + assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash + nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total + if LOCAL_RANK in {-1, 0}: + d = f"{desc} {nf} images, {nc} corrupt" + TQDM(None, desc=d, total=n, initial=n) + if cache["msgs"]: + LOGGER.info("\n".join(cache["msgs"])) # display warnings + return samples + + except (FileNotFoundError, AssertionError, AttributeError): + # Run scan if *.cache retrieval failed + nf, nc, msgs, samples, x = 0, 0, [], [], {} + with ThreadPool(NUM_THREADS) as pool: + results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix))) + pbar = TQDM(results, desc=desc, total=len(self.samples)) + for sample, nf_f, nc_f, msg in pbar: + if nf_f: + samples.append(sample) + if msg: + msgs.append(msg) + nf += nf_f + nc += nc_f + pbar.desc = f"{desc} {nf} images, {nc} corrupt" + pbar.close() + if msgs: + LOGGER.info("\n".join(msgs)) + x["hash"] = get_hash([x[0] for x in self.samples]) + x["results"] = nf, nc, len(samples), samples + x["msgs"] = msgs # warnings + save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION) + return samples diff --git a/ultralytics/data/loaders.py b/ultralytics/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..3a04bb0383d9ea8d21426327b0a1d309dd70f607 --- /dev/null +++ b/ultralytics/data/loaders.py @@ -0,0 +1,658 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import glob +import math +import os +import time +from dataclasses import dataclass +from pathlib import Path +from threading import Thread +from urllib.parse import urlparse + +import cv2 +import numpy as np +import requests +import torch +from PIL import Image + +from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS +from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.patches import imread + + +@dataclass +class SourceTypes: + """ + Class to represent various types of input sources for predictions. + + This class uses dataclass to define boolean flags for different types of input sources that can be used for + making predictions with YOLO models. + + Attributes: + stream (bool): Flag indicating if the input source is a video stream. + screenshot (bool): Flag indicating if the input source is a screenshot. + from_img (bool): Flag indicating if the input source is an image file. + + Examples: + >>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False) + >>> print(source_types.stream) + True + >>> print(source_types.from_img) + False + """ + + stream: bool = False + screenshot: bool = False + from_img: bool = False + tensor: bool = False + + +class LoadStreams: + """ + Stream Loader for various types of video streams. + + Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video + streams simultaneously, making it suitable for real-time video analysis tasks. + + Attributes: + sources (List[str]): The source input paths or URLs for the video streams. + vid_stride (int): Video frame-rate stride. + buffer (bool): Whether to buffer input streams. + running (bool): Flag to indicate if the streaming thread is running. + mode (str): Set to 'stream' indicating real-time capture. + imgs (List[List[np.ndarray]]): List of image frames for each stream. + fps (List[float]): List of FPS for each stream. + frames (List[int]): List of total frames for each stream. + threads (List[Thread]): List of threads for each stream. + shape (List[Tuple[int, int, int]]): List of shapes for each stream. + caps (List[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream. + bs (int): Batch size for processing. + + Methods: + update: Read stream frames in daemon thread. + close: Close stream loader and release resources. + __iter__: Returns an iterator object for the class. + __next__: Returns source paths, transformed, and original images for processing. + __len__: Return the length of the sources object. + + Examples: + >>> stream_loader = LoadStreams("rtsp://example.com/stream1.mp4") + >>> for sources, imgs, _ in stream_loader: + ... # Process the images + ... pass + >>> stream_loader.close() + + Notes: + - The class uses threading to efficiently load frames from multiple streams simultaneously. + - It automatically handles YouTube links, converting them to the best available stream URL. + - The class implements a buffer system to manage frame storage and retrieval. + """ + + def __init__(self, sources="file.streams", vid_stride=1, buffer=False): + """Initialize stream loader for multiple video sources, supporting various stream types.""" + torch.backends.cudnn.benchmark = True # faster for fixed-size inference + self.buffer = buffer # buffer input streams + self.running = True # running flag for Thread + self.mode = "stream" + self.vid_stride = vid_stride # video frame-rate stride + + sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources] + n = len(sources) + self.bs = n + self.fps = [0] * n # frames per second + self.frames = [0] * n + self.threads = [None] * n + self.caps = [None] * n # video capture objects + self.imgs = [[] for _ in range(n)] # images + self.shape = [[] for _ in range(n)] # image shapes + self.sources = [ops.clean_str(x) for x in sources] # clean source names for later + for i, s in enumerate(sources): # index, source + # Start thread to read frames from video stream + st = f"{i + 1}/{n}: {s}... " + if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video + # YouTube format i.e. 'https://www.youtube.com/watch?v=Jsn8D3aC840' or 'https://youtu.be/Jsn8D3aC840' + s = get_best_youtube_url(s) + s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam + if s == 0 and (IS_COLAB or IS_KAGGLE): + raise NotImplementedError( + "'source=0' webcam not supported in Colab and Kaggle notebooks. " + "Try running 'source=0' in a local environment." + ) + self.caps[i] = cv2.VideoCapture(s) # store video capture object + if not self.caps[i].isOpened(): + raise ConnectionError(f"{st}Failed to open {s}") + w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH)) + h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan + self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float( + "inf" + ) # infinite stream fallback + self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback + + success, im = self.caps[i].read() # guarantee first frame + if not success or im is None: + raise ConnectionError(f"{st}Failed to read images from {s}") + self.imgs[i].append(im) + self.shape[i] = im.shape + self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True) + LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)") + self.threads[i].start() + LOGGER.info("") # newline + + def update(self, i, cap, stream): + """Read stream frames in daemon thread and update image buffer.""" + n, f = 0, self.frames[i] # frame number, frame array + while self.running and cap.isOpened() and n < (f - 1): + if len(self.imgs[i]) < 30: # keep a <=30-image buffer + n += 1 + cap.grab() # .read() = .grab() followed by .retrieve() + if n % self.vid_stride == 0: + success, im = cap.retrieve() + if not success: + im = np.zeros(self.shape[i], dtype=np.uint8) + LOGGER.warning("WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.") + cap.open(stream) # re-open stream if signal was lost + if self.buffer: + self.imgs[i].append(im) + else: + self.imgs[i] = [im] + else: + time.sleep(0.01) # wait until the buffer is empty + + def close(self): + """Terminates stream loader, stops threads, and releases video capture resources.""" + self.running = False # stop flag for Thread + for thread in self.threads: + if thread.is_alive(): + thread.join(timeout=5) # Add timeout + for cap in self.caps: # Iterate through the stored VideoCapture objects + try: + cap.release() # release video capture + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ Could not release VideoCapture object: {e}") + cv2.destroyAllWindows() + + def __iter__(self): + """Iterates through YOLO image feed and re-opens unresponsive streams.""" + self.count = -1 + return self + + def __next__(self): + """Returns the next batch of frames from multiple video streams for processing.""" + self.count += 1 + + images = [] + for i, x in enumerate(self.imgs): + # Wait until a frame is available in each buffer + while not x: + if not self.threads[i].is_alive() or cv2.waitKey(1) == ord("q"): # q to quit + self.close() + raise StopIteration + time.sleep(1 / min(self.fps)) + x = self.imgs[i] + if not x: + LOGGER.warning(f"WARNING ⚠️ Waiting for stream {i}") + + # Get and remove the first frame from imgs buffer + if self.buffer: + images.append(x.pop(0)) + + # Get the last frame, and clear the rest from the imgs buffer + else: + images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8)) + x.clear() + + return self.sources, images, [""] * self.bs + + def __len__(self): + """Return the number of video streams in the LoadStreams object.""" + return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years + + +class LoadScreenshots: + """ + Ultralytics screenshot dataloader for capturing and processing screen images. + + This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with + `yolo predict source=screen`. + + Attributes: + source (str): The source input indicating which screen to capture. + screen (int): The screen number to capture. + left (int): The left coordinate for screen capture area. + top (int): The top coordinate for screen capture area. + width (int): The width of the screen capture area. + height (int): The height of the screen capture area. + mode (str): Set to 'stream' indicating real-time capture. + frame (int): Counter for captured frames. + sct (mss.mss): Screen capture object from `mss` library. + bs (int): Batch size, set to 1. + fps (int): Frames per second, set to 30. + monitor (Dict[str, int]): Monitor configuration details. + + Methods: + __iter__: Returns an iterator object. + __next__: Captures the next screenshot and returns it. + + Examples: + >>> loader = LoadScreenshots("0 100 100 640 480") # screen 0, top-left (100,100), 640x480 + >>> for source, im, im0s, vid_cap, s in loader: + ... print(f"Captured frame: {im.shape}") + """ + + def __init__(self, source): + """Initialize screenshot capture with specified screen and region parameters.""" + check_requirements("mss") + import mss # noqa + + source, *params = source.split() + self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0 + if len(params) == 1: + self.screen = int(params[0]) + elif len(params) == 4: + left, top, width, height = (int(x) for x in params) + elif len(params) == 5: + self.screen, left, top, width, height = (int(x) for x in params) + self.mode = "stream" + self.frame = 0 + self.sct = mss.mss() + self.bs = 1 + self.fps = 30 + + # Parse monitor shape + monitor = self.sct.monitors[self.screen] + self.top = monitor["top"] if top is None else (monitor["top"] + top) + self.left = monitor["left"] if left is None else (monitor["left"] + left) + self.width = width or monitor["width"] + self.height = height or monitor["height"] + self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height} + + def __iter__(self): + """Yields the next screenshot image from the specified screen or region for processing.""" + return self + + def __next__(self): + """Captures and returns the next screenshot as a numpy array using the mss library.""" + im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR + s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: " + + self.frame += 1 + return [str(self.screen)], [im0], [s] # screen, img, string + + +class LoadImagesAndVideos: + """ + A class for loading and processing images and videos for YOLO object detection. + + This class manages the loading and pre-processing of image and video data from various sources, including + single image files, video files, and lists of image and video paths. + + Attributes: + files (List[str]): List of image and video file paths. + nf (int): Total number of files (images and videos). + video_flag (List[bool]): Flags indicating whether a file is a video (True) or an image (False). + mode (str): Current mode, 'image' or 'video'. + vid_stride (int): Stride for video frame-rate. + bs (int): Batch size. + cap (cv2.VideoCapture): Video capture object for OpenCV. + frame (int): Frame counter for video. + frames (int): Total number of frames in the video. + count (int): Counter for iteration, initialized at 0 during __iter__(). + ni (int): Number of images. + + Methods: + __init__: Initialize the LoadImagesAndVideos object. + __iter__: Returns an iterator object for VideoStream or ImageFolder. + __next__: Returns the next batch of images or video frames along with their paths and metadata. + _new_video: Creates a new video capture object for the given path. + __len__: Returns the number of batches in the object. + + Examples: + >>> loader = LoadImagesAndVideos("path/to/data", batch=32, vid_stride=1) + >>> for paths, imgs, info in loader: + ... # Process batch of images or video frames + ... pass + + Notes: + - Supports various image formats including HEIC. + - Handles both local files and directories. + - Can read from a text file containing paths to images and videos. + """ + + def __init__(self, path, batch=1, vid_stride=1): + """Initialize dataloader for images and videos, supporting various input formats.""" + parent = None + if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line + parent = Path(path).parent + path = Path(path).read_text().splitlines() # list of sources + files = [] + for p in sorted(path) if isinstance(path, (list, tuple)) else [path]: + a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912 + if "*" in a: + files.extend(sorted(glob.glob(a, recursive=True))) # glob + elif os.path.isdir(a): + files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir + elif os.path.isfile(a): + files.append(a) # files (absolute or relative to CWD) + elif parent and (parent / p).is_file(): + files.append(str((parent / p).absolute())) # files (relative to *.txt file parent) + else: + raise FileNotFoundError(f"{p} does not exist") + + # Define files as images or videos + images, videos = [], [] + for f in files: + suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase + if suffix in IMG_FORMATS: + images.append(f) + elif suffix in VID_FORMATS: + videos.append(f) + ni, nv = len(images), len(videos) + + self.files = images + videos + self.nf = ni + nv # number of files + self.ni = ni # number of images + self.video_flag = [False] * ni + [True] * nv + self.mode = "video" if ni == 0 else "image" # default to video if no images + self.vid_stride = vid_stride # video frame-rate stride + self.bs = batch + if any(videos): + self._new_video(videos[0]) # new video + else: + self.cap = None + if self.nf == 0: + raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}") + + def __iter__(self): + """Iterates through image/video files, yielding source paths, images, and metadata.""" + self.count = 0 + return self + + def __next__(self): + """Returns the next batch of images or video frames with their paths and metadata.""" + paths, imgs, info = [], [], [] + while len(imgs) < self.bs: + if self.count >= self.nf: # end of file list + if imgs: + return paths, imgs, info # return last partial batch + else: + raise StopIteration + + path = self.files[self.count] + if self.video_flag[self.count]: + self.mode = "video" + if not self.cap or not self.cap.isOpened(): + self._new_video(path) + + success = False + for _ in range(self.vid_stride): + success = self.cap.grab() + if not success: + break # end of video or failure + + if success: + success, im0 = self.cap.retrieve() + if success: + self.frame += 1 + paths.append(path) + imgs.append(im0) + info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ") + if self.frame == self.frames: # end of video + self.count += 1 + self.cap.release() + else: + # Move to the next file if the current video ended or failed to open + self.count += 1 + if self.cap: + self.cap.release() + if self.count < self.nf: + self._new_video(self.files[self.count]) + else: + # Handle image files (including HEIC) + self.mode = "image" + if path.split(".")[-1].lower() == "heic": + # Load HEIC image using Pillow with pillow-heif + check_requirements("pillow-heif") + + from pillow_heif import register_heif_opener + + register_heif_opener() # Register HEIF opener with Pillow + with Image.open(path) as img: + im0 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # convert image to BGR nparray + else: + im0 = imread(path) # BGR + if im0 is None: + LOGGER.warning(f"WARNING ⚠️ Image Read Error {path}") + else: + paths.append(path) + imgs.append(im0) + info.append(f"image {self.count + 1}/{self.nf} {path}: ") + self.count += 1 # move to the next file + if self.count >= self.ni: # end of image list + break + + return paths, imgs, info + + def _new_video(self, path): + """Creates a new video capture object for the given path and initializes video-related attributes.""" + self.frame = 0 + self.cap = cv2.VideoCapture(path) + self.fps = int(self.cap.get(cv2.CAP_PROP_FPS)) + if not self.cap.isOpened(): + raise FileNotFoundError(f"Failed to open video {path}") + self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride) + + def __len__(self): + """Returns the number of files (images and videos) in the dataset.""" + return math.ceil(self.nf / self.bs) # number of batches + + +class LoadPilAndNumpy: + """ + Load images from PIL and Numpy arrays for batch processing. + + This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic + validation and format conversion to ensure that the images are in the required format for downstream processing. + + Attributes: + paths (List[str]): List of image paths or autogenerated filenames. + im0 (List[np.ndarray]): List of images stored as Numpy arrays. + mode (str): Type of data being processed, set to 'image'. + bs (int): Batch size, equivalent to the length of `im0`. + + Methods: + _single_check: Validate and format a single image to a Numpy array. + + Examples: + >>> from PIL import Image + >>> import numpy as np + >>> pil_img = Image.new("RGB", (100, 100)) + >>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + >>> loader = LoadPilAndNumpy([pil_img, np_img]) + >>> paths, images, _ = next(iter(loader)) + >>> print(f"Loaded {len(images)} images") + Loaded 2 images + """ + + def __init__(self, im0): + """Initializes a loader for PIL and Numpy images, converting inputs to a standardized format.""" + if not isinstance(im0, list): + im0 = [im0] + # use `image{i}.jpg` when Image.filename returns an empty path. + self.paths = [getattr(im, "filename", "") or f"image{i}.jpg" for i, im in enumerate(im0)] + self.im0 = [self._single_check(im) for im in im0] + self.mode = "image" + self.bs = len(self.im0) + + @staticmethod + def _single_check(im): + """Validate and format an image to numpy array, ensuring RGB order and contiguous memory.""" + assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}" + if isinstance(im, Image.Image): + if im.mode != "RGB": + im = im.convert("RGB") + im = np.asarray(im)[:, :, ::-1] + im = np.ascontiguousarray(im) # contiguous + return im + + def __len__(self): + """Returns the length of the 'im0' attribute, representing the number of loaded images.""" + return len(self.im0) + + def __next__(self): + """Returns the next batch of images, paths, and metadata for processing.""" + if self.count == 1: # loop only once as it's batch inference + raise StopIteration + self.count += 1 + return self.paths, self.im0, [""] * self.bs + + def __iter__(self): + """Iterates through PIL/numpy images, yielding paths, raw images, and metadata for processing.""" + self.count = 0 + return self + + +class LoadTensor: + """ + A class for loading and processing tensor data for object detection tasks. + + This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for + further processing in object detection pipelines. + + Attributes: + im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W). + bs (int): Batch size, inferred from the shape of `im0`. + mode (str): Current processing mode, set to 'image'. + paths (List[str]): List of image paths or auto-generated filenames. + + Methods: + _single_check: Validates and formats an input tensor. + + Examples: + >>> import torch + >>> tensor = torch.rand(1, 3, 640, 640) + >>> loader = LoadTensor(tensor) + >>> paths, images, info = next(iter(loader)) + >>> print(f"Processed {len(images)} images") + """ + + def __init__(self, im0) -> None: + """Initialize LoadTensor object for processing torch.Tensor image data.""" + self.im0 = self._single_check(im0) + self.bs = self.im0.shape[0] + self.mode = "image" + self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)] + + @staticmethod + def _single_check(im, stride=32): + """Validates and formats a single image tensor, ensuring correct shape and normalization.""" + s = ( + f"WARNING ⚠️ torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) " + f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible." + ) + if len(im.shape) != 4: + if len(im.shape) != 3: + raise ValueError(s) + LOGGER.warning(s) + im = im.unsqueeze(0) + if im.shape[2] % stride or im.shape[3] % stride: + raise ValueError(s) + if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07 + LOGGER.warning( + f"WARNING ⚠️ torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. " + f"Dividing input by 255." + ) + im = im.float() / 255.0 + + return im + + def __iter__(self): + """Yields an iterator object for iterating through tensor image data.""" + self.count = 0 + return self + + def __next__(self): + """Yields the next batch of tensor images and metadata for processing.""" + if self.count == 1: + raise StopIteration + self.count += 1 + return self.paths, self.im0, [""] * self.bs + + def __len__(self): + """Returns the batch size of the tensor input.""" + return self.bs + + +def autocast_list(source): + """Merges a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction.""" + files = [] + for im in source: + if isinstance(im, (str, Path)): # filename or uri + files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith("http") else im)) + elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image + files.append(im) + else: + raise TypeError( + f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n" + f"See https://docs.ultralytics.com/modes/predict for supported source types." + ) + + return files + + +def get_best_youtube_url(url, method="pytube"): + """ + Retrieves the URL of the best quality MP4 video stream from a given YouTube video. + + Args: + url (str): The URL of the YouTube video. + method (str): The method to use for extracting video info. Options are "pytube", "pafy", and "yt-dlp". + Defaults to "pytube". + + Returns: + (str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found. + + Examples: + >>> url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" + >>> best_url = get_best_youtube_url(url) + >>> print(best_url) + https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=... + + Notes: + - Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp. + - The function prioritizes streams with at least 1080p resolution when available. + - For the "yt-dlp" method, it looks for formats with video codec, no audio, and *.mp4 extension. + """ + if method == "pytube": + # Switched from pytube to pytubefix to resolve https://github.com/pytube/pytube/issues/1954 + check_requirements("pytubefix>=6.5.2") + from pytubefix import YouTube + + streams = YouTube(url).streams.filter(file_extension="mp4", only_video=True) + streams = sorted(streams, key=lambda s: s.resolution, reverse=True) # sort streams by resolution + for stream in streams: + if stream.resolution and int(stream.resolution[:-1]) >= 1080: # check if resolution is at least 1080p + return stream.url + + elif method == "pafy": + check_requirements(("pafy", "youtube_dl==2020.12.2")) + import pafy # noqa + + return pafy.new(url).getbestvideo(preftype="mp4").url + + elif method == "yt-dlp": + check_requirements("yt-dlp") + import yt_dlp + + with yt_dlp.YoutubeDL({"quiet": True}) as ydl: + info_dict = ydl.extract_info(url, download=False) # extract info + for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last + # Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size + good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080 + if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4": + return f.get("url") + + +# Define constants +LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots) diff --git a/ultralytics/data/scripts/download_weights.sh b/ultralytics/data/scripts/download_weights.sh new file mode 100644 index 0000000000000000000000000000000000000000..f8a739f6d6199873b156da9bb24bf3c43edcff3d --- /dev/null +++ b/ultralytics/data/scripts/download_weights.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Ultralytics YOLO 🚀, AGPL-3.0 license +# Download latest models from https://github.com/ultralytics/assets/releases +# Example usage: bash ultralytics/data/scripts/download_weights.sh +# parent +# └── weights +# ├── yolov8n.pt ← downloads here +# ├── yolov8s.pt +# └── ... + +python - < gap, f"invalid crop_size gap pair [{crop_size} {gap}]" + step = crop_size - gap + + xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1) + xs = [step * i for i in range(xn)] + if len(xs) > 1 and xs[-1] + crop_size > w: + xs[-1] = w - crop_size + + yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1) + ys = [step * i for i in range(yn)] + if len(ys) > 1 and ys[-1] + crop_size > h: + ys[-1] = h - crop_size + + start = np.array(list(itertools.product(xs, ys)), dtype=np.int64) + stop = start + crop_size + windows.append(np.concatenate([start, stop], axis=1)) + windows = np.concatenate(windows, axis=0) + + im_in_wins = windows.copy() + im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w) + im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h) + im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1]) + win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1]) + im_rates = im_areas / win_areas + if not (im_rates > im_rate_thr).any(): + max_rate = im_rates.max() + im_rates[abs(im_rates - max_rate) < eps] = 1 + return windows[im_rates > im_rate_thr] + + +def get_window_obj(anno, windows, iof_thr=0.7): + """Get objects for each window.""" + h, w = anno["ori_size"] + label = anno["label"] + if len(label): + label[:, 1::2] *= w + label[:, 2::2] *= h + iofs = bbox_iof(label[:, 1:], windows) + # Unnormalized and misaligned coordinates + return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns + else: + return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns + + +def crop_and_save(anno, windows, window_objs, im_dir, lb_dir, allow_background_images=True): + """ + Crop images and save new labels. + + Args: + anno (dict): Annotation dict, including `filepath`, `label`, `ori_size` as its keys. + windows (list): A list of windows coordinates. + window_objs (list): A list of labels inside each window. + im_dir (str): The output directory path of images. + lb_dir (str): The output directory path of labels. + allow_background_images (bool): Whether to include background images without labels. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + """ + im = cv2.imread(anno["filepath"]) + name = Path(anno["filepath"]).stem + for i, window in enumerate(windows): + x_start, y_start, x_stop, y_stop = window.tolist() + new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}" + patch_im = im[y_start:y_stop, x_start:x_stop] + ph, pw = patch_im.shape[:2] + + label = window_objs[i] + if len(label) or allow_background_images: + cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im) + if len(label): + label[:, 1::2] -= x_start + label[:, 2::2] -= y_start + label[:, 1::2] /= pw + label[:, 2::2] /= ph + + with open(Path(lb_dir) / f"{new_name}.txt", "w") as f: + for lb in label: + formatted_coords = [f"{coord:.6g}" for coord in lb[1:]] + f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n") + + +def split_images_and_labels(data_root, save_dir, split="train", crop_sizes=(1024,), gaps=(200,)): + """ + Split both images and labels. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - split + - labels + - split + and the output directory structure is: + - save_dir + - images + - split + - labels + - split + """ + im_dir = Path(save_dir) / "images" / split + im_dir.mkdir(parents=True, exist_ok=True) + lb_dir = Path(save_dir) / "labels" / split + lb_dir.mkdir(parents=True, exist_ok=True) + + annos = load_yolo_dota(data_root, split=split) + for anno in tqdm(annos, total=len(annos), desc=split): + windows = get_windows(anno["ori_size"], crop_sizes, gaps) + window_objs = get_window_obj(anno, windows) + crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir)) + + +def split_trainval(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)): + """ + Split train and val set of DOTA. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - train + - val + - labels + - train + - val + and the output directory structure is: + - save_dir + - images + - train + - val + - labels + - train + - val + """ + crop_sizes, gaps = [], [] + for r in rates: + crop_sizes.append(int(crop_size / r)) + gaps.append(int(gap / r)) + for split in ["train", "val"]: + split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps) + + +def split_test(data_root, save_dir, crop_size=1024, gap=200, rates=(1.0,)): + """ + Split test set of DOTA, labels are not included within this set. + + Notes: + The directory structure assumed for the DOTA dataset: + - data_root + - images + - test + and the output directory structure is: + - save_dir + - images + - test + """ + crop_sizes, gaps = [], [] + for r in rates: + crop_sizes.append(int(crop_size / r)) + gaps.append(int(gap / r)) + save_dir = Path(save_dir) / "images" / "test" + save_dir.mkdir(parents=True, exist_ok=True) + + im_dir = Path(data_root) / "images" / "test" + assert im_dir.exists(), f"Can't find {im_dir}, please check your data root." + im_files = glob(str(im_dir / "*")) + for im_file in tqdm(im_files, total=len(im_files), desc="test"): + w, h = exif_size(Image.open(im_file)) + windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps) + im = cv2.imread(im_file) + name = Path(im_file).stem + for window in windows: + x_start, y_start, x_stop, y_stop = window.tolist() + new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}" + patch_im = im[y_start:y_stop, x_start:x_stop] + cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im) + + +if __name__ == "__main__": + split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split") + split_test(data_root="DOTAv2", save_dir="DOTAv2-split") diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50b597d86115ba713f877c5d672460b7e26266e5 --- /dev/null +++ b/ultralytics/data/utils.py @@ -0,0 +1,721 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import hashlib +import json +import os +import random +import subprocess +import time +import zipfile +from multiprocessing.pool import ThreadPool +from pathlib import Path +from tarfile import is_tarfile + +import cv2 +import numpy as np +from PIL import Image, ImageOps + +from ultralytics.nn.autobackend import check_class_names +from ultralytics.utils import ( + DATASETS_DIR, + LOGGER, + NUM_THREADS, + ROOT, + SETTINGS_FILE, + TQDM, + clean_url, + colorstr, + emojis, + is_dir_writeable, + yaml_load, + yaml_save, +) +from ultralytics.utils.checks import check_file, check_font, is_ascii +from ultralytics.utils.downloads import download, safe_download, unzip_file +from ultralytics.utils.ops import segments2boxes + +HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance." +IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes +VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes +PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders +FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}" + + +def img2label_paths(img_paths): + """Define label paths as a function of image paths.""" + sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings + return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths] + + +def get_hash(paths): + """Returns a single hash value of a list of paths (files or dirs).""" + size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes + h = hashlib.sha256(str(size).encode()) # hash sizes + h.update("".join(paths).encode()) # hash paths + return h.hexdigest() # return hash + + +def exif_size(img: Image.Image): + """Returns exif-corrected PIL size.""" + s = img.size # (width, height) + if img.format == "JPEG": # only support JPEG images + try: + if exif := img.getexif(): + rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274 + if rotation in {6, 8}: # rotation 270 or 90 + s = s[1], s[0] + except Exception: + pass + return s + + +def verify_image(args): + """Verify one image.""" + (im_file, cls), prefix = args + # Number (found, corrupt), message + nf, nc, msg = 0, 0, "" + try: + im = Image.open(im_file) + im.verify() # PIL verify + shape = exif_size(im) # image size + shape = (shape[1], shape[0]) # hw + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}" + if im.format.lower() in {"jpg", "jpeg"}: + with open(im_file, "rb") as f: + f.seek(-2, 2) + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" + nf = 1 + except Exception as e: + nc = 1 + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" + return (im_file, cls), nf, nc, msg + + +def verify_image_label(args): + """Verify one image-label pair.""" + im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim = args + # Number (missing, found, empty, corrupt), message, segments, keypoints + nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None + try: + # Verify images + im = Image.open(im_file) + im.verify() # PIL verify + shape = exif_size(im) # image size + shape = (shape[1], shape[0]) # hw + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}" + if im.format.lower() in {"jpg", "jpeg"}: + with open(im_file, "rb") as f: + f.seek(-2, 2) + if f.read() != b"\xff\xd9": # corrupt JPEG + ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100) + msg = f"{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved" + + # Verify labels + if os.path.isfile(lb_file): + nf = 1 # label found + with open(lb_file) as f: + lb = [x.split() for x in f.read().strip().splitlines() if len(x)] + if any(len(x) > 6 for x in lb) and (not keypoint): # is segment + classes = np.array([x[0] for x in lb], dtype=np.float32) + segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...) + lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh) + lb = np.array(lb, dtype=np.float32) + if nl := len(lb): + if keypoint: + assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each" + points = lb[:, 5:].reshape(-1, ndim)[:, :2] + else: + assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected" + points = lb[:, 1:] + assert points.max() <= 1, f"non-normalized or out of bounds coordinates {points[points > 1]}" + assert lb.min() >= 0, f"negative label values {lb[lb < 0]}" + + # All labels + max_cls = lb[:, 0].max() # max label count + assert max_cls <= num_cls, ( + f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. " + f"Possible class labels are 0-{num_cls - 1}" + ) + _, i = np.unique(lb, axis=0, return_index=True) + if len(i) < nl: # duplicate row check + lb = lb[i] # remove duplicates + if segments: + segments = [segments[x] for x in i] + msg = f"{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed" + else: + ne = 1 # label empty + lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32) + else: + nm = 1 # label missing + lb = np.zeros((0, (5 + nkpt * ndim) if keypoints else 5), dtype=np.float32) + if keypoint: + keypoints = lb[:, 5:].reshape(-1, nkpt, ndim) + if ndim == 2: + kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32) + keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3) + lb = lb[:, :5] + return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg + except Exception as e: + nc = 1 + msg = f"{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}" + return [None, None, None, None, None, nm, nf, ne, nc, msg] + + +def visualize_image_annotations(image_path, txt_path, label_map): + """ + Visualizes YOLO annotations (bounding boxes and class labels) on an image. + + This function reads an image and its corresponding annotation file in YOLO format, then + draws bounding boxes around detected objects and labels them with their respective class names. + The bounding box colors are assigned based on the class ID, and the text color is dynamically + adjusted for readability, depending on the background color's luminance. + + Args: + image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL (e.g., .jpg, .png). + txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object with: + - class_id (int): The class index. + - x_center (float): The X center of the bounding box (relative to image width). + - y_center (float): The Y center of the bounding box (relative to image height). + - width (float): The width of the bounding box (relative to image width). + - height (float): The height of the bounding box (relative to image height). + label_map (dict): A dictionary that maps class IDs (integers) to class labels (strings). + + Example: + >>> label_map = {0: "cat", 1: "dog", 2: "bird"} # It should include all annotated classes details + >>> visualize_image_annotations("path/to/image.jpg", "path/to/annotations.txt", label_map) + """ + import matplotlib.pyplot as plt + + from ultralytics.utils.plotting import colors + + img = np.array(Image.open(image_path)) + img_height, img_width = img.shape[:2] + annotations = [] + with open(txt_path) as file: + for line in file: + class_id, x_center, y_center, width, height = map(float, line.split()) + x = (x_center - width / 2) * img_width + y = (y_center - height / 2) * img_height + w = width * img_width + h = height * img_height + annotations.append((x, y, w, h, int(class_id))) + fig, ax = plt.subplots(1) # Plot the image and annotations + for x, y, w, h, label in annotations: + color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color + rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle + ax.add_patch(rect) + luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance + ax.text(x, y - 5, label_map[label], color="white" if luminance < 0.5 else "black", backgroundcolor=color) + ax.imshow(img) + plt.show() + + +def polygon2mask(imgsz, polygons, color=1, downsample_ratio=1): + """ + Convert a list of polygons to a binary mask of the specified image size. + + Args: + imgsz (tuple): The size of the image as (height, width). + polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where + N is the number of polygons, and M is the number of points such that M % 2 = 0. + color (int, optional): The color value to fill in the polygons on the mask. Defaults to 1. + downsample_ratio (int, optional): Factor by which to downsample the mask. Defaults to 1. + + Returns: + (np.ndarray): A binary mask of the specified image size with the polygons filled in. + """ + mask = np.zeros(imgsz, dtype=np.uint8) + polygons = np.asarray(polygons, dtype=np.int32) + polygons = polygons.reshape((polygons.shape[0], -1, 2)) + cv2.fillPoly(mask, polygons, color=color) + nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio) + # Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1 + return cv2.resize(mask, (nw, nh)) + + +def polygons2masks(imgsz, polygons, color, downsample_ratio=1): + """ + Convert a list of polygons to a set of binary masks of the specified image size. + + Args: + imgsz (tuple): The size of the image as (height, width). + polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape [N, M], where + N is the number of polygons, and M is the number of points such that M % 2 = 0. + color (int): The color value to fill in the polygons on the masks. + downsample_ratio (int, optional): Factor by which to downsample each mask. Defaults to 1. + + Returns: + (np.ndarray): A set of binary masks of the specified image size with the polygons filled in. + """ + return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons]) + + +def polygons2masks_overlap(imgsz, segments, downsample_ratio=1): + """Return a (640, 640) overlap mask.""" + masks = np.zeros( + (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio), + dtype=np.int32 if len(segments) > 255 else np.uint8, + ) + areas = [] + ms = [] + for si in range(len(segments)): + mask = polygon2mask(imgsz, [segments[si].reshape(-1)], downsample_ratio=downsample_ratio, color=1) + ms.append(mask.astype(masks.dtype)) + areas.append(mask.sum()) + areas = np.asarray(areas) + index = np.argsort(-areas) + ms = np.array(ms)[index] + for i in range(len(segments)): + mask = ms[i] * (i + 1) + masks = masks + mask + masks = np.clip(masks, a_min=0, a_max=i + 1) + return masks, index + + +def find_dataset_yaml(path: Path) -> Path: + """ + Find and return the YAML file associated with a Detect, Segment or Pose dataset. + + This function searches for a YAML file at the root level of the provided directory first, and if not found, it + performs a recursive search. It prefers YAML files that have the same stem as the provided path. An AssertionError + is raised if no YAML file is found or if multiple YAML files are found. + + Args: + path (Path): The directory path to search for the YAML file. + + Returns: + (Path): The path of the found YAML file. + """ + files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive + assert files, f"No YAML file found in '{path.resolve()}'" + if len(files) > 1: + files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match + assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}" + return files[0] + + +def check_det_dataset(dataset, autodownload=True): + """ + Download, verify, and/or unzip a dataset if not found locally. + + This function checks the availability of a specified dataset, and if not found, it has the option to download and + unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also + resolves paths related to the dataset. + + Args: + dataset (str): Path to the dataset or dataset descriptor (like a YAML file). + autodownload (bool, optional): Whether to automatically download the dataset if not found. Defaults to True. + + Returns: + (dict): Parsed dataset information and paths. + """ + file = check_file(dataset) + + # Download (optional) + extract_dir = "" + if zipfile.is_zipfile(file) or is_tarfile(file): + new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) + file = find_dataset_yaml(DATASETS_DIR / new_dir) + extract_dir, autodownload = file.parent, False + + # Read YAML + data = yaml_load(file, append_filename=True) # dictionary + + # Checks + for k in "train", "val": + if k not in data: + if k != "val" or "validation" not in data: + raise SyntaxError( + emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.") + ) + LOGGER.info("WARNING ⚠️ renaming data YAML 'validation' key to 'val' to match YOLO format.") + data["val"] = data.pop("validation") # replace 'validation' key with 'val' key + if "names" not in data and "nc" not in data: + raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs.")) + if "names" in data and "nc" in data and len(data["names"]) != data["nc"]: + raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match.")) + if "names" not in data: + data["names"] = [f"class_{i}" for i in range(data["nc"])] + else: + data["nc"] = len(data["names"]) + + data["names"] = check_class_names(data["names"]) + + # Resolve paths + path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root + if not path.is_absolute(): + path = (DATASETS_DIR / path).resolve() + + # Set paths + data["path"] = path # download scripts + for k in "train", "val", "test", "minival": + if data.get(k): # prepend path + if isinstance(data[k], str): + x = (path / data[k]).resolve() + if not x.exists() and data[k].startswith("../"): + x = (path / data[k][3:]).resolve() + data[k] = str(x) + else: + data[k] = [str((path / x).resolve()) for x in data[k]] + + # Parse YAML + val, s = (data.get(x) for x in ("val", "download")) + if val: + val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path + if not all(x.exists() for x in val): + name = clean_url(dataset) # dataset name with URL auth stripped + m = f"\nDataset '{name}' images not found ⚠️, missing path '{[x for x in val if not x.exists()][0]}'" + if s and autodownload: + LOGGER.warning(m) + else: + m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'" + raise FileNotFoundError(m) + t = time.time() + r = None # success + if s.startswith("http") and s.endswith(".zip"): # URL + safe_download(url=s, dir=DATASETS_DIR, delete=True) + elif s.startswith("bash "): # bash script + LOGGER.info(f"Running {s} ...") + r = os.system(s) + else: # python script + exec(s, {"yaml": data}) + dt = f"({round(time.time() - t, 1)}s)" + s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌" + LOGGER.info(f"Dataset download {s}\n") + check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts + + return data # dictionary + + +def check_cls_dataset(dataset, split=""): + """ + Checks a classification dataset such as Imagenet. + + This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information. + If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally. + + Args: + dataset (str | Path): The name of the dataset. + split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''. + + Returns: + (dict): A dictionary containing the following keys: + - 'train' (Path): The directory path containing the training set of the dataset. + - 'val' (Path): The directory path containing the validation set of the dataset. + - 'test' (Path): The directory path containing the test set of the dataset. + - 'nc' (int): The number of classes in the dataset. + - 'names' (dict): A dictionary of class names in the dataset. + """ + # Download (optional if dataset=https://file.zip is passed directly) + if str(dataset).startswith(("http:/", "https:/")): + dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False) + elif Path(dataset).suffix in {".zip", ".tar", ".gz"}: + file = check_file(dataset) + dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False) + + dataset = Path(dataset) + data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() + if not data_dir.is_dir(): + LOGGER.warning(f"\nDataset not found ⚠️, missing path {data_dir}, attempting download...") + t = time.time() + if str(dataset) == "imagenet": + subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) + else: + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip" + download(url, dir=data_dir.parent) + s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" + LOGGER.info(s) + train_set = data_dir / "train" + val_set = ( + data_dir / "val" + if (data_dir / "val").exists() + else data_dir / "validation" + if (data_dir / "validation").exists() + else None + ) # data/test or data/val + test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test + if split == "val" and not val_set: + LOGGER.warning("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") + elif split == "test" and not test_set: + LOGGER.warning("WARNING ⚠️ Dataset 'split=test' not found, using 'split=val' instead.") + + nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes + names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list + names = dict(enumerate(sorted(names))) + + # Print to console + for k, v in {"train": train_set, "val": val_set, "test": test_set}.items(): + prefix = f"{colorstr(f'{k}:')} {v}..." + if v is None: + LOGGER.info(prefix) + else: + files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS] + nf = len(files) # number of files + nd = len({file.parent for file in files}) # number of directories + if nf == 0: + if k == "train": + raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ ")) + else: + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: WARNING ⚠️ no images found") + elif nd != nc: + LOGGER.warning(f"{prefix} found {nf} images in {nd} classes: ERROR ❌️ requires {nc} classes, not {nd}") + else: + LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ") + + return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names} + + +class HUBDatasetStats: + """ + A class for generating HUB dataset JSON and `-hub` dataset directory. + + Args: + path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco8.yaml'. + task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'. + autodownload (bool): Attempt to download dataset if not found locally. Default is False. + + Example: + Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets + i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip. + ```python + from ultralytics.data.utils import HUBDatasetStats + + stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset + stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset + stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset + stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset + stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset + + stats.get_json(save=True) + stats.process_images() + ``` + """ + + def __init__(self, path="coco8.yaml", task="detect", autodownload=False): + """Initialize class.""" + path = Path(path).resolve() + LOGGER.info(f"Starting HUB dataset checks for {path}....") + + self.task = task # detect, segment, pose, classify, obb + if self.task == "classify": + unzip_dir = unzip_file(path) + data = check_cls_dataset(unzip_dir) + data["path"] = unzip_dir + else: # detect, segment, pose, obb + _, data_dir, yaml_path = self._unzip(Path(path)) + try: + # Load YAML with checks + data = yaml_load(yaml_path) + data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets + yaml_save(yaml_path, data) + data = check_det_dataset(yaml_path, autodownload) # dict + data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute) + except Exception as e: + raise Exception("error/HUB/dataset_stats/init") from e + + self.hub_dir = Path(f"{data['path']}-hub") + self.im_dir = self.hub_dir / "images" + self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary + self.data = data + + @staticmethod + def _unzip(path): + """Unzip data.zip.""" + if not str(path).endswith(".zip"): # path is data.yaml + return False, None, path + unzip_dir = unzip_file(path, path=path.parent) + assert unzip_dir.is_dir(), ( + f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/" + ) + return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path + + def _hub_ops(self, f): + """Saves a compressed image for HUB previews.""" + compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub + + def get_json(self, save=False, verbose=False): + """Return dataset JSON for Ultralytics HUB.""" + + def _round(labels): + """Update labels to integer class and 4 decimal place floats.""" + if self.task == "detect": + coordinates = labels["bboxes"] + elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy + coordinates = [x.flatten() for x in labels["segments"]] + elif self.task == "pose": + n, nk, nd = labels["keypoints"].shape + coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1) + else: + raise ValueError(f"Undefined dataset task={self.task}.") + zipped = zip(labels["cls"], coordinates) + return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped] + + for split in "train", "val", "test": + self.stats[split] = None # predefine + path = self.data.get(split) + + # Check split + if path is None: # no split + continue + files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split + if not files: # no images + continue + + # Get dataset statistics + if self.task == "classify": + from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics' + + dataset = ImageFolder(self.data[split]) + + x = np.zeros(len(dataset.classes)).astype(int) + for im in dataset.imgs: + x[im[1]] += 1 + + self.stats[split] = { + "instance_stats": {"total": len(dataset), "per_class": x.tolist()}, + "image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()}, + "labels": [{Path(k).name: v} for k, v in dataset.imgs], + } + else: + from ultralytics.data import YOLODataset + + dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task) + x = np.array( + [ + np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"]) + for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics") + ] + ) # shape(128x80) + self.stats[split] = { + "instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()}, + "image_stats": { + "total": len(dataset), + "unlabelled": int(np.all(x == 0, 1).sum()), + "per_class": (x > 0).sum(0).tolist(), + }, + "labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)], + } + + # Save, print and return + if save: + self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/ + stats_path = self.hub_dir / "stats.json" + LOGGER.info(f"Saving {stats_path.resolve()}...") + with open(stats_path, "w") as f: + json.dump(self.stats, f) # save stats.json + if verbose: + LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False)) + return self.stats + + def process_images(self): + """Compress images for Ultralytics HUB.""" + from ultralytics.data import YOLODataset # ClassificationDataset + + self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/ + for split in "train", "val", "test": + if self.data.get(split) is None: + continue + dataset = YOLODataset(img_path=self.data[split], data=self.data) + with ThreadPool(NUM_THREADS) as pool: + for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"): + pass + LOGGER.info(f"Done. All images saved to {self.im_dir}") + return self.im_dir + + +def compress_one_image(f, f_new=None, max_dim=1920, quality=50): + """ + Compresses a single image file to reduced size while preserving its aspect ratio and quality using either the Python + Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be + resized. + + Args: + f (str): The path to the input image file. + f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten. + max_dim (int, optional): The maximum dimension (width or height) of the output image. Default is 1920 pixels. + quality (int, optional): The image compression quality as a percentage. Default is 50%. + + Example: + ```python + from pathlib import Path + from ultralytics.data.utils import compress_one_image + + for f in Path("path/to/dataset").rglob("*.jpg"): + compress_one_image(f) + ``` + """ + try: # use PIL + im = Image.open(f) + r = max_dim / max(im.height, im.width) # ratio + if r < 1.0: # image too large + im = im.resize((int(im.width * r), int(im.height * r))) + im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save + except Exception as e: # use OpenCV + LOGGER.info(f"WARNING ⚠️ HUB ops PIL failure {f}: {e}") + im = cv2.imread(f) + im_height, im_width = im.shape[:2] + r = max_dim / max(im_height, im_width) # ratio + if r < 1.0: # image too large + im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA) + cv2.imwrite(str(f_new or f), im) + + +def autosplit(path=DATASETS_DIR / "coco8/images", weights=(0.9, 0.1, 0.0), annotated_only=False): + """ + Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files. + + Args: + path (Path, optional): Path to images directory. Defaults to DATASETS_DIR / 'coco8/images'. + weights (list | tuple, optional): Train, validation, and test split fractions. Defaults to (0.9, 0.1, 0.0). + annotated_only (bool, optional): If True, only images with an associated txt file are used. Defaults to False. + + Example: + ```python + from ultralytics.data.utils import autosplit + + autosplit() + ``` + """ + path = Path(path) # images dir + files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only + n = len(files) # number of files + random.seed(0) # for reproducibility + indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split + + txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files + for x in txt: + if (path.parent / x).exists(): + (path.parent / x).unlink() # remove existing + + LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only) + for i, img in TQDM(zip(indices, files), total=n): + if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label + with open(path.parent / txt[i], "a") as f: + f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file + + +def load_dataset_cache_file(path): + """Load an Ultralytics *.cache dictionary from path.""" + import gc + + gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 + cache = np.load(str(path), allow_pickle=True).item() # load dict + gc.enable() + return cache + + +def save_dataset_cache_file(prefix, path, x, version): + """Save an Ultralytics dataset *.cache dictionary x to path.""" + x["version"] = version # add cache version + if is_dir_writeable(path.parent): + if path.exists(): + path.unlink() # remove *.cache file if exists + np.save(str(path), x) # save cache for next time + path.with_suffix(".cache.npy").rename(path) # remove .npy suffix + LOGGER.info(f"{prefix}New cache created: {path}") + else: + LOGGER.warning(f"{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.") diff --git a/ultralytics/engine/__init__.py b/ultralytics/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/ultralytics/engine/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py new file mode 100644 index 0000000000000000000000000000000000000000..37c994ef21e0bbb45ec65435c306466210b72e59 --- /dev/null +++ b/ultralytics/engine/exporter.py @@ -0,0 +1,1476 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Export a YOLO PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit. + +Format | `format=argument` | Model +--- | --- | --- +PyTorch | - | yolo11n.pt +TorchScript | `torchscript` | yolo11n.torchscript +ONNX | `onnx` | yolo11n.onnx +OpenVINO | `openvino` | yolo11n_openvino_model/ +TensorRT | `engine` | yolo11n.engine +CoreML | `coreml` | yolo11n.mlpackage +TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/ +TensorFlow GraphDef | `pb` | yolo11n.pb +TensorFlow Lite | `tflite` | yolo11n.tflite +TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite +TensorFlow.js | `tfjs` | yolo11n_web_model/ +PaddlePaddle | `paddle` | yolo11n_paddle_model/ +MNN | `mnn` | yolo11n.mnn +NCNN | `ncnn` | yolo11n_ncnn_model/ +IMX | `imx` | yolo11n_imx_model/ + +Requirements: + $ pip install "ultralytics[export]" + +Python: + from ultralytics import YOLO + model = YOLO('yolo11n.pt') + results = model.export(format='onnx') + +CLI: + $ yolo mode=export model=yolo11n.pt format=onnx + +Inference: + $ yolo predict model=yolo11n.pt # PyTorch + yolo11n.torchscript # TorchScript + yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolo11n_openvino_model # OpenVINO + yolo11n.engine # TensorRT + yolo11n.mlpackage # CoreML (macOS-only) + yolo11n_saved_model # TensorFlow SavedModel + yolo11n.pb # TensorFlow GraphDef + yolo11n.tflite # TensorFlow Lite + yolo11n_edgetpu.tflite # TensorFlow Edge TPU + yolo11n_paddle_model # PaddlePaddle + yolo11n.mnn # MNN + yolo11n_ncnn_model # NCNN + yolo11n_imx_model # IMX + +TensorFlow.js: + $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example + $ npm install + $ ln -s ../../yolo11n_web_model public/yolo11n_web_model + $ npm start +""" + +import gc +import json +import os +import shutil +import subprocess +import time +import warnings +from copy import deepcopy +from datetime import datetime +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.cfg import TASK2DATA, get_cfg +from ultralytics.data import build_dataloader +from ultralytics.data.dataset import YOLODataset +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.autobackend import check_class_names, default_class_names +from ultralytics.nn.modules import C2f, Classify, Detect, RTDETRDecoder +from ultralytics.nn.tasks import DetectionModel, SegmentationModel, WorldModel +from ultralytics.utils import ( + ARM64, + DEFAULT_CFG, + IS_JETSON, + LINUX, + LOGGER, + MACOS, + PYTHON_VERSION, + ROOT, + WINDOWS, + __version__, + callbacks, + colorstr, + get_default_args, + yaml_save, +) +from ultralytics.utils.checks import ( + check_imgsz, + check_is_path_safe, + check_requirements, + check_version, + is_sudo_available, +) +from ultralytics.utils.downloads import attempt_download_asset, get_github_assets, safe_download +from ultralytics.utils.files import file_size, spaces_in_path +from ultralytics.utils.ops import Profile +from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_device + + +def export_formats(): + """Ultralytics YOLO export formats.""" + x = [ + ["PyTorch", "-", ".pt", True, True, []], + ["TorchScript", "torchscript", ".torchscript", True, True, ["batch", "optimize"]], + ["ONNX", "onnx", ".onnx", True, True, ["batch", "dynamic", "half", "opset", "simplify"]], + ["OpenVINO", "openvino", "_openvino_model", True, False, ["batch", "dynamic", "half", "int8"]], + ["TensorRT", "engine", ".engine", False, True, ["batch", "dynamic", "half", "int8", "simplify"]], + ["CoreML", "coreml", ".mlpackage", True, False, ["batch", "half", "int8", "nms"]], + ["TensorFlow SavedModel", "saved_model", "_saved_model", True, True, ["batch", "int8", "keras"]], + ["TensorFlow GraphDef", "pb", ".pb", True, True, ["batch"]], + ["TensorFlow Lite", "tflite", ".tflite", True, False, ["batch", "half", "int8"]], + ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False, []], + ["TensorFlow.js", "tfjs", "_web_model", True, False, ["batch", "half", "int8"]], + ["PaddlePaddle", "paddle", "_paddle_model", True, True, ["batch"]], + ["MNN", "mnn", ".mnn", True, True, ["batch", "half", "int8"]], + ["NCNN", "ncnn", "_ncnn_model", True, True, ["batch", "half"]], + ["IMX", "imx", "_imx_model", True, True, ["int8"]], + ] + return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU", "Arguments"], zip(*x))) + + +def validate_args(format, passed_args, valid_args): + """ + Validates arguments based on format. + + Args: + format (str): The export format. + passed_args (Namespace): The arguments used during export. + valid_args (dict): List of valid arguments for the format. + + Raises: + AssertionError: If an argument that's not supported by the export format is used, or if format doesn't have the supported arguments listed. + """ + # Only check valid usage of these args + export_args = ["half", "int8", "dynamic", "keras", "nms", "batch"] + + assert valid_args is not None, f"ERROR ❌️ valid arguments for '{format}' not listed." + custom = {"batch": 1, "data": None, "device": None} # exporter defaults + default_args = get_cfg(DEFAULT_CFG, custom) + for arg in export_args: + not_default = getattr(passed_args, arg, None) != getattr(default_args, arg, None) + if not_default: + assert arg in valid_args, f"ERROR ❌️ argument '{arg}' is not supported for format='{format}'" + + +def gd_outputs(gd): + """TensorFlow GraphDef model output node names.""" + name_list, input_list = [], [] + for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef + name_list.append(node.name) + input_list.extend(node.input) + return sorted(f"{x}:0" for x in list(set(name_list) - set(input_list)) if not x.startswith("NoOp")) + + +def try_export(inner_func): + """YOLO export decorator, i.e. @try_export.""" + inner_args = get_default_args(inner_func) + + def outer_func(*args, **kwargs): + """Export a model.""" + prefix = inner_args["prefix"] + try: + with Profile() as dt: + f, model = inner_func(*args, **kwargs) + LOGGER.info(f"{prefix} export success ✅ {dt.t:.1f}s, saved as '{f}' ({file_size(f):.1f} MB)") + return f, model + except Exception as e: + LOGGER.error(f"{prefix} export failure ❌ {dt.t:.1f}s: {e}") + raise e + + return outer_func + + +class Exporter: + """ + A class for exporting a model. + + Attributes: + args (SimpleNamespace): Configuration for the exporter. + callbacks (list, optional): List of callback functions. Defaults to None. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the Exporter class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + _callbacks (dict, optional): Dictionary of callback functions. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback + + self.callbacks = _callbacks or callbacks.get_default_callbacks() + callbacks.add_integration_callbacks(self) + + def __call__(self, model=None) -> str: + """Returns list of exported files/dirs after running callbacks.""" + self.run_callbacks("on_export_start") + t = time.time() + fmt = self.args.format.lower() # to lowercase + if fmt in {"tensorrt", "trt"}: # 'engine' aliases + fmt = "engine" + if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases + fmt = "coreml" + fmts_dict = export_formats() + fmts = tuple(fmts_dict["Argument"][1:]) # available export formats + if fmt not in fmts: + import difflib + + # Get the closest match if format is invalid + matches = difflib.get_close_matches(fmt, fmts, n=1, cutoff=0.6) # 60% similarity required to match + if not matches: + raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") + LOGGER.warning(f"WARNING ⚠️ Invalid export format='{fmt}', updating to format='{matches[0]}'") + fmt = matches[0] + flags = [x == fmt for x in fmts] + if sum(flags) != 1: + raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") + ( + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + ) = flags # export booleans + is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs)) + + # Device + dla = None + if fmt == "engine" and self.args.device is None: + LOGGER.warning("WARNING ⚠️ TensorRT requires GPU export, automatically assigning device=0") + self.args.device = "0" + if fmt == "engine" and "dla" in str(self.args.device): # convert int/list to str first + dla = self.args.device.split(":")[-1] + self.args.device = "0" # update device to "0" + assert dla in {"0", "1"}, f"Expected self.args.device='dla:0' or 'dla:1, but got {self.args.device}." + self.device = select_device("cpu" if self.args.device is None else self.args.device) + + # Argument compatibility checks + fmt_keys = fmts_dict["Arguments"][flags.index(True) + 1] + validate_args(fmt, self.args, fmt_keys) + if imx and not self.args.int8: + LOGGER.warning("WARNING ⚠️ IMX only supports int8 export, setting int8=True.") + self.args.int8 = True + if not hasattr(model, "names"): + model.names = default_class_names() + model.names = check_class_names(model.names) + if self.args.half and self.args.int8: + LOGGER.warning("WARNING ⚠️ half=True and int8=True are mutually exclusive, setting half=False.") + self.args.half = False + if self.args.half and onnx and self.device.type == "cpu": + LOGGER.warning("WARNING ⚠️ half=True only compatible with GPU export, i.e. use device=0") + self.args.half = False + assert not self.args.dynamic, "half=True not compatible with dynamic=True, i.e. use only one." + self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size + if self.args.int8 and engine: + self.args.dynamic = True # enforce dynamic to export TensorRT INT8 + if self.args.optimize: + assert not ncnn, "optimize=True not compatible with format='ncnn', i.e. use optimize=False" + assert self.device.type == "cpu", "optimize=True not compatible with cuda devices, i.e. use device='cpu'" + if self.args.int8 and tflite: + assert not getattr(model, "end2end", False), "TFLite INT8 export not supported for end2end models." + if edgetpu: + if not LINUX: + raise SystemError("Edge TPU export only supported on Linux. See https://coral.ai/docs/edgetpu/compiler") + elif self.args.batch != 1: # see github.com/ultralytics/ultralytics/pull/13420 + LOGGER.warning("WARNING ⚠️ Edge TPU export requires batch size 1, setting batch=1.") + self.args.batch = 1 + if isinstance(model, WorldModel): + LOGGER.warning( + "WARNING ⚠️ YOLOWorld (original version) export is not supported to any format.\n" + "WARNING ⚠️ YOLOWorldv2 models (i.e. 'yolov8s-worldv2.pt') only support export to " + "(torchscript, onnx, openvino, engine, coreml) formats. " + "See https://docs.ultralytics.com/models/yolo-world for details." + ) + model.clip_model = None # openvino int8 export error: https://github.com/ultralytics/ultralytics/pull/18445 + if self.args.int8 and not self.args.data: + self.args.data = DEFAULT_CFG.data or TASK2DATA[getattr(model, "task", "detect")] # assign default data + LOGGER.warning( + "WARNING ⚠️ INT8 export requires a missing 'data' arg for calibration. " + f"Using default 'data={self.args.data}'." + ) + + # Input + im = torch.zeros(self.args.batch, 3, *self.imgsz).to(self.device) + file = Path( + getattr(model, "pt_path", None) or getattr(model, "yaml_file", None) or model.yaml.get("yaml_file", "") + ) + if file.suffix in {".yaml", ".yml"}: + file = Path(file.name) + + # Update model + model = deepcopy(model).to(self.device) + for p in model.parameters(): + p.requires_grad = False + model.eval() + model.float() + model = model.fuse() + + if imx: + from ultralytics.utils.torch_utils import FXModel + + model = FXModel(model) + for m in model.modules(): + if isinstance(m, Classify): + m.export = True + if isinstance(m, (Detect, RTDETRDecoder)): # includes all Detect subclasses like Segment, Pose, OBB + m.dynamic = self.args.dynamic + m.export = True + m.format = self.args.format + m.max_det = self.args.max_det + elif isinstance(m, C2f) and not is_tf_format: + # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph + m.forward = m.forward_split + if isinstance(m, Detect) and imx: + from ultralytics.utils.tal import make_anchors + + m.anchors, m.strides = ( + x.transpose(0, 1) + for x in make_anchors( + torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5 + ) + ) + + y = None + for _ in range(2): + y = model(im) # dry runs + if self.args.half and onnx and self.device.type != "cpu": + im, model = im.half(), model.half() # to FP16 + + # Filter warnings + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) # suppress TracerWarning + warnings.filterwarnings("ignore", category=UserWarning) # suppress shape prim::Constant missing ONNX warning + warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress CoreML np.bool deprecation warning + + # Assign + self.im = im + self.model = model + self.file = file + self.output_shape = ( + tuple(y.shape) + if isinstance(y, torch.Tensor) + else tuple(tuple(x.shape if isinstance(x, torch.Tensor) else []) for x in y) + ) + self.pretty_name = Path(self.model.yaml.get("yaml_file", self.file)).stem.replace("yolo", "YOLO") + data = model.args["data"] if hasattr(model, "args") and isinstance(model.args, dict) else "" + description = f"Ultralytics {self.pretty_name} model {f'trained on {data}' if data else ''}" + self.metadata = { + "description": description, + "author": "Ultralytics", + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + "stride": int(max(model.stride)), + "task": model.task, + "batch": self.args.batch, + "imgsz": self.imgsz, + "names": model.names, + "args": {k: v for k, v in self.args if k in fmt_keys}, + } # model metadata + if model.task == "pose": + self.metadata["kpt_shape"] = model.model[-1].kpt_shape + + LOGGER.info( + f"\n{colorstr('PyTorch:')} starting from '{file}' with input shape {tuple(im.shape)} BCHW and " + f"output shape(s) {self.output_shape} ({file_size(file):.1f} MB)" + ) + + # Exports + f = [""] * len(fmts) # exported filenames + if jit or ncnn: # TorchScript + f[0], _ = self.export_torchscript() + if engine: # TensorRT required before ONNX + f[1], _ = self.export_engine(dla=dla) + if onnx: # ONNX + f[2], _ = self.export_onnx() + if xml: # OpenVINO + f[3], _ = self.export_openvino() + if coreml: # CoreML + f[4], _ = self.export_coreml() + if is_tf_format: # TensorFlow formats + self.args.int8 |= edgetpu + f[5], keras_model = self.export_saved_model() + if pb or tfjs: # pb prerequisite to tfjs + f[6], _ = self.export_pb(keras_model=keras_model) + if tflite: + f[7], _ = self.export_tflite(keras_model=keras_model, nms=False, agnostic_nms=self.args.agnostic_nms) + if edgetpu: + f[8], _ = self.export_edgetpu(tflite_model=Path(f[5]) / f"{self.file.stem}_full_integer_quant.tflite") + if tfjs: + f[9], _ = self.export_tfjs() + if paddle: # PaddlePaddle + f[10], _ = self.export_paddle() + if mnn: # MNN + f[11], _ = self.export_mnn() + if ncnn: # NCNN + f[12], _ = self.export_ncnn() + if imx: + f[13], _ = self.export_imx() + + # Finish + f = [str(x) for x in f if x] # filter out '' and None + if any(f): + f = str(Path(f[-1])) + square = self.imgsz[0] == self.imgsz[1] + s = ( + "" + if square + else f"WARNING ⚠️ non-PyTorch val requires square images, 'imgsz={self.imgsz}' will not " + f"work. Use export 'imgsz={max(self.imgsz)}' if val is required." + ) + imgsz = self.imgsz[0] if square else str(self.imgsz)[1:-1].replace(" ", "") + predict_data = f"data={data}" if model.task == "segment" and fmt == "pb" else "" + q = "int8" if self.args.int8 else "half" if self.args.half else "" # quantization + LOGGER.info( + f"\nExport complete ({time.time() - t:.1f}s)" + f"\nResults saved to {colorstr('bold', file.parent.resolve())}" + f"\nPredict: yolo predict task={model.task} model={f} imgsz={imgsz} {q} {predict_data}" + f"\nValidate: yolo val task={model.task} model={f} imgsz={imgsz} data={data} {q} {s}" + f"\nVisualize: https://netron.app" + ) + + self.run_callbacks("on_export_end") + return f # return list of exported files/dirs + + def get_int8_calibration_dataloader(self, prefix=""): + """Build and return a dataloader suitable for calibration of INT8 models.""" + LOGGER.info(f"{prefix} collecting INT8 calibration images from 'data={self.args.data}'") + data = (check_cls_dataset if self.model.task == "classify" else check_det_dataset)(self.args.data) + # TensorRT INT8 calibration should use 2x batch size + batch = self.args.batch * (2 if self.args.format == "engine" else 1) + dataset = YOLODataset( + data[self.args.split or "val"], + data=data, + task=self.model.task, + imgsz=self.imgsz[0], + augment=False, + batch_size=batch, + ) + n = len(dataset) + if n < self.args.batch: + raise ValueError( + f"The calibration dataset ({n} images) must have at least as many images as the batch size ('batch={self.args.batch}')." + ) + elif n < 300: + LOGGER.warning(f"{prefix} WARNING ⚠️ >300 images recommended for INT8 calibration, found {n} images.") + return build_dataloader(dataset, batch=batch, workers=0) # required for batch loading + + @try_export + def export_torchscript(self, prefix=colorstr("TorchScript:")): + """YOLO TorchScript model export.""" + LOGGER.info(f"\n{prefix} starting export with torch {torch.__version__}...") + f = self.file.with_suffix(".torchscript") + + ts = torch.jit.trace(self.model, self.im, strict=False) + extra_files = {"config.txt": json.dumps(self.metadata)} # torch._C.ExtraFilesMap() + if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html + LOGGER.info(f"{prefix} optimizing for mobile...") + from torch.utils.mobile_optimizer import optimize_for_mobile + + optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) + else: + ts.save(str(f), _extra_files=extra_files) + return f, None + + @try_export + def export_onnx(self, prefix=colorstr("ONNX:")): + """YOLO ONNX export.""" + requirements = ["onnx>=1.12.0"] + if self.args.simplify: + requirements += ["onnxslim", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")] + check_requirements(requirements) + import onnx # noqa + + opset_version = self.args.opset or get_latest_opset() + LOGGER.info(f"\n{prefix} starting export with onnx {onnx.__version__} opset {opset_version}...") + f = str(self.file.with_suffix(".onnx")) + + output_names = ["output0", "output1"] if isinstance(self.model, SegmentationModel) else ["output0"] + dynamic = self.args.dynamic + if dynamic: + dynamic = {"images": {0: "batch", 2: "height", 3: "width"}} # shape(1,3,640,640) + if isinstance(self.model, SegmentationModel): + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 116, 8400) + dynamic["output1"] = {0: "batch", 2: "mask_height", 3: "mask_width"} # shape(1,32,160,160) + elif isinstance(self.model, DetectionModel): + dynamic["output0"] = {0: "batch", 2: "anchors"} # shape(1, 84, 8400) + + torch.onnx.export( + self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu + self.im.cpu() if dynamic else self.im, + f, + verbose=False, + opset_version=opset_version, + do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False + input_names=["images"], + output_names=output_names, + dynamic_axes=dynamic or None, + ) + + # Checks + model_onnx = onnx.load(f) # load onnx model + + # Simplify + if self.args.simplify: + try: + import onnxslim + + LOGGER.info(f"{prefix} slimming with onnxslim {onnxslim.__version__}...") + model_onnx = onnxslim.slim(model_onnx) + + except Exception as e: + LOGGER.warning(f"{prefix} simplifier failure: {e}") + + # Metadata + for k, v in self.metadata.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + + onnx.save(model_onnx, f) + return f, model_onnx + + @try_export + def export_openvino(self, prefix=colorstr("OpenVINO:")): + """YOLO OpenVINO export.""" + check_requirements("openvino>=2024.5.0") + import openvino as ov + + LOGGER.info(f"\n{prefix} starting export with openvino {ov.__version__}...") + assert TORCH_1_13, f"OpenVINO export requires torch>=1.13.0 but torch=={torch.__version__} is installed" + ov_model = ov.convert_model( + self.model, + input=None if self.args.dynamic else [self.im.shape], + example_input=self.im, + ) + + def serialize(ov_model, file): + """Set RT info, serialize and save metadata YAML.""" + ov_model.set_rt_info("YOLO", ["model_info", "model_type"]) + ov_model.set_rt_info(True, ["model_info", "reverse_input_channels"]) + ov_model.set_rt_info(114, ["model_info", "pad_value"]) + ov_model.set_rt_info([255.0], ["model_info", "scale_values"]) + ov_model.set_rt_info(self.args.iou, ["model_info", "iou_threshold"]) + ov_model.set_rt_info([v.replace(" ", "_") for v in self.model.names.values()], ["model_info", "labels"]) + if self.model.task != "classify": + ov_model.set_rt_info("fit_to_window_letterbox", ["model_info", "resize_type"]) + + ov.runtime.save_model(ov_model, file, compress_to_fp16=self.args.half) + yaml_save(Path(file).parent / "metadata.yaml", self.metadata) # add metadata.yaml + + if self.args.int8: + fq = str(self.file).replace(self.file.suffix, f"_int8_openvino_model{os.sep}") + fq_ov = str(Path(fq) / self.file.with_suffix(".xml").name) + check_requirements("nncf>=2.14.0") + import nncf + + def transform_fn(data_item) -> np.ndarray: + """Quantization transform function.""" + data_item: torch.Tensor = data_item["img"] if isinstance(data_item, dict) else data_item + assert data_item.dtype == torch.uint8, "Input image must be uint8 for the quantization preprocessing" + im = data_item.numpy().astype(np.float32) / 255.0 # uint8 to fp16/32 and 0 - 255 to 0.0 - 1.0 + return np.expand_dims(im, 0) if im.ndim == 3 else im + + # Generate calibration data for integer quantization + ignored_scope = None + if isinstance(self.model.model[-1], Detect): + # Includes all Detect subclasses like Segment, Pose, OBB, WorldDetect + head_module_name = ".".join(list(self.model.named_modules())[-1][0].split(".")[:2]) + ignored_scope = nncf.IgnoredScope( # ignore operations + patterns=[ + f".*{head_module_name}/.*/Add", + f".*{head_module_name}/.*/Sub*", + f".*{head_module_name}/.*/Mul*", + f".*{head_module_name}/.*/Div*", + f".*{head_module_name}\\.dfl.*", + ], + types=["Sigmoid"], + ) + + quantized_ov_model = nncf.quantize( + model=ov_model, + calibration_dataset=nncf.Dataset(self.get_int8_calibration_dataloader(prefix), transform_fn), + preset=nncf.QuantizationPreset.MIXED, + ignored_scope=ignored_scope, + ) + serialize(quantized_ov_model, fq_ov) + return fq, None + + f = str(self.file).replace(self.file.suffix, f"_openvino_model{os.sep}") + f_ov = str(Path(f) / self.file.with_suffix(".xml").name) + + serialize(ov_model, f_ov) + return f, None + + @try_export + def export_paddle(self, prefix=colorstr("PaddlePaddle:")): + """YOLO Paddle export.""" + check_requirements(("paddlepaddle-gpu" if torch.cuda.is_available() else "paddlepaddle", "x2paddle")) + import x2paddle # noqa + from x2paddle.convert import pytorch2paddle # noqa + + LOGGER.info(f"\n{prefix} starting export with X2Paddle {x2paddle.__version__}...") + f = str(self.file).replace(self.file.suffix, f"_paddle_model{os.sep}") + + pytorch2paddle(module=self.model, save_dir=f, jit_type="trace", input_examples=[self.im]) # export + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml + return f, None + + @try_export + def export_mnn(self, prefix=colorstr("MNN:")): + """YOLOv8 MNN export using MNN https://github.com/alibaba/MNN.""" + f_onnx, _ = self.export_onnx() # get onnx model first + + check_requirements("MNN>=2.9.6") + import MNN # noqa + from MNN.tools import mnnconvert + + # Setup and checks + LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...") + assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" + f = str(self.file.with_suffix(".mnn")) # MNN model file + args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)] + if self.args.int8: + args.extend(("--weightQuantBits", "8")) + if self.args.half: + args.append("--fp16") + mnnconvert.convert(args) + # remove scratch file for model convert optimize + convert_scratch = Path(self.file.parent / ".__convert_external_data.bin") + if convert_scratch.exists(): + convert_scratch.unlink() + return f, None + + @try_export + def export_ncnn(self, prefix=colorstr("NCNN:")): + """YOLO NCNN export using PNNX https://github.com/pnnx/pnnx.""" + check_requirements("ncnn") + import ncnn # noqa + + LOGGER.info(f"\n{prefix} starting export with NCNN {ncnn.__version__}...") + f = Path(str(self.file).replace(self.file.suffix, f"_ncnn_model{os.sep}")) + f_ts = self.file.with_suffix(".torchscript") + + name = Path("pnnx.exe" if WINDOWS else "pnnx") # PNNX filename + pnnx = name if name.is_file() else (ROOT / name) + if not pnnx.is_file(): + LOGGER.warning( + f"{prefix} WARNING ⚠️ PNNX not found. Attempting to download binary file from " + "https://github.com/pnnx/pnnx/.\nNote PNNX Binary file must be placed in current working directory " + f"or in {ROOT}. See PNNX repo for full installation instructions." + ) + system = "macos" if MACOS else "windows" if WINDOWS else "linux-aarch64" if ARM64 else "linux" + try: + release, assets = get_github_assets(repo="pnnx/pnnx") + asset = [x for x in assets if f"{system}.zip" in x][0] + assert isinstance(asset, str), "Unable to retrieve PNNX repo assets" # i.e. pnnx-20240410-macos.zip + LOGGER.info(f"{prefix} successfully found latest PNNX asset file {asset}") + except Exception as e: + release = "20240410" + asset = f"pnnx-{release}-{system}.zip" + LOGGER.warning(f"{prefix} WARNING ⚠️ PNNX GitHub assets not found: {e}, using default {asset}") + unzip_dir = safe_download(f"https://github.com/pnnx/pnnx/releases/download/{release}/{asset}", delete=True) + if check_is_path_safe(Path.cwd(), unzip_dir): # avoid path traversal security vulnerability + shutil.move(src=unzip_dir / name, dst=pnnx) # move binary to ROOT + pnnx.chmod(0o777) # set read, write, and execute permissions for everyone + shutil.rmtree(unzip_dir) # delete unzip dir + + ncnn_args = [ + f"ncnnparam={f / 'model.ncnn.param'}", + f"ncnnbin={f / 'model.ncnn.bin'}", + f"ncnnpy={f / 'model_ncnn.py'}", + ] + + pnnx_args = [ + f"pnnxparam={f / 'model.pnnx.param'}", + f"pnnxbin={f / 'model.pnnx.bin'}", + f"pnnxpy={f / 'model_pnnx.py'}", + f"pnnxonnx={f / 'model.pnnx.onnx'}", + ] + + cmd = [ + str(pnnx), + str(f_ts), + *ncnn_args, + *pnnx_args, + f"fp16={int(self.args.half)}", + f"device={self.device.type}", + f'inputshape="{[self.args.batch, 3, *self.imgsz]}"', + ] + f.mkdir(exist_ok=True) # make ncnn_model directory + LOGGER.info(f"{prefix} running '{' '.join(cmd)}'") + subprocess.run(cmd, check=True) + + # Remove debug files + pnnx_files = [x.split("=")[-1] for x in pnnx_args] + for f_debug in ("debug.bin", "debug.param", "debug2.bin", "debug2.param", *pnnx_files): + Path(f_debug).unlink(missing_ok=True) + + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml + return str(f), None + + @try_export + def export_coreml(self, prefix=colorstr("CoreML:")): + """YOLO CoreML export.""" + mlmodel = self.args.format.lower() == "mlmodel" # legacy *.mlmodel export format requested + check_requirements("coremltools>=6.0,<=6.2" if mlmodel else "coremltools>=7.0") + import coremltools as ct # noqa + + LOGGER.info(f"\n{prefix} starting export with coremltools {ct.__version__}...") + assert not WINDOWS, "CoreML export is not supported on Windows, please run on macOS or Linux." + assert self.args.batch == 1, "CoreML batch sizes > 1 are not supported. Please retry at 'batch=1'." + f = self.file.with_suffix(".mlmodel" if mlmodel else ".mlpackage") + if f.is_dir(): + shutil.rmtree(f) + if self.args.nms and getattr(self.model, "end2end", False): + LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is not available for end2end models. Forcing 'nms=False'.") + self.args.nms = False + + bias = [0.0, 0.0, 0.0] + scale = 1 / 255 + classifier_config = None + if self.model.task == "classify": + classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None + model = self.model + elif self.model.task == "detect": + model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model + else: + if self.args.nms: + LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolov8n.pt'.") + # TODO CoreML Segment and Pose model pipelining + model = self.model + + ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model + ct_model = ct.convert( + ts, + inputs=[ct.ImageType("image", shape=self.im.shape, scale=scale, bias=bias)], + classifier_config=classifier_config, + convert_to="neuralnetwork" if mlmodel else "mlprogram", + ) + bits, mode = (8, "kmeans") if self.args.int8 else (16, "linear") if self.args.half else (32, None) + if bits < 32: + if "kmeans" in mode: + check_requirements("scikit-learn") # scikit-learn package required for k-means quantization + if mlmodel: + ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) + elif bits == 8: # mlprogram already quantized to FP16 + import coremltools.optimize.coreml as cto + + op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512) + config = cto.OptimizationConfig(global_config=op_config) + ct_model = cto.palettize_weights(ct_model, config=config) + if self.args.nms and self.model.task == "detect": + if mlmodel: + # coremltools<=6.2 NMS export requires Python<3.11 + check_version(PYTHON_VERSION, "<3.11", name="Python ", hard=True) + weights_dir = None + else: + ct_model.save(str(f)) # save otherwise weights_dir does not exist + weights_dir = str(f / "Data/com.apple.CoreML/weights") + ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir) + + m = self.metadata # metadata dict + ct_model.short_description = m.pop("description") + ct_model.author = m.pop("author") + ct_model.license = m.pop("license") + ct_model.version = m.pop("version") + ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()}) + try: + ct_model.save(str(f)) # save *.mlpackage + except Exception as e: + LOGGER.warning( + f"{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. " + f"Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928." + ) + f = f.with_suffix(".mlmodel") + ct_model.save(str(f)) + return f, ct_model + + @try_export + def export_engine(self, dla=None, prefix=colorstr("TensorRT:")): + """YOLO TensorRT export https://developer.nvidia.com/tensorrt.""" + assert self.im.device.type != "cpu", "export running on CPU but must be on GPU, i.e. use 'device=0'" + f_onnx, _ = self.export_onnx() # run before TRT import https://github.com/ultralytics/ultralytics/issues/7016 + + try: + import tensorrt as trt # noqa + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + + # Setup and checks + LOGGER.info(f"\n{prefix} starting export with TensorRT {trt.__version__}...") + is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10 + assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" + f = self.file.with_suffix(".engine") # TensorRT engine file + logger = trt.Logger(trt.Logger.INFO) + if self.args.verbose: + logger.min_severity = trt.Logger.Severity.VERBOSE + + # Engine builder + builder = trt.Builder(logger) + config = builder.create_builder_config() + workspace = int(self.args.workspace * (1 << 30)) if self.args.workspace is not None else 0 + if is_trt10 and workspace > 0: + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace) + elif workspace > 0: # TensorRT versions 7, 8 + config.max_workspace_size = workspace + flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(flag) + half = builder.platform_has_fast_fp16 and self.args.half + int8 = builder.platform_has_fast_int8 and self.args.int8 + + # Optionally switch to DLA if enabled + if dla is not None: + if not IS_JETSON: + raise ValueError("DLA is only available on NVIDIA Jetson devices") + LOGGER.info(f"{prefix} enabling DLA on core {dla}...") + if not self.args.half and not self.args.int8: + raise ValueError( + "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again." + ) + config.default_device_type = trt.DeviceType.DLA + config.DLA_core = int(dla) + config.set_flag(trt.BuilderFlag.GPU_FALLBACK) + + # Read ONNX file + parser = trt.OnnxParser(network, logger) + if not parser.parse_from_file(f_onnx): + raise RuntimeError(f"failed to load ONNX file: {f_onnx}") + + # Network inputs + inputs = [network.get_input(i) for i in range(network.num_inputs)] + outputs = [network.get_output(i) for i in range(network.num_outputs)] + for inp in inputs: + LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}') + for out in outputs: + LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}') + + if self.args.dynamic: + shape = self.im.shape + if shape[0] <= 1: + LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'") + profile = builder.create_optimization_profile() + min_shape = (1, shape[1], 32, 32) # minimum input shape + max_shape = (*shape[:2], *(int(max(1, workspace) * d) for d in shape[2:])) # max input shape + for inp in inputs: + profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape) + config.add_optimization_profile(profile) + + LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {f}") + if int8: + config.set_flag(trt.BuilderFlag.INT8) + config.set_calibration_profile(profile) + config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED + + class EngineCalibrator(trt.IInt8Calibrator): + def __init__( + self, + dataset, # ultralytics.data.build.InfiniteDataLoader + batch: int, + cache: str = "", + ) -> None: + trt.IInt8Calibrator.__init__(self) + self.dataset = dataset + self.data_iter = iter(dataset) + self.algo = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 + self.batch = batch + self.cache = Path(cache) + + def get_algorithm(self) -> trt.CalibrationAlgoType: + """Get the calibration algorithm to use.""" + return self.algo + + def get_batch_size(self) -> int: + """Get the batch size to use for calibration.""" + return self.batch or 1 + + def get_batch(self, names) -> list: + """Get the next batch to use for calibration, as a list of device memory pointers.""" + try: + im0s = next(self.data_iter)["img"] / 255.0 + im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s + return [int(im0s.data_ptr())] + except StopIteration: + # Return [] or None, signal to TensorRT there is no calibration data remaining + return None + + def read_calibration_cache(self) -> bytes: + """Use existing cache instead of calibrating again, otherwise, implicitly return None.""" + if self.cache.exists() and self.cache.suffix == ".cache": + return self.cache.read_bytes() + + def write_calibration_cache(self, cache) -> None: + """Write calibration cache to disk.""" + _ = self.cache.write_bytes(cache) + + # Load dataset w/ builder (for batching) and calibrate + config.int8_calibrator = EngineCalibrator( + dataset=self.get_int8_calibration_dataloader(prefix), + batch=2 * self.args.batch, # TensorRT INT8 calibration should use 2x batch size + cache=str(self.file.with_suffix(".cache")), + ) + + elif half: + config.set_flag(trt.BuilderFlag.FP16) + + # Free CUDA memory + del self.model + gc.collect() + torch.cuda.empty_cache() + + # Write file + build = builder.build_serialized_network if is_trt10 else builder.build_engine + with build(network, config) as engine, open(f, "wb") as t: + # Metadata + meta = json.dumps(self.metadata) + t.write(len(meta).to_bytes(4, byteorder="little", signed=True)) + t.write(meta.encode()) + # Model + t.write(engine if is_trt10 else engine.serialize()) + + return f, None + + @try_export + def export_saved_model(self, prefix=colorstr("TensorFlow SavedModel:")): + """YOLO TensorFlow SavedModel export.""" + cuda = torch.cuda.is_available() + try: + import tensorflow as tf # noqa + except ImportError: + suffix = "-macos" if MACOS else "-aarch64" if ARM64 else "" if cuda else "-cpu" + version = ">=2.0.0" + check_requirements(f"tensorflow{suffix}{version}") + import tensorflow as tf # noqa + check_requirements( + ( + "keras", # required by 'onnx2tf' package + "tf_keras", # required by 'onnx2tf' package + "sng4onnx>=1.0.1", # required by 'onnx2tf' package + "onnx_graphsurgeon>=0.3.26", # required by 'onnx2tf' package + "onnx>=1.12.0", + "onnx2tf>1.17.5,<=1.26.3", + "onnxslim>=0.1.31", + "tflite_support<=0.4.3" if IS_JETSON else "tflite_support", # fix ImportError 'GLIBCXX_3.4.29' + "flatbuffers>=23.5.26,<100", # update old 'flatbuffers' included inside tensorflow package + "onnxruntime-gpu" if cuda else "onnxruntime", + ), + cmds="--extra-index-url https://pypi.ngc.nvidia.com", # onnx_graphsurgeon only on NVIDIA + ) + + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + check_version( + tf.__version__, + ">=2.0.0", + name="tensorflow", + verbose=True, + msg="https://github.com/ultralytics/ultralytics/issues/5161", + ) + import onnx2tf + + f = Path(str(self.file).replace(self.file.suffix, "_saved_model")) + if f.is_dir(): + shutil.rmtree(f) # delete output folder + + # Pre-download calibration file to fix https://github.com/PINTO0309/onnx2tf/issues/545 + onnx2tf_file = Path("calibration_image_sample_data_20x128x128x3_float32.npy") + if not onnx2tf_file.exists(): + attempt_download_asset(f"{onnx2tf_file}.zip", unzip=True, delete=True) + + # Export to ONNX + self.args.simplify = True + f_onnx, _ = self.export_onnx() + + # Export to TF + np_data = None + if self.args.int8: + tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file + if self.args.data: + f.mkdir() + images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)] + images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute( + 0, 2, 3, 1 + ) + np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC + np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]] + + LOGGER.info(f"{prefix} starting TFLite export with onnx2tf {onnx2tf.__version__}...") + keras_model = onnx2tf.convert( + input_onnx_file_path=f_onnx, + output_folder_path=str(f), + not_use_onnxsim=True, + verbosity="error", # note INT8-FP16 activation bug https://github.com/ultralytics/ultralytics/issues/15873 + output_integer_quantized_tflite=self.args.int8, + quant_type="per-tensor", # "per-tensor" (faster) or "per-channel" (slower but more accurate) + custom_input_op_name_np_data_path=np_data, + disable_group_convolution=True, # for end-to-end model compatibility + enable_batchmatmul_unfold=True, # for end-to-end model compatibility + ) + yaml_save(f / "metadata.yaml", self.metadata) # add metadata.yaml + + # Remove/rename TFLite models + if self.args.int8: + tmp_file.unlink(missing_ok=True) + for file in f.rglob("*_dynamic_range_quant.tflite"): + file.rename(file.with_name(file.stem.replace("_dynamic_range_quant", "_int8") + file.suffix)) + for file in f.rglob("*_integer_quant_with_int16_act.tflite"): + file.unlink() # delete extra fp16 activation TFLite files + + # Add TFLite metadata + for file in f.rglob("*.tflite"): + f.unlink() if "quant_with_int16_act.tflite" in str(f) else self._add_tflite_metadata(file) + + return str(f), keras_model # or keras_model = tf.saved_model.load(f, tags=None, options=None) + + @try_export + def export_pb(self, keras_model, prefix=colorstr("TensorFlow GraphDef:")): + """YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow.""" + import tensorflow as tf # noqa + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 # noqa + + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + f = self.file.with_suffix(".pb") + + m = tf.function(lambda x: keras_model(x)) # full model + m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) + frozen_func = convert_variables_to_constants_v2(m) + frozen_func.graph.as_graph_def() + tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) + return f, None + + @try_export + def export_tflite(self, keras_model, nms, agnostic_nms, prefix=colorstr("TensorFlow Lite:")): + """YOLO TensorFlow Lite export.""" + # BUG https://github.com/ultralytics/ultralytics/issues/13436 + import tensorflow as tf # noqa + + LOGGER.info(f"\n{prefix} starting export with tensorflow {tf.__version__}...") + saved_model = Path(str(self.file).replace(self.file.suffix, "_saved_model")) + if self.args.int8: + f = saved_model / f"{self.file.stem}_int8.tflite" # fp32 in/out + elif self.args.half: + f = saved_model / f"{self.file.stem}_float16.tflite" # fp32 in/out + else: + f = saved_model / f"{self.file.stem}_float32.tflite" + return str(f), None + + @try_export + def export_edgetpu(self, tflite_model="", prefix=colorstr("Edge TPU:")): + """YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/.""" + LOGGER.warning(f"{prefix} WARNING ⚠️ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185") + + cmd = "edgetpu_compiler --version" + help_url = "https://coral.ai/docs/edgetpu/compiler/" + assert LINUX, f"export only supported on Linux. See {help_url}" + if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0: + LOGGER.info(f"\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}") + for c in ( + "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -", + 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' + "sudo tee /etc/apt/sources.list.d/coral-edgetpu.list", + "sudo apt-get update", + "sudo apt-get install edgetpu-compiler", + ): + subprocess.run(c if is_sudo_available() else c.replace("sudo ", ""), shell=True, check=True) + ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] + + LOGGER.info(f"\n{prefix} starting export with Edge TPU compiler {ver}...") + f = str(tflite_model).replace(".tflite", "_edgetpu.tflite") # Edge TPU model + + cmd = ( + "edgetpu_compiler " + f'--out_dir "{Path(f).parent}" ' + "--show_operations " + "--search_delegate " + "--delegate_search_step 30 " + "--timeout_sec 180 " + f'"{tflite_model}"' + ) + LOGGER.info(f"{prefix} running '{cmd}'") + subprocess.run(cmd, shell=True) + self._add_tflite_metadata(f) + return f, None + + @try_export + def export_tfjs(self, prefix=colorstr("TensorFlow.js:")): + """YOLO TensorFlow.js export.""" + check_requirements("tensorflowjs") + if ARM64: + # Fix error: `np.object` was a deprecated alias for the builtin `object` when exporting to TF.js on ARM64 + check_requirements("numpy==1.23.5") + import tensorflow as tf + import tensorflowjs as tfjs # noqa + + LOGGER.info(f"\n{prefix} starting export with tensorflowjs {tfjs.__version__}...") + f = str(self.file).replace(self.file.suffix, "_web_model") # js dir + f_pb = str(self.file.with_suffix(".pb")) # *.pb path + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(f_pb, "rb") as file: + gd.ParseFromString(file.read()) + outputs = ",".join(gd_outputs(gd)) + LOGGER.info(f"\n{prefix} output node names: {outputs}") + + quantization = "--quantize_float16" if self.args.half else "--quantize_uint8" if self.args.int8 else "" + with spaces_in_path(f_pb) as fpb_, spaces_in_path(f) as f_: # exporter can not handle spaces in path + cmd = ( + "tensorflowjs_converter " + f'--input_format=tf_frozen_model {quantization} --output_node_names={outputs} "{fpb_}" "{f_}"' + ) + LOGGER.info(f"{prefix} running '{cmd}'") + subprocess.run(cmd, shell=True) + + if " " in f: + LOGGER.warning(f"{prefix} WARNING ⚠️ your model may not work correctly with spaces in path '{f}'.") + + # Add metadata + yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml + return f, None + + @try_export + def export_imx(self, prefix=colorstr("IMX:")): + """YOLO IMX export.""" + gptq = False + assert LINUX, ( + "export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter" + ) + if getattr(self.model, "end2end", False): + raise ValueError("IMX export is not supported for end2end models.") + if "C2f" not in self.model.__str__(): + raise ValueError("IMX export is only supported for YOLOv8n detection models") + check_requirements(("model-compression-toolkit==2.1.1", "sony-custom-layers==0.2.0", "tensorflow==2.12.0")) + check_requirements("imx500-converter[pt]==3.14.3") # Separate requirements for imx500-converter + + import model_compression_toolkit as mct + import onnx + from sony_custom_layers.pytorch.object_detection.nms import multiclass_nms + + try: + out = subprocess.run( + ["java", "--version"], check=True, capture_output=True + ) # Java 17 is required for imx500-converter + if "openjdk 17" not in str(out.stdout): + raise FileNotFoundError + except FileNotFoundError: + c = ["apt", "install", "-y", "openjdk-17-jdk", "openjdk-17-jre"] + if is_sudo_available(): + c.insert(0, "sudo") + subprocess.run(c, check=True) + + def representative_dataset_gen(dataloader=self.get_int8_calibration_dataloader(prefix)): + for batch in dataloader: + img = batch["img"] + img = img / 255.0 + yield [img] + + tpc = mct.get_target_platform_capabilities( + fw_name="pytorch", target_platform_name="imx500", target_platform_version="v1" + ) + + config = mct.core.CoreConfig( + mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10), + quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True), + ) + + resource_utilization = mct.core.ResourceUtilization(weights_memory=3146176 * 0.76) + + quant_model = ( + mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization + model=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + gptq_config=mct.gptq.get_pytorch_gptq_config(n_epochs=1000, use_hessian_based_weights=False), + core_config=config, + target_platform_capabilities=tpc, + )[0] + if gptq + else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization + in_module=self.model, + representative_data_gen=representative_dataset_gen, + target_resource_utilization=resource_utilization, + core_config=config, + target_platform_capabilities=tpc, + )[0] + ) + + class NMSWrapper(torch.nn.Module): + def __init__( + self, + model: torch.nn.Module, + score_threshold: float = 0.001, + iou_threshold: float = 0.7, + max_detections: int = 300, + ): + """ + Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers. + + Args: + model (nn.Module): Model instance. + score_threshold (float): Score threshold for non-maximum suppression. + iou_threshold (float): Intersection over union threshold for non-maximum suppression. + max_detections (float): The number of detections to return. + """ + super().__init__() + self.model = model + self.score_threshold = score_threshold + self.iou_threshold = iou_threshold + self.max_detections = max_detections + + def forward(self, images): + # model inference + outputs = self.model(images) + + boxes = outputs[0] + scores = outputs[1] + nms = multiclass_nms( + boxes=boxes, + scores=scores, + score_threshold=self.score_threshold, + iou_threshold=self.iou_threshold, + max_detections=self.max_detections, + ) + return nms + + quant_model = NMSWrapper( + model=quant_model, + score_threshold=self.args.conf or 0.001, + iou_threshold=self.args.iou, + max_detections=self.args.max_det, + ).to(self.device) + + f = Path(str(self.file).replace(self.file.suffix, "_imx_model")) + f.mkdir(exist_ok=True) + onnx_model = f / Path(str(self.file).replace(self.file.suffix, "_imx.onnx")) # js dir + mct.exporter.pytorch_export_model( + model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen + ) + + model_onnx = onnx.load(onnx_model) # load onnx model + for k, v in self.metadata.items(): + meta = model_onnx.metadata_props.add() + meta.key, meta.value = k, str(v) + + onnx.save(model_onnx, onnx_model) + + subprocess.run( + ["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"], + check=True, + ) + + # Needed for imx models. + with open(f / "labels.txt", "w") as file: + file.writelines([f"{name}\n" for _, name in self.model.names.items()]) + + return f, None + + def _add_tflite_metadata(self, file): + """Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata.""" + import flatbuffers + + try: + # TFLite Support bug https://github.com/tensorflow/tflite-support/issues/954#issuecomment-2108570845 + from tensorflow_lite_support.metadata import metadata_schema_py_generated as schema # noqa + from tensorflow_lite_support.metadata.python import metadata # noqa + except ImportError: # ARM64 systems may not have the 'tensorflow_lite_support' package available + from tflite_support import metadata # noqa + from tflite_support import metadata_schema_py_generated as schema # noqa + + # Create model info + model_meta = schema.ModelMetadataT() + model_meta.name = self.metadata["description"] + model_meta.version = self.metadata["version"] + model_meta.author = self.metadata["author"] + model_meta.license = self.metadata["license"] + + # Label file + tmp_file = Path(file).parent / "temp_meta.txt" + with open(tmp_file, "w") as f: + f.write(str(self.metadata)) + + label_file = schema.AssociatedFileT() + label_file.name = tmp_file.name + label_file.type = schema.AssociatedFileType.TENSOR_AXIS_LABELS + + # Create input info + input_meta = schema.TensorMetadataT() + input_meta.name = "image" + input_meta.description = "Input image to be detected." + input_meta.content = schema.ContentT() + input_meta.content.contentProperties = schema.ImagePropertiesT() + input_meta.content.contentProperties.colorSpace = schema.ColorSpaceType.RGB + input_meta.content.contentPropertiesType = schema.ContentProperties.ImageProperties + + # Create output info + output1 = schema.TensorMetadataT() + output1.name = "output" + output1.description = "Coordinates of detected objects, class labels, and confidence score" + output1.associatedFiles = [label_file] + if self.model.task == "segment": + output2 = schema.TensorMetadataT() + output2.name = "output" + output2.description = "Mask protos" + output2.associatedFiles = [label_file] + + # Create subgraph info + subgraph = schema.SubGraphMetadataT() + subgraph.inputTensorMetadata = [input_meta] + subgraph.outputTensorMetadata = [output1, output2] if self.model.task == "segment" else [output1] + model_meta.subgraphMetadata = [subgraph] + + b = flatbuffers.Builder(0) + b.Finish(model_meta.Pack(b), metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + populator = metadata.MetadataPopulator.with_model_file(str(file)) + populator.load_metadata_buffer(metadata_buf) + populator.load_associated_files([str(tmp_file)]) + populator.populate() + tmp_file.unlink() + + def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")): + """YOLO CoreML pipeline.""" + import coremltools as ct # noqa + + LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...") + _, _, h, w = list(self.im.shape) # BCHW + + # Output shapes + spec = model.get_spec() + out0, out1 = iter(spec.description.output) + if MACOS: + from PIL import Image + + img = Image.new("RGB", (w, h)) # w=192, h=320 + out = model.predict({"image": img}) + out0_shape = out[out0.name].shape # (3780, 80) + out1_shape = out[out1.name].shape # (3780, 4) + else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y + out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80) + out1_shape = self.output_shape[2], 4 # (3780, 4) + + # Checks + names = self.metadata["names"] + nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height + _, nc = out0_shape # number of anchors, number of classes + assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check + + # Define output shapes (missing) + out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80) + out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4) + + # Model from spec + model = ct.models.MLModel(spec, weights_dir=weights_dir) + + # 3. Create NMS protobuf + nms_spec = ct.proto.Model_pb2.Model() + nms_spec.specificationVersion = 5 + for i in range(2): + decoder_output = model._spec.description.output[i].SerializeToString() + nms_spec.description.input.add() + nms_spec.description.input[i].ParseFromString(decoder_output) + nms_spec.description.output.add() + nms_spec.description.output[i].ParseFromString(decoder_output) + + nms_spec.description.output[0].name = "confidence" + nms_spec.description.output[1].name = "coordinates" + + output_sizes = [nc, 4] + for i in range(2): + ma_type = nms_spec.description.output[i].type.multiArrayType + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[0].lowerBound = 0 + ma_type.shapeRange.sizeRanges[0].upperBound = -1 + ma_type.shapeRange.sizeRanges.add() + ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i] + ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i] + del ma_type.shape[:] + + nms = nms_spec.nonMaximumSuppression + nms.confidenceInputFeatureName = out0.name # 1x507x80 + nms.coordinatesInputFeatureName = out1.name # 1x507x4 + nms.confidenceOutputFeatureName = "confidence" + nms.coordinatesOutputFeatureName = "coordinates" + nms.iouThresholdInputFeatureName = "iouThreshold" + nms.confidenceThresholdInputFeatureName = "confidenceThreshold" + nms.iouThreshold = 0.45 + nms.confidenceThreshold = 0.25 + nms.pickTop.perClass = True + nms.stringClassLabels.vector.extend(names.values()) + nms_model = ct.models.MLModel(nms_spec) + + # 4. Pipeline models together + pipeline = ct.models.pipeline.Pipeline( + input_features=[ + ("image", ct.models.datatypes.Array(3, ny, nx)), + ("iouThreshold", ct.models.datatypes.Double()), + ("confidenceThreshold", ct.models.datatypes.Double()), + ], + output_features=["confidence", "coordinates"], + ) + pipeline.add_model(model) + pipeline.add_model(nms_model) + + # Correct datatypes + pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString()) + pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString()) + pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString()) + + # Update metadata + pipeline.spec.specificationVersion = 5 + pipeline.spec.description.metadata.userDefined.update( + {"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)} + ) + + # Save the model + model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) + model.input_description["image"] = "Input image" + model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})" + model.input_description["confidenceThreshold"] = ( + f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})" + ) + model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")' + model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)" + LOGGER.info(f"{prefix} pipeline success") + return model + + def add_callback(self, event: str, callback): + """Appends the given callback.""" + self.callbacks[event].append(callback) + + def run_callbacks(self, event: str): + """Execute all callbacks for a given event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + +class IOSDetectModel(torch.nn.Module): + """Wrap an Ultralytics YOLO model for Apple iOS CoreML export.""" + + def __init__(self, model, im): + """Initialize the IOSDetectModel class with a YOLO model and example image.""" + super().__init__() + _, _, h, w = im.shape # batch, channel, height, width + self.model = model + self.nc = len(model.names) # number of classes + if w == h: + self.normalize = 1.0 / w # scalar + else: + self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller) + + def forward(self, x): + """Normalize predictions of object detection model with input size-dependent factors.""" + xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1) + return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4) diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9b38690031258b8971c39f7af21a16e63b567a9b --- /dev/null +++ b/ultralytics/engine/model.py @@ -0,0 +1,1173 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import inspect +from pathlib import Path +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from PIL import Image + +from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir +from ultralytics.engine.results import Results +from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession +from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load +from ultralytics.utils import ( + ARGV, + ASSETS, + DEFAULT_CFG_DICT, + LOGGER, + RANK, + SETTINGS, + callbacks, + checks, + emojis, + yaml_load, +) + + +class Model(nn.Module): + """ + A base class for implementing YOLO models, unifying APIs across different model types. + + This class provides a common interface for various operations related to YOLO models, such as training, + validation, prediction, exporting, and benchmarking. It handles different types of models, including those + loaded from local files, Ultralytics HUB, or Triton Server. + + Attributes: + callbacks (Dict): A dictionary of callback functions for various events during model operations. + predictor (BasePredictor): The predictor object used for making predictions. + model (nn.Module): The underlying PyTorch model. + trainer (BaseTrainer): The trainer object used for training the model. + ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file. + cfg (str): The configuration of the model if loaded from a *.yaml file. + ckpt_path (str): The path to the checkpoint file. + overrides (Dict): A dictionary of overrides for model configuration. + metrics (Dict): The latest training/validation metrics. + session (HUBTrainingSession): The Ultralytics HUB session, if applicable. + task (str): The type of task the model is intended for. + model_name (str): The name of the model. + + Methods: + __call__: Alias for the predict method, enabling the model instance to be callable. + _new: Initializes a new model based on a configuration file. + _load: Loads a model from a checkpoint file. + _check_is_pytorch_model: Ensures that the model is a PyTorch model. + reset_weights: Resets the model's weights to their initial state. + load: Loads model weights from a specified file. + save: Saves the current state of the model to a file. + info: Logs or returns information about the model. + fuse: Fuses Conv2d and BatchNorm2d layers for optimized inference. + predict: Performs object detection predictions. + track: Performs object tracking. + val: Validates the model on a dataset. + benchmark: Benchmarks the model on various export formats. + export: Exports the model to different formats. + train: Trains the model on a dataset. + tune: Performs hyperparameter tuning. + _apply: Applies a function to the model's tensors. + add_callback: Adds a callback function for an event. + clear_callback: Clears all callbacks for an event. + reset_callbacks: Resets all callbacks to their default functions. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict("image.jpg") + >>> model.train(data="coco8.yaml", epochs=3) + >>> metrics = model.val() + >>> model.export(format="onnx") + """ + + def __init__( + self, + model: Union[str, Path] = "yolo11n.pt", + task: str = None, + verbose: bool = False, + ) -> None: + """ + Initializes a new instance of the YOLO model class. + + This constructor sets up the model based on the provided model path or name. It handles various types of + model sources, including local files, Ultralytics HUB models, and Triton Server models. The method + initializes several important attributes of the model and prepares it for operations like training, + prediction, or export. + + Args: + model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a + model name from Ultralytics HUB, or a Triton Server model. + task (str | None): The task type associated with the YOLO model, specifying its application domain. + verbose (bool): If True, enables verbose output during the model's initialization and subsequent + operations. + + Raises: + FileNotFoundError: If the specified model file does not exist or is inaccessible. + ValueError: If the model file or configuration is invalid or unsupported. + ImportError: If required dependencies for specific model types (like HUB SDK) are not installed. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = Model("path/to/model.yaml", task="detect") + >>> model = Model("hub_model", verbose=True) + """ + super().__init__() + self.callbacks = callbacks.get_default_callbacks() + self.predictor = None # reuse predictor + self.model = None # model object + self.trainer = None # trainer object + self.ckpt = {} # if loaded from *.pt + self.cfg = None # if loaded from *.yaml + self.ckpt_path = None + self.overrides = {} # overrides for trainer object + self.metrics = None # validation/training metrics + self.session = None # HUB session + self.task = task # task type + model = str(model).strip() + + # Check if Ultralytics HUB model from https://hub.ultralytics.com + if self.is_hub_model(model): + # Fetch model from HUB + checks.check_requirements("hub-sdk>=0.0.12") + session = HUBTrainingSession.create_session(model) + model = session.model_file + if session.train_args: # training sent from HUB + self.session = session + + # Check if Triton Server model + elif self.is_triton_model(model): + self.model_name = self.model = model + self.overrides["task"] = task or "detect" # set `task=detect` if not explicitly set + return + + # Load or create new YOLO model + if Path(model).suffix in {".yaml", ".yml"}: + self._new(model, task=task, verbose=verbose) + else: + self._load(model, task=task) + + # Delete super().training for accessing self.model.training + del self.training + + def __call__( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs: Any, + ) -> list: + """ + Alias for the predict method, enabling the model instance to be callable for predictions. + + This method simplifies the process of making predictions by allowing the model instance to be called + directly with the required arguments. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of + the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch + tensor, or a list/tuple of these. + stream (bool): If True, treat the input source as a continuous stream for predictions. + **kwargs: Additional keyword arguments to configure the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model("https://ultralytics.com/images/bus.jpg") + >>> for r in results: + ... print(f"Detected {len(r)} objects in image") + """ + return self.predict(source, stream, **kwargs) + + @staticmethod + def is_triton_model(model: str) -> bool: + """ + Checks if the given model string is a Triton Server URL. + + This static method determines whether the provided model string represents a valid Triton Server URL by + parsing its components using urllib.parse.urlsplit(). + + Args: + model (str): The model string to be checked. + + Returns: + (bool): True if the model string is a valid Triton Server URL, False otherwise. + + Examples: + >>> Model.is_triton_model("http://localhost:8000/v2/models/yolov8n") + True + >>> Model.is_triton_model("yolo11n.pt") + False + """ + from urllib.parse import urlsplit + + url = urlsplit(model) + return url.netloc and url.path and url.scheme in {"http", "grpc"} + + @staticmethod + def is_hub_model(model: str) -> bool: + """ + Check if the provided model is an Ultralytics HUB model. + + This static method determines whether the given model string represents a valid Ultralytics HUB model + identifier. + + Args: + model (str): The model string to check. + + Returns: + (bool): True if the model is a valid Ultralytics HUB model, False otherwise. + + Examples: + >>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL") + True + >>> Model.is_hub_model("yolo11n.pt") + False + """ + return model.startswith(f"{HUB_WEB_ROOT}/models/") + + def _new(self, cfg: str, task=None, model=None, verbose=False) -> None: + """ + Initializes a new model and infers the task type from the model definitions. + + This method creates a new model instance based on the provided configuration file. It loads the model + configuration, infers the task type if not specified, and initializes the model using the appropriate + class from the task map. + + Args: + cfg (str): Path to the model configuration file in YAML format. + task (str | None): The specific task for the model. If None, it will be inferred from the config. + model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating + a new one. + verbose (bool): If True, displays model information during loading. + + Raises: + ValueError: If the configuration file is invalid or the task cannot be inferred. + ImportError: If the required dependencies for the specified task are not installed. + + Examples: + >>> model = Model() + >>> model._new("yolov8n.yaml", task="detect", verbose=True) + """ + cfg_dict = yaml_model_load(cfg) + self.cfg = cfg + self.task = task or guess_model_task(cfg_dict) + self.model = (model or self._smart_load("model"))(cfg_dict, verbose=verbose and RANK == -1) # build model + self.overrides["model"] = self.cfg + self.overrides["task"] = self.task + + # Below added to allow export from YAMLs + self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) + self.model.task = self.task + self.model_name = cfg + + def _load(self, weights: str, task=None) -> None: + """ + Loads a model from a checkpoint file or initializes it from a weights file. + + This method handles loading models from either .pt checkpoint files or other weight file formats. It sets + up the model, task, and related attributes based on the loaded weights. + + Args: + weights (str): Path to the model weights file to be loaded. + task (str | None): The task associated with the model. If None, it will be inferred from the model. + + Raises: + FileNotFoundError: If the specified weights file does not exist or is inaccessible. + ValueError: If the weights file format is unsupported or invalid. + + Examples: + >>> model = Model() + >>> model._load("yolo11n.pt") + >>> model._load("path/to/weights.pth", task="detect") + """ + if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): + weights = checks.check_file(weights, download_dir=SETTINGS["weights_dir"]) # download and return local file + weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt + + if Path(weights).suffix == ".pt": + self.model, self.ckpt = attempt_load_one_weight(weights) + self.task = self.model.args["task"] + self.overrides = self.model.args = self._reset_ckpt_args(self.model.args) + self.ckpt_path = self.model.pt_path + else: + weights = checks.check_file(weights) # runs in all cases, not redundant with above call + self.model, self.ckpt = weights, None + self.task = task or guess_model_task(weights) + self.ckpt_path = weights + self.overrides["model"] = weights + self.overrides["task"] = self.task + self.model_name = weights + + def _check_is_pytorch_model(self) -> None: + """ + Checks if the model is a PyTorch model and raises a TypeError if it's not. + + This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that + certain operations that require a PyTorch model are only performed on compatible model types. + + Raises: + TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed + information about supported model formats and operations. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model._check_is_pytorch_model() # No error raised + >>> model = Model("yolov8n.onnx") + >>> model._check_is_pytorch_model() # Raises TypeError + """ + pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" + pt_module = isinstance(self.model, nn.Module) + if not (pt_module or pt_str): + raise TypeError( + f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " + f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " + f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " + f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " + f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'" + ) + + def reset_weights(self) -> "Model": + """ + Resets the model's weights to their initial state. + + This method iterates through all modules in the model and resets their parameters if they have a + 'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, + enabling them to be updated during training. + + Returns: + (Model): The instance of the class with reset weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.reset_weights() + """ + self._check_is_pytorch_model() + for m in self.model.modules(): + if hasattr(m, "reset_parameters"): + m.reset_parameters() + for p in self.model.parameters(): + p.requires_grad = True + return self + + def load(self, weights: Union[str, Path] = "yolo11n.pt") -> "Model": + """ + Loads parameters from the specified weights file into the model. + + This method supports loading weights from a file or directly from a weights object. It matches parameters by + name and shape and transfers them to the model. + + Args: + weights (Union[str, Path]): Path to the weights file or a weights object. + + Returns: + (Model): The instance of the class with loaded weights. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model() + >>> model.load("yolo11n.pt") + >>> model.load(Path("path/to/weights.pt")) + """ + self._check_is_pytorch_model() + if isinstance(weights, (str, Path)): + self.overrides["pretrained"] = weights # remember the weights for DDP training + weights, self.ckpt = attempt_load_one_weight(weights) + self.model.load(weights) + return self + + def save(self, filename: Union[str, Path] = "saved_model.pt") -> None: + """ + Saves the current model state to a file. + + This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as + the date, Ultralytics version, license information, and a link to the documentation. + + Args: + filename (Union[str, Path]): The name of the file to save the model to. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.save("my_model.pt") + """ + self._check_is_pytorch_model() + from copy import deepcopy + from datetime import datetime + + from ultralytics import __version__ + + updates = { + "model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model, + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + torch.save({**self.ckpt, **updates}, filename) + + def info(self, detailed: bool = False, verbose: bool = True): + """ + Logs or returns model information. + + This method provides an overview or detailed information about the model, depending on the arguments + passed. It can control the verbosity of the output and return the information as a list. + + Args: + detailed (bool): If True, shows detailed information about the model layers and parameters. + verbose (bool): If True, prints the information. If False, returns the information as a list. + + Returns: + (List[str]): A list of strings containing various types of information about the model, including + model summary, layer details, and parameter counts. Empty if verbose is True. + + Raises: + TypeError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.info() # Prints model summary + >>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list + """ + self._check_is_pytorch_model() + return self.model.info(detailed=detailed, verbose=verbose) + + def fuse(self): + """ + Fuses Conv2d and BatchNorm2d layers in the model for optimized inference. + + This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers + into a single layer. This fusion can significantly improve inference speed by reducing the number of + operations and memory accesses required during forward passes. + + The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and + bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that + performs both convolution and normalization in one step. + + Raises: + TypeError: If the model is not a PyTorch nn.Module. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model.fuse() + >>> # Model is now fused and ready for optimized inference + """ + self._check_is_pytorch_model() + self.model.fuse() + + def embed( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + **kwargs: Any, + ) -> list: + """ + Generates image embeddings based on the provided source. + + This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image + source. It allows customization of the embedding process through various keyword arguments. + + Args: + source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for + generating embeddings. Can be a file path, URL, PIL image, numpy array, etc. + stream (bool): If True, predictions are streamed. + **kwargs: Additional keyword arguments for configuring the embedding process. + + Returns: + (List[torch.Tensor]): A list containing the image embeddings. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> image = "https://ultralytics.com/images/bus.jpg" + >>> embeddings = model.embed(image) + >>> print(embeddings[0].shape) + """ + if not kwargs.get("embed"): + kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed + return self.predict(source, stream, **kwargs) + + def predict( + self, + source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + predictor=None, + **kwargs: Any, + ) -> List[Results]: + """ + Performs predictions on the given image source using the YOLO model. + + This method facilitates the prediction process, allowing various configurations through keyword arguments. + It supports predictions with custom predictors or the default predictor method. The method handles different + types of image sources and can operate in a streaming mode. + + Args: + source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source + of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL + images, numpy arrays, and torch tensors. + stream (bool): If True, treats the input source as a continuous stream for predictions. + predictor (BasePredictor | None): An instance of a custom predictor class for making predictions. + If None, the method uses a default predictor. + **kwargs: Additional keyword arguments for configuring the prediction process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a + Results object. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.predict(source="path/to/image.jpg", conf=0.25) + >>> for r in results: + ... print(r.boxes.data) # print detection bounding boxes + + Notes: + - If 'source' is not provided, it defaults to the ASSETS constant with a warning. + - The method sets up a new predictor if not already present and updates its arguments with each call. + - For SAM-type models, 'prompts' can be passed as a keyword argument. + """ + if source is None: + source = ASSETS + LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + + is_cli = (ARGV[0].endswith("yolo") or ARGV[0].endswith("ultralytics")) and any( + x in ARGV for x in ("predict", "track", "mode=predict", "mode=track") + ) + + custom = {"conf": 0.25, "batch": 1, "save": is_cli, "mode": "predict"} # method defaults + args = {**self.overrides, **custom, **kwargs} # highest priority args on the right + prompts = args.pop("prompts", None) # for SAM-type models + + if not self.predictor: + self.predictor = (predictor or self._smart_load("predictor"))(overrides=args, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=is_cli) + else: # only update args if predictor is already setup + self.predictor.args = get_cfg(self.predictor.args, args) + if "project" in args or "name" in args: + self.predictor.save_dir = get_save_dir(self.predictor.args) + if prompts and hasattr(self.predictor, "set_prompts"): # for SAM-type models + self.predictor.set_prompts(prompts) + return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) + + def track( + self, + source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None, + stream: bool = False, + persist: bool = False, + **kwargs: Any, + ) -> List[Results]: + """ + Conducts object tracking on the specified input source using the registered trackers. + + This method performs object tracking using the model's predictors and optionally registered trackers. It handles + various input sources such as file paths or video streams, and supports customization through keyword arguments. + The method registers trackers if not already present and can persist them between calls. + + Args: + source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object + tracking. Can be a file path, URL, or video stream. + stream (bool): If True, treats the input source as a continuous video stream. Defaults to False. + persist (bool): If True, persists trackers between different calls to this method. Defaults to False. + **kwargs: Additional keyword arguments for configuring the tracking process. + + Returns: + (List[ultralytics.engine.results.Results]): A list of tracking results, each a Results object. + + Raises: + AttributeError: If the predictor does not have registered trackers. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.track(source="path/to/video.mp4", show=True) + >>> for r in results: + ... print(r.boxes.id) # print tracking IDs + + Notes: + - This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking. + - The tracking mode is explicitly set in the keyword arguments. + - Batch size is set to 1 for tracking in videos. + """ + if not hasattr(self.predictor, "trackers"): + from ultralytics.trackers import register_tracker + + register_tracker(self, persist) + kwargs["conf"] = kwargs.get("conf") or 0.1 # ByteTrack-based method needs low confidence predictions as input + kwargs["batch"] = kwargs.get("batch") or 1 # batch-size 1 for tracking in videos + kwargs["mode"] = "track" + return self.predict(source=source, stream=stream, **kwargs) + + def val( + self, + validator=None, + **kwargs: Any, + ): + """ + Validates the model using a specified dataset and validation configuration. + + This method facilitates the model validation process, allowing for customization through various settings. It + supports validation with a custom validator or the default validation approach. The method combines default + configurations, method-specific defaults, and user-provided arguments to configure the validation process. + + Args: + validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for + validating the model. + **kwargs: Arbitrary keyword arguments for customizing the validation process. + + Returns: + (ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.val(data="coco8.yaml", imgsz=640) + >>> print(results.box.map) # Print mAP50-95 + """ + custom = {"rect": True} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right + + validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks) + validator(model=self.model) + self.metrics = validator.metrics + return validator.metrics + + def benchmark( + self, + **kwargs: Any, + ): + """ + Benchmarks the model across various export formats to evaluate performance. + + This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc. + It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is + configured using a combination of default configuration values, model-specific arguments, method-specific + defaults, and any additional user-provided keyword arguments. + + Args: + **kwargs: Arbitrary keyword arguments to customize the benchmarking process. These are combined with + default configurations, model-specific arguments, and method defaults. Common options include: + - data (str): Path to the dataset for benchmarking. + - imgsz (int | List[int]): Image size for benchmarking. + - half (bool): Whether to use half-precision (FP16) mode. + - int8 (bool): Whether to use int8 precision mode. + - device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda'). + - verbose (bool): Whether to print detailed benchmark information. + + Returns: + (Dict): A dictionary containing the results of the benchmarking process, including metrics for + different export formats. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.benchmark(data="coco8.yaml", imgsz=640, half=True) + >>> print(results) + """ + self._check_is_pytorch_model() + from ultralytics.utils.benchmarks import benchmark + + custom = {"verbose": False} # method defaults + args = {**DEFAULT_CFG_DICT, **self.model.args, **custom, **kwargs, "mode": "benchmark"} + return benchmark( + model=self, + data=kwargs.get("data"), # if no 'data' argument passed set data=None for default datasets + imgsz=args["imgsz"], + half=args["half"], + int8=args["int8"], + device=args["device"], + verbose=kwargs.get("verbose"), + ) + + def export( + self, + **kwargs: Any, + ) -> str: + """ + Exports the model to a different format suitable for deployment. + + This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment + purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method + defaults, and any additional arguments provided. + + Args: + **kwargs: Arbitrary keyword arguments to customize the export process. These are combined with + the model's overrides and method defaults. Common arguments include: + format (str): Export format (e.g., 'onnx', 'engine', 'coreml'). + half (bool): Export model in half-precision. + int8 (bool): Export model in int8 precision. + device (str): Device to run the export on. + workspace (int): Maximum memory workspace size for TensorRT engines. + nms (bool): Add Non-Maximum Suppression (NMS) module to model. + simplify (bool): Simplify ONNX model. + + Returns: + (str): The path to the exported model file. + + Raises: + AssertionError: If the model is not a PyTorch model. + ValueError: If an unsupported export format is specified. + RuntimeError: If the export process fails due to errors. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.export(format="onnx", dynamic=True, simplify=True) + 'path/to/exported/model.onnx' + """ + self._check_is_pytorch_model() + from .exporter import Exporter + + custom = { + "imgsz": self.model.args["imgsz"], + "batch": 1, + "data": None, + "device": None, # reset to avoid multi-GPU errors + "verbose": False, + } # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right + return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) + + def train( + self, + trainer=None, + **kwargs: Any, + ): + """ + Trains the model using the specified dataset and training configuration. + + This method facilitates model training with a range of customizable settings. It supports training with a + custom trainer or the default training approach. The method handles scenarios such as resuming training + from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training. + + When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training + arguments and warns if local arguments are provided. It checks for pip updates and combines default + configurations, method-specific defaults, and user-provided arguments to configure the training process. + + Args: + trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default. + **kwargs: Arbitrary keyword arguments for training configuration. Common options include: + data (str): Path to dataset configuration file. + epochs (int): Number of training epochs. + batch_size (int): Batch size for training. + imgsz (int): Input image size. + device (str): Device to run training on (e.g., 'cuda', 'cpu'). + workers (int): Number of worker threads for data loading. + optimizer (str): Optimizer to use for training. + lr0 (float): Initial learning rate. + patience (int): Epochs to wait for no observable improvement for early stopping of training. + + Returns: + (Dict | None): Training metrics if available and training is successful; otherwise, None. + + Raises: + AssertionError: If the model is not a PyTorch model. + PermissionError: If there is a permission issue with the HUB session. + ModuleNotFoundError: If the HUB SDK is not installed. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.train(data="coco8.yaml", epochs=3) + """ + self._check_is_pytorch_model() + if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model + if any(kwargs): + LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.") + kwargs = self.session.train_args # overwrite kwargs + + checks.check_pip_update_available() + + overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides + custom = { + # NOTE: handle the case when 'cfg' includes 'data'. + "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task], + "model": self.overrides["model"], + "task": self.task, + } # method defaults + args = {**overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + if args.get("resume"): + args["resume"] = self.ckpt_path + + self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks) + if not args.get("resume"): # manually set model only if not resuming + self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) + self.model = self.trainer.model + + self.trainer.hub_session = self.session # attach optional HUB session + self.trainer.train() + # Update model and cfg after training + if RANK in {-1, 0}: + ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last + self.model, self.ckpt = attempt_load_one_weight(ckpt) + self.overrides = self.model.args + self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP + return self.metrics + + def tune( + self, + use_ray=False, + iterations=10, + *args: Any, + **kwargs: Any, + ): + """ + Conducts hyperparameter tuning for the model, with an option to use Ray Tune. + + This method supports two modes of hyperparameter tuning: using Ray Tune or a custom tuning method. + When Ray Tune is enabled, it leverages the 'run_ray_tune' function from the ultralytics.utils.tuner module. + Otherwise, it uses the internal 'Tuner' class for tuning. The method combines default, overridden, and + custom arguments to configure the tuning process. + + Args: + use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False. + iterations (int): The number of tuning iterations to perform. Defaults to 10. + *args: Variable length argument list for additional arguments. + **kwargs: Arbitrary keyword arguments. These are combined with the model's overrides and defaults. + + Returns: + (Dict): A dictionary containing the results of the hyperparameter search. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> results = model.tune(use_ray=True, iterations=20) + >>> print(results) + """ + self._check_is_pytorch_model() + if use_ray: + from ultralytics.utils.tuner import run_ray_tune + + return run_ray_tune(self, max_samples=iterations, *args, **kwargs) + else: + from .tuner import Tuner + + custom = {} # method defaults + args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right + return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) + + def _apply(self, fn) -> "Model": + """ + Applies a function to model tensors that are not parameters or registered buffers. + + This method extends the functionality of the parent class's _apply method by additionally resetting the + predictor and updating the device in the model's overrides. It's typically used for operations like + moving the model to a different device or changing its precision. + + Args: + fn (Callable): A function to be applied to the model's tensors. This is typically a method like + to(), cpu(), cuda(), half(), or float(). + + Returns: + (Model): The model instance with the function applied and updated attributes. + + Raises: + AssertionError: If the model is not a PyTorch model. + + Examples: + >>> model = Model("yolo11n.pt") + >>> model = model._apply(lambda t: t.cuda()) # Move model to GPU + """ + self._check_is_pytorch_model() + self = super()._apply(fn) # noqa + self.predictor = None # reset predictor as device may have changed + self.overrides["device"] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0' + return self + + @property + def names(self) -> Dict[int, str]: + """ + Retrieves the class names associated with the loaded model. + + This property returns the class names if they are defined in the model. It checks the class names for validity + using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not + initialized, it sets it up before retrieving the names. + + Returns: + (Dict[int, str]): A dict of class names associated with the model. + + Raises: + AttributeError: If the model or predictor does not have a 'names' attribute. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.names) + {0: 'person', 1: 'bicycle', 2: 'car', ...} + """ + from ultralytics.nn.autobackend import check_class_names + + if hasattr(self.model, "names"): + return check_class_names(self.model.names) + if not self.predictor: # export formats will not have predictor defined until predict() is called + self.predictor = self._smart_load("predictor")(overrides=self.overrides, _callbacks=self.callbacks) + self.predictor.setup_model(model=self.model, verbose=False) + return self.predictor.model.names + + @property + def device(self) -> torch.device: + """ + Retrieves the device on which the model's parameters are allocated. + + This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is + applicable only to models that are instances of nn.Module. + + Returns: + (torch.device): The device (CPU/GPU) of the model. + + Raises: + AttributeError: If the model is not a PyTorch nn.Module instance. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.device) + device(type='cuda', index=0) # if CUDA is available + >>> model = model.to("cpu") + >>> print(model.device) + device(type='cpu') + """ + return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None + + @property + def transforms(self): + """ + Retrieves the transformations applied to the input data of the loaded model. + + This property returns the transformations if they are defined in the model. The transforms + typically include preprocessing steps like resizing, normalization, and data augmentation + that are applied to input data before it is fed into the model. + + Returns: + (object | None): The transform object of the model if available, otherwise None. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> transforms = model.transforms + >>> if transforms: + ... print(f"Model transforms: {transforms}") + ... else: + ... print("No transforms defined for this model.") + """ + return self.model.transforms if hasattr(self.model, "transforms") else None + + def add_callback(self, event: str, func) -> None: + """ + Adds a callback function for a specified event. + + This method allows registering custom callback functions that are triggered on specific events during + model operations such as training or inference. Callbacks provide a way to extend and customize the + behavior of the model at various stages of its lifecycle. + + Args: + event (str): The name of the event to attach the callback to. Must be a valid event name recognized + by the Ultralytics framework. + func (Callable): The callback function to be registered. This function will be called when the + specified event occurs. + + Raises: + ValueError: If the event name is not recognized or is invalid. + + Examples: + >>> def on_train_start(trainer): + ... print("Training is starting!") + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", on_train_start) + >>> model.train(data="coco8.yaml", epochs=1) + """ + self.callbacks[event].append(func) + + def clear_callback(self, event: str) -> None: + """ + Clears all callback functions registered for a specified event. + + This method removes all custom and default callback functions associated with the given event. + It resets the callback list for the specified event to an empty list, effectively removing all + registered callbacks for that event. + + Args: + event (str): The name of the event for which to clear the callbacks. This should be a valid event name + recognized by the Ultralytics callback system. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", lambda: print("Training started")) + >>> model.clear_callback("on_train_start") + >>> # All callbacks for 'on_train_start' are now removed + + Notes: + - This method affects both custom callbacks added by the user and default callbacks + provided by the Ultralytics framework. + - After calling this method, no callbacks will be executed for the specified event + until new ones are added. + - Use with caution as it removes all callbacks, including essential ones that might + be required for proper functioning of certain operations. + """ + self.callbacks[event] = [] + + def reset_callbacks(self) -> None: + """ + Resets all callbacks to their default functions. + + This method reinstates the default callback functions for all events, removing any custom callbacks that were + previously added. It iterates through all default callback events and replaces the current callbacks with the + default ones. + + The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined + functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc. + + This method is useful when you want to revert to the original set of callbacks after making custom + modifications, ensuring consistent behavior across different runs or experiments. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> model.add_callback("on_train_start", custom_function) + >>> model.reset_callbacks() + # All callbacks are now reset to their default functions + """ + for event in callbacks.default_callbacks.keys(): + self.callbacks[event] = [callbacks.default_callbacks[event][0]] + + @staticmethod + def _reset_ckpt_args(args: dict) -> dict: + """ + Resets specific arguments when loading a PyTorch model checkpoint. + + This static method filters the input arguments dictionary to retain only a specific set of keys that are + considered important for model loading. It's used to ensure that only relevant arguments are preserved + when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings. + + Args: + args (dict): A dictionary containing various model arguments and settings. + + Returns: + (dict): A new dictionary containing only the specified include keys from the input arguments. + + Examples: + >>> original_args = {"imgsz": 640, "data": "coco.yaml", "task": "detect", "batch": 16, "epochs": 100} + >>> reset_args = Model._reset_ckpt_args(original_args) + >>> print(reset_args) + {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'} + """ + include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model + return {k: v for k, v in args.items() if k in include} + + # def __getattr__(self, attr): + # """Raises error if object has no requested attribute.""" + # name = self.__class__.__name__ + # raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def _smart_load(self, key: str): + """ + Loads the appropriate module based on the model task. + + This method dynamically selects and returns the correct module (model, trainer, validator, or predictor) + based on the current task of the model and the provided key. It uses the task_map attribute to determine + the correct module to load. + + Args: + key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'. + + Returns: + (object): The loaded module corresponding to the specified key and current task. + + Raises: + NotImplementedError: If the specified key is not supported for the current task. + + Examples: + >>> model = Model(task="detect") + >>> predictor = model._smart_load("predictor") + >>> trainer = model._smart_load("trainer") + + Notes: + - This method is typically used internally by other methods of the Model class. + - The task_map attribute should be properly initialized with the correct mappings for each task. + """ + try: + return self.task_map[self.task][key] + except Exception as e: + name = self.__class__.__name__ + mode = inspect.stack()[1][3] # get the function name. + raise NotImplementedError( + emojis(f"WARNING ⚠️ '{name}' model does not support '{mode}' mode for '{self.task}' task yet.") + ) from e + + @property + def task_map(self) -> dict: + """ + Provides a mapping from model tasks to corresponding classes for different modes. + + This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify) + to a nested dictionary. The nested dictionary contains mappings for different operational modes + (model, trainer, validator, predictor) to their respective class implementations. + + The mapping allows for dynamic loading of appropriate classes based on the model's task and the + desired operational mode. This facilitates a flexible and extensible architecture for handling + various tasks and modes within the Ultralytics framework. + + Returns: + (Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are + nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and + 'predictor', mapping to their respective class implementations. + + Examples: + >>> model = Model() + >>> task_map = model.task_map + >>> detect_class_map = task_map["detect"] + >>> segment_class_map = task_map["segment"] + + Note: + The actual implementation of this method may vary depending on the specific tasks and + classes supported by the Ultralytics framework. The docstring provides a general + description of the expected behavior and structure. + """ + raise NotImplementedError("Please provide task map for your model!") + + def eval(self): + """ + Sets the model to evaluation mode. + + This method changes the model's mode to evaluation, which affects layers like dropout and batch normalization + that behave differently during training and evaluation. + + Returns: + (Model): The model instance with evaluation mode set. + + Examples: + >> model = YOLO("yolo11n.pt") + >> model.eval() + """ + self.model.eval() + return self + + def __getattr__(self, name): + """ + Enables accessing model attributes directly through the Model class. + + This method provides a way to access attributes of the underlying model directly through the Model class + instance. It first checks if the requested attribute is 'model', in which case it returns the model from + the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model. + + Args: + name (str): The name of the attribute to retrieve. + + Returns: + (Any): The requested attribute value. + + Raises: + AttributeError: If the requested attribute does not exist in the model. + + Examples: + >>> model = YOLO("yolo11n.pt") + >>> print(model.stride) + >>> print(model.task) + """ + return self._modules["model"] if name == "model" else getattr(self.model, name) diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..d10fce5ab8123c1dc0f4a5f3c11f73802339d7ba --- /dev/null +++ b/ultralytics/engine/predictor.py @@ -0,0 +1,408 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc. + +Usage - sources: + $ yolo mode=predict model=yolov8n.pt source=0 # webcam + img.jpg # image + vid.mp4 # video + screen # screenshot + path/ # directory + list.txt # list of images + list.streams # list of streams + 'path/*.jpg' # glob + 'https://youtu.be/LNwODJXcvt4' # YouTube + 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream + +Usage - formats: + $ yolo mode=predict model=yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolov8n_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlpackage # CoreML (macOS-only) + yolov8n_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov8n_edgetpu.tflite # TensorFlow Edge TPU + yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN + yolov8n_ncnn_model # NCNN +""" + +import platform +import re +import threading +from pathlib import Path + +import cv2 +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data import load_inference_source +from ultralytics.data.augment import LetterBox, classify_transforms +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops +from ultralytics.utils.checks import check_imgsz, check_imshow +from ultralytics.utils.files import increment_path +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +STREAM_WARNING = """ +WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory +errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help. + +Example: + results = model(source=..., stream=True) # generator of Results objects + for r in results: + boxes = r.boxes # Boxes object for bbox outputs + masks = r.masks # Masks object for segment masks outputs + probs = r.probs # Class probabilities for classification outputs +""" + + +class BasePredictor: + """ + BasePredictor. + + A base class for creating predictors. + + Attributes: + args (SimpleNamespace): Configuration for the predictor. + save_dir (Path): Directory to save results. + done_warmup (bool): Whether the predictor has finished setup. + model (nn.Module): Model used for prediction. + data (dict): Data configuration. + device (torch.device): Device used for prediction. + dataset (Dataset): Dataset used for prediction. + vid_writer (dict): Dictionary of {save_path: video_writer, ...} writer for saving video output. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the BasePredictor class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.save_dir = get_save_dir(self.args) + if self.args.conf is None: + self.args.conf = 0.25 # default conf=0.25 + self.done_warmup = False + if self.args.show: + self.args.show = check_imshow(warn=True) + + # Usable if setup is done + self.model = None + self.data = self.args.data # data_dict + self.imgsz = None + self.device = None + self.dataset = None + self.vid_writer = {} # dict of {save_path: video_writer, ...} + self.plotted_img = None + self.source_type = None + self.seen = 0 + self.windows = [] + self.batch = None + self.results = None + self.transforms = None + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.txt_path = None + self._lock = threading.Lock() # for automatic thread-safe inference + callbacks.add_integration_callbacks(self) + + def preprocess(self, im): + """ + Prepares input image before inference. + + Args: + im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list. + """ + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW, (n, 3, h, w) + im = np.ascontiguousarray(im) # contiguous + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 + if not_tensor: + im /= 255 # 0 - 255 to 0.0 - 1.0 + return im + + def inference(self, im, *args, **kwargs): + """Runs inference on a given image using the specified model and arguments.""" + visualize = ( + increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True) + if self.args.visualize and (not self.source_type.tensor) + else False + ) + return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs) + + def pre_transform(self, im): + """ + Pre-transform input image before inference. + + Args: + im (List(np.ndarray)): (N, 3, h, w) for tensor, [(h, w, 3) x N] for list. + + Returns: + (list): A list of transformed images. + """ + same_shapes = len({x.shape for x in im}) == 1 + letterbox = LetterBox( + self.imgsz, + auto=same_shapes and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)), + stride=self.model.stride, + ) + return [letterbox(image=x) for x in im] + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions for an image and returns them.""" + return preds + + def __call__(self, source=None, model=None, stream=False, *args, **kwargs): + """Performs inference on an image or stream.""" + self.stream = stream + if stream: + return self.stream_inference(source, model, *args, **kwargs) + else: + return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one + + def predict_cli(self, source=None, model=None): + """ + Method used for Command Line Interface (CLI) prediction. + + This function is designed to run predictions using the CLI. It sets up the source and model, then processes + the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the + generator without storing results. + + Note: + Do not modify this function or remove the generator. The generator ensures that no outputs are + accumulated in memory, which is critical for preventing memory issues during long-running predictions. + """ + gen = self.stream_inference(source, model) + for _ in gen: # sourcery skip: remove-empty-nested-block, noqa + pass + + def setup_source(self, source): + """Sets up source and inference mode.""" + self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.transforms = ( + getattr( + self.model.model, + "transforms", + classify_transforms(self.imgsz[0], crop_fraction=self.args.crop_fraction), + ) + if self.args.task == "classify" + else None + ) + self.dataset = load_inference_source( + source=source, + batch=self.args.batch, + vid_stride=self.args.vid_stride, + buffer=self.args.stream_buffer, + ) + self.source_type = self.dataset.source_type + if not getattr(self, "stream", True) and ( + self.source_type.stream + or self.source_type.screenshot + or len(self.dataset) > 1000 # many images + or any(getattr(self.dataset, "video_flag", [False])) + ): # videos + LOGGER.warning(STREAM_WARNING) + self.vid_writer = {} + + @smart_inference_mode() + def stream_inference(self, source=None, model=None, *args, **kwargs): + """Streams real-time inference on camera feed and saves results to file.""" + if self.args.verbose: + LOGGER.info("") + + # Setup model + if not self.model: + self.setup_model(model) + + with self._lock: # for thread-safe inference + # Setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) + + # Check if save_dir/ label file exists + if self.args.save or self.args.save_txt: + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + + # Warmup model + if not self.done_warmup: + self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz)) + self.done_warmup = True + + self.seen, self.windows, self.batch = 0, [], None + profilers = ( + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ops.Profile(device=self.device), + ) + self.run_callbacks("on_predict_start") + for self.batch in self.dataset: + self.run_callbacks("on_predict_batch_start") + paths, im0s, s = self.batch + + # Preprocess + with profilers[0]: + im = self.preprocess(im0s) + + # Inference + with profilers[1]: + preds = self.inference(im, *args, **kwargs) + if self.args.embed: + yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors + continue + + # Postprocess + with profilers[2]: + self.results = self.postprocess(preds, im, im0s) + self.run_callbacks("on_predict_postprocess_end") + + # Visualize, save, write results + n = len(im0s) + for i in range(n): + self.seen += 1 + self.results[i].speed = { + "preprocess": profilers[0].dt * 1e3 / n, + "inference": profilers[1].dt * 1e3 / n, + "postprocess": profilers[2].dt * 1e3 / n, + } + if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: + s[i] += self.write_results(i, Path(paths[i]), im, s) + + # Print batch results + if self.args.verbose: + LOGGER.info("\n".join(s)) + + self.run_callbacks("on_predict_batch_end") + yield from self.results + + # Release assets + for v in self.vid_writer.values(): + if isinstance(v, cv2.VideoWriter): + v.release() + + # Print final results + if self.args.verbose and self.seen: + t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image + LOGGER.info( + f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape " + f"{(min(self.args.batch, self.seen), 3, *im.shape[2:])}" % t + ) + if self.args.save or self.args.save_txt or self.args.save_crop: + nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels + s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else "" + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}") + self.run_callbacks("on_predict_end") + + def setup_model(self, model, verbose=True): + """Initialize YOLO model with given parameters and set it to evaluation mode.""" + self.model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, verbose=verbose), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + batch=self.args.batch, + fuse=True, + verbose=verbose, + ) + + self.device = self.model.device # update device + self.args.half = self.model.fp16 # update half + self.model.eval() + + def write_results(self, i, p, im, s): + """Write inference results to a file or directory.""" + string = "" # print string + if len(im.shape) == 3: + im = im[None] # expand for batch dim + if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1 + string += f"{i}: " + frame = self.dataset.count + else: + match = re.search(r"frame (\d+)/", s[i]) + frame = int(match[1]) if match else None # 0 if frame undetermined + + self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}")) + string += "{:g}x{:g} ".format(*im.shape[2:]) + result = self.results[i] + result.save_dir = self.save_dir.__str__() # used in other locations + string += f"{result.verbose()}{result.speed['inference']:.1f}ms" + + # Add predictions to image + if self.args.save or self.args.show: + self.plotted_img = result.plot( + line_width=self.args.line_width, + boxes=self.args.show_boxes, + conf=self.args.show_conf, + labels=self.args.show_labels, + im_gpu=None if self.args.retina_masks else im[i], + ) + + # Save results + if self.args.save_txt: + result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf) + if self.args.save_crop: + result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem) + if self.args.show: + self.show(str(p)) + if self.args.save: + self.save_predicted_images(str(self.save_dir / p.name), frame) + + return string + + def save_predicted_images(self, save_path="", frame=0): + """Save video predictions as mp4 at specified path.""" + im = self.plotted_img + + # Save videos and streams + if self.dataset.mode in {"stream", "video"}: + fps = self.dataset.fps if self.dataset.mode == "video" else 30 + frames_path = f"{save_path.split('.', 1)[0]}_frames/" + if save_path not in self.vid_writer: # new video + if self.args.save_frames: + Path(frames_path).mkdir(parents=True, exist_ok=True) + suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG") + self.vid_writer[save_path] = cv2.VideoWriter( + filename=str(Path(save_path).with_suffix(suffix)), + fourcc=cv2.VideoWriter_fourcc(*fourcc), + fps=fps, # integer required, floats produce error in MP4 codec + frameSize=(im.shape[1], im.shape[0]), # (width, height) + ) + + # Save video + self.vid_writer[save_path].write(im) + if self.args.save_frames: + cv2.imwrite(f"{frames_path}{frame}.jpg", im) + + # Save images + else: + cv2.imwrite(str(Path(save_path).with_suffix(".jpg")), im) # save to JPG for best support + + def show(self, p=""): + """Display an image in a window using the OpenCV imshow function.""" + im = self.plotted_img + if platform.system() == "Linux" and p not in self.windows: + self.windows.append(p) + cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux) + cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height) + cv2.imshow(p, im) + cv2.waitKey(300 if self.dataset.mode == "image" else 1) # 1 millisecond + + def run_callbacks(self, event: str): + """Runs all registered callbacks for a specific event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def add_callback(self, event: str, func): + """Add callback.""" + self.callbacks[event].append(func) diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py new file mode 100644 index 0000000000000000000000000000000000000000..5ef0b7784b60d6ead8595f407c4ce2b7624fc436 --- /dev/null +++ b/ultralytics/engine/results.py @@ -0,0 +1,1740 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Ultralytics Results, Boxes and Masks classes for handling inference results. + +Usage: See https://docs.ultralytics.com/modes/predict/ +""" + +from copy import deepcopy +from functools import lru_cache +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.data.augment import LetterBox +from ultralytics.utils import LOGGER, SimpleClass, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.plotting import Annotator, colors, save_one_box +from ultralytics.utils.torch_utils import smart_inference_mode + + +class BaseTensor(SimpleClass): + """ + Base tensor class with additional methods for easy manipulation and device handling. + + Attributes: + data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints. + orig_shape (Tuple[int, int]): Original shape of the image, typically in the format (height, width). + + Methods: + cpu: Return a copy of the tensor stored in CPU memory. + numpy: Returns a copy of the tensor as a numpy array. + cuda: Moves the tensor to GPU memory, returning a new instance if necessary. + to: Return a copy of the tensor with the specified device and dtype. + + Examples: + >>> import torch + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> orig_shape = (720, 1280) + >>> base_tensor = BaseTensor(data, orig_shape) + >>> cpu_tensor = base_tensor.cpu() + >>> numpy_array = base_tensor.numpy() + >>> gpu_tensor = base_tensor.cuda() + """ + + def __init__(self, data, orig_shape) -> None: + """ + Initialize BaseTensor with prediction data and the original shape of the image. + + Args: + data (torch.Tensor | np.ndarray): Prediction data such as bounding boxes, masks, or keypoints. + orig_shape (Tuple[int, int]): Original shape of the image in (height, width) format. + + Examples: + >>> import torch + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> orig_shape = (720, 1280) + >>> base_tensor = BaseTensor(data, orig_shape) + """ + assert isinstance(data, (torch.Tensor, np.ndarray)), "data must be torch.Tensor or np.ndarray" + self.data = data + self.orig_shape = orig_shape + + @property + def shape(self): + """ + Returns the shape of the underlying data tensor. + + Returns: + (Tuple[int, ...]): The shape of the data tensor. + + Examples: + >>> data = torch.rand(100, 4) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> print(base_tensor.shape) + (100, 4) + """ + return self.data.shape + + def cpu(self): + """ + Returns a copy of the tensor stored in CPU memory. + + Returns: + (BaseTensor): A new BaseTensor object with the data tensor moved to CPU memory. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]).cuda() + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> cpu_tensor = base_tensor.cpu() + >>> isinstance(cpu_tensor, BaseTensor) + True + >>> cpu_tensor.data.device + device(type='cpu') + """ + return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.cpu(), self.orig_shape) + + def numpy(self): + """ + Returns a copy of the tensor as a numpy array. + + Returns: + (np.ndarray): A numpy array containing the same data as the original tensor. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> orig_shape = (720, 1280) + >>> base_tensor = BaseTensor(data, orig_shape) + >>> numpy_array = base_tensor.numpy() + >>> print(type(numpy_array)) + + """ + return self if isinstance(self.data, np.ndarray) else self.__class__(self.data.numpy(), self.orig_shape) + + def cuda(self): + """ + Moves the tensor to GPU memory. + + Returns: + (BaseTensor): A new BaseTensor instance with the data moved to GPU memory if it's not already a + numpy array, otherwise returns self. + + Examples: + >>> import torch + >>> from ultralytics.engine.results import BaseTensor + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> gpu_tensor = base_tensor.cuda() + >>> print(gpu_tensor.data.device) + cuda:0 + """ + return self.__class__(torch.as_tensor(self.data).cuda(), self.orig_shape) + + def to(self, *args, **kwargs): + """ + Return a copy of the tensor with the specified device and dtype. + + Args: + *args (Any): Variable length argument list to be passed to torch.Tensor.to(). + **kwargs (Any): Arbitrary keyword arguments to be passed to torch.Tensor.to(). + + Returns: + (BaseTensor): A new BaseTensor instance with the data moved to the specified device and/or dtype. + + Examples: + >>> base_tensor = BaseTensor(torch.randn(3, 4), orig_shape=(480, 640)) + >>> cuda_tensor = base_tensor.to("cuda") + >>> float16_tensor = base_tensor.to(dtype=torch.float16) + """ + return self.__class__(torch.as_tensor(self.data).to(*args, **kwargs), self.orig_shape) + + def __len__(self): # override len(results) + """ + Returns the length of the underlying data tensor. + + Returns: + (int): The number of elements in the first dimension of the data tensor. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> len(base_tensor) + 2 + """ + return len(self.data) + + def __getitem__(self, idx): + """ + Returns a new BaseTensor instance containing the specified indexed elements of the data tensor. + + Args: + idx (int | List[int] | torch.Tensor): Index or indices to select from the data tensor. + + Returns: + (BaseTensor): A new BaseTensor instance containing the indexed data. + + Examples: + >>> data = torch.tensor([[1, 2, 3], [4, 5, 6]]) + >>> base_tensor = BaseTensor(data, orig_shape=(720, 1280)) + >>> result = base_tensor[0] # Select the first row + >>> print(result.data) + tensor([1, 2, 3]) + """ + return self.__class__(self.data[idx], self.orig_shape) + + +class Results(SimpleClass): + """ + A class for storing and manipulating inference results. + + This class encapsulates the functionality for handling detection, segmentation, pose estimation, + and classification results from YOLO models. + + Attributes: + orig_img (numpy.ndarray): Original image as a numpy array. + orig_shape (Tuple[int, int]): Original image shape in (height, width) format. + boxes (Boxes | None): Object containing detection bounding boxes. + masks (Masks | None): Object containing detection masks. + probs (Probs | None): Object containing class probabilities for classification tasks. + keypoints (Keypoints | None): Object containing detected keypoints for each object. + obb (OBB | None): Object containing oriented bounding boxes. + speed (Dict[str, float | None]): Dictionary of preprocess, inference, and postprocess speeds. + names (Dict[int, str]): Dictionary mapping class IDs to class names. + path (str): Path to the image file. + _keys (Tuple[str, ...]): Tuple of attribute names for internal use. + + Methods: + update: Updates object attributes with new detection results. + cpu: Returns a copy of the Results object with all tensors on CPU memory. + numpy: Returns a copy of the Results object with all tensors as numpy arrays. + cuda: Returns a copy of the Results object with all tensors on GPU memory. + to: Returns a copy of the Results object with tensors on a specified device and dtype. + new: Returns a new Results object with the same image, path, and names. + plot: Plots detection results on an input image, returning an annotated image. + show: Shows annotated results on screen. + save: Saves annotated results to file. + verbose: Returns a log string for each task, detailing detections and classifications. + save_txt: Saves detection results to a text file. + save_crop: Saves cropped detection images. + tojson: Converts detection results to JSON format. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... print(result.boxes) # Print detection boxes + ... result.show() # Display the annotated image + ... result.save(filename="result.jpg") # Save annotated image + """ + + def __init__( + self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None, obb=None, speed=None + ) -> None: + """ + Initialize the Results class for storing and manipulating inference results. + + Args: + orig_img (numpy.ndarray): The original image as a numpy array. + path (str): The path to the image file. + names (Dict): A dictionary of class names. + boxes (torch.Tensor | None): A 2D tensor of bounding box coordinates for each detection. + masks (torch.Tensor | None): A 3D tensor of detection masks, where each mask is a binary image. + probs (torch.Tensor | None): A 1D tensor of probabilities of each class for classification task. + keypoints (torch.Tensor | None): A 2D tensor of keypoint coordinates for each detection. + obb (torch.Tensor | None): A 2D tensor of oriented bounding box coordinates for each detection. + speed (Dict | None): A dictionary containing preprocess, inference, and postprocess speeds (ms/image). + + Examples: + >>> results = model("path/to/image.jpg") + >>> result = results[0] # Get the first result + >>> boxes = result.boxes # Get the boxes for the first result + >>> masks = result.masks # Get the masks for the first result + + Notes: + For the default pose model, keypoint indices for human body pose estimation are: + 0: Nose, 1: Left Eye, 2: Right Eye, 3: Left Ear, 4: Right Ear + 5: Left Shoulder, 6: Right Shoulder, 7: Left Elbow, 8: Right Elbow + 9: Left Wrist, 10: Right Wrist, 11: Left Hip, 12: Right Hip + 13: Left Knee, 14: Right Knee, 15: Left Ankle, 16: Right Ankle + """ + self.orig_img = orig_img + self.orig_shape = orig_img.shape[:2] + self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes + self.masks = Masks(masks, self.orig_shape) if masks is not None else None # native size or imgsz masks + self.probs = Probs(probs) if probs is not None else None + self.keypoints = Keypoints(keypoints, self.orig_shape) if keypoints is not None else None + self.obb = OBB(obb, self.orig_shape) if obb is not None else None + self.speed = speed if speed is not None else {"preprocess": None, "inference": None, "postprocess": None} + self.names = names + self.path = path + self.save_dir = None + self._keys = "boxes", "masks", "probs", "keypoints", "obb" + + def __getitem__(self, idx): + """ + Return a Results object for a specific index of inference results. + + Args: + idx (int | slice): Index or slice to retrieve from the Results object. + + Returns: + (Results): A new Results object containing the specified subset of inference results. + + Examples: + >>> results = model("path/to/image.jpg") # Perform inference + >>> single_result = results[0] # Get the first result + >>> subset_results = results[1:4] # Get a slice of results + """ + return self._apply("__getitem__", idx) + + def __len__(self): + """ + Return the number of detections in the Results object. + + Returns: + (int): The number of detections, determined by the length of the first non-empty attribute + (boxes, masks, probs, keypoints, or obb). + + Examples: + >>> results = Results(orig_img, path, names, boxes=torch.rand(5, 4)) + >>> len(results) + 5 + """ + for k in self._keys: + v = getattr(self, k) + if v is not None: + return len(v) + + def update(self, boxes=None, masks=None, probs=None, obb=None): + """ + Updates the Results object with new detection data. + + This method allows updating the boxes, masks, probabilities, and oriented bounding boxes (OBB) of the + Results object. It ensures that boxes are clipped to the original image shape. + + Args: + boxes (torch.Tensor | None): A tensor of shape (N, 6) containing bounding box coordinates and + confidence scores. The format is (x1, y1, x2, y2, conf, class). + masks (torch.Tensor | None): A tensor of shape (N, H, W) containing segmentation masks. + probs (torch.Tensor | None): A tensor of shape (num_classes,) containing class probabilities. + obb (torch.Tensor | None): A tensor of shape (N, 5) containing oriented bounding box coordinates. + + Examples: + >>> results = model("image.jpg") + >>> new_boxes = torch.tensor([[100, 100, 200, 200, 0.9, 0]]) + >>> results[0].update(boxes=new_boxes) + """ + if boxes is not None: + self.boxes = Boxes(ops.clip_boxes(boxes, self.orig_shape), self.orig_shape) + if masks is not None: + self.masks = Masks(masks, self.orig_shape) + if probs is not None: + self.probs = probs + if obb is not None: + self.obb = OBB(obb, self.orig_shape) + + def _apply(self, fn, *args, **kwargs): + """ + Applies a function to all non-empty attributes and returns a new Results object with modified attributes. + + This method is internally called by methods like .to(), .cuda(), .cpu(), etc. + + Args: + fn (str): The name of the function to apply. + *args (Any): Variable length argument list to pass to the function. + **kwargs (Any): Arbitrary keyword arguments to pass to the function. + + Returns: + (Results): A new Results object with attributes modified by the applied function. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... result_cuda = result.cuda() + ... result_cpu = result.cpu() + """ + r = self.new() + for k in self._keys: + v = getattr(self, k) + if v is not None: + setattr(r, k, getattr(v, fn)(*args, **kwargs)) + return r + + def cpu(self): + """ + Returns a copy of the Results object with all its tensors moved to CPU memory. + + This method creates a new Results object with all tensor attributes (boxes, masks, probs, keypoints, obb) + transferred to CPU memory. It's useful for moving data from GPU to CPU for further processing or saving. + + Returns: + (Results): A new Results object with all tensor attributes on CPU memory. + + Examples: + >>> results = model("path/to/image.jpg") # Perform inference + >>> cpu_result = results[0].cpu() # Move the first result to CPU + >>> print(cpu_result.boxes.device) # Output: cpu + """ + return self._apply("cpu") + + def numpy(self): + """ + Converts all tensors in the Results object to numpy arrays. + + Returns: + (Results): A new Results object with all tensors converted to numpy arrays. + + Examples: + >>> results = model("path/to/image.jpg") + >>> numpy_result = results[0].numpy() + >>> type(numpy_result.boxes.data) + + + Notes: + This method creates a new Results object, leaving the original unchanged. It's useful for + interoperability with numpy-based libraries or when CPU-based operations are required. + """ + return self._apply("numpy") + + def cuda(self): + """ + Moves all tensors in the Results object to GPU memory. + + Returns: + (Results): A new Results object with all tensors moved to CUDA device. + + Examples: + >>> results = model("path/to/image.jpg") + >>> cuda_results = results[0].cuda() # Move first result to GPU + >>> for result in results: + ... result_cuda = result.cuda() # Move each result to GPU + """ + return self._apply("cuda") + + def to(self, *args, **kwargs): + """ + Moves all tensors in the Results object to the specified device and dtype. + + Args: + *args (Any): Variable length argument list to be passed to torch.Tensor.to(). + **kwargs (Any): Arbitrary keyword arguments to be passed to torch.Tensor.to(). + + Returns: + (Results): A new Results object with all tensors moved to the specified device and dtype. + + Examples: + >>> results = model("path/to/image.jpg") + >>> result_cuda = results[0].to("cuda") # Move first result to GPU + >>> result_cpu = results[0].to("cpu") # Move first result to CPU + >>> result_half = results[0].to(dtype=torch.float16) # Convert first result to half precision + """ + return self._apply("to", *args, **kwargs) + + def new(self): + """ + Creates a new Results object with the same image, path, names, and speed attributes. + + Returns: + (Results): A new Results object with copied attributes from the original instance. + + Examples: + >>> results = model("path/to/image.jpg") + >>> new_result = results[0].new() + """ + return Results(orig_img=self.orig_img, path=self.path, names=self.names, speed=self.speed) + + def plot( + self, + conf=True, + line_width=None, + font_size=None, + font="Arial.ttf", + pil=False, + img=None, + im_gpu=None, + kpt_radius=5, + kpt_line=True, + labels=True, + boxes=True, + masks=True, + probs=True, + show=False, + save=False, + filename=None, + color_mode="class", + ): + """ + Plots detection results on an input RGB image. + + Args: + conf (bool): Whether to plot detection confidence scores. + line_width (float | None): Line width of bounding boxes. If None, scaled to image size. + font_size (float | None): Font size for text. If None, scaled to image size. + font (str): Font to use for text. + pil (bool): Whether to return the image as a PIL Image. + img (np.ndarray | None): Image to plot on. If None, uses original image. + im_gpu (torch.Tensor | None): Normalized image on GPU for faster mask plotting. + kpt_radius (int): Radius of drawn keypoints. + kpt_line (bool): Whether to draw lines connecting keypoints. + labels (bool): Whether to plot labels of bounding boxes. + boxes (bool): Whether to plot bounding boxes. + masks (bool): Whether to plot masks. + probs (bool): Whether to plot classification probabilities. + show (bool): Whether to display the annotated image. + save (bool): Whether to save the annotated image. + filename (str | None): Filename to save image if save is True. + color_mode (bool): Specify the color mode, e.g., 'instance' or 'class'. Default to 'class'. + + Returns: + (np.ndarray): Annotated image as a numpy array. + + Examples: + >>> results = model("image.jpg") + >>> for result in results: + ... im = result.plot() + ... im.show() + """ + assert color_mode in {"instance", "class"}, f"Expected color_mode='instance' or 'class', not {color_mode}." + if img is None and isinstance(self.orig_img, torch.Tensor): + img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).to(torch.uint8).cpu().numpy() + + names = self.names + is_obb = self.obb is not None + pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes + pred_masks, show_masks = self.masks, masks + pred_probs, show_probs = self.probs, probs + annotator = Annotator( + deepcopy(self.orig_img if img is None else img), + line_width, + font_size, + font, + pil or (pred_probs is not None and show_probs), # Classify tasks default to pil=True + example=names, + ) + + # Plot Segment results + if pred_masks and show_masks: + if im_gpu is None: + img = LetterBox(pred_masks.shape[1:])(image=annotator.result()) + im_gpu = ( + torch.as_tensor(img, dtype=torch.float16, device=pred_masks.data.device) + .permute(2, 0, 1) + .flip(0) + .contiguous() + / 255 + ) + idx = ( + pred_boxes.id + if pred_boxes.id is not None and color_mode == "instance" + else pred_boxes.cls + if pred_boxes and color_mode == "class" + else reversed(range(len(pred_masks))) + ) + annotator.masks(pred_masks.data, colors=[colors(x, True) for x in idx], im_gpu=im_gpu) + + # Plot Detect results + if pred_boxes is not None and show_boxes: + for i, d in enumerate(reversed(pred_boxes)): + c, d_conf, id = int(d.cls), float(d.conf) if conf else None, None if d.id is None else int(d.id.item()) + name = ("" if id is None else f"id:{id} ") + names[c] + label = (f"{name} {d_conf:.2f}" if conf else name) if labels else None + box = d.xyxyxyxy.reshape(-1, 4, 2).squeeze() if is_obb else d.xyxy.squeeze() + annotator.box_label( + box, + label, + color=colors( + c + if color_mode == "class" + else id + if id is not None + else i + if color_mode == "instance" + else None, + True, + ), + rotated=is_obb, + ) + + # Plot Classify results + if pred_probs is not None and show_probs: + text = ",\n".join(f"{names[j] if names else j} {pred_probs.data[j]:.2f}" for j in pred_probs.top5) + x = round(self.orig_shape[0] * 0.03) + annotator.text([x, x], text, txt_color=(255, 255, 255)) # TODO: allow setting colors + + # Plot Pose results + if self.keypoints is not None: + for i, k in enumerate(reversed(self.keypoints.data)): + annotator.kpts( + k, + self.orig_shape, + radius=kpt_radius, + kpt_line=kpt_line, + kpt_color=colors(i, True) if color_mode == "instance" else None, + ) + + # Show results + if show: + annotator.show(self.path) + + # Save results + if save: + annotator.save(filename) + + return annotator.result() + + def show(self, *args, **kwargs): + """ + Display the image with annotated inference results. + + This method plots the detection results on the original image and displays it. It's a convenient way to + visualize the model's predictions directly. + + Args: + *args (Any): Variable length argument list to be passed to the `plot()` method. + **kwargs (Any): Arbitrary keyword arguments to be passed to the `plot()` method. + + Examples: + >>> results = model("path/to/image.jpg") + >>> results[0].show() # Display the first result + >>> for result in results: + ... result.show() # Display all results + """ + self.plot(show=True, *args, **kwargs) + + def save(self, filename=None, *args, **kwargs): + """ + Saves annotated inference results image to file. + + This method plots the detection results on the original image and saves the annotated image to a file. It + utilizes the `plot` method to generate the annotated image and then saves it to the specified filename. + + Args: + filename (str | Path | None): The filename to save the annotated image. If None, a default filename + is generated based on the original image path. + *args (Any): Variable length argument list to be passed to the `plot` method. + **kwargs (Any): Arbitrary keyword arguments to be passed to the `plot` method. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... result.save("annotated_image.jpg") + >>> # Or with custom plot arguments + >>> for result in results: + ... result.save("annotated_image.jpg", conf=False, line_width=2) + """ + if not filename: + filename = f"results_{Path(self.path).name}" + self.plot(save=True, filename=filename, *args, **kwargs) + return filename + + def verbose(self): + """ + Returns a log string for each task in the results, detailing detection and classification outcomes. + + This method generates a human-readable string summarizing the detection and classification results. It includes + the number of detections for each class and the top probabilities for classification tasks. + + Returns: + (str): A formatted string containing a summary of the results. For detection tasks, it includes the + number of detections per class. For classification tasks, it includes the top 5 class probabilities. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... print(result.verbose()) + 2 persons, 1 car, 3 traffic lights, + dog 0.92, cat 0.78, horse 0.64, + + Notes: + - If there are no detections, the method returns "(no detections), " for detection tasks. + - For classification tasks, it returns the top 5 class probabilities and their corresponding class names. + - The returned string is comma-separated and ends with a comma and a space. + """ + log_string = "" + probs = self.probs + if len(self) == 0: + return log_string if probs is not None else f"{log_string}(no detections), " + if probs is not None: + log_string += f"{', '.join(f'{self.names[j]} {probs.data[j]:.2f}' for j in probs.top5)}, " + if boxes := self.boxes: + for c in boxes.cls.unique(): + n = (boxes.cls == c).sum() # detections per class + log_string += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " + return log_string + + def save_txt(self, txt_file, save_conf=False): + """ + Save detection results to a text file. + + Args: + txt_file (str | Path): Path to the output text file. + save_conf (bool): Whether to include confidence scores in the output. + + Returns: + (str): Path to the saved text file. + + Examples: + >>> from ultralytics import YOLO + >>> model = YOLO("yolo11n.pt") + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... result.save_txt("output.txt") + + Notes: + - The file will contain one line per detection or classification with the following structure: + - For detections: `class confidence x_center y_center width height` + - For classifications: `confidence class_name` + - For masks and keypoints, the specific formats will vary accordingly. + - The function will create the output directory if it does not exist. + - If save_conf is False, the confidence scores will be excluded from the output. + - Existing contents of the file will not be overwritten; new results will be appended. + """ + is_obb = self.obb is not None + boxes = self.obb if is_obb else self.boxes + masks = self.masks + probs = self.probs + kpts = self.keypoints + texts = [] + if probs is not None: + # Classify + [texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5] + elif boxes: + # Detect/segment/pose + for j, d in enumerate(boxes): + c, conf, id = int(d.cls), float(d.conf), None if d.id is None else int(d.id.item()) + line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1))) + if masks: + seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2) + line = (c, *seg) + if kpts is not None: + kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn + line += (*kpt.reshape(-1).tolist(),) + line += (conf,) * save_conf + (() if id is None else (id,)) + texts.append(("%g " * len(line)).rstrip() % line) + + if texts: + Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory + with open(txt_file, "a") as f: + f.writelines(text + "\n" for text in texts) + + def save_crop(self, save_dir, file_name=Path("im.jpg")): + """ + Saves cropped detection images to specified directory. + + This method saves cropped images of detected objects to a specified directory. Each crop is saved in a + subdirectory named after the object's class, with the filename based on the input file_name. + + Args: + save_dir (str | Path): Directory path where cropped images will be saved. + file_name (str | Path): Base filename for the saved cropped images. Default is Path("im.jpg"). + + Notes: + - This method does not support Classify or Oriented Bounding Box (OBB) tasks. + - Crops are saved as 'save_dir/class_name/file_name.jpg'. + - The method will create necessary subdirectories if they don't exist. + - Original image is copied before cropping to avoid modifying the original. + + Examples: + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... result.save_crop(save_dir="path/to/crops", file_name="detection") + """ + if self.probs is not None: + LOGGER.warning("WARNING ⚠️ Classify task do not support `save_crop`.") + return + if self.obb is not None: + LOGGER.warning("WARNING ⚠️ OBB task do not support `save_crop`.") + return + for d in self.boxes: + save_one_box( + d.xyxy, + self.orig_img.copy(), + file=Path(save_dir) / self.names[int(d.cls)] / Path(file_name).with_suffix(".jpg"), + BGR=True, + ) + + def summary(self, normalize=False, decimals=5): + """ + Converts inference results to a summarized dictionary with optional normalization for box coordinates. + + This method creates a list of detection dictionaries, each containing information about a single + detection or classification result. For classification tasks, it returns the top class and its + confidence. For detection tasks, it includes class information, bounding box coordinates, and + optionally mask segments and keypoints. + + Args: + normalize (bool): Whether to normalize bounding box coordinates by image dimensions. Defaults to False. + decimals (int): Number of decimal places to round the output values to. Defaults to 5. + + Returns: + (List[Dict]): A list of dictionaries, each containing summarized information for a single + detection or classification result. The structure of each dictionary varies based on the + task type (classification or detection) and available information (boxes, masks, keypoints). + + Examples: + >>> results = model("image.jpg") + >>> summary = results[0].summary() + >>> print(summary) + """ + # Create list of detection dictionaries + results = [] + if self.probs is not None: + class_id = self.probs.top1 + results.append( + { + "name": self.names[class_id], + "class": class_id, + "confidence": round(self.probs.top1conf.item(), decimals), + } + ) + return results + + is_obb = self.obb is not None + data = self.obb if is_obb else self.boxes + h, w = self.orig_shape if normalize else (1, 1) + for i, row in enumerate(data): # xyxy, track_id if tracking, conf, class_id + class_id, conf = int(row.cls), round(row.conf.item(), decimals) + box = (row.xyxyxyxy if is_obb else row.xyxy).squeeze().reshape(-1, 2).tolist() + xy = {} + for j, b in enumerate(box): + xy[f"x{j + 1}"] = round(b[0] / w, decimals) + xy[f"y{j + 1}"] = round(b[1] / h, decimals) + result = {"name": self.names[class_id], "class": class_id, "confidence": conf, "box": xy} + if data.is_track: + result["track_id"] = int(row.id.item()) # track ID + if self.masks: + result["segments"] = { + "x": (self.masks.xy[i][:, 0] / w).round(decimals).tolist(), + "y": (self.masks.xy[i][:, 1] / h).round(decimals).tolist(), + } + if self.keypoints is not None: + x, y, visible = self.keypoints[i].data[0].cpu().unbind(dim=1) # torch Tensor + result["keypoints"] = { + "x": (x / w).numpy().round(decimals).tolist(), # decimals named argument required + "y": (y / h).numpy().round(decimals).tolist(), + "visible": visible.numpy().round(decimals).tolist(), + } + results.append(result) + + return results + + def to_df(self, normalize=False, decimals=5): + """ + Converts detection results to a Pandas Dataframe. + + This method converts the detection results into Pandas Dataframe format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. Defaults to False. + decimals (int): Number of decimal places to round the output values to. Defaults to 5. + + Returns: + (DataFrame): A Pandas Dataframe containing all the information in results in an organized way. + + Examples: + >>> results = model("path/to/image.jpg") + >>> df_result = results[0].to_df() + >>> print(df_result) + """ + import pandas as pd # scope for faster 'import ultralytics' + + return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals)) + + def to_csv(self, normalize=False, decimals=5, *args, **kwargs): + """ + Converts detection results to a CSV format. + + This method serializes the detection results into a CSV format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. Defaults to False. + decimals (int): Number of decimal places to round the output values to. Defaults to 5. + *args (Any): Variable length argument list to be passed to pandas.DataFrame.to_csv(). + **kwargs (Any): Arbitrary keyword arguments to be passed to pandas.DataFrame.to_csv(). + + + Returns: + (str): CSV containing all the information in results in an organized way. + + Examples: + >>> results = model("path/to/image.jpg") + >>> csv_result = results[0].to_csv() + >>> print(csv_result) + """ + return self.to_df(normalize=normalize, decimals=decimals).to_csv(*args, **kwargs) + + def to_xml(self, normalize=False, decimals=5, *args, **kwargs): + """ + Converts detection results to XML format. + + This method serializes the detection results into an XML format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. Defaults to False. + decimals (int): Number of decimal places to round the output values to. Defaults to 5. + *args (Any): Variable length argument list to be passed to pandas.DataFrame.to_xml(). + **kwargs (Any): Arbitrary keyword arguments to be passed to pandas.DataFrame.to_xml(). + + Returns: + (str): An XML string containing all the information in results in an organized way. + + Examples: + >>> results = model("path/to/image.jpg") + >>> xml_result = results[0].to_xml() + >>> print(xml_result) + """ + check_requirements("lxml") + df = self.to_df(normalize=normalize, decimals=decimals) + return '\n' if df.empty else df.to_xml(*args, **kwargs) + + def tojson(self, normalize=False, decimals=5): + """Deprecated version of to_json().""" + LOGGER.warning("WARNING ⚠️ 'result.tojson()' is deprecated, replace with 'result.to_json()'.") + return self.to_json(normalize, decimals) + + def to_json(self, normalize=False, decimals=5): + """ + Converts detection results to JSON format. + + This method serializes the detection results into a JSON-compatible format. It includes information + about detected objects such as bounding boxes, class names, confidence scores, and optionally + segmentation masks and keypoints. + + Args: + normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions. + If True, coordinates will be returned as float values between 0 and 1. Defaults to False. + decimals (int): Number of decimal places to round the output values to. Defaults to 5. + + Returns: + (str): A JSON string containing the serialized detection results. + + Examples: + >>> results = model("path/to/image.jpg") + >>> json_result = results[0].to_json() + >>> print(json_result) + + Notes: + - For classification tasks, the JSON will contain class probabilities instead of bounding boxes. + - For object detection tasks, the JSON will include bounding box coordinates, class names, and + confidence scores. + - If available, segmentation masks and keypoints will also be included in the JSON output. + - The method uses the `summary` method internally to generate the data structure before + converting it to JSON. + """ + import json + + return json.dumps(self.summary(normalize=normalize, decimals=decimals), indent=2) + + +class Boxes(BaseTensor): + """ + A class for managing and manipulating detection boxes. + + This class provides functionality for handling detection boxes, including their coordinates, confidence scores, + class labels, and optional tracking IDs. It supports various box formats and offers methods for easy manipulation + and conversion between different coordinate systems. + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw tensor containing detection boxes and associated data. + orig_shape (Tuple[int, int]): The original image dimensions (height, width). + is_track (bool): Indicates whether tracking IDs are included in the box data. + xyxy (torch.Tensor | numpy.ndarray): Boxes in [x1, y1, x2, y2] format. + conf (torch.Tensor | numpy.ndarray): Confidence scores for each box. + cls (torch.Tensor | numpy.ndarray): Class labels for each box. + id (torch.Tensor | numpy.ndarray): Tracking IDs for each box (if available). + xywh (torch.Tensor | numpy.ndarray): Boxes in [x, y, width, height] format. + xyxyn (torch.Tensor | numpy.ndarray): Normalized [x1, y1, x2, y2] boxes relative to orig_shape. + xywhn (torch.Tensor | numpy.ndarray): Normalized [x, y, width, height] boxes relative to orig_shape. + + Methods: + cpu(): Returns a copy of the object with all tensors on CPU memory. + numpy(): Returns a copy of the object with all tensors as numpy arrays. + cuda(): Returns a copy of the object with all tensors on GPU memory. + to(*args, **kwargs): Returns a copy of the object with tensors on specified device and dtype. + + Examples: + >>> import torch + >>> boxes_data = torch.tensor([[100, 50, 150, 100, 0.9, 0], [200, 150, 300, 250, 0.8, 1]]) + >>> orig_shape = (480, 640) # height, width + >>> boxes = Boxes(boxes_data, orig_shape) + >>> print(boxes.xyxy) + >>> print(boxes.conf) + >>> print(boxes.cls) + >>> print(boxes.xywhn) + """ + + def __init__(self, boxes, orig_shape) -> None: + """ + Initialize the Boxes class with detection box data and the original image shape. + + This class manages detection boxes, providing easy access and manipulation of box coordinates, + confidence scores, class identifiers, and optional tracking IDs. It supports multiple formats + for box coordinates, including both absolute and normalized forms. + + Args: + boxes (torch.Tensor | np.ndarray): A tensor or numpy array with detection boxes of shape + (num_boxes, 6) or (num_boxes, 7). Columns should contain + [x1, y1, x2, y2, confidence, class, (optional) track_id]. + orig_shape (Tuple[int, int]): The original image shape as (height, width). Used for normalization. + + Attributes: + data (torch.Tensor): The raw tensor containing detection boxes and their associated data. + orig_shape (Tuple[int, int]): The original image size, used for normalization. + is_track (bool): Indicates whether tracking IDs are included in the box data. + + Examples: + >>> import torch + >>> boxes = torch.tensor([[100, 50, 150, 100, 0.9, 0]]) + >>> orig_shape = (480, 640) + >>> detection_boxes = Boxes(boxes, orig_shape) + >>> print(detection_boxes.xyxy) + tensor([[100., 50., 150., 100.]]) + """ + if boxes.ndim == 1: + boxes = boxes[None, :] + n = boxes.shape[-1] + assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls + super().__init__(boxes, orig_shape) + self.is_track = n == 7 + self.orig_shape = orig_shape + + @property + def xyxy(self): + """ + Returns bounding boxes in [x1, y1, x2, y2] format. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array of shape (n, 4) containing bounding box + coordinates in [x1, y1, x2, y2] format, where n is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> boxes = results[0].boxes + >>> xyxy = boxes.xyxy + >>> print(xyxy) + """ + return self.data[:, :4] + + @property + def conf(self): + """ + Returns the confidence scores for each detection box. + + Returns: + (torch.Tensor | numpy.ndarray): A 1D tensor or array containing confidence scores for each detection, + with shape (N,) where N is the number of detections. + + Examples: + >>> boxes = Boxes(torch.tensor([[10, 20, 30, 40, 0.9, 0]]), orig_shape=(100, 100)) + >>> conf_scores = boxes.conf + >>> print(conf_scores) + tensor([0.9000]) + """ + return self.data[:, -2] + + @property + def cls(self): + """ + Returns the class ID tensor representing category predictions for each bounding box. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the class IDs for each detection box. + The shape is (N,), where N is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> boxes = results[0].boxes + >>> class_ids = boxes.cls + >>> print(class_ids) # tensor([0., 2., 1.]) + """ + return self.data[:, -1] + + @property + def id(self): + """ + Returns the tracking IDs for each detection box if available. + + Returns: + (torch.Tensor | None): A tensor containing tracking IDs for each box if tracking is enabled, + otherwise None. Shape is (N,) where N is the number of boxes. + + Examples: + >>> results = model.track("path/to/video.mp4") + >>> for result in results: + ... boxes = result.boxes + ... if boxes.is_track: + ... track_ids = boxes.id + ... print(f"Tracking IDs: {track_ids}") + ... else: + ... print("Tracking is not enabled for these boxes.") + + Notes: + - This property is only available when tracking is enabled (i.e., when `is_track` is True). + - The tracking IDs are typically used to associate detections across multiple frames in video analysis. + """ + return self.data[:, -3] if self.is_track else None + + @property + @lru_cache(maxsize=2) # maxsize 1 should suffice + def xywh(self): + """ + Convert bounding boxes from [x1, y1, x2, y2] format to [x, y, width, height] format. + + Returns: + (torch.Tensor | numpy.ndarray): Boxes in [x_center, y_center, width, height] format, where x_center, y_center are the coordinates of + the center point of the bounding box, width, height are the dimensions of the bounding box and the + shape of the returned tensor is (N, 4), where N is the number of boxes. + + Examples: + >>> boxes = Boxes(torch.tensor([[100, 50, 150, 100], [200, 150, 300, 250]]), orig_shape=(480, 640)) + >>> xywh = boxes.xywh + >>> print(xywh) + tensor([[100.0000, 50.0000, 50.0000, 50.0000], + [200.0000, 150.0000, 100.0000, 100.0000]]) + """ + return ops.xyxy2xywh(self.xyxy) + + @property + @lru_cache(maxsize=2) + def xyxyn(self): + """ + Returns normalized bounding box coordinates relative to the original image size. + + This property calculates and returns the bounding box coordinates in [x1, y1, x2, y2] format, + normalized to the range [0, 1] based on the original image dimensions. + + Returns: + (torch.Tensor | numpy.ndarray): Normalized bounding box coordinates with shape (N, 4), where N is + the number of boxes. Each row contains [x1, y1, x2, y2] values normalized to [0, 1]. + + Examples: + >>> boxes = Boxes(torch.tensor([[100, 50, 300, 400, 0.9, 0]]), orig_shape=(480, 640)) + >>> normalized = boxes.xyxyn + >>> print(normalized) + tensor([[0.1562, 0.1042, 0.4688, 0.8333]]) + """ + xyxy = self.xyxy.clone() if isinstance(self.xyxy, torch.Tensor) else np.copy(self.xyxy) + xyxy[..., [0, 2]] /= self.orig_shape[1] + xyxy[..., [1, 3]] /= self.orig_shape[0] + return xyxy + + @property + @lru_cache(maxsize=2) + def xywhn(self): + """ + Returns normalized bounding boxes in [x, y, width, height] format. + + This property calculates and returns the normalized bounding box coordinates in the format + [x_center, y_center, width, height], where all values are relative to the original image dimensions. + + Returns: + (torch.Tensor | numpy.ndarray): Normalized bounding boxes with shape (N, 4), where N is the + number of boxes. Each row contains [x_center, y_center, width, height] values normalized + to [0, 1] based on the original image dimensions. + + Examples: + >>> boxes = Boxes(torch.tensor([[100, 50, 150, 100, 0.9, 0]]), orig_shape=(480, 640)) + >>> normalized = boxes.xywhn + >>> print(normalized) + tensor([[0.1953, 0.1562, 0.0781, 0.1042]]) + """ + xywh = ops.xyxy2xywh(self.xyxy) + xywh[..., [0, 2]] /= self.orig_shape[1] + xywh[..., [1, 3]] /= self.orig_shape[0] + return xywh + + +class Masks(BaseTensor): + """ + A class for storing and manipulating detection masks. + + This class extends BaseTensor and provides functionality for handling segmentation masks, + including methods for converting between pixel and normalized coordinates. + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw tensor or array containing mask data. + orig_shape (tuple): Original image shape in (height, width) format. + xy (List[numpy.ndarray]): A list of segments in pixel coordinates. + xyn (List[numpy.ndarray]): A list of normalized segments. + + Methods: + cpu(): Returns a copy of the Masks object with the mask tensor on CPU memory. + numpy(): Returns a copy of the Masks object with the mask tensor as a numpy array. + cuda(): Returns a copy of the Masks object with the mask tensor on GPU memory. + to(*args, **kwargs): Returns a copy of the Masks object with the mask tensor on specified device and dtype. + + Examples: + >>> masks_data = torch.rand(1, 160, 160) + >>> orig_shape = (720, 1280) + >>> masks = Masks(masks_data, orig_shape) + >>> pixel_coords = masks.xy + >>> normalized_coords = masks.xyn + """ + + def __init__(self, masks, orig_shape) -> None: + """ + Initialize the Masks class with detection mask data and the original image shape. + + Args: + masks (torch.Tensor | np.ndarray): Detection masks with shape (num_masks, height, width). + orig_shape (tuple): The original image shape as (height, width). Used for normalization. + + Examples: + >>> import torch + >>> from ultralytics.engine.results import Masks + >>> masks = torch.rand(10, 160, 160) # 10 masks of 160x160 resolution + >>> orig_shape = (720, 1280) # Original image shape + >>> mask_obj = Masks(masks, orig_shape) + """ + if masks.ndim == 2: + masks = masks[None, :] + super().__init__(masks, orig_shape) + + @property + @lru_cache(maxsize=1) + def xyn(self): + """ + Returns normalized xy-coordinates of the segmentation masks. + + This property calculates and caches the normalized xy-coordinates of the segmentation masks. The coordinates + are normalized relative to the original image shape. + + Returns: + (List[numpy.ndarray]): A list of numpy arrays, where each array contains the normalized xy-coordinates + of a single segmentation mask. Each array has shape (N, 2), where N is the number of points in the + mask contour. + + Examples: + >>> results = model("image.jpg") + >>> masks = results[0].masks + >>> normalized_coords = masks.xyn + >>> print(normalized_coords[0]) # Normalized coordinates of the first mask + """ + return [ + ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True) + for x in ops.masks2segments(self.data) + ] + + @property + @lru_cache(maxsize=1) + def xy(self): + """ + Returns the [x, y] pixel coordinates for each segment in the mask tensor. + + This property calculates and returns a list of pixel coordinates for each segmentation mask in the + Masks object. The coordinates are scaled to match the original image dimensions. + + Returns: + (List[numpy.ndarray]): A list of numpy arrays, where each array contains the [x, y] pixel + coordinates for a single segmentation mask. Each array has shape (N, 2), where N is the + number of points in the segment. + + Examples: + >>> results = model("image.jpg") + >>> masks = results[0].masks + >>> xy_coords = masks.xy + >>> print(len(xy_coords)) # Number of masks + >>> print(xy_coords[0].shape) # Shape of first mask's coordinates + """ + return [ + ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False) + for x in ops.masks2segments(self.data) + ] + + +class Keypoints(BaseTensor): + """ + A class for storing and manipulating detection keypoints. + + This class encapsulates functionality for handling keypoint data, including coordinate manipulation, + normalization, and confidence values. + + Attributes: + data (torch.Tensor): The raw tensor containing keypoint data. + orig_shape (Tuple[int, int]): The original image dimensions (height, width). + has_visible (bool): Indicates whether visibility information is available for keypoints. + xy (torch.Tensor): Keypoint coordinates in [x, y] format. + xyn (torch.Tensor): Normalized keypoint coordinates in [x, y] format, relative to orig_shape. + conf (torch.Tensor): Confidence values for each keypoint, if available. + + Methods: + cpu(): Returns a copy of the keypoints tensor on CPU memory. + numpy(): Returns a copy of the keypoints tensor as a numpy array. + cuda(): Returns a copy of the keypoints tensor on GPU memory. + to(*args, **kwargs): Returns a copy of the keypoints tensor with specified device and dtype. + + Examples: + >>> import torch + >>> from ultralytics.engine.results import Keypoints + >>> keypoints_data = torch.rand(1, 17, 3) # 1 detection, 17 keypoints, (x, y, conf) + >>> orig_shape = (480, 640) # Original image shape (height, width) + >>> keypoints = Keypoints(keypoints_data, orig_shape) + >>> print(keypoints.xy.shape) # Access xy coordinates + >>> print(keypoints.conf) # Access confidence values + >>> keypoints_cpu = keypoints.cpu() # Move keypoints to CPU + """ + + @smart_inference_mode() # avoid keypoints < conf in-place error + def __init__(self, keypoints, orig_shape) -> None: + """ + Initializes the Keypoints object with detection keypoints and original image dimensions. + + This method processes the input keypoints tensor, handling both 2D and 3D formats. For 3D tensors + (x, y, confidence), it masks out low-confidence keypoints by setting their coordinates to zero. + + Args: + keypoints (torch.Tensor): A tensor containing keypoint data. Shape can be either: + - (num_objects, num_keypoints, 2) for x, y coordinates only + - (num_objects, num_keypoints, 3) for x, y coordinates and confidence scores + orig_shape (Tuple[int, int]): The original image dimensions (height, width). + + Examples: + >>> kpts = torch.rand(1, 17, 3) # 1 object, 17 keypoints (COCO format), x,y,conf + >>> orig_shape = (720, 1280) # Original image height, width + >>> keypoints = Keypoints(kpts, orig_shape) + """ + if keypoints.ndim == 2: + keypoints = keypoints[None, :] + if keypoints.shape[2] == 3: # x, y, conf + mask = keypoints[..., 2] < 0.5 # points with conf < 0.5 (not visible) + keypoints[..., :2][mask] = 0 + super().__init__(keypoints, orig_shape) + self.has_visible = self.data.shape[-1] == 3 + + @property + @lru_cache(maxsize=1) + def xy(self): + """ + Returns x, y coordinates of keypoints. + + Returns: + (torch.Tensor): A tensor containing the x, y coordinates of keypoints with shape (N, K, 2), where N is + the number of detections and K is the number of keypoints per detection. + + Examples: + >>> results = model("image.jpg") + >>> keypoints = results[0].keypoints + >>> xy = keypoints.xy + >>> print(xy.shape) # (N, K, 2) + >>> print(xy[0]) # x, y coordinates of keypoints for first detection + + Notes: + - The returned coordinates are in pixel units relative to the original image dimensions. + - If keypoints were initialized with confidence values, only keypoints with confidence >= 0.5 are returned. + - This property uses LRU caching to improve performance on repeated access. + """ + return self.data[..., :2] + + @property + @lru_cache(maxsize=1) + def xyn(self): + """ + Returns normalized coordinates (x, y) of keypoints relative to the original image size. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or array of shape (N, K, 2) containing normalized keypoint + coordinates, where N is the number of instances, K is the number of keypoints, and the last + dimension contains [x, y] values in the range [0, 1]. + + Examples: + >>> keypoints = Keypoints(torch.rand(1, 17, 2), orig_shape=(480, 640)) + >>> normalized_kpts = keypoints.xyn + >>> print(normalized_kpts.shape) + torch.Size([1, 17, 2]) + """ + xy = self.xy.clone() if isinstance(self.xy, torch.Tensor) else np.copy(self.xy) + xy[..., 0] /= self.orig_shape[1] + xy[..., 1] /= self.orig_shape[0] + return xy + + @property + @lru_cache(maxsize=1) + def conf(self): + """ + Returns confidence values for each keypoint. + + Returns: + (torch.Tensor | None): A tensor containing confidence scores for each keypoint if available, + otherwise None. Shape is (num_detections, num_keypoints) for batched data or (num_keypoints,) + for single detection. + + Examples: + >>> keypoints = Keypoints(torch.rand(1, 17, 3), orig_shape=(640, 640)) # 1 detection, 17 keypoints + >>> conf = keypoints.conf + >>> print(conf.shape) # torch.Size([1, 17]) + """ + return self.data[..., 2] if self.has_visible else None + + +class Probs(BaseTensor): + """ + A class for storing and manipulating classification probabilities. + + This class extends BaseTensor and provides methods for accessing and manipulating + classification probabilities, including top-1 and top-5 predictions. + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw tensor or array containing classification probabilities. + orig_shape (tuple | None): The original image shape as (height, width). Not used in this class. + top1 (int): Index of the class with the highest probability. + top5 (List[int]): Indices of the top 5 classes by probability. + top1conf (torch.Tensor | numpy.ndarray): Confidence score of the top 1 class. + top5conf (torch.Tensor | numpy.ndarray): Confidence scores of the top 5 classes. + + Methods: + cpu(): Returns a copy of the probabilities tensor on CPU memory. + numpy(): Returns a copy of the probabilities tensor as a numpy array. + cuda(): Returns a copy of the probabilities tensor on GPU memory. + to(*args, **kwargs): Returns a copy of the probabilities tensor with specified device and dtype. + + Examples: + >>> probs = torch.tensor([0.1, 0.3, 0.6]) + >>> p = Probs(probs) + >>> print(p.top1) + 2 + >>> print(p.top5) + [2, 1, 0] + >>> print(p.top1conf) + tensor(0.6000) + >>> print(p.top5conf) + tensor([0.6000, 0.3000, 0.1000]) + """ + + def __init__(self, probs, orig_shape=None) -> None: + """ + Initialize the Probs class with classification probabilities. + + This class stores and manages classification probabilities, providing easy access to top predictions and their + confidences. + + Args: + probs (torch.Tensor | np.ndarray): A 1D tensor or array of classification probabilities. + orig_shape (tuple | None): The original image shape as (height, width). Not used in this class but kept for + consistency with other result classes. + + Attributes: + data (torch.Tensor | np.ndarray): The raw tensor or array containing classification probabilities. + top1 (int): Index of the top 1 class. + top5 (List[int]): Indices of the top 5 classes. + top1conf (torch.Tensor | np.ndarray): Confidence of the top 1 class. + top5conf (torch.Tensor | np.ndarray): Confidences of the top 5 classes. + + Examples: + >>> import torch + >>> probs = torch.tensor([0.1, 0.3, 0.2, 0.4]) + >>> p = Probs(probs) + >>> print(p.top1) + 3 + >>> print(p.top1conf) + tensor(0.4000) + >>> print(p.top5) + [3, 1, 2, 0] + """ + super().__init__(probs, orig_shape) + + @property + @lru_cache(maxsize=1) + def top1(self): + """ + Returns the index of the class with the highest probability. + + Returns: + (int): Index of the class with the highest probability. + + Examples: + >>> probs = Probs(torch.tensor([0.1, 0.3, 0.6])) + >>> probs.top1 + 2 + """ + return int(self.data.argmax()) + + @property + @lru_cache(maxsize=1) + def top5(self): + """ + Returns the indices of the top 5 class probabilities. + + Returns: + (List[int]): A list containing the indices of the top 5 class probabilities, sorted in descending order. + + Examples: + >>> probs = Probs(torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5])) + >>> print(probs.top5) + [4, 3, 2, 1, 0] + """ + return (-self.data).argsort(0)[:5].tolist() # this way works with both torch and numpy. + + @property + @lru_cache(maxsize=1) + def top1conf(self): + """ + Returns the confidence score of the highest probability class. + + This property retrieves the confidence score (probability) of the class with the highest predicted probability + from the classification results. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor containing the confidence score of the top 1 class. + + Examples: + >>> results = model("image.jpg") # classify an image + >>> probs = results[0].probs # get classification probabilities + >>> top1_confidence = probs.top1conf # get confidence of top 1 class + >>> print(f"Top 1 class confidence: {top1_confidence.item():.4f}") + """ + return self.data[self.top1] + + @property + @lru_cache(maxsize=1) + def top5conf(self): + """ + Returns confidence scores for the top 5 classification predictions. + + This property retrieves the confidence scores corresponding to the top 5 class probabilities + predicted by the model. It provides a quick way to access the most likely class predictions + along with their associated confidence levels. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or array containing the confidence scores for the + top 5 predicted classes, sorted in descending order of probability. + + Examples: + >>> results = model("image.jpg") + >>> probs = results[0].probs + >>> top5_conf = probs.top5conf + >>> print(top5_conf) # Prints confidence scores for top 5 classes + """ + return self.data[self.top5] + + +class OBB(BaseTensor): + """ + A class for storing and manipulating Oriented Bounding Boxes (OBB). + + This class provides functionality to handle oriented bounding boxes, including conversion between + different formats, normalization, and access to various properties of the boxes. + + Attributes: + data (torch.Tensor): The raw OBB tensor containing box coordinates and associated data. + orig_shape (tuple): Original image size as (height, width). + is_track (bool): Indicates whether tracking IDs are included in the box data. + xywhr (torch.Tensor | numpy.ndarray): Boxes in [x_center, y_center, width, height, rotation] format. + conf (torch.Tensor | numpy.ndarray): Confidence scores for each box. + cls (torch.Tensor | numpy.ndarray): Class labels for each box. + id (torch.Tensor | numpy.ndarray): Tracking IDs for each box, if available. + xyxyxyxy (torch.Tensor | numpy.ndarray): Boxes in 8-point [x1, y1, x2, y2, x3, y3, x4, y4] format. + xyxyxyxyn (torch.Tensor | numpy.ndarray): Normalized 8-point coordinates relative to orig_shape. + xyxy (torch.Tensor | numpy.ndarray): Axis-aligned bounding boxes in [x1, y1, x2, y2] format. + + Methods: + cpu(): Returns a copy of the OBB object with all tensors on CPU memory. + numpy(): Returns a copy of the OBB object with all tensors as numpy arrays. + cuda(): Returns a copy of the OBB object with all tensors on GPU memory. + to(*args, **kwargs): Returns a copy of the OBB object with tensors on specified device and dtype. + + Examples: + >>> boxes = torch.tensor([[100, 50, 150, 100, 30, 0.9, 0]]) # xywhr, conf, cls + >>> obb = OBB(boxes, orig_shape=(480, 640)) + >>> print(obb.xyxyxyxy) + >>> print(obb.conf) + >>> print(obb.cls) + """ + + def __init__(self, boxes, orig_shape) -> None: + """ + Initialize an OBB (Oriented Bounding Box) instance with oriented bounding box data and original image shape. + + This class stores and manipulates Oriented Bounding Boxes (OBB) for object detection tasks. It provides + various properties and methods to access and transform the OBB data. + + Args: + boxes (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the detection boxes, + with shape (num_boxes, 7) or (num_boxes, 8). The last two columns contain confidence and class values. + If present, the third last column contains track IDs, and the fifth column contains rotation. + orig_shape (Tuple[int, int]): Original image size, in the format (height, width). + + Attributes: + data (torch.Tensor | numpy.ndarray): The raw OBB tensor. + orig_shape (Tuple[int, int]): The original image shape. + is_track (bool): Whether the boxes include tracking IDs. + + Raises: + AssertionError: If the number of values per box is not 7 or 8. + + Examples: + >>> import torch + >>> boxes = torch.rand(3, 7) # 3 boxes with 7 values each + >>> orig_shape = (640, 480) + >>> obb = OBB(boxes, orig_shape) + >>> print(obb.xywhr) # Access the boxes in xywhr format + """ + if boxes.ndim == 1: + boxes = boxes[None, :] + n = boxes.shape[-1] + assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls + super().__init__(boxes, orig_shape) + self.is_track = n == 8 + self.orig_shape = orig_shape + + @property + def xywhr(self): + """ + Returns boxes in [x_center, y_center, width, height, rotation] format. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the oriented bounding boxes with format + [x_center, y_center, width, height, rotation]. The shape is (N, 5) where N is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> obb = results[0].obb + >>> xywhr = obb.xywhr + >>> print(xywhr.shape) + torch.Size([3, 5]) + """ + return self.data[:, :5] + + @property + def conf(self): + """ + Returns the confidence scores for Oriented Bounding Boxes (OBBs). + + This property retrieves the confidence values associated with each OBB detection. The confidence score + represents the model's certainty in the detection. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array of shape (N,) containing confidence scores + for N detections, where each score is in the range [0, 1]. + + Examples: + >>> results = model("image.jpg") + >>> obb_result = results[0].obb + >>> confidence_scores = obb_result.conf + >>> print(confidence_scores) + """ + return self.data[:, -2] + + @property + def cls(self): + """ + Returns the class values of the oriented bounding boxes. + + Returns: + (torch.Tensor | numpy.ndarray): A tensor or numpy array containing the class values for each oriented + bounding box. The shape is (N,), where N is the number of boxes. + + Examples: + >>> results = model("image.jpg") + >>> result = results[0] + >>> obb = result.obb + >>> class_values = obb.cls + >>> print(class_values) + """ + return self.data[:, -1] + + @property + def id(self): + """ + Returns the tracking IDs of the oriented bounding boxes (if available). + + Returns: + (torch.Tensor | numpy.ndarray | None): A tensor or numpy array containing the tracking IDs for each + oriented bounding box. Returns None if tracking IDs are not available. + + Examples: + >>> results = model("image.jpg", tracker=True) # Run inference with tracking + >>> for result in results: + ... if result.obb is not None: + ... track_ids = result.obb.id + ... if track_ids is not None: + ... print(f"Tracking IDs: {track_ids}") + """ + return self.data[:, -3] if self.is_track else None + + @property + @lru_cache(maxsize=2) + def xyxyxyxy(self): + """ + Converts OBB format to 8-point (xyxyxyxy) coordinate format for rotated bounding boxes. + + Returns: + (torch.Tensor | numpy.ndarray): Rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), where N is + the number of boxes. Each box is represented by 4 points (x, y), starting from the top-left corner and + moving clockwise. + + Examples: + >>> obb = OBB(torch.tensor([[100, 100, 50, 30, 0.5, 0.9, 0]]), orig_shape=(640, 640)) + >>> xyxyxyxy = obb.xyxyxyxy + >>> print(xyxyxyxy.shape) + torch.Size([1, 4, 2]) + """ + return ops.xywhr2xyxyxyxy(self.xywhr) + + @property + @lru_cache(maxsize=2) + def xyxyxyxyn(self): + """ + Converts rotated bounding boxes to normalized xyxyxyxy format. + + Returns: + (torch.Tensor | numpy.ndarray): Normalized rotated bounding boxes in xyxyxyxy format with shape (N, 4, 2), + where N is the number of boxes. Each box is represented by 4 points (x, y), normalized relative to + the original image dimensions. + + Examples: + >>> obb = OBB(torch.rand(10, 7), orig_shape=(640, 480)) # 10 random OBBs + >>> normalized_boxes = obb.xyxyxyxyn + >>> print(normalized_boxes.shape) + torch.Size([10, 4, 2]) + """ + xyxyxyxyn = self.xyxyxyxy.clone() if isinstance(self.xyxyxyxy, torch.Tensor) else np.copy(self.xyxyxyxy) + xyxyxyxyn[..., 0] /= self.orig_shape[1] + xyxyxyxyn[..., 1] /= self.orig_shape[0] + return xyxyxyxyn + + @property + @lru_cache(maxsize=2) + def xyxy(self): + """ + Converts oriented bounding boxes (OBB) to axis-aligned bounding boxes in xyxy format. + + This property calculates the minimal enclosing rectangle for each oriented bounding box and returns it in + xyxy format (x1, y1, x2, y2). This is useful for operations that require axis-aligned bounding boxes, such + as IoU calculation with non-rotated boxes. + + Returns: + (torch.Tensor | numpy.ndarray): Axis-aligned bounding boxes in xyxy format with shape (N, 4), where N + is the number of boxes. Each row contains [x1, y1, x2, y2] coordinates. + + Examples: + >>> import torch + >>> from ultralytics import YOLO + >>> model = YOLO("yolov8n-obb.pt") + >>> results = model("path/to/image.jpg") + >>> for result in results: + ... obb = result.obb + ... if obb is not None: + ... xyxy_boxes = obb.xyxy + ... print(xyxy_boxes.shape) # (N, 4) + + Notes: + - This method approximates the OBB by its minimal enclosing rectangle. + - The returned format is compatible with standard object detection metrics and visualization tools. + - The property uses caching to improve performance for repeated access. + """ + x = self.xyxyxyxy[..., 0] + y = self.xyxyxyxy[..., 1] + return ( + torch.stack([x.amin(1), y.amin(1), x.amax(1), y.amax(1)], -1) + if isinstance(x, torch.Tensor) + else np.stack([x.min(1), y.min(1), x.max(1), y.max(1)], -1) + ) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a373cd82528a8ba2e4b742245b1829aba960be47 --- /dev/null +++ b/ultralytics/engine/trainer.py @@ -0,0 +1,820 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Train a model on a dataset. + +Usage: + $ yolo mode=train model=yolov8n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16 +""" + +import gc +import math +import os +import subprocess +import time +import warnings +from copy import copy, deepcopy +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np +import torch +from torch import distributed as dist +from torch import nn, optim + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights +from ultralytics.utils import ( + DEFAULT_CFG, + LOCAL_RANK, + LOGGER, + RANK, + TQDM, + __version__, + callbacks, + clean_url, + colorstr, + emojis, + yaml_save, +) +from ultralytics.utils.autobatch import check_train_batch_size +from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args +from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command +from ultralytics.utils.files import get_latest_run +from ultralytics.utils.torch_utils import ( + TORCH_2_4, + EarlyStopping, + ModelEMA, + autocast, + convert_optimizer_state_dict_to_fp16, + init_seeds, + one_cycle, + select_device, + strip_optimizer, + torch_distributed_zero_first, +) + + +class BaseTrainer: + """ + A base class for creating trainers. + + Attributes: + args (SimpleNamespace): Configuration for the trainer. + validator (BaseValidator): Validator instance. + model (nn.Module): Model instance. + callbacks (defaultdict): Dictionary of callbacks. + save_dir (Path): Directory to save results. + wdir (Path): Directory to save weights. + last (Path): Path to the last checkpoint. + best (Path): Path to the best checkpoint. + save_period (int): Save checkpoint every x epochs (disabled if < 1). + batch_size (int): Batch size for training. + epochs (int): Number of epochs to train for. + start_epoch (int): Starting epoch for training. + device (torch.device): Device to use for training. + amp (bool): Flag to enable AMP (Automatic Mixed Precision). + scaler (amp.GradScaler): Gradient scaler for AMP. + data (str): Path to data. + trainset (torch.utils.data.Dataset): Training dataset. + testset (torch.utils.data.Dataset): Testing dataset. + ema (nn.Module): EMA (Exponential Moving Average) of the model. + resume (bool): Resume training from a checkpoint. + lf (nn.Module): Loss function. + scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. + best_fitness (float): The best fitness value achieved. + fitness (float): Current fitness value. + loss (float): Current loss value. + tloss (float): Total loss value. + loss_names (list): List of loss names. + csv (Path): Path to results CSV file. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initializes the BaseTrainer class. + + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + self.args = get_cfg(cfg, overrides) + self.check_resume(overrides) + self.device = select_device(self.args.device, self.args.batch) + self.validator = None + self.metrics = None + self.plots = {} + init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) + + # Dirs + self.save_dir = get_save_dir(self.args) + self.args.name = self.save_dir.name # update name for loggers + self.wdir = self.save_dir / "weights" # weights dir + if RANK in {-1, 0}: + self.wdir.mkdir(parents=True, exist_ok=True) # make dir + self.args.save_dir = str(self.save_dir) + yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args + self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths + self.save_period = self.args.save_period + + self.batch_size = self.args.batch + self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training + self.start_epoch = 0 + if RANK == -1: + print_args(vars(self.args)) + + # Device + if self.device.type in {"cpu", "mps"}: + self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading + + # Model and Dataset + self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolov8n -> yolov8n.pt + with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times + self.trainset, self.testset = self.get_dataset() + self.ema = None + + # Optimization utils init + self.lf = None + self.scheduler = None + + # Epoch level metrics + self.best_fitness = None + self.fitness = None + self.loss = None + self.tloss = None + self.loss_names = ["Loss"] + self.csv = self.save_dir / "results.csv" + self.plot_idx = [0, 1, 2] + + # HUB + self.hub_session = None + + # Callbacks + self.callbacks = _callbacks or callbacks.get_default_callbacks() + if RANK in {-1, 0}: + callbacks.add_integration_callbacks(self) + + def add_callback(self, event: str, callback): + """Appends the given callback.""" + self.callbacks[event].append(callback) + + def set_callback(self, event: str, callback): + """Overrides the existing callbacks with the given callback.""" + self.callbacks[event] = [callback] + + def run_callbacks(self, event: str): + """Run all existing callbacks associated with a particular event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def train(self): + """Allow device='', device=None on Multi-GPU systems to default to device=0.""" + if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3' + world_size = len(self.args.device.split(",")) + elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list) + world_size = len(self.args.device) + elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps' + world_size = 0 + elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number + world_size = 1 # default to device 0 + else: # i.e. device=None or device='' + world_size = 0 + + # Run subprocess if DDP training, else train normally + if world_size > 1 and "LOCAL_RANK" not in os.environ: + # Argument checks + if self.args.rect: + LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'") + self.args.rect = False + if self.args.batch < 1.0: + LOGGER.warning( + "WARNING ⚠️ 'batch<1' for AutoBatch is incompatible with Multi-GPU training, setting " + "default 'batch=16'" + ) + self.args.batch = 16 + + # Command + cmd, file = generate_ddp_command(world_size, self) + try: + LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}") + subprocess.run(cmd, check=True) + except Exception as e: + raise e + finally: + ddp_cleanup(self, str(file)) + + else: + self._do_train(world_size) + + def _setup_scheduler(self): + """Initialize training learning rate scheduler.""" + if self.args.cos_lr: + self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] + else: + self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) + + def _setup_ddp(self, world_size): + """Initializes and sets the DistributedDataParallel parameters for training.""" + torch.cuda.set_device(RANK) + self.device = torch.device("cuda", RANK) + # LOGGER.info(f'DDP info: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') + os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout + dist.init_process_group( + backend="nccl" if dist.is_nccl_available() else "gloo", + timeout=timedelta(seconds=10800), # 3 hours + rank=RANK, + world_size=world_size, + ) + + def _setup_train(self, world_size): + """Builds dataloaders and optimizer on correct rank process.""" + # Model + self.run_callbacks("on_pretrain_routine_start") + ckpt = self.setup_model() + self.model = self.model.to(self.device) + self.set_model_attributes() + + # Freeze layers + freeze_list = ( + self.args.freeze + if isinstance(self.args.freeze, list) + else range(self.args.freeze) + if isinstance(self.args.freeze, int) + else [] + ) + always_freeze_names = [".dfl"] # always freeze these layers + freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names + for k, v in self.model.named_parameters(): + # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results) + if any(x in k for x in freeze_layer_names): + LOGGER.info(f"Freezing layer '{k}'") + v.requires_grad = False + elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients + LOGGER.info( + f"WARNING ⚠️ setting 'requires_grad=True' for frozen layer '{k}'. " + "See ultralytics.engine.trainer for customization of frozen layers." + ) + v.requires_grad = True + + # Check AMP + self.amp = torch.tensor(self.args.amp).to(self.device) # True or False + if self.amp and RANK in {-1, 0}: # Single-GPU and DDP + callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them + self.amp = torch.tensor(check_amp(self.model), device=self.device) + callbacks.default_callbacks = callbacks_backup # restore callbacks + if RANK > -1 and world_size > 1: # DDP + dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) + self.amp = bool(self.amp) # as boolean + self.scaler = ( + torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp) + ) + if world_size > 1: + self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True) + self.set_model_attributes() # set again after DDP wrapper + + # Check imgsz + gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride) + self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) + self.stride = gs # for multiscale training + + # Batch size + if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size + self.args.batch = self.batch_size = self.auto_batch() + + # Dataloaders + batch_size = self.batch_size // max(world_size, 1) + self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=LOCAL_RANK, mode="train") + if RANK in {-1, 0}: + # Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects. + self.test_loader = self.get_dataloader( + self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val" + ) + self.validator = self.get_validator() + metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val") + self.metrics = dict(zip(metric_keys, [0] * len(metric_keys))) + self.ema = ModelEMA(self.model) + if self.args.plots: + self.plot_training_labels() + + # Optimizer + self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing + weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay + iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs + self.optimizer = self.build_optimizer( + model=self.model, + name=self.args.optimizer, + lr=self.args.lr0, + momentum=self.args.momentum, + decay=weight_decay, + iterations=iterations, + ) + # Scheduler + self._setup_scheduler() + self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False + self.resume_training(ckpt) + self.scheduler.last_epoch = self.start_epoch - 1 # do not move + self.run_callbacks("on_pretrain_routine_end") + + def _do_train(self, world_size=1): + """Train completed, evaluate and plot if specified by arguments.""" + if world_size > 1: + self._setup_ddp(world_size) + self._setup_train(world_size) + + nb = len(self.train_loader) # number of batches + nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations + last_opt_step = -1 + self.epoch_time = None + self.epoch_time_start = time.time() + self.train_time_start = time.time() + self.run_callbacks("on_train_start") + LOGGER.info( + f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n" + f"Using {self.train_loader.num_workers * (world_size or 1)} dataloader workers\n" + f"Logging results to {colorstr('bold', self.save_dir)}\n" + f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...") + ) + if self.args.close_mosaic: + base_idx = (self.epochs - self.args.close_mosaic) * nb + self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2]) + epoch = self.start_epoch + self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start + while True: + self.epoch = epoch + self.run_callbacks("on_train_epoch_start") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' + self.scheduler.step() + + self.model.train() + if RANK != -1: + self.train_loader.sampler.set_epoch(epoch) + pbar = enumerate(self.train_loader) + # Update dataloader attributes (optional) + if epoch == (self.epochs - self.args.close_mosaic): + self._close_dataloader_mosaic() + self.train_loader.reset() + + if RANK in {-1, 0}: + LOGGER.info(self.progress_string()) + pbar = TQDM(enumerate(self.train_loader), total=nb) + self.tloss = None + for i, batch in pbar: + self.run_callbacks("on_train_batch_start") + # Warmup + ni = i + nb * epoch + if ni <= nw: + xi = [0, nw] # x interp + self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round())) + for j, x in enumerate(self.optimizer.param_groups): + # Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 + x["lr"] = np.interp( + ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)] + ) + if "momentum" in x: + x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) + + # Forward + with autocast(self.amp): + batch = self.preprocess_batch(batch) + self.loss, self.loss_items = self.model(batch) + if RANK != -1: + self.loss *= world_size + self.tloss = ( + (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items + ) + + # Backward + self.scaler.scale(self.loss).backward() + + # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html + if ni - last_opt_step >= self.accumulate: + self.optimizer_step() + last_opt_step = ni + + # Timed stopping + if self.args.time: + self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) + if RANK != -1: # if DDP training + broadcast_list = [self.stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + self.stop = broadcast_list[0] + if self.stop: # training time exceeded + break + + # Log + if RANK in {-1, 0}: + loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1 + pbar.set_description( + ("%11s" * 2 + "%11.4g" * (2 + loss_length)) + % ( + f"{epoch + 1}/{self.epochs}", + f"{self._get_memory():.3g}G", # (GB) GPU memory util + *(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses + batch["cls"].shape[0], # batch size, i.e. 8 + batch["img"].shape[-1], # imgsz, i.e 640 + ) + ) + self.run_callbacks("on_batch_end") + if self.args.plots and ni in self.plot_idx: + self.plot_training_samples(batch, ni) + + self.run_callbacks("on_train_batch_end") + + self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers + self.run_callbacks("on_train_epoch_end") + if RANK in {-1, 0}: + final_epoch = epoch + 1 >= self.epochs + self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"]) + + # Validation + if self.args.val or final_epoch or self.stopper.possible_stop or self.stop: + self.metrics, self.fitness = self.validate() + self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr}) + self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch + if self.args.time: + self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600) + + # Save model + if self.args.save or final_epoch: + self.save_model() + self.run_callbacks("on_model_save") + + # Scheduler + t = time.time() + self.epoch_time = t - self.epoch_time_start + self.epoch_time_start = t + if self.args.time: + mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) + self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) + self._setup_scheduler() + self.scheduler.last_epoch = self.epoch # do not move + self.stop |= epoch >= self.epochs # stop if exceeded epochs + self.run_callbacks("on_fit_epoch_end") + self._clear_memory() + + # Early Stopping + if RANK != -1: # if DDP training + broadcast_list = [self.stop if RANK == 0 else None] + dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks + self.stop = broadcast_list[0] + if self.stop: + break # must break all DDP ranks + epoch += 1 + + if RANK in {-1, 0}: + # Do final val with best.pt + seconds = time.time() - self.train_time_start + LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.") + self.final_eval() + if self.args.plots: + self.plot_metrics() + self.run_callbacks("on_train_end") + self._clear_memory() + self.run_callbacks("teardown") + + def auto_batch(self, max_num_obj=0): + """Get batch size by calculating memory occupation of model.""" + return check_train_batch_size( + model=self.model, + imgsz=self.args.imgsz, + amp=self.amp, + batch=self.batch_size, + max_num_obj=max_num_obj, + ) # returns batch size + + def _get_memory(self): + """Get accelerator memory utilization in GB.""" + if self.device.type == "mps": + memory = torch.mps.driver_allocated_memory() + elif self.device.type == "cpu": + memory = 0 + else: + memory = torch.cuda.memory_reserved() + return memory / 1e9 + + def _clear_memory(self): + """Clear accelerator memory on different platforms.""" + gc.collect() + if self.device.type == "mps": + torch.mps.empty_cache() + elif self.device.type == "cpu": + return + else: + torch.cuda.empty_cache() + + def read_results_csv(self): + """Read results.csv into a dict using pandas.""" + import pandas as pd # scope for faster 'import ultralytics' + + return pd.read_csv(self.csv).to_dict(orient="list") + + def save_model(self): + """Save model training checkpoints with additional metadata.""" + import io + + # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls) + buffer = io.BytesIO() + torch.save( + { + "epoch": self.epoch, + "best_fitness": self.best_fitness, + "model": None, # resume and final checkpoints derive from EMA + "ema": deepcopy(self.ema.ema).half(), + "updates": self.ema.updates, + "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())), + "train_args": vars(self.args), # save as dict + "train_metrics": {**self.metrics, **{"fitness": self.fitness}}, + "train_results": self.read_results_csv(), + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + }, + buffer, + ) + serialized_ckpt = buffer.getvalue() # get the serialized content to save + + # Save checkpoints + self.last.write_bytes(serialized_ckpt) # save last.pt + if self.best_fitness == self.fitness: + self.best.write_bytes(serialized_ckpt) # save best.pt + if (self.save_period > 0) and (self.epoch % self.save_period == 0): + (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt' + # if self.args.close_mosaic and self.epoch == (self.epochs - self.args.close_mosaic - 1): + # (self.wdir / "last_mosaic.pt").write_bytes(serialized_ckpt) # save mosaic checkpoint + + def get_dataset(self): + """ + Get train, val path from data dict if it exists. + + Returns None if data format is not recognized. + """ + try: + if self.args.task == "classify": + data = check_cls_dataset(self.args.data) + elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in { + "detect", + "segment", + "pose", + "obb", + }: + data = check_det_dataset(self.args.data) + if "yaml_file" in data: + self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage + except Exception as e: + raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e + self.data = data + return data["train"], data.get("val") or data.get("test") + + def setup_model(self): + """Load/create/download model for any task.""" + if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed + return + + cfg, weights = self.model, None + ckpt = None + if str(self.model).endswith(".pt"): + weights, ckpt = attempt_load_one_weight(self.model) + cfg = weights.yaml + elif isinstance(self.args.pretrained, (str, Path)): + weights, _ = attempt_load_one_weight(self.args.pretrained) + self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights) + return ckpt + + def optimizer_step(self): + """Perform a single step of the training optimizer with gradient clipping and EMA update.""" + self.scaler.unscale_(self.optimizer) # unscale gradients + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + if self.ema: + self.ema.update(self.model) + + def preprocess_batch(self, batch): + """Allows custom preprocessing model inputs and ground truths depending on task type.""" + return batch + + def validate(self): + """ + Runs validation on test set using self.validator. + + The returned dict is expected to contain "fitness" key. + """ + metrics = self.validator(self) + fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found + if not self.best_fitness or self.best_fitness < fitness: + self.best_fitness = fitness + return metrics, fitness + + def get_model(self, cfg=None, weights=None, verbose=True): + """Get model and raise NotImplementedError for loading cfg files.""" + raise NotImplementedError("This task trainer doesn't support loading cfg files") + + def get_validator(self): + """Returns a NotImplementedError when the get_validator function is called.""" + raise NotImplementedError("get_validator function not implemented in trainer") + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Returns dataloader derived from torch.data.Dataloader.""" + raise NotImplementedError("get_dataloader function not implemented in trainer") + + def build_dataset(self, img_path, mode="train", batch=None): + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in trainer") + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Returns a loss dict with labelled training loss items tensor. + + Note: + This is not needed for classification but necessary for segmentation & detection + """ + return {"loss": loss_items} if loss_items is not None else ["loss"] + + def set_model_attributes(self): + """To set or update model parameters before training.""" + self.model.names = self.data["names"] + + def build_targets(self, preds, targets): + """Builds target tensors for training YOLO model.""" + pass + + def progress_string(self): + """Returns a string describing training progress.""" + return "" + + # TODO: may need to put these following functions into callback + def plot_training_samples(self, batch, ni): + """Plots training samples during YOLO training.""" + pass + + def plot_training_labels(self): + """Plots training labels for YOLO model.""" + pass + + def save_metrics(self, metrics): + """Saves training metrics to a CSV file.""" + keys, vals = list(metrics.keys()), list(metrics.values()) + n = len(metrics) + 2 # number of cols + s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header + t = time.time() - self.train_time_start + with open(self.csv, "a") as f: + f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n") + + def plot_metrics(self): + """Plot and display metrics visually.""" + pass + + def on_plot(self, name, data=None): + """Registers plots (e.g. to be consumed in callbacks).""" + path = Path(name) + self.plots[path] = {"data": data, "timestamp": time.time()} + + def final_eval(self): + """Performs final evaluation and validation for object detection YOLO model.""" + ckpt = {} + for f in self.last, self.best: + if f.exists(): + if f is self.last: + ckpt = strip_optimizer(f) + elif f is self.best: + k = "train_results" # update best.pt train_metrics from last.pt + strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None) + LOGGER.info(f"\nValidating {f}...") + self.validator.args.plots = self.args.plots + self.metrics = self.validator(model=f) + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") + + def check_resume(self, overrides): + """Check if resume checkpoint exists and update arguments accordingly.""" + resume = self.args.resume + if resume: + try: + exists = isinstance(resume, (str, Path)) and Path(resume).exists() + last = Path(check_file(resume) if exists else get_latest_run()) + + # Check that resume data YAML exists, otherwise strip to force re-download of dataset + ckpt_args = attempt_load_weights(last).args + if not Path(ckpt_args["data"]).exists(): + ckpt_args["data"] = self.args.data + + resume = True + self.args = get_cfg(ckpt_args) + self.args.model = self.args.resume = str(last) # reinstate model + for k in ( + "imgsz", + "batch", + "device", + "close_mosaic", + ): # allow arg updates to reduce memory or update device on resume + if k in overrides: + setattr(self.args, k, overrides[k]) + + except Exception as e: + raise FileNotFoundError( + "Resume checkpoint not found. Please pass a valid checkpoint to resume from, " + "i.e. 'yolo train resume model=path/to/last.pt'" + ) from e + self.resume = resume + + def resume_training(self, ckpt): + """Resume YOLO training from given epoch and best fitness.""" + if ckpt is None or not self.resume: + return + best_fitness = 0.0 + start_epoch = ckpt.get("epoch", -1) + 1 + if ckpt.get("optimizer", None) is not None: + self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer + best_fitness = ckpt["best_fitness"] + if self.ema and ckpt.get("ema"): + self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA + self.ema.updates = ckpt["updates"] + assert start_epoch > 0, ( + f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n" + f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'" + ) + LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs") + if self.epochs < start_epoch: + LOGGER.info( + f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs." + ) + self.epochs += ckpt["epoch"] # finetune additional epochs + self.best_fitness = best_fitness + self.start_epoch = start_epoch + if start_epoch > (self.epochs - self.args.close_mosaic): + self._close_dataloader_mosaic() + + def _close_dataloader_mosaic(self): + """Update dataloaders to stop using mosaic augmentation.""" + if hasattr(self.train_loader.dataset, "mosaic"): + self.train_loader.dataset.mosaic = False + if hasattr(self.train_loader.dataset, "close_mosaic"): + LOGGER.info("Closing dataloader mosaic") + self.train_loader.dataset.close_mosaic(hyp=copy(self.args)) + + def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5): + """ + Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum, + weight decay, and number of iterations. + + Args: + model (torch.nn.Module): The model for which to build an optimizer. + name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected + based on the number of iterations. Default: 'auto'. + lr (float, optional): The learning rate for the optimizer. Default: 0.001. + momentum (float, optional): The momentum factor for the optimizer. Default: 0.9. + decay (float, optional): The weight decay for the optimizer. Default: 1e-5. + iterations (float, optional): The number of iterations, which determines the optimizer if + name is 'auto'. Default: 1e5. + + Returns: + (torch.optim.Optimizer): The constructed optimizer. + """ + g = [], [], [] # optimizer parameter groups + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() + if name == "auto": + LOGGER.info( + f"{colorstr('optimizer:')} 'optimizer=auto' found, " + f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and " + f"determining best 'optimizer', 'lr0' and 'momentum' automatically... " + ) + nc = getattr(model, "nc", 10) # number of classes + lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places + name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9) + self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam + + for module_name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + fullname = f"{module_name}.{param_name}" if module_name else param_name + if "bias" in fullname: # bias (no decay) + g[2].append(param) + elif isinstance(module, bn): # weight (no decay) + g[1].append(param) + else: # weight (with decay) + g[0].append(param) + + optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"} + name = {x.lower(): x for x in optimizers}.get(name.lower()) + if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}: + optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) + elif name == "RMSProp": + optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum) + elif name == "SGD": + optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) + else: + raise NotImplementedError( + f"Optimizer '{name}' not found in list of available optimizers {optimizers}. " + "Request support for addition optimizers at https://github.com/ultralytics/ultralytics." + ) + + optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay + optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights) + LOGGER.info( + f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups " + f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)" + ) + return optimizer diff --git a/ultralytics/engine/tuner.py b/ultralytics/engine/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..0df109c7575225bb6be00a9bd92773e3f4c1e67c --- /dev/null +++ b/ultralytics/engine/tuner.py @@ -0,0 +1,242 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance +segmentation, image classification, pose estimation, and multi-object tracking. + +Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters +that yield the best model performance. This is particularly crucial in deep learning models like YOLO, +where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency. + +Example: + Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. + ```python + from ultralytics import YOLO + + model = YOLO("yolo11n.pt") + model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False) + ``` +""" + +import random +import shutil +import subprocess +import time + +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, remove_colorstr, yaml_print, yaml_save +from ultralytics.utils.plotting import plot_tune_results + + +class Tuner: + """ + Class responsible for hyperparameter tuning of YOLO models. + + The class evolves YOLO model hyperparameters over a given number of iterations + by mutating them according to the search space and retraining the model to evaluate their performance. + + Attributes: + space (dict): Hyperparameter search space containing bounds and scaling factors for mutation. + tune_dir (Path): Directory where evolution logs and results will be saved. + tune_csv (Path): Path to the CSV file where evolution logs are saved. + + Methods: + _mutate(hyp: dict) -> dict: + Mutates the given hyperparameters within the bounds specified in `self.space`. + + __call__(): + Executes the hyperparameter evolution across multiple iterations. + + Example: + Tune hyperparameters for YOLOv8n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations. + ```python + from ultralytics import YOLO + + model = YOLO("yolo11n.pt") + model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False) + ``` + + Tune with custom search space. + ```python + from ultralytics import YOLO + + model = YOLO("yolo11n.pt") + model.tune(space={key1: val1, key2: val2}) # custom search space dictionary + ``` + """ + + def __init__(self, args=DEFAULT_CFG, _callbacks=None): + """ + Initialize the Tuner with configurations. + + Args: + args (dict, optional): Configuration for hyperparameter evolution. + """ + self.space = args.pop("space", None) or { # key: (min, max, gain(optional)) + # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), + "lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3) + "lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1 + "weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4 + "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": (0.0, 0.95), # warmup initial momentum + "box": (1.0, 20.0), # box loss gain + "cls": (0.2, 4.0), # cls loss gain (scale with pixels) + "dfl": (0.4, 6.0), # dfl loss gain + "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": (0.0, 45.0), # image rotation (+/- deg) + "translate": (0.0, 0.9), # image translation (+/- fraction) + "scale": (0.0, 0.95), # image scale (+/- gain) + "shear": (0.0, 10.0), # image shear (+/- deg) + "perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": (0.0, 1.0), # image flip up-down (probability) + "fliplr": (0.0, 1.0), # image flip left-right (probability) + "bgr": (0.0, 1.0), # image channel bgr (probability) + "mosaic": (0.0, 1.0), # image mixup (probability) + "mixup": (0.0, 1.0), # image mixup (probability) + "copy_paste": (0.0, 1.0), # segment copy-paste (probability) + } + self.args = get_cfg(overrides=args) + self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune") + self.args.name = None # reset to not affect training directory + self.tune_csv = self.tune_dir / "tune_results.csv" + self.callbacks = _callbacks or callbacks.get_default_callbacks() + self.prefix = colorstr("Tuner: ") + callbacks.add_integration_callbacks(self) + LOGGER.info( + f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n" + f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning" + ) + + def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2): + """ + Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`. + + Args: + parent (str): Parent selection method: 'single' or 'weighted'. + n (int): Number of parents to consider. + mutation (float): Probability of a parameter mutation in any given iteration. + sigma (float): Standard deviation for Gaussian random number generator. + + Returns: + (dict): A dictionary containing mutated hyperparameters. + """ + if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate + # Select parent(s) + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) + fitness = x[:, 0] # first column + n = min(n, len(x)) # number of previous results to consider + x = x[np.argsort(-fitness)][:n] # top n mutations + w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0) + if parent == "single" or len(x) == 1: + # x = x[random.randint(0, n - 1)] # random selection + x = x[random.choices(range(n), weights=w)[0]] # weighted selection + elif parent == "weighted": + x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination + + # Mutate + r = np.random # method + r.seed(int(time.time())) + g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1 + ng = len(self.space) + v = np.ones(ng) + while all(v == 1): # mutate until a change occurs (prevent duplicates) + v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0) + hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())} + else: + hyp = {k: getattr(self.args, k) for k in self.space.keys()} + + # Constrain to limits + for k, v in self.space.items(): + hyp[k] = max(hyp[k], v[0]) # lower limit + hyp[k] = min(hyp[k], v[1]) # upper limit + hyp[k] = round(hyp[k], 5) # significant digits + + return hyp + + def __call__(self, model=None, iterations=10, cleanup=True): + """ + Executes the hyperparameter evolution process when the Tuner instance is called. + + This method iterates through the number of iterations, performing the following steps in each iteration: + 1. Load the existing hyperparameters or initialize new ones. + 2. Mutate the hyperparameters using the `mutate` method. + 3. Train a YOLO model with the mutated hyperparameters. + 4. Log the fitness score and mutated hyperparameters to a CSV file. + + Args: + model (Model): A pre-initialized YOLO model to be used for training. + iterations (int): The number of generations to run the evolution for. + cleanup (bool): Whether to delete iteration weights to reduce storage space used during tuning. + + Note: + The method utilizes the `self.tune_csv` Path object to read and log hyperparameters and fitness scores. + Ensure this path is set correctly in the Tuner instance. + """ + t0 = time.time() + best_save_dir, best_metrics = None, None + (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True) + for i in range(iterations): + # Mutate hyperparameters + mutated_hyp = self._mutate() + LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}") + + metrics = {} + train_args = {**vars(self.args), **mutated_hyp} + save_dir = get_save_dir(get_cfg(train_args)) + weights_dir = save_dir / "weights" + try: + # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang) + cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())] + return_code = subprocess.run(" ".join(cmd), check=True, shell=True).returncode + ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt") + metrics = torch.load(ckpt_file)["train_metrics"] + assert return_code == 0, "training failed" + + except Exception as e: + LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}") + + # Save results and mutated_hyp to CSV + fitness = metrics.get("fitness", 0.0) + log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()] + headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n") + with open(self.tune_csv, "a") as f: + f.write(headers + ",".join(map(str, log_row)) + "\n") + + # Get best results + x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1) + fitness = x[:, 0] # first column + best_idx = fitness.argmax() + best_is_current = best_idx == i + if best_is_current: + best_save_dir = save_dir + best_metrics = {k: round(v, 5) for k, v in metrics.items()} + for ckpt in weights_dir.glob("*.pt"): + shutil.copy2(ckpt, self.tune_dir / "weights") + elif cleanup: + shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space + + # Plot tune results + plot_tune_results(self.tune_csv) + + # Save and print tune results + header = ( + f"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n" + f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n" + f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n" + f"{self.prefix}Best fitness metrics are {best_metrics}\n" + f"{self.prefix}Best fitness model is {best_save_dir}\n" + f"{self.prefix}Best fitness hyperparameters are printed below.\n" + ) + LOGGER.info("\n" + header) + data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())} + yaml_save( + self.tune_dir / "best_hyperparameters.yaml", + data=data, + header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n", + ) + yaml_print(self.tune_dir / "best_hyperparameters.yaml") diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..6dc8026f28176bbe1b3b880146f21e479fc7eee5 --- /dev/null +++ b/ultralytics/engine/validator.py @@ -0,0 +1,341 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Check a model's accuracy on a test or val split of a dataset. + +Usage: + $ yolo mode=val model=yolov8n.pt data=coco8.yaml imgsz=640 + +Usage - formats: + $ yolo mode=val model=yolov8n.pt # PyTorch + yolov8n.torchscript # TorchScript + yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True + yolov8n_openvino_model # OpenVINO + yolov8n.engine # TensorRT + yolov8n.mlpackage # CoreML (macOS-only) + yolov8n_saved_model # TensorFlow SavedModel + yolov8n.pb # TensorFlow GraphDef + yolov8n.tflite # TensorFlow Lite + yolov8n_edgetpu.tflite # TensorFlow Edge TPU + yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN + yolov8n_ncnn_model # NCNN +""" + +import json +import time +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.cfg import get_cfg, get_save_dir +from ultralytics.data.utils import check_cls_dataset, check_det_dataset +from ultralytics.nn.autobackend import AutoBackend +from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis +from ultralytics.utils.checks import check_imgsz +from ultralytics.utils.ops import Profile +from ultralytics.utils.torch_utils import de_parallel, select_device, smart_inference_mode + + +class BaseValidator: + """ + BaseValidator. + + A base class for creating validators. + + Attributes: + args (SimpleNamespace): Configuration for the validator. + dataloader (DataLoader): Dataloader to use for validation. + pbar (tqdm): Progress bar to update during validation. + model (nn.Module): Model to validate. + data (dict): Data dictionary. + device (torch.device): Device to use for validation. + batch_i (int): Current batch index. + training (bool): Whether the model is in training mode. + names (dict): Class names. + seen: Records the number of images seen so far during validation. + stats: Placeholder for statistics during validation. + confusion_matrix: Placeholder for a confusion matrix. + nc: Number of classes. + iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05. + jdict (dict): Dictionary to store JSON validation results. + speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective + batch processing times in milliseconds. + save_dir (Path): Directory to save results. + plots (dict): Dictionary to store plots for visualization. + callbacks (dict): Dictionary to store various callback functions. + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """ + Initializes a BaseValidator instance. + + Args: + dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. + save_dir (Path, optional): Directory to save results. + pbar (tqdm.tqdm): Progress bar for displaying progress. + args (SimpleNamespace): Configuration for the validator. + _callbacks (dict): Dictionary to store various callback functions. + """ + self.args = get_cfg(overrides=args) + self.dataloader = dataloader + self.pbar = pbar + self.stride = None + self.data = None + self.device = None + self.batch_i = None + self.training = True + self.names = None + self.seen = None + self.stats = None + self.confusion_matrix = None + self.nc = None + self.iouv = None + self.jdict = None + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + + self.save_dir = save_dir or get_save_dir(self.args) + (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) + if self.args.conf is None: + self.args.conf = 0.001 # default conf=0.001 + self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1) + + self.plots = {} + self.callbacks = _callbacks or callbacks.get_default_callbacks() + + @smart_inference_mode() + def __call__(self, trainer=None, model=None): + """Executes validation process, running inference on dataloader and computing performance metrics.""" + self.training = trainer is not None + augment = self.args.augment and (not self.training) + if self.training: + self.device = trainer.device + self.data = trainer.data + # force FP16 val during training + self.args.half = self.device.type != "cpu" and trainer.amp + model = trainer.ema.ema or trainer.model + model = model.half() if self.args.half else model.float() + # self.model = model + self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device) + self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1) + model.eval() + else: + if str(self.args.model).endswith(".yaml") and model is None: + LOGGER.warning("WARNING ⚠️ validating an untrained model YAML will result in 0 mAP.") + callbacks.add_integration_callbacks(self) + model = AutoBackend( + weights=model or self.args.model, + device=select_device(self.args.device, self.args.batch), + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + ) + # self.model = model + self.device = model.device # update device + self.args.half = model.fp16 # update half + stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine + imgsz = check_imgsz(self.args.imgsz, stride=stride) + if engine: + self.args.batch = model.batch_size + elif not pt and not jit: + self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1 + LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})") + + if str(self.args.data).split(".")[-1] in {"yaml", "yml"}: + self.data = check_det_dataset(self.args.data) + elif self.args.task == "classify": + self.data = check_cls_dataset(self.args.data, split=self.args.split) + else: + raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌")) + + if self.device.type in {"cpu", "mps"}: + self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading + if not pt: + self.args.rect = False + self.stride = model.stride # used in get_dataloader() for padding + self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch) + + model.eval() + model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup + + self.run_callbacks("on_val_start") + dt = ( + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + Profile(device=self.device), + ) + bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader)) + self.init_metrics(de_parallel(model)) + self.jdict = [] # empty before each val + for batch_i, batch in enumerate(bar): + self.run_callbacks("on_val_batch_start") + self.batch_i = batch_i + # Preprocess + with dt[0]: + batch = self.preprocess(batch) + + # Inference + with dt[1]: + preds = model(batch["img"], augment=augment) + + # Loss + with dt[2]: + if self.training: + self.loss += model.loss(batch, preds)[1] + + # Postprocess + with dt[3]: + preds = self.postprocess(preds) + + self.update_metrics(preds, batch) + if self.args.plots and batch_i < 3: + self.plot_val_samples(batch, batch_i) + self.plot_predictions(batch, preds, batch_i) + + self.run_callbacks("on_val_batch_end") + stats = self.get_stats() + self.check_stats(stats) + self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt))) + self.finalize_metrics() + self.print_results() + self.run_callbacks("on_val_end") + if self.training: + model.float() + results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")} + return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats + else: + LOGGER.info( + "Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format( + *tuple(self.speed.values()) + ) + ) + if self.args.save_json and self.jdict: + with open(str(self.save_dir / "predictions.json"), "w") as f: + LOGGER.info(f"Saving {f.name}...") + json.dump(self.jdict, f) # flatten and save + stats = self.eval_json(stats) # update stats + if self.args.plots or self.args.save_json: + LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") + return stats + + def match_predictions(self, pred_classes, true_classes, iou, use_scipy=False): + """ + Matches predictions to ground truth objects (pred_classes, true_classes) using IoU. + + Args: + pred_classes (torch.Tensor): Predicted class indices of shape(N,). + true_classes (torch.Tensor): Target class indices of shape(M,). + iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground of truth + use_scipy (bool): Whether to use scipy for matching (more precise). + + Returns: + (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. + """ + # Dx10 matrix, where D - detections, 10 - IoU thresholds + correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) + # LxD matrix where L - labels (rows), D - detections (columns) + correct_class = true_classes[:, None] == pred_classes + iou = iou * correct_class # zero out the wrong classes + iou = iou.cpu().numpy() + for i, threshold in enumerate(self.iouv.cpu().tolist()): + if use_scipy: + # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708 + import scipy # scope import to avoid importing for all commands + + cost_matrix = iou * (iou >= threshold) + if cost_matrix.any(): + labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix) + valid = cost_matrix[labels_idx, detections_idx] > 0 + if valid.any(): + correct[detections_idx[valid], i] = True + else: + matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match + matches = np.array(matches).T + if matches.shape[0]: + if matches.shape[0] > 1: + matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device) + + def add_callback(self, event: str, callback): + """Appends the given callback.""" + self.callbacks[event].append(callback) + + def run_callbacks(self, event: str): + """Runs all callbacks associated with a specified event.""" + for callback in self.callbacks.get(event, []): + callback(self) + + def get_dataloader(self, dataset_path, batch_size): + """Get data loader from dataset path and batch size.""" + raise NotImplementedError("get_dataloader function not implemented for this validator") + + def build_dataset(self, img_path): + """Build dataset.""" + raise NotImplementedError("build_dataset function not implemented in validator") + + def preprocess(self, batch): + """Preprocesses an input batch.""" + return batch + + def postprocess(self, preds): + """Preprocesses the predictions.""" + return preds + + def init_metrics(self, model): + """Initialize performance metrics for the YOLO model.""" + pass + + def update_metrics(self, preds, batch): + """Updates metrics based on predictions and batch.""" + pass + + def finalize_metrics(self, *args, **kwargs): + """Finalizes and returns all metrics.""" + pass + + def get_stats(self): + """Returns statistics about the model's performance.""" + return {} + + def check_stats(self, stats): + """Checks statistics.""" + pass + + def print_results(self): + """Prints the results of the model's predictions.""" + pass + + def get_desc(self): + """Get description of the YOLO model.""" + pass + + @property + def metric_keys(self): + """Returns the metric keys used in YOLO training/validation.""" + return [] + + def on_plot(self, name, data=None): + """Registers plots (e.g. to be consumed in callbacks).""" + self.plots[Path(name)] = {"data": data, "timestamp": time.time()} + + # TODO: may need to put these following functions into callback + def plot_val_samples(self, batch, ni): + """Plots validation samples during training.""" + pass + + def plot_predictions(self, batch, preds, ni): + """Plots YOLO model predictions on batch images.""" + pass + + def pred_to_json(self, preds, batch): + """Convert predictions to JSON format.""" + pass + + def eval_json(self, stats): + """Evaluate and return JSON format of prediction statistics.""" + pass diff --git a/ultralytics/hub/__init__.py b/ultralytics/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..74c0dfeda7b03f6e7ad4f1c65d942fceae8745c0 --- /dev/null +++ b/ultralytics/hub/__init__.py @@ -0,0 +1,146 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import requests + +from ultralytics.data.utils import HUBDatasetStats +from ultralytics.hub.auth import Auth +from ultralytics.hub.session import HUBTrainingSession +from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, events +from ultralytics.utils import LOGGER, SETTINGS, checks + +__all__ = ( + "PREFIX", + "HUB_WEB_ROOT", + "HUBTrainingSession", + "login", + "logout", + "reset_model", + "export_fmts_hub", + "export_model", + "get_export", + "check_dataset", + "events", +) + + +def login(api_key: str = None, save=True) -> bool: + """ + Log in to the Ultralytics HUB API using the provided API key. + + The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY + environment variable if successfully authenticated. + + Args: + api_key (str, optional): API key to use for authentication. + If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable. + save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful. + + Returns: + (bool): True if authentication is successful, False otherwise. + """ + checks.check_requirements("hub-sdk>=0.0.12") + from hub_sdk import HUBClient + + api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL + saved_key = SETTINGS.get("api_key") + active_key = api_key or saved_key + credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials + + client = HUBClient(credentials) # initialize HUBClient + + if client.authenticated: + # Successfully authenticated with HUB + + if save and client.api_key != saved_key: + SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key + + # Set message based on whether key was provided or retrieved from settings + log_message = ( + "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅" + ) + LOGGER.info(f"{PREFIX}{log_message}") + + return True + else: + # Failed to authenticate with HUB + LOGGER.info(f"{PREFIX}Get API key from {api_key_url} and then run 'yolo login API_KEY'") + return False + + +def logout(): + """ + Log out of Ultralytics HUB by removing the API key from the settings file. To log in again, use 'yolo login'. + + Example: + ```python + from ultralytics import hub + + hub.logout() + ``` + """ + SETTINGS["api_key"] = "" + LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo login'.") + + +def reset_model(model_id=""): + """Reset a trained model to an untrained state.""" + r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key}) + if r.status_code == 200: + LOGGER.info(f"{PREFIX}Model reset successfully") + return + LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}") + + +def export_fmts_hub(): + """Returns a list of HUB-supported export formats.""" + from ultralytics.engine.exporter import export_formats + + return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"] + + +def export_model(model_id="", format="torchscript"): + """Export a model to all formats.""" + assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" + r = requests.post( + f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key} + ) + assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}" + LOGGER.info(f"{PREFIX}{format} export started ✅") + + +def get_export(model_id="", format="torchscript"): + """Get an exported model dictionary with download URL.""" + assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}" + r = requests.post( + f"{HUB_API_ROOT}/get-export", + json={"apiKey": Auth().api_key, "modelId": model_id, "format": format}, + headers={"x-api-key": Auth().api_key}, + ) + assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}" + return r.json() + + +def check_dataset(path: str, task: str) -> None: + """ + Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded + to the HUB. Usage examples are given below. + + Args: + path (str): Path to data.zip (with data.yaml inside data.zip). + task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify', 'obb'. + + Example: + Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets + i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip. + ```python + from ultralytics.hub import check_dataset + + check_dataset("path/to/coco8.zip", task="detect") # detect dataset + check_dataset("path/to/coco8-seg.zip", task="segment") # segment dataset + check_dataset("path/to/coco8-pose.zip", task="pose") # pose dataset + check_dataset("path/to/dota8.zip", task="obb") # OBB dataset + check_dataset("path/to/imagenet10.zip", task="classify") # classification dataset + ``` + """ + HUBDatasetStats(path=path, task=task).get_json() + LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.") diff --git a/ultralytics/hub/auth.py b/ultralytics/hub/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..2e62739f31c75cb9b9e33807ed634d71e39e8a77 --- /dev/null +++ b/ultralytics/hub/auth.py @@ -0,0 +1,140 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import requests + +from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials +from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, emojis + +API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys" + + +class Auth: + """ + Manages authentication processes including API key handling, cookie-based authentication, and header generation. + + The class supports different methods of authentication: + 1. Directly using an API key. + 2. Authenticating using browser cookies (specifically in Google Colab). + 3. Prompting the user to enter an API key. + + Attributes: + id_token (str or bool): Token used for identity verification, initialized as False. + api_key (str or bool): API key for authentication, initialized as False. + model_key (bool): Placeholder for model key, initialized as False. + """ + + id_token = api_key = model_key = False + + def __init__(self, api_key="", verbose=False): + """ + Initialize Auth class and authenticate user. + + Handles API key validation, Google Colab authentication, and new key requests. Updates SETTINGS upon successful + authentication. + + Args: + api_key (str): API key or combined key_id format. + verbose (bool): Enable verbose logging. + """ + # Split the input API key in case it contains a combined key_model and keep only the API key part + api_key = api_key.split("_")[0] + + # Set API key attribute as value passed or SETTINGS API key if none passed + self.api_key = api_key or SETTINGS.get("api_key", "") + + # If an API key is provided + if self.api_key: + # If the provided API key matches the API key in the SETTINGS + if self.api_key == SETTINGS.get("api_key"): + # Log that the user is already logged in + if verbose: + LOGGER.info(f"{PREFIX}Authenticated ✅") + return + else: + # Attempt to authenticate with the provided API key + success = self.authenticate() + # If the API key is not provided and the environment is a Google Colab notebook + elif IS_COLAB: + # Attempt to authenticate using browser cookies + success = self.auth_with_cookies() + else: + # Request an API key + success = self.request_api_key() + + # Update SETTINGS with the new API key after successful authentication + if success: + SETTINGS.update({"api_key": self.api_key}) + # Log that the new login was successful + if verbose: + LOGGER.info(f"{PREFIX}New authentication successful ✅") + elif verbose: + LOGGER.info(f"{PREFIX}Get API key from {API_KEY_URL} and then run 'yolo login API_KEY'") + + def request_api_key(self, max_attempts=3): + """ + Prompt the user to input their API key. + + Returns the model ID. + """ + import getpass + + for attempts in range(max_attempts): + LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}") + input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ") + self.api_key = input_key.split("_")[0] # remove model id if present + if self.authenticate(): + return True + raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌")) + + def authenticate(self) -> bool: + """ + Attempt to authenticate with the server using either id_token or API key. + + Returns: + (bool): True if authentication is successful, False otherwise. + """ + try: + if header := self.get_auth_header(): + r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header) + if not r.json().get("success", False): + raise ConnectionError("Unable to authenticate.") + return True + raise ConnectionError("User has not authenticated locally.") + except ConnectionError: + self.id_token = self.api_key = False # reset invalid + LOGGER.warning(f"{PREFIX}Invalid API key ⚠️") + return False + + def auth_with_cookies(self) -> bool: + """ + Attempt to fetch authentication via cookies and set id_token. User must be logged in to HUB and running in a + supported browser. + + Returns: + (bool): True if authentication is successful, False otherwise. + """ + if not IS_COLAB: + return False # Currently only works with Colab + try: + authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto") + if authn.get("success", False): + self.id_token = authn.get("data", {}).get("idToken", None) + self.authenticate() + return True + raise ConnectionError("Unable to fetch browser authentication details.") + except ConnectionError: + self.id_token = False # reset invalid + return False + + def get_auth_header(self): + """ + Get the authentication header for making API requests. + + Returns: + (dict): The authentication header if id_token or API key is set, None otherwise. + """ + if self.id_token: + return {"authorization": f"Bearer {self.id_token}"} + elif self.api_key: + return {"x-api-key": self.api_key} + # else returns None diff --git a/ultralytics/hub/google/__init__.py b/ultralytics/hub/google/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0acd2dd26da7f5ff583a4ec694a2b0d4030ae9c8 --- /dev/null +++ b/ultralytics/hub/google/__init__.py @@ -0,0 +1,159 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import concurrent.futures +import statistics +import time +from typing import List, Optional, Tuple + +import requests + + +class GCPRegions: + """ + A class for managing and analyzing Google Cloud Platform (GCP) regions. + + This class provides functionality to initialize, categorize, and analyze GCP regions based on their + geographical location, tier classification, and network latency. + + Attributes: + regions (Dict[str, Tuple[int, str, str]]): A dictionary of GCP regions with their tier, city, and country. + + Methods: + tier1: Returns a list of tier 1 GCP regions. + tier2: Returns a list of tier 2 GCP regions. + lowest_latency: Determines the GCP region(s) with the lowest network latency. + + Examples: + >>> from ultralytics.hub.google import GCPRegions + >>> regions = GCPRegions() + >>> lowest_latency_region = regions.lowest_latency(verbose=True, attempts=3) + >>> print(f"Lowest latency region: {lowest_latency_region[0][0]}") + """ + + def __init__(self): + """Initializes the GCPRegions class with predefined Google Cloud Platform regions and their details.""" + self.regions = { + "asia-east1": (1, "Taiwan", "China"), + "asia-east2": (2, "Hong Kong", "China"), + "asia-northeast1": (1, "Tokyo", "Japan"), + "asia-northeast2": (1, "Osaka", "Japan"), + "asia-northeast3": (2, "Seoul", "South Korea"), + "asia-south1": (2, "Mumbai", "India"), + "asia-south2": (2, "Delhi", "India"), + "asia-southeast1": (2, "Jurong West", "Singapore"), + "asia-southeast2": (2, "Jakarta", "Indonesia"), + "australia-southeast1": (2, "Sydney", "Australia"), + "australia-southeast2": (2, "Melbourne", "Australia"), + "europe-central2": (2, "Warsaw", "Poland"), + "europe-north1": (1, "Hamina", "Finland"), + "europe-southwest1": (1, "Madrid", "Spain"), + "europe-west1": (1, "St. Ghislain", "Belgium"), + "europe-west10": (2, "Berlin", "Germany"), + "europe-west12": (2, "Turin", "Italy"), + "europe-west2": (2, "London", "United Kingdom"), + "europe-west3": (2, "Frankfurt", "Germany"), + "europe-west4": (1, "Eemshaven", "Netherlands"), + "europe-west6": (2, "Zurich", "Switzerland"), + "europe-west8": (1, "Milan", "Italy"), + "europe-west9": (1, "Paris", "France"), + "me-central1": (2, "Doha", "Qatar"), + "me-west1": (1, "Tel Aviv", "Israel"), + "northamerica-northeast1": (2, "Montreal", "Canada"), + "northamerica-northeast2": (2, "Toronto", "Canada"), + "southamerica-east1": (2, "São Paulo", "Brazil"), + "southamerica-west1": (2, "Santiago", "Chile"), + "us-central1": (1, "Iowa", "United States"), + "us-east1": (1, "South Carolina", "United States"), + "us-east4": (1, "Northern Virginia", "United States"), + "us-east5": (1, "Columbus", "United States"), + "us-south1": (1, "Dallas", "United States"), + "us-west1": (1, "Oregon", "United States"), + "us-west2": (2, "Los Angeles", "United States"), + "us-west3": (2, "Salt Lake City", "United States"), + "us-west4": (2, "Las Vegas", "United States"), + } + + def tier1(self) -> List[str]: + """Returns a list of GCP regions classified as tier 1 based on predefined criteria.""" + return [region for region, info in self.regions.items() if info[0] == 1] + + def tier2(self) -> List[str]: + """Returns a list of GCP regions classified as tier 2 based on predefined criteria.""" + return [region for region, info in self.regions.items() if info[0] == 2] + + @staticmethod + def _ping_region(region: str, attempts: int = 1) -> Tuple[str, float, float, float, float]: + """Pings a specified GCP region and returns latency statistics: mean, min, max, and standard deviation.""" + url = f"https://{region}-docker.pkg.dev" + latencies = [] + for _ in range(attempts): + try: + start_time = time.time() + _ = requests.head(url, timeout=5) + latency = (time.time() - start_time) * 1000 # convert latency to milliseconds + if latency != float("inf"): + latencies.append(latency) + except requests.RequestException: + pass + if not latencies: + return region, float("inf"), float("inf"), float("inf"), float("inf") + + std_dev = statistics.stdev(latencies) if len(latencies) > 1 else 0 + return region, statistics.mean(latencies), std_dev, min(latencies), max(latencies) + + def lowest_latency( + self, + top: int = 1, + verbose: bool = False, + tier: Optional[int] = None, + attempts: int = 1, + ) -> List[Tuple[str, float, float, float, float]]: + """ + Determines the GCP regions with the lowest latency based on ping tests. + + Args: + top (int): Number of top regions to return. + verbose (bool): If True, prints detailed latency information for all tested regions. + tier (int | None): Filter regions by tier (1 or 2). If None, all regions are tested. + attempts (int): Number of ping attempts per region. + + Returns: + (List[Tuple[str, float, float, float, float]]): List of tuples containing region information and + latency statistics. Each tuple contains (region, mean_latency, std_dev, min_latency, max_latency). + + Examples: + >>> regions = GCPRegions() + >>> results = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=2) + >>> print(results[0][0]) # Print the name of the lowest latency region + """ + if verbose: + print(f"Testing GCP regions for latency (with {attempts} {'retry' if attempts == 1 else 'attempts'})...") + + regions_to_test = [k for k, v in self.regions.items() if v[0] == tier] if tier else list(self.regions.keys()) + with concurrent.futures.ThreadPoolExecutor(max_workers=50) as executor: + results = list(executor.map(lambda r: self._ping_region(r, attempts), regions_to_test)) + + sorted_results = sorted(results, key=lambda x: x[1]) + + if verbose: + print(f"{'Region':<25} {'Location':<35} {'Tier':<5} Latency (ms)") + for region, mean, std, min_, max_ in sorted_results: + tier, city, country = self.regions[region] + location = f"{city}, {country}" + if mean == float("inf"): + print(f"{region:<25} {location:<35} {tier:<5} Timeout") + else: + print(f"{region:<25} {location:<35} {tier:<5} {mean:.0f} ± {std:.0f} ({min_:.0f} - {max_:.0f})") + print(f"\nLowest latency region{'s' if top > 1 else ''}:") + for region, mean, std, min_, max_ in sorted_results[:top]: + tier, city, country = self.regions[region] + location = f"{city}, {country}" + print(f"{region} ({location}, {mean:.0f} ± {std:.0f} ms ({min_:.0f} - {max_:.0f}))") + + return sorted_results[:top] + + +# Usage example +if __name__ == "__main__": + regions = GCPRegions() + top_3_latency_tier1 = regions.lowest_latency(top=3, verbose=True, tier=1, attempts=3) diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py new file mode 100644 index 0000000000000000000000000000000000000000..37fba131359ede53a7fb212920530c5e2fd92cba --- /dev/null +++ b/ultralytics/hub/session.py @@ -0,0 +1,390 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import shutil +import threading +import time +from http import HTTPStatus +from pathlib import Path +from urllib.parse import parse_qs, urlparse + +import requests + +from ultralytics.hub.utils import HELP_MSG, HUB_WEB_ROOT, PREFIX, TQDM +from ultralytics.utils import IS_COLAB, LOGGER, SETTINGS, __version__, checks, emojis +from ultralytics.utils.errors import HUBModelError + +AGENT_NAME = f"python-{__version__}-colab" if IS_COLAB else f"python-{__version__}-local" + + +class HUBTrainingSession: + """ + HUB training session for Ultralytics HUB YOLO models. Handles model initialization, heartbeats, and checkpointing. + + Attributes: + model_id (str): Identifier for the YOLO model being trained. + model_url (str): URL for the model in Ultralytics HUB. + rate_limits (dict): Rate limits for different API calls (in seconds). + timers (dict): Timers for rate limiting. + metrics_queue (dict): Queue for the model's metrics. + model (dict): Model data fetched from Ultralytics HUB. + """ + + def __init__(self, identifier): + """ + Initialize the HUBTrainingSession with the provided model identifier. + + Args: + identifier (str): Model identifier used to initialize the HUB training session. + It can be a URL string or a model key with specific format. + + Raises: + ValueError: If the provided model identifier is invalid. + ConnectionError: If connecting with global API key is not supported. + ModuleNotFoundError: If hub-sdk package is not installed. + """ + from hub_sdk import HUBClient + + self.rate_limits = {"metrics": 3, "ckpt": 900, "heartbeat": 300} # rate limits (seconds) + self.metrics_queue = {} # holds metrics for each epoch until upload + self.metrics_upload_failed_queue = {} # holds metrics for each epoch if upload failed + self.timers = {} # holds timers in ultralytics/utils/callbacks/hub.py + self.model = None + self.model_url = None + self.model_file = None + self.train_args = None + + # Parse input + api_key, model_id, self.filename = self._parse_identifier(identifier) + + # Get credentials + active_key = api_key or SETTINGS.get("api_key") + credentials = {"api_key": active_key} if active_key else None # set credentials + + # Initialize client + self.client = HUBClient(credentials) + + # Load models + try: + if model_id: + self.load_model(model_id) # load existing model + else: + self.model = self.client.model() # load empty model + except Exception: + if identifier.startswith(f"{HUB_WEB_ROOT}/models/") and not self.client.authenticated: + LOGGER.warning( + f"{PREFIX}WARNING ⚠️ Please log in using 'yolo login API_KEY'. " + "You can find your API Key at: https://hub.ultralytics.com/settings?tab=api+keys." + ) + + @classmethod + def create_session(cls, identifier, args=None): + """Class method to create an authenticated HUBTrainingSession or return None.""" + try: + session = cls(identifier) + if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL + session.create_model(args) + assert session.model.id, "HUB model not loaded correctly" + return session + # PermissionError and ModuleNotFoundError indicate hub-sdk not installed + except (PermissionError, ModuleNotFoundError, AssertionError): + return None + + def load_model(self, model_id): + """Loads an existing model from Ultralytics HUB using the provided model identifier.""" + self.model = self.client.model(model_id) + if not self.model.data: # then model does not exist + raise ValueError(emojis("❌ The specified HUB model does not exist")) # TODO: improve error handling + + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" + if self.model.is_trained(): + print(emojis(f"Loading trained HUB model {self.model_url} 🚀")) + url = self.model.get_weights_url("best") # download URL with auth + self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id) + return + + # Set training args and start heartbeats for HUB to monitor agent + self._set_train_args() + self.model.start_heartbeat(self.rate_limits["heartbeat"]) + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") + + def create_model(self, model_args): + """Initializes a HUB training session with the specified model identifier.""" + payload = { + "config": { + "batchSize": model_args.get("batch", -1), + "epochs": model_args.get("epochs", 300), + "imageSize": model_args.get("imgsz", 640), + "patience": model_args.get("patience", 100), + "device": str(model_args.get("device", "")), # convert None to string + "cache": str(model_args.get("cache", "ram")), # convert True, False, None to string + }, + "dataset": {"name": model_args.get("data")}, + "lineage": { + "architecture": {"name": self.filename.replace(".pt", "").replace(".yaml", "")}, + "parent": {}, + }, + "meta": {"name": self.filename}, + } + + if self.filename.endswith(".pt"): + payload["lineage"]["parent"]["name"] = self.filename + + self.model.create_model(payload) + + # Model could not be created + # TODO: improve error handling + if not self.model.id: + return None + + self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}" + + # Start heartbeats for HUB to monitor agent + self.model.start_heartbeat(self.rate_limits["heartbeat"]) + + LOGGER.info(f"{PREFIX}View model at {self.model_url} 🚀") + + @staticmethod + def _parse_identifier(identifier): + """ + Parses the given identifier to determine the type of identifier and extract relevant components. + + The method supports different identifier formats: + - A HUB model URL https://hub.ultralytics.com/models/MODEL + - A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY + - A local filename that ends with '.pt' or '.yaml' + + Args: + identifier (str): The identifier string to be parsed. + + Returns: + (tuple): A tuple containing the API key, model ID, and filename as applicable. + + Raises: + HUBModelError: If the identifier format is not recognized. + """ + api_key, model_id, filename = None, None, None + if Path(identifier).suffix in {".pt", ".yaml"}: + filename = identifier + elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"): + parsed_url = urlparse(identifier) + model_id = Path(parsed_url.path).stem # handle possible final backslash robustly + query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]} + api_key = query_params.get("api_key", [None])[0] + else: + raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID") + return api_key, model_id, filename + + def _set_train_args(self): + """ + Initializes training arguments and creates a model entry on the Ultralytics HUB. + + This method sets up training arguments based on the model's state and updates them with any additional + arguments provided. It handles different states of the model, such as whether it's resumable, pretrained, + or requires specific file setup. + + Raises: + ValueError: If the model is already trained, if required dataset information is missing, or if there are + issues with the provided training arguments. + """ + if self.model.is_resumable(): + # Model has saved weights + self.train_args = {"data": self.model.get_dataset_url(), "resume": True} + self.model_file = self.model.get_weights_url("last") + else: + # Model has no saved weights + self.train_args = self.model.data.get("train_args") # new response + + # Set the model file as either a *.pt or *.yaml file + self.model_file = ( + self.model.get_weights_url("parent") if self.model.is_pretrained() else self.model.get_architecture() + ) + + if "data" not in self.train_args: + # RF bug - datasets are sometimes not exported + raise ValueError("Dataset may still be processing. Please wait a minute and try again.") + + self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u + self.model_id = self.model.id + + def request_queue( + self, + request_func, + retry=3, + timeout=30, + thread=True, + verbose=True, + progress_total=None, + stream_response=None, + *args, + **kwargs, + ): + """Attempts to execute `request_func` with retries, timeout handling, optional threading, and progress.""" + + def retry_request(): + """Attempts to call `request_func` with retries, timeout, and optional threading.""" + t0 = time.time() # Record the start time for the timeout + response = None + for i in range(retry + 1): + if (time.time() - t0) > timeout: + LOGGER.warning(f"{PREFIX}Timeout for request reached. {HELP_MSG}") + break # Timeout reached, exit loop + + response = request_func(*args, **kwargs) + if response is None: + LOGGER.warning(f"{PREFIX}Received no response from the request. {HELP_MSG}") + time.sleep(2**i) # Exponential backoff before retrying + continue # Skip further processing and retry + + if progress_total: + self._show_upload_progress(progress_total, response) + elif stream_response: + self._iterate_content(response) + + if HTTPStatus.OK <= response.status_code < HTTPStatus.MULTIPLE_CHOICES: + # if request related to metrics upload + if kwargs.get("metrics"): + self.metrics_upload_failed_queue = {} + return response # Success, no need to retry + + if i == 0: + # Initial attempt, check status code and provide messages + message = self._get_failure_message(response, retry, timeout) + + if verbose: + LOGGER.warning(f"{PREFIX}{message} {HELP_MSG} ({response.status_code})") + + if not self._should_retry(response.status_code): + LOGGER.warning(f"{PREFIX}Request failed. {HELP_MSG} ({response.status_code}") + break # Not an error that should be retried, exit loop + + time.sleep(2**i) # Exponential backoff for retries + + # if request related to metrics upload and exceed retries + if response is None and kwargs.get("metrics"): + self.metrics_upload_failed_queue.update(kwargs.get("metrics")) + + return response + + if thread: + # Start a new thread to run the retry_request function + threading.Thread(target=retry_request, daemon=True).start() + else: + # If running in the main thread, call retry_request directly + return retry_request() + + @staticmethod + def _should_retry(status_code): + """Determines if a request should be retried based on the HTTP status code.""" + retry_codes = { + HTTPStatus.REQUEST_TIMEOUT, + HTTPStatus.BAD_GATEWAY, + HTTPStatus.GATEWAY_TIMEOUT, + } + return status_code in retry_codes + + def _get_failure_message(self, response: requests.Response, retry: int, timeout: int): + """ + Generate a retry message based on the response status code. + + Args: + response: The HTTP response object. + retry: The number of retry attempts allowed. + timeout: The maximum timeout duration. + + Returns: + (str): The retry message. + """ + if self._should_retry(response.status_code): + return f"Retrying {retry}x for {timeout}s." if retry else "" + elif response.status_code == HTTPStatus.TOO_MANY_REQUESTS: # rate limit + headers = response.headers + return ( + f"Rate limit reached ({headers['X-RateLimit-Remaining']}/{headers['X-RateLimit-Limit']}). " + f"Please retry after {headers['Retry-After']}s." + ) + else: + try: + return response.json().get("message", "No JSON message.") + except AttributeError: + return "Unable to read JSON." + + def upload_metrics(self): + """Upload model metrics to Ultralytics HUB.""" + return self.request_queue(self.model.upload_metrics, metrics=self.metrics_queue.copy(), thread=True) + + def upload_model( + self, + epoch: int, + weights: str, + is_best: bool = False, + map: float = 0.0, + final: bool = False, + ) -> None: + """ + Upload a model checkpoint to Ultralytics HUB. + + Args: + epoch (int): The current training epoch. + weights (str): Path to the model weights file. + is_best (bool): Indicates if the current model is the best one so far. + map (float): Mean average precision of the model. + final (bool): Indicates if the model is the final model after training. + """ + weights = Path(weights) + if not weights.is_file(): + last = weights.with_name(f"last{weights.suffix}") + if final and last.is_file(): + LOGGER.warning( + f"{PREFIX} WARNING ⚠️ Model 'best.pt' not found, copying 'last.pt' to 'best.pt' and uploading. " + "This often happens when resuming training in transient environments like Google Colab. " + "For more reliable training, consider using Ultralytics HUB Cloud. " + "Learn more at https://docs.ultralytics.com/hub/cloud-training." + ) + shutil.copy(last, weights) # copy last.pt to best.pt + else: + LOGGER.warning(f"{PREFIX} WARNING ⚠️ Model upload issue. Missing model {weights}.") + return + + self.request_queue( + self.model.upload_model, + epoch=epoch, + weights=str(weights), + is_best=is_best, + map=map, + final=final, + retry=10, + timeout=3600, + thread=not final, + progress_total=weights.stat().st_size if final else None, # only show progress if final + stream_response=True, + ) + + @staticmethod + def _show_upload_progress(content_length: int, response: requests.Response) -> None: + """ + Display a progress bar to track the upload progress of a file download. + + Args: + content_length (int): The total size of the content to be downloaded in bytes. + response (requests.Response): The response object from the file download request. + + Returns: + None + """ + with TQDM(total=content_length, unit="B", unit_scale=True, unit_divisor=1024) as pbar: + for data in response.iter_content(chunk_size=1024): + pbar.update(len(data)) + + @staticmethod + def _iterate_content(response: requests.Response) -> None: + """ + Process the streamed HTTP response data. + + Args: + response (requests.Response): The response object from the file download request. + + Returns: + None + """ + for _ in response.iter_content(chunk_size=1024): + pass # Do nothing with data chunks diff --git a/ultralytics/hub/utils.py b/ultralytics/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f83758894755c7a09f67923eca11b5d5f3b5799 --- /dev/null +++ b/ultralytics/hub/utils.py @@ -0,0 +1,246 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +import platform +import random +import threading +import time +from pathlib import Path + +import requests + +from ultralytics.utils import ( + ARGV, + ENVIRONMENT, + IS_COLAB, + IS_GIT_DIR, + IS_PIP_PACKAGE, + LOGGER, + ONLINE, + RANK, + SETTINGS, + TESTS_RUNNING, + TQDM, + TryExcept, + __version__, + colorstr, + get_git_origin_url, +) +from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES + +HUB_API_ROOT = os.environ.get("ULTRALYTICS_HUB_API", "https://api.ultralytics.com") +HUB_WEB_ROOT = os.environ.get("ULTRALYTICS_HUB_WEB", "https://hub.ultralytics.com") + +PREFIX = colorstr("Ultralytics HUB: ") +HELP_MSG = "If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance." + + +def request_with_credentials(url: str) -> any: + """ + Make an AJAX request with cookies attached in a Google Colab environment. + + Args: + url (str): The URL to make the request to. + + Returns: + (any): The response data from the AJAX request. + + Raises: + OSError: If the function is not run in a Google Colab environment. + """ + if not IS_COLAB: + raise OSError("request_with_credentials() must run in a Colab environment") + from google.colab import output # noqa + from IPython import display # noqa + + display.display( + display.Javascript( + f""" + window._hub_tmp = new Promise((resolve, reject) => {{ + const timeout = setTimeout(() => reject("Failed authenticating existing browser session"), 5000) + fetch("{url}", {{ + method: 'POST', + credentials: 'include' + }}) + .then((response) => resolve(response.json())) + .then((json) => {{ + clearTimeout(timeout); + }}).catch((err) => {{ + clearTimeout(timeout); + reject(err); + }}); + }}); + """ + ) + ) + return output.eval_js("_hub_tmp") + + +def requests_with_progress(method, url, **kwargs): + """ + Make an HTTP request using the specified method and URL, with an optional progress bar. + + Args: + method (str): The HTTP method to use (e.g. 'GET', 'POST'). + url (str): The URL to send the request to. + **kwargs (any): Additional keyword arguments to pass to the underlying `requests.request` function. + + Returns: + (requests.Response): The response object from the HTTP request. + + Note: + - If 'progress' is set to True, the progress bar will display the download progress for responses with a known + content length. + - If 'progress' is a number then progress bar will display assuming content length = progress. + """ + progress = kwargs.pop("progress", False) + if not progress: + return requests.request(method, url, **kwargs) + response = requests.request(method, url, stream=True, **kwargs) + total = int(response.headers.get("content-length", 0) if isinstance(progress, bool) else progress) # total size + try: + pbar = TQDM(total=total, unit="B", unit_scale=True, unit_divisor=1024) + for data in response.iter_content(chunk_size=1024): + pbar.update(len(data)) + pbar.close() + except requests.exceptions.ChunkedEncodingError: # avoid 'Connection broken: IncompleteRead' warnings + response.close() + return response + + +def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs): + """ + Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout. + + Args: + method (str): The HTTP method to use for the request. Choices are 'post' and 'get'. + url (str): The URL to make the request to. + retry (int, optional): Number of retries to attempt before giving up. Default is 3. + timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30. + thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True. + code (int, optional): An identifier for the request, used for logging purposes. Default is -1. + verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True. + progress (bool, optional): Whether to show a progress bar during the request. Default is False. + **kwargs (any): Keyword arguments to be passed to the requests function specified in method. + + Returns: + (requests.Response): The HTTP response object. If the request is executed in a separate thread, returns None. + """ + retry_codes = (408, 500) # retry only these codes + + @TryExcept(verbose=verbose) + def func(func_method, func_url, **func_kwargs): + """Make HTTP requests with retries and timeouts, with optional progress tracking.""" + r = None # response + t0 = time.time() # initial time for timer + for i in range(retry + 1): + if (time.time() - t0) > timeout: + break + r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files) + if r.status_code < 300: # return codes in the 2xx range are generally considered "good" or "successful" + break + try: + m = r.json().get("message", "No JSON message.") + except AttributeError: + m = "Unable to read JSON." + if i == 0: + if r.status_code in retry_codes: + m += f" Retrying {retry}x for {timeout}s." if retry else "" + elif r.status_code == 429: # rate limit + h = r.headers # response headers + m = ( + f"Rate limit reached ({h['X-RateLimit-Remaining']}/{h['X-RateLimit-Limit']}). " + f"Please retry after {h['Retry-After']}s." + ) + if verbose: + LOGGER.warning(f"{PREFIX}{m} {HELP_MSG} ({r.status_code} #{code})") + if r.status_code not in retry_codes: + return r + time.sleep(2**i) # exponential standoff + return r + + args = method, url + kwargs["progress"] = progress + if thread: + threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start() + else: + return func(*args, **kwargs) + + +class Events: + """ + A class for collecting anonymous event analytics. Event analytics are enabled when sync=True in settings and + disabled when sync=False. Run 'yolo settings' to see and update settings. + + Attributes: + url (str): The URL to send anonymous events. + rate_limit (float): The rate limit in seconds for sending events. + metadata (dict): A dictionary containing metadata about the environment. + enabled (bool): A flag to enable or disable Events based on certain conditions. + """ + + url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw" + + def __init__(self): + """Initializes the Events object with default values for events, rate_limit, and metadata.""" + self.events = [] # events list + self.rate_limit = 30.0 # rate limit (seconds) + self.t = 0.0 # rate limit timer (seconds) + self.metadata = { + "cli": Path(ARGV[0]).name == "yolo", + "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "python": ".".join(platform.python_version_tuple()[:2]), # i.e. 3.10 + "version": __version__, + "env": ENVIRONMENT, + "session_id": round(random.random() * 1e15), + "engagement_time_msec": 1000, + } + self.enabled = ( + SETTINGS["sync"] + and RANK in {-1, 0} + and not TESTS_RUNNING + and ONLINE + and (IS_PIP_PACKAGE or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git") + ) + + def __call__(self, cfg): + """ + Attempts to add a new event to the events list and send events if the rate limit is reached. + + Args: + cfg (IterableSimpleNamespace): The configuration object containing mode and task information. + """ + if not self.enabled: + # Events disabled, do nothing + return + + # Attempt to add to events + if len(self.events) < 25: # Events list limited to 25 events (drop any events past this) + params = { + **self.metadata, + "task": cfg.task, + "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom", + } + if cfg.mode == "export": + params["format"] = cfg.format + self.events.append({"name": cfg.mode, "params": params}) + + # Check rate limit + t = time.time() + if (t - self.t) < self.rate_limit: + # Time is under rate limiter, wait to send + return + + # Time is over rate limiter, send now + data = {"client_id": SETTINGS["uuid"], "events": self.events} # SHA-256 anonymized UUID hash and events list + + # POST equivalent to requests.post(self.url, json=data) + smart_request("post", self.url, json=data, retry=0, verbose=False) + + # Reset events and rate limit timer + self.events = [] + self.t = t + + +# Run below code on hub/utils init ------------------------------------------------------------------------------------- +events = Events() diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ead1e9230417102878174b3a011608f3c3d450db --- /dev/null +++ b/ultralytics/models/__init__.py @@ -0,0 +1,9 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .fastsam import FastSAM +from .nas import NAS +from .rtdetr import RTDETR +from .sam import SAM +from .yolo import YOLO, YOLOWorld + +__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld" # allow simpler import diff --git a/ultralytics/models/fastsam/__init__.py b/ultralytics/models/fastsam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c224ac8f9e8ef5f78b50558e4bb159674b1ba42 --- /dev/null +++ b/ultralytics/models/fastsam/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import FastSAM +from .predict import FastSAMPredictor +from .val import FastSAMValidator + +__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator" diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f9deb7a12b95586a3e93d705fa6ceec2962a3e28 --- /dev/null +++ b/ultralytics/models/fastsam/model.py @@ -0,0 +1,55 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +from ultralytics.engine.model import Model + +from .predict import FastSAMPredictor +from .val import FastSAMValidator + + +class FastSAM(Model): + """ + FastSAM model interface. + + Example: + ```python + from ultralytics import FastSAM + + model = FastSAM("last.pt") + results = model.predict("ultralytics/assets/bus.jpg") + ``` + """ + + def __init__(self, model="FastSAM-x.pt"): + """Call the __init__ method of the parent class (YOLO) with the updated default model.""" + if str(model) == "FastSAM.pt": + model = "FastSAM-x.pt" + assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models." + super().__init__(model=model, task="segment") + + def predict(self, source, stream=False, bboxes=None, points=None, labels=None, texts=None, **kwargs): + """ + Perform segmentation prediction on image or video source. + + Supports prompted segmentation with bounding boxes, points, labels, and texts. + + Args: + source (str | PIL.Image | numpy.ndarray): Input source. + stream (bool): Enable real-time streaming. + bboxes (list): Bounding box coordinates for prompted segmentation. + points (list): Points for prompted segmentation. + labels (list): Labels for prompted segmentation. + texts (list): Texts for prompted segmentation. + **kwargs (Any): Additional keyword arguments. + + Returns: + (list): Model predictions. + """ + prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts) + return super().predict(source, stream, prompts=prompts, **kwargs) + + @property + def task_map(self): + """Returns a dictionary mapping segment task to corresponding predictor and validator classes.""" + return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}} diff --git a/ultralytics/models/fastsam/predict.py b/ultralytics/models/fastsam/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..0d019afb9e28e835958e52baaa1a81eaf5d467bd --- /dev/null +++ b/ultralytics/models/fastsam/predict.py @@ -0,0 +1,150 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +from PIL import Image + +from ultralytics.models.yolo.segment import SegmentationPredictor +from ultralytics.utils import DEFAULT_CFG, checks +from ultralytics.utils.metrics import box_iou +from ultralytics.utils.ops import scale_masks + +from .utils import adjust_bboxes_to_image_border + + +class FastSAMPredictor(SegmentationPredictor): + """ + FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics + YOLO framework. + + This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It + adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single- + class segmentation. + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes a FastSAMPredictor for fast SAM segmentation tasks in Ultralytics YOLO framework.""" + super().__init__(cfg, overrides, _callbacks) + self.prompts = {} + + def postprocess(self, preds, img, orig_imgs): + """Applies box postprocess for FastSAM predictions.""" + bboxes = self.prompts.pop("bboxes", None) + points = self.prompts.pop("points", None) + labels = self.prompts.pop("labels", None) + texts = self.prompts.pop("texts", None) + results = super().postprocess(preds, img, orig_imgs) + for result in results: + full_box = torch.tensor( + [0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32 + ) + boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape) + idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten() + if idx.numel() != 0: + result.boxes.xyxy[idx] = full_box + + return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts) + + def prompt(self, results, bboxes=None, points=None, labels=None, texts=None): + """ + Internal function for image segmentation inference based on cues like bounding boxes, points, and masks. + Leverages SAM's specialized architecture for prompt-based, real-time segmentation. + + Args: + results (Results | List[Results]): The original inference results from FastSAM models without any prompts. + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + texts (str | List[str], optional): Textual prompts, a list contains string objects. + + Returns: + (List[Results]): The output results determined by prompts. + """ + if bboxes is None and points is None and texts is None: + return results + prompt_results = [] + if not isinstance(results, list): + results = [results] + for result in results: + if len(result) == 0: + prompt_results.append(result) + continue + masks = result.masks.data + if masks.shape[1:] != result.orig_shape: + masks = scale_masks(masks[None], result.orig_shape)[0] + # bboxes prompt + idx = torch.zeros(len(result), dtype=torch.bool, device=self.device) + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) + mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes]) + full_mask_areas = torch.sum(masks, dim=(1, 2)) + + union = bbox_areas[:, None] + full_mask_areas - mask_areas + idx[torch.argmax(mask_areas / union, dim=1)] = True + if points is not None: + points = torch.as_tensor(points, dtype=torch.int32, device=self.device) + points = points[None] if points.ndim == 1 else points + if labels is None: + labels = torch.ones(points.shape[0]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert len(labels) == len(points), ( + f"Excepted `labels` got same size as `point`, but got {len(labels)} and {len(points)}" + ) + point_idx = ( + torch.ones(len(result), dtype=torch.bool, device=self.device) + if labels.sum() == 0 # all negative points + else torch.zeros(len(result), dtype=torch.bool, device=self.device) + ) + for point, label in zip(points, labels): + point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label) + idx |= point_idx + if texts is not None: + if isinstance(texts, str): + texts = [texts] + crop_ims, filter_idx = [], [] + for i, b in enumerate(result.boxes.xyxy.tolist()): + x1, y1, x2, y2 = (int(x) for x in b) + if masks[i].sum() <= 100: + filter_idx.append(i) + continue + crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1])) + similarity = self._clip_inference(crop_ims, texts) + text_idx = torch.argmax(similarity, dim=-1) # (M, ) + if len(filter_idx): + text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0) + idx[text_idx] = True + + prompt_results.append(result[idx]) + + return prompt_results + + def _clip_inference(self, images, texts): + """ + CLIP Inference process. + + Args: + images (List[PIL.Image]): A list of source images and each of them should be PIL.Image type with RGB channel order. + texts (List[str]): A list of prompt texts and each of them should be string object. + + Returns: + (torch.Tensor): The similarity between given images and texts. + """ + try: + import clip + except ImportError: + checks.check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")): + self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device) + images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images]) + tokenized_text = clip.tokenize(texts).to(self.device) + image_features = self.clip_model.encode_image(images) + text_features = self.clip_model.encode_text(tokenized_text) + image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512) + text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512) + return (image_features * text_features[:, None]).sum(-1) # (M, N) + + def set_prompts(self, prompts): + """Set prompts in advance.""" + self.prompts = prompts diff --git a/ultralytics/models/fastsam/utils.py b/ultralytics/models/fastsam/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e1aa172ba78a5d10bcb00b0b944c6ea1cdf3d4 --- /dev/null +++ b/ultralytics/models/fastsam/utils.py @@ -0,0 +1,24 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + + +def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): + """ + Adjust bounding boxes to stick to image border if they are within a certain threshold. + + Args: + boxes (torch.Tensor): (n, 4) + image_shape (tuple): (height, width) + threshold (int): pixel threshold + + Returns: + adjusted_boxes (torch.Tensor): adjusted bounding boxes + """ + # Image dimensions + h, w = image_shape + + # Adjust boxes + boxes[boxes[:, 0] < threshold, 0] = 0 # x1 + boxes[boxes[:, 1] < threshold, 1] = 0 # y1 + boxes[boxes[:, 2] > w - threshold, 2] = w # x2 + boxes[boxes[:, 3] > h - threshold, 3] = h # y2 + return boxes diff --git a/ultralytics/models/fastsam/val.py b/ultralytics/models/fastsam/val.py new file mode 100644 index 0000000000000000000000000000000000000000..aa130dbfc9a7f72eafd3b20655e431cc4b617d31 --- /dev/null +++ b/ultralytics/models/fastsam/val.py @@ -0,0 +1,40 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.models.yolo.segment import SegmentationValidator +from ultralytics.utils.metrics import SegmentMetrics + + +class FastSAMValidator(SegmentationValidator): + """ + Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework. + + Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class + sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled + to avoid errors during validation. + + Attributes: + dataloader: The data loader object used for validation. + save_dir (str): The directory where validation results will be saved. + pbar: A progress bar object. + args: Additional arguments for customization. + _callbacks: List of callback functions to be invoked during validation. + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """ + Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics. + + Args: + dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. + save_dir (Path, optional): Directory to save results. + pbar (tqdm.tqdm): Progress bar for displaying progress. + args (SimpleNamespace): Configuration for the validator. + _callbacks (dict): Dictionary to store various callback functions. + + Notes: + Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors. + """ + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.args.task = "segment" + self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors + self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) diff --git a/ultralytics/models/nas/__init__.py b/ultralytics/models/nas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c36c0a42f0331f925a7086d82613b7e4f729b7bb --- /dev/null +++ b/ultralytics/models/nas/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import NAS +from .predict import NASPredictor +from .val import NASValidator + +__all__ = "NASPredictor", "NASValidator", "NAS" diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py new file mode 100644 index 0000000000000000000000000000000000000000..10fd72b4e46f82b264a580e1c01ccc171900efdb --- /dev/null +++ b/ultralytics/models/nas/model.py @@ -0,0 +1,94 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +YOLO-NAS model interface. + +Example: + ```python + from ultralytics import NAS + + model = NAS("yolo_nas_s") + results = model.predict("ultralytics/assets/bus.jpg") + ``` +""" + +from pathlib import Path + +import torch + +from ultralytics.engine.model import Model +from ultralytics.utils.downloads import attempt_download_asset +from ultralytics.utils.torch_utils import model_info + +from .predict import NASPredictor +from .val import NASValidator + + +class NAS(Model): + """ + YOLO NAS model for object detection. + + This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine. + It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models. + + Example: + ```python + from ultralytics import NAS + + model = NAS("yolo_nas_s") + results = model.predict("ultralytics/assets/bus.jpg") + ``` + + Attributes: + model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'. + + Note: + YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files. + """ + + def __init__(self, model="yolo_nas_s.pt") -> None: + """Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model.""" + assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." + super().__init__(model, task="detect") + + def _load(self, weights: str, task=None) -> None: + """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" + import super_gradients + + suffix = Path(weights).suffix + if suffix == ".pt": + self.model = torch.load(attempt_download_asset(weights)) + + elif suffix == "": + self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") + + # Override the forward method to ignore additional arguments + def new_forward(x, *args, **kwargs): + """Ignore additional __call__ arguments.""" + return self.model._original_forward(x) + + self.model._original_forward = self.model.forward + self.model.forward = new_forward + + # Standardize model + self.model.fuse = lambda verbose=True: self.model + self.model.stride = torch.tensor([32]) + self.model.names = dict(enumerate(self.model._class_names)) + self.model.is_fused = lambda: False # for info() + self.model.yaml = {} # for info() + self.model.pt_path = weights # for export() + self.model.task = "detect" # for export() + + def info(self, detailed=False, verbose=True): + """ + Logs model info. + + Args: + detailed (bool): Show detailed information about model. + verbose (bool): Controls verbosity. + """ + return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) + + @property + def task_map(self): + """Returns a dictionary mapping tasks to respective predictor and validator classes.""" + return {"detect": {"predictor": NASPredictor, "validator": NASValidator}} diff --git a/ultralytics/models/nas/predict.py b/ultralytics/models/nas/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..e140900e7bafdb44e0e8faf821f3e7c3e705ec6e --- /dev/null +++ b/ultralytics/models/nas/predict.py @@ -0,0 +1,57 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import ops + + +class NASPredictor(BasePredictor): + """ + Ultralytics YOLO NAS Predictor for object detection. + + This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the + raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and + scaling the bounding boxes to fit the original image dimensions. + + Attributes: + args (Namespace): Namespace containing various configurations for post-processing. + + Example: + ```python + from ultralytics import NAS + + model = NAS("yolo_nas_s") + predictor = model.predictor + # Assumes that raw_preds, img, orig_imgs are available + results = predictor.postprocess(raw_preds, img, orig_imgs) + ``` + + Note: + Typically, this class is not instantiated directly. It is used internally within the `NAS` class. + """ + + def postprocess(self, preds_in, img, orig_imgs): + """Postprocess predictions and returns a list of Results objects.""" + # Cat boxes and class scores + boxes = ops.xyxy2xywh(preds_in[0][0]) + preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) + + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) + return results diff --git a/ultralytics/models/nas/val.py b/ultralytics/models/nas/val.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d0f37e37ac1b5a02ea45a606fa020472f11429 --- /dev/null +++ b/ultralytics/models/nas/val.py @@ -0,0 +1,50 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import ops + +__all__ = ["NASValidator"] + + +class NASValidator(DetectionValidator): + """ + Ultralytics YOLO NAS Validator for object detection. + + Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions + generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes, + ultimately producing the final detections. + + Attributes: + args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU. + lb (torch.Tensor): Optional tensor for multilabel NMS. + + Example: + ```python + from ultralytics import NAS + + model = NAS("yolo_nas_s") + validator = model.validator + # Assumes that raw_preds are available + final_preds = validator.postprocess(raw_preds) + ``` + + Note: + This class is generally not instantiated directly but is used internally within the `NAS` class. + """ + + def postprocess(self, preds_in): + """Apply Non-maximum suppression to prediction outputs.""" + boxes = ops.xyxy2xywh(preds_in[0][0]) + preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=False, + agnostic=self.args.single_cls or self.args.agnostic_nms, + max_det=self.args.max_det, + max_time_img=0.5, + ) diff --git a/ultralytics/models/rtdetr/__init__.py b/ultralytics/models/rtdetr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d038d652cfcfb7e8ddf0424981b558dbbeb270 --- /dev/null +++ b/ultralytics/models/rtdetr/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import RTDETR +from .predict import RTDETRPredictor +from .val import RTDETRValidator + +__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR" diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py new file mode 100644 index 0000000000000000000000000000000000000000..fa4123a8a24c3b482ac5dbebcc1ad8134c4e4674 --- /dev/null +++ b/ultralytics/models/rtdetr/model.py @@ -0,0 +1,54 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time +performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient +hybrid encoder and IoU-aware query selection for enhanced detection accuracy. + +For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf +""" + +from ultralytics.engine.model import Model +from ultralytics.nn.tasks import RTDETRDetectionModel + +from .predict import RTDETRPredictor +from .train import RTDETRTrainer +from .val import RTDETRValidator + + +class RTDETR(Model): + """ + Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance + with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed. + + Attributes: + model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. + """ + + def __init__(self, model="rtdetr-l.pt") -> None: + """ + Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats. + + Args: + model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'. + + Raises: + NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'. + """ + super().__init__(model=model, task="detect") + + @property + def task_map(self) -> dict: + """ + Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes. + + Returns: + dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model. + """ + return { + "detect": { + "predictor": RTDETRPredictor, + "validator": RTDETRValidator, + "trainer": RTDETRTrainer, + "model": RTDETRDetectionModel, + } + } diff --git a/ultralytics/models/rtdetr/predict.py b/ultralytics/models/rtdetr/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..782cc2f640de876e281d53022d8ffee282ebd770 --- /dev/null +++ b/ultralytics/models/rtdetr/predict.py @@ -0,0 +1,84 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import ops + + +class RTDETRPredictor(BasePredictor): + """ + RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using + Baidu's RT-DETR model. + + This class leverages the power of Vision Transformers to provide real-time object detection while maintaining + high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.rtdetr import RTDETRPredictor + + args = dict(model="rtdetr-l.pt", source=ASSETS) + predictor = RTDETRPredictor(overrides=args) + predictor.predict_cli() + ``` + + Attributes: + imgsz (int): Image size for inference (must be square and scale-filled). + args (dict): Argument overrides for the predictor. + """ + + def postprocess(self, preds, img, orig_imgs): + """ + Postprocess the raw predictions from the model to generate bounding boxes and confidence scores. + + The method filters detections based on confidence and class if specified in `self.args`. + + Args: + preds (list): List of [predictions, extra] from the model. + img (torch.Tensor): Processed input images. + orig_imgs (list or torch.Tensor): Original, unprocessed images. + + Returns: + (list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores, + and class labels. + """ + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + + nd = preds[0].shape[-1] + bboxes, scores = preds[0].split((4, nd - 4), dim=-1) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4) + bbox = ops.xywh2xyxy(bbox) + max_score, cls = score.max(-1, keepdim=True) # (300, 1) + idx = max_score.squeeze(-1) > self.args.conf # (300, ) + if self.args.classes is not None: + idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx + pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter + oh, ow = orig_img.shape[:2] + pred[..., [0, 2]] *= ow + pred[..., [1, 3]] *= oh + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) + return results + + def pre_transform(self, im): + """ + Pre-transforms the input images before feeding them into the model for inference. The input images are + letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled. + + Args: + im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list. + + Returns: + (list): List of pre-transformed images ready for model inference. + """ + letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True) + return [letterbox(image=x) for x in im] diff --git a/ultralytics/models/rtdetr/train.py b/ultralytics/models/rtdetr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc30f9f39aaf51c55a239f4e63f454cad6a657c --- /dev/null +++ b/ultralytics/models/rtdetr/train.py @@ -0,0 +1,105 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +import torch + +from ultralytics.models.yolo.detect import DetectionTrainer +from ultralytics.nn.tasks import RTDETRDetectionModel +from ultralytics.utils import RANK, colorstr + +from .val import RTDETRDataset, RTDETRValidator + + +class RTDETRTrainer(DetectionTrainer): + """ + Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer + class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision + Transformers and has capabilities like IoU-aware query selection and adaptable inference speed. + + Notes: + - F.grid_sample used in RT-DETR does not support the `deterministic=True` argument. + - AMP training can lead to NaN outputs and may produce errors during bipartite graph matching. + + Example: + ```python + from ultralytics.models.rtdetr.train import RTDETRTrainer + + args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3) + trainer = RTDETRTrainer(overrides=args) + trainer.train() + ``` + """ + + def get_model(self, cfg=None, weights=None, verbose=True): + """ + Initialize and return an RT-DETR model for object detection tasks. + + Args: + cfg (dict, optional): Model configuration. Defaults to None. + weights (str, optional): Path to pre-trained model weights. Defaults to None. + verbose (bool): Verbose logging if True. Defaults to True. + + Returns: + (RTDETRDetectionModel): Initialized model. + """ + model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build and return an RT-DETR dataset for training or validation. + + Args: + img_path (str): Path to the folder containing images. + mode (str): Dataset mode, either 'train' or 'val'. + batch (int, optional): Batch size for rectangle training. Defaults to None. + + Returns: + (RTDETRDataset): Dataset object for the specific mode. + """ + return RTDETRDataset( + img_path=img_path, + imgsz=self.args.imgsz, + batch_size=batch, + augment=mode == "train", + hyp=self.args, + rect=False, + cache=self.args.cache or None, + single_cls=self.args.single_cls or False, + prefix=colorstr(f"{mode}: "), + classes=self.args.classes, + data=self.data, + fraction=self.args.fraction if mode == "train" else 1.0, + ) + + def get_validator(self): + """ + Returns a DetectionValidator suitable for RT-DETR model validation. + + Returns: + (RTDETRValidator): Validator object for model validation. + """ + self.loss_names = "giou_loss", "cls_loss", "l1_loss" + return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) + + def preprocess_batch(self, batch): + """ + Preprocess a batch of images. Scales and converts the images to float format. + + Args: + batch (dict): Dictionary containing a batch of images, bboxes, and labels. + + Returns: + (dict): Preprocessed batch. + """ + batch = super().preprocess_batch(batch) + bs = len(batch["img"]) + batch_idx = batch["batch_idx"] + gt_bbox, gt_class = [], [] + for i in range(bs): + gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device)) + gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long)) + return batch diff --git a/ultralytics/models/rtdetr/val.py b/ultralytics/models/rtdetr/val.py new file mode 100644 index 0000000000000000000000000000000000000000..a218b4af5aea9b2c52936f09c88293758201c0b0 --- /dev/null +++ b/ultralytics/models/rtdetr/val.py @@ -0,0 +1,135 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.data import YOLODataset +from ultralytics.data.augment import Compose, Format, v8_transforms +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import colorstr, ops + +__all__ = ("RTDETRValidator",) # tuple or list + + +class RTDETRDataset(YOLODataset): + """ + Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class. + + This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for + real-time detection and tracking tasks. + """ + + def __init__(self, *args, data=None, **kwargs): + """Initialize the RTDETRDataset class by inheriting from the YOLODataset class.""" + super().__init__(*args, data=data, **kwargs) + + # NOTE: add stretch version load_image for RTDETR mosaic + def load_image(self, i, rect_mode=False): + """Loads 1 image from dataset index 'i', returns (im, resized hw).""" + return super().load_image(i=i, rect_mode=rect_mode) + + def build_transforms(self, hyp=None): + """Temporary, only for evaluation.""" + if self.augment: + hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 + hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 + transforms = v8_transforms(self, self.imgsz, hyp, stretch=True) + else: + # transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)]) + transforms = Compose([]) + transforms.append( + Format( + bbox_format="xywh", + normalize=True, + return_mask=self.use_segments, + return_keypoint=self.use_keypoints, + batch_idx=True, + mask_ratio=hyp.mask_ratio, + mask_overlap=hyp.overlap_mask, + ) + ) + return transforms + + +class RTDETRValidator(DetectionValidator): + """ + RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for + the RT-DETR (Real-Time DETR) object detection model. + + The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for + post-processing, and updates evaluation metrics accordingly. + + Example: + ```python + from ultralytics.models.rtdetr import RTDETRValidator + + args = dict(model="rtdetr-l.pt", data="coco8.yaml") + validator = RTDETRValidator(args=args) + validator() + ``` + + Note: + For further details on the attributes and methods, refer to the parent DetectionValidator class. + """ + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build an RTDETR Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + return RTDETRDataset( + img_path=img_path, + imgsz=self.args.imgsz, + batch_size=batch, + augment=False, # no augmentation + hyp=self.args, + rect=False, # no rect + cache=self.args.cache or None, + prefix=colorstr(f"{mode}: "), + data=self.data, + ) + + def postprocess(self, preds): + """Apply Non-maximum suppression to prediction outputs.""" + if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference + preds = [preds, None] + + bs, _, nd = preds[0].shape + bboxes, scores = preds[0].split((4, nd - 4), dim=-1) + bboxes *= self.args.imgsz + outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs + for i, bbox in enumerate(bboxes): # (300, 4) + bbox = ops.xywh2xyxy(bbox) + score, cls = scores[i].max(-1) # (300, ) + # Do not need threshold for evaluation as only got 300 boxes here + # idx = score > self.args.conf + pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter + # Sort by confidence to correctly get internal metrics + pred = pred[score.argsort(descending=True)] + outputs[i] = pred # [idx] + + return outputs + + def _prepare_batch(self, si, batch): + """Prepares a batch for training or inference by applying transformations.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ops.xywh2xyxy(bbox) # target boxes + bbox[..., [0, 2]] *= ori_shape[1] # native-space pred + bbox[..., [1, 3]] *= ori_shape[0] # native-space pred + return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad} + + def _prepare_pred(self, pred, pbatch): + """Prepares and returns a batch with transformed bounding boxes and class labels.""" + predn = pred.clone() + predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred + predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred + return predn.float() diff --git a/ultralytics/models/sam/__init__.py b/ultralytics/models/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d9de7b64ea566a6cdd26aeb1959fb4cc08a598e --- /dev/null +++ b/ultralytics/models/sam/__init__.py @@ -0,0 +1,6 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .model import SAM +from .predict import Predictor, SAM2Predictor, SAM2VideoPredictor + +__all__ = "SAM", "Predictor", "SAM2Predictor", "SAM2VideoPredictor" # tuple or list diff --git a/ultralytics/models/sam/amg.py b/ultralytics/models/sam/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..4abce4cd7db716ce3fbabe82f47b3e863451dd7d --- /dev/null +++ b/ultralytics/models/sam/amg.py @@ -0,0 +1,193 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +from itertools import product +from typing import Any, Generator, List, Tuple + +import numpy as np +import torch + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Determines if bounding boxes are near the edge of a cropped image region using a specified tolerance.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + """Yields batches of data from input arguments with specified batch size for efficient processing.""" + assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. + + The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at + high and low values. + + Args: + masks (torch.Tensor): Batch of predicted mask logits. + mask_threshold (float): Threshold value for creating binary masks. + threshold_offset (float): Offset applied to the threshold for creating high and low binary masks. + + Returns: + (torch.Tensor): Stability scores for each mask in the batch. + + Notes: + - One mask is always contained inside the other. + - Memory is saved by preventing unnecessary cast to torch.int64. + + Examples: + >>> masks = torch.rand(10, 256, 256) # Batch of 10 masks + >>> mask_threshold = 0.5 + >>> threshold_offset = 0.1 + >>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset) + """ + intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks.""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + return np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + + +def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> List[np.ndarray]: + """Generates point grids for multiple crop layers with varying scales and densities.""" + return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)] + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """Generates crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.""" + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + """Crops bounding boxes to the size of the input image.""" + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + """Uncrop bounding boxes by adding the crop box offset to their coordinates.""" + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + """Uncrop points by adding the crop box offset to their coordinates.""" + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks(masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int) -> torch.Tensor: + """Uncrop masks by padding them to the original image size, handling coordinate transformations.""" + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tuple[np.ndarray, bool]: + """Removes small disconnected regions or holes in a mask based on area threshold and mode.""" + import cv2 # type: ignore + + assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid" + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if not small_regions: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + # If every region is below threshold, keep largest + fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """Calculates bounding boxes in XYXY format around binary masks, handling empty masks and various input shapes.""" + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0) + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0] diff --git a/ultralytics/models/sam/build.py b/ultralytics/models/sam/build.py new file mode 100644 index 0000000000000000000000000000000000000000..47c9d5a345ba4d9f74c54b1f2427874d82739fa4 --- /dev/null +++ b/ultralytics/models/sam/build.py @@ -0,0 +1,358 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial + +import torch + +from ultralytics.utils.downloads import attempt_download_asset + +from .modules.decoders import MaskDecoder +from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder +from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer +from .modules.sam import SAM2Model, SAMModel +from .modules.tiny_encoder import TinyViT +from .modules.transformer import TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None): + """Builds and returns a Segment Anything Model (SAM) h-size model with specified encoder parameters.""" + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +def build_sam_vit_l(checkpoint=None): + """Builds and returns a Segment Anything Model (SAM) l-size model with specified encoder parameters.""" + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + """Constructs and returns a Segment Anything Model (SAM) with b-size architecture and optional checkpoint.""" + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_mobile_sam(checkpoint=None): + """Builds and returns a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation.""" + return _build_sam( + encoder_embed_dim=[64, 128, 160, 320], + encoder_depth=[2, 2, 6, 2], + encoder_num_heads=[2, 4, 5, 10], + encoder_global_attn_indexes=None, + mobile_sam=True, + checkpoint=checkpoint, + ) + + +def build_sam2_t(checkpoint=None): + """Builds and returns a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 7, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[5, 7, 9], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_s(checkpoint=None): + """Builds and returns a small-size Segment Anything Model (SAM2) with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=96, + encoder_stages=[1, 2, 11, 2], + encoder_num_heads=1, + encoder_global_att_blocks=[7, 10, 13], + encoder_window_spec=[8, 4, 14, 7], + encoder_backbone_channel_list=[768, 384, 192, 96], + checkpoint=checkpoint, + ) + + +def build_sam2_b(checkpoint=None): + """Builds and returns a SAM2 base-size model with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=112, + encoder_stages=[2, 3, 16, 3], + encoder_num_heads=2, + encoder_global_att_blocks=[12, 16, 20], + encoder_window_spec=[8, 4, 14, 7], + encoder_window_spatial_size=[14, 14], + encoder_backbone_channel_list=[896, 448, 224, 112], + checkpoint=checkpoint, + ) + + +def build_sam2_l(checkpoint=None): + """Builds and returns a large-size Segment Anything Model (SAM2) with specified architecture parameters.""" + return _build_sam2( + encoder_embed_dim=144, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[23, 33, 43], + encoder_window_spec=[8, 4, 16, 8], + encoder_backbone_channel_list=[1152, 576, 288, 144], + checkpoint=checkpoint, + ) + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + mobile_sam=False, +): + """ + Builds a Segment Anything Model (SAM) with specified encoder parameters. + + Args: + encoder_embed_dim (int | List[int]): Embedding dimension for the encoder. + encoder_depth (int | List[int]): Depth of the encoder. + encoder_num_heads (int | List[int]): Number of attention heads in the encoder. + encoder_global_attn_indexes (List[int] | None): Indexes for global attention in the encoder. + checkpoint (str | None): Path to the model checkpoint file. + mobile_sam (bool): Whether to build a Mobile-SAM model. + + Returns: + (SAMModel): A Segment Anything Model instance with the specified architecture. + + Examples: + >>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11]) + >>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True) + """ + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + image_encoder = ( + TinyViT( + img_size=1024, + in_chans=3, + num_classes=1000, + embed_dims=encoder_embed_dim, + depths=encoder_depth, + num_heads=encoder_num_heads, + window_sizes=[7, 7, 14, 7], + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8, + ) + if mobile_sam + else ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ) + ) + sam = SAMModel( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + if checkpoint is not None: + checkpoint = attempt_download_asset(checkpoint) + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + sam.eval() + return sam + + +def _build_sam2( + encoder_embed_dim=1280, + encoder_stages=[2, 6, 36, 4], + encoder_num_heads=2, + encoder_global_att_blocks=[7, 15, 23, 31], + encoder_backbone_channel_list=[1152, 576, 288, 144], + encoder_window_spatial_size=[7, 7], + encoder_window_spec=[8, 4, 16, 8], + checkpoint=None, +): + """ + Builds and returns a Segment Anything Model 2 (SAM2) with specified architecture parameters. + + Args: + encoder_embed_dim (int): Embedding dimension for the encoder. + encoder_stages (List[int]): Number of blocks in each stage of the encoder. + encoder_num_heads (int): Number of attention heads in the encoder. + encoder_global_att_blocks (List[int]): Indices of global attention blocks in the encoder. + encoder_backbone_channel_list (List[int]): Channel dimensions for each level of the encoder backbone. + encoder_window_spatial_size (List[int]): Spatial size of the window for position embeddings. + encoder_window_spec (List[int]): Window specifications for each stage of the encoder. + checkpoint (str | None): Path to the checkpoint file for loading pre-trained weights. + + Returns: + (SAM2Model): A configured and initialized SAM2 model. + + Examples: + >>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2]) + >>> sam2_model.eval() + """ + image_encoder = ImageEncoder( + trunk=Hiera( + embed_dim=encoder_embed_dim, + num_heads=encoder_num_heads, + stages=encoder_stages, + global_att_blocks=encoder_global_att_blocks, + window_pos_embed_bkg_spatial_size=encoder_window_spatial_size, + window_spec=encoder_window_spec, + ), + neck=FpnNeck( + d_model=256, + backbone_channel_list=encoder_backbone_channel_list, + fpn_top_down_levels=[2, 3], + fpn_interp_model="nearest", + ), + scalp=1, + ) + memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer()) + memory_encoder = MemoryEncoder(out_dim=64) + + is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint + sam2 = SAM2Model( + image_encoder=image_encoder, + memory_attention=memory_attention, + memory_encoder=memory_encoder, + num_maskmem=7, + image_size=1024, + sigmoid_scale_for_mem_enc=20.0, + sigmoid_bias_for_mem_enc=-10.0, + use_mask_input_as_output_without_sam=True, + directly_add_no_mem_embed=True, + use_high_res_features_in_sam=True, + multimask_output_in_sam=True, + iou_prediction_use_sigmoid=True, + use_obj_ptrs_in_encoder=True, + add_tpos_enc_to_obj_ptrs=True, + only_obj_ptrs_in_the_past_for_eval=True, + pred_obj_scores=True, + pred_obj_scores_mlp=True, + fixed_no_obj_ptr=True, + multimask_output_for_tracking=True, + use_multimask_token_for_obj_ptr=True, + multimask_min_pt_num=0, + multimask_max_pt_num=1, + use_mlp_for_obj_ptr_proj=True, + compile_image_encoder=False, + no_obj_embed_spatial=is_sam2_1, + proj_tpos_enc_in_obj_ptrs=is_sam2_1, + use_signed_tpos_enc_to_obj_ptrs=is_sam2_1, + sam_mask_decoder_extra_args=dict( + dynamic_multimask_via_stability=True, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + ), + ) + + if checkpoint is not None: + checkpoint = attempt_download_asset(checkpoint) + with open(checkpoint, "rb") as f: + state_dict = torch.load(f)["model"] + sam2.load_state_dict(state_dict) + sam2.eval() + return sam2 + + +sam_model_map = { + "sam_h.pt": build_sam_vit_h, + "sam_l.pt": build_sam_vit_l, + "sam_b.pt": build_sam_vit_b, + "mobile_sam.pt": build_mobile_sam, + "sam2_t.pt": build_sam2_t, + "sam2_s.pt": build_sam2_s, + "sam2_b.pt": build_sam2_b, + "sam2_l.pt": build_sam2_l, + "sam2.1_t.pt": build_sam2_t, + "sam2.1_s.pt": build_sam2_s, + "sam2.1_b.pt": build_sam2_b, + "sam2.1_l.pt": build_sam2_l, +} + + +def build_sam(ckpt="sam_b.pt"): + """ + Builds and returns a Segment Anything Model (SAM) based on the provided checkpoint. + + Args: + ckpt (str | Path): Path to the checkpoint file or name of a pre-defined SAM model. + + Returns: + (SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance. + + Raises: + FileNotFoundError: If the provided checkpoint is not a supported SAM model. + + Examples: + >>> sam_model = build_sam("sam_b.pt") + >>> sam_model = build_sam("path/to/custom_checkpoint.pt") + + Notes: + Supported pre-defined models include: + - SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt' + - SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt' + """ + model_builder = None + ckpt = str(ckpt) # to allow Path ckpt types + for k in sam_model_map.keys(): + if ckpt.endswith(k): + model_builder = sam_model_map.get(k) + + if not model_builder: + raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}") + + return model_builder(ckpt) diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fb501b795e23d4e54bdd6f988569ce3c4dfbe1 --- /dev/null +++ b/ultralytics/models/sam/model.py @@ -0,0 +1,175 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +SAM model interface. + +This module provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for real-time image +segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis, +and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new +image distributions and tasks without prior knowledge. + +Key Features: + - Promptable segmentation + - Real-time performance + - Zero-shot transfer capabilities + - Trained on SA-1B dataset +""" + +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.utils.torch_utils import model_info + +from .build import build_sam +from .predict import Predictor, SAM2Predictor + + +class SAM(Model): + """ + SAM (Segment Anything Model) interface class for real-time image segmentation tasks. + + This class provides an interface to the Segment Anything Model (SAM) from Ultralytics, designed for + promptable segmentation with versatility in image analysis. It supports various prompts such as bounding + boxes, points, or labels, and features zero-shot performance capabilities. + + Attributes: + model (torch.nn.Module): The loaded SAM model. + is_sam2 (bool): Indicates whether the model is SAM2 variant. + task (str): The task type, set to "segment" for SAM models. + + Methods: + predict: Performs segmentation prediction on the given image or video source. + info: Logs information about the SAM model. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> results = sam.predict("image.jpg", points=[[500, 375]]) + >>> for r in results: + >>> print(f"Detected {len(r.masks)} masks") + """ + + def __init__(self, model="sam_b.pt") -> None: + """ + Initializes the SAM (Segment Anything Model) instance. + + Args: + model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension. + + Raises: + NotImplementedError: If the model file extension is not .pt or .pth. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> print(sam.is_sam2) + """ + if model and Path(model).suffix not in {".pt", ".pth"}: + raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.") + self.is_sam2 = "sam2" in Path(model).stem + super().__init__(model=model, task="segment") + + def _load(self, weights: str, task=None): + """ + Loads the specified weights into the SAM model. + + This method initializes the SAM model with the provided weights file, setting up the model architecture + and loading the pre-trained parameters. + + Args: + weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters. + task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> sam._load("path/to/custom_weights.pt") + """ + self.model = build_sam(weights) + + def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): + """ + Performs segmentation prediction on the given image or video source. + + Args: + source (str | PIL.Image | numpy.ndarray): Path to the image or video file, or a PIL.Image object, or + a numpy.ndarray object. + stream (bool): If True, enables real-time streaming. + bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation. + points (List[List[float]] | None): List of points for prompted segmentation. + labels (List[int] | None): List of labels for prompted segmentation. + **kwargs (Any): Additional keyword arguments for prediction. + + Returns: + (List): The model predictions. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> results = sam.predict("image.jpg", points=[[500, 375]]) + >>> for r in results: + ... print(f"Detected {len(r.masks)} masks") + """ + overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024) + kwargs = {**overrides, **kwargs} + prompts = dict(bboxes=bboxes, points=points, labels=labels) + return super().predict(source, stream, prompts=prompts, **kwargs) + + def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): + """ + Performs segmentation prediction on the given image or video source. + + This method is an alias for the 'predict' method, providing a convenient way to call the SAM model + for segmentation tasks. + + Args: + source (str | PIL.Image | numpy.ndarray | None): Path to the image or video file, or a PIL.Image + object, or a numpy.ndarray object. + stream (bool): If True, enables real-time streaming. + bboxes (List[List[float]] | None): List of bounding box coordinates for prompted segmentation. + points (List[List[float]] | None): List of points for prompted segmentation. + labels (List[int] | None): List of labels for prompted segmentation. + **kwargs (Any): Additional keyword arguments to be passed to the predict method. + + Returns: + (List): The model predictions, typically containing segmentation masks and other relevant information. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> results = sam("image.jpg", points=[[500, 375]]) + >>> print(f"Detected {len(results[0].masks)} masks") + """ + return self.predict(source, stream, bboxes, points, labels, **kwargs) + + def info(self, detailed=False, verbose=True): + """ + Logs information about the SAM model. + + This method provides details about the Segment Anything Model (SAM), including its architecture, + parameters, and computational requirements. + + Args: + detailed (bool): If True, displays detailed information about the model layers and operations. + verbose (bool): If True, prints the information to the console. + + Returns: + (tuple): A tuple containing the model's information (string representations of the model). + + Examples: + >>> sam = SAM("sam_b.pt") + >>> info = sam.info() + >>> print(info[0]) # Print summary information + """ + return model_info(self.model, detailed=detailed, verbose=verbose) + + @property + def task_map(self): + """ + Provides a mapping from the 'segment' task to its corresponding 'Predictor'. + + Returns: + (Dict[str, Type[Predictor]]): A dictionary mapping the 'segment' task to its corresponding Predictor + class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor. + + Examples: + >>> sam = SAM("sam_b.pt") + >>> task_map = sam.task_map + >>> print(task_map) + {'segment': } + """ + return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}} diff --git a/ultralytics/models/sam/modules/__init__.py b/ultralytics/models/sam/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/ultralytics/models/sam/modules/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/ultralytics/models/sam/modules/blocks.py b/ultralytics/models/sam/modules/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..9abcc4406e261158397c29a7838dba7714fea220 --- /dev/null +++ b/ultralytics/models/sam/modules/blocks.py @@ -0,0 +1,1129 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import copy +import math +from functools import partial +from typing import Any, Optional, Tuple, Type, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from ultralytics.nn.modules import MLP, LayerNorm2d, MLPBlock + +from .transformer import Attention, TwoWayAttentionBlock, TwoWayTransformer +from .utils import add_decomposed_rel_pos, apply_rotary_enc, compute_axial_cis, window_partition, window_unpartition + + +class DropPath(nn.Module): + """ + Implements stochastic depth regularization for neural networks during training. + + Attributes: + drop_prob (float): Probability of dropping a path during training. + scale_by_keep (bool): Whether to scale the output by the keep probability. + + Methods: + forward: Applies stochastic depth to input tensor during training, with optional scaling. + + Examples: + >>> drop_path = DropPath(drop_prob=0.2, scale_by_keep=True) + >>> x = torch.randn(32, 64, 224, 224) + >>> output = drop_path(x) + """ + + def __init__(self, drop_prob=0.0, scale_by_keep=True): + """Initialize DropPath module for stochastic depth regularization during training.""" + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + """Applies stochastic depth to input tensor during training, with optional scaling.""" + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class MaskDownSampler(nn.Module): + """ + A mask downsampling and embedding module for efficient processing of input masks. + + This class implements a mask downsampler that progressively reduces the spatial dimensions of input masks + while expanding their channel dimensions using convolutional layers, layer normalization, and activation + functions. + + Attributes: + encoder (nn.Sequential): A sequential container of convolutional layers, layer normalization, and + activation functions for downsampling and embedding masks. + + Methods: + forward: Downsamples and encodes input mask to embed_dim channels. + + Examples: + >>> mask_downsampler = MaskDownSampler(embed_dim=256, kernel_size=4, stride=4, padding=0, total_stride=16) + >>> input_mask = torch.randn(1, 1, 256, 256) + >>> output = mask_downsampler(input_mask) + >>> print(output.shape) + torch.Size([1, 256, 16, 16]) + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + """Initializes a mask downsampler module for progressive downsampling and channel expansion.""" + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + """Downsamples and encodes input mask to embed_dim channels using convolutional layers and LayerNorm2d.""" + return self.encoder(x) + + +class CXBlock(nn.Module): + """ + ConvNeXt Block for efficient feature extraction in convolutional neural networks. + + This block implements a modified version of the ConvNeXt architecture, offering improved performance and + flexibility in feature extraction. + + Attributes: + dwconv (nn.Conv2d): Depthwise or standard 2D convolution layer. + norm (LayerNorm2d): Layer normalization applied to channels. + pwconv1 (nn.Linear): First pointwise convolution implemented as a linear layer. + act (nn.GELU): GELU activation function. + pwconv2 (nn.Linear): Second pointwise convolution implemented as a linear layer. + gamma (nn.Parameter | None): Learnable scale parameter for layer scaling. + drop_path (nn.Module): DropPath layer for stochastic depth regularization. + + Methods: + forward: Processes the input tensor through the ConvNeXt block. + + Examples: + >>> import torch + >>> x = torch.randn(1, 64, 56, 56) + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 56, 56]) + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + """ + Initialize a ConvNeXt Block for efficient feature extraction in convolutional neural networks. + + This block implements a modified version of the ConvNeXt architecture, offering improved performance and + flexibility in feature extraction. + + Args: + dim (int): Number of input channels. + kernel_size (int): Size of the convolutional kernel. + padding (int): Padding size for the convolution. + drop_path (float): Stochastic depth rate. + layer_scale_init_value (float): Initial value for Layer Scale. + use_dwconv (bool): Whether to use depthwise convolution. + + Examples: + >>> block = CXBlock(dim=64, kernel_size=7, padding=3) + >>> x = torch.randn(1, 64, 32, 32) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 64, 32, 32]) + """ + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + """Applies ConvNeXt block operations to input tensor, including convolutions and residual connection.""" + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + """ + A module for fusing features through multiple layers of a neural network. + + This class applies a series of identical layers to an input tensor, optionally projecting the input first. + + Attributes: + proj (nn.Module): An optional input projection layer. Identity if no projection is needed. + layers (nn.ModuleList): A list of identical layers to be applied sequentially. + + Methods: + forward: Applies the fuser to an input tensor. + + Examples: + >>> layer = CXBlock(dim=256) + >>> fuser = Fuser(layer, num_layers=3, dim=256, input_projection=True) + >>> x = torch.randn(1, 256, 32, 32) + >>> output = fuser(x) + >>> print(output.shape) + torch.Size([1, 256, 32, 32]) + """ + + def __init__(self, layer, num_layers, dim=None, input_projection=False): + """ + Initializes the Fuser module for feature fusion through multiple layers. + + This module creates a sequence of identical layers and optionally applies an input projection. + + Args: + layer (nn.Module): The layer to be replicated in the fuser. + num_layers (int): The number of times to replicate the layer. + dim (int | None): The dimension for input projection, if used. + input_projection (bool): Whether to use input projection. + + Examples: + >>> layer = nn.Linear(64, 64) + >>> fuser = Fuser(layer, num_layers=3, dim=64, input_projection=True) + >>> input_tensor = torch.randn(1, 64) + >>> output = fuser(input_tensor) + """ + super().__init__() + self.proj = nn.Identity() + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + """Applies a series of layers to the input tensor, optionally projecting it first.""" + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class SAM2TwoWayAttentionBlock(TwoWayAttentionBlock): + """ + A two-way attention block for performing self-attention and cross-attention in both directions. + + This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on + sparse inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and + cross-attention from dense to sparse inputs. + + Attributes: + self_attn (Attention): Self-attention layer for queries. + norm1 (nn.LayerNorm): Layer normalization after the first attention block. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization after the second attention block. + mlp (MLP): MLP block for transforming query embeddings. + norm3 (nn.LayerNorm): Layer normalization after the MLP block. + norm4 (nn.LayerNorm): Layer normalization after the third attention block. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Flag to skip positional encoding in the first layer. + + Methods: + forward: Processes input through the attention blocks and MLP. + + Examples: + >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8) + >>> sparse_input = torch.randn(1, 100, 256) + >>> dense_input = torch.randn(1, 256, 16, 16) + >>> sparse_output, dense_output = block(sparse_input, dense_input) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + Initializes a SAM2TwoWayAttentionBlock for performing self-attention and cross-attention in two directions. + + This block extends the TwoWayAttentionBlock and consists of four main components: self-attention on sparse + inputs, cross-attention from sparse to dense inputs, an MLP block on sparse inputs, and cross-attention + from dense to sparse inputs. + + Args: + embedding_dim (int): The channel dimension of the embeddings. + num_heads (int): The number of heads in the attention layers. + mlp_dim (int): The hidden dimension of the MLP block. + activation (Type[nn.Module]): The activation function of the MLP block. + attention_downsample_rate (int): The downsample rate for attention computations. + skip_first_layer_pe (bool): Whether to skip the positional encoding in the first layer. + + Examples: + >>> block = SAM2TwoWayAttentionBlock(embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> sparse_inputs = torch.randn(1, 100, 256) + >>> dense_inputs = torch.randn(1, 256, 32, 32) + >>> sparse_outputs, dense_outputs = block(sparse_inputs, dense_inputs) + """ + super().__init__(embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate, skip_first_layer_pe) + self.mlp = MLP(embedding_dim, mlp_dim, embedding_dim, num_layers=2, act=activation) + + +class SAM2TwoWayTransformer(TwoWayTransformer): + """ + A Two-Way Transformer module for simultaneous attention to image and query points. + + This class extends the TwoWayTransformer, implementing a specialized transformer decoder that attends to an + input image using queries with supplied positional embeddings. It is particularly useful for tasks like + object detection, image segmentation, and point cloud processing. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of SAM2TwoWayAttentionBlock layers comprising the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Methods: + forward: Processes input image embeddings and query embeddings through the transformer. + + Examples: + >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 64, 64) + >>> query_embedding = torch.randn(1, 100, 256) + >>> output = transformer(image_embedding, query_embedding) + >>> print(output[0].shape, output[1].shape) + torch.Size([1, 100, 256]) torch.Size([1, 256, 64, 64]) + """ + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + Initializes a SAM2TwoWayTransformer instance. + + This transformer decoder attends to an input image using queries with supplied positional embeddings. + It is designed for tasks like object detection, image segmentation, and point cloud processing. + + Args: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for the input embeddings. + num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. + mlp_dim (int): Channel dimension internal to the MLP block. + activation (Type[nn.Module]): Activation function to use in the MLP block. + attention_downsample_rate (int): Downsampling rate for attention computations. + + Examples: + >>> transformer = SAM2TwoWayTransformer(depth=5, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> transformer + SAM2TwoWayTransformer( + (layers): ModuleList( + (0-4): 5 x SAM2TwoWayAttentionBlock(...) + ) + (final_attn_token_to_image): Attention(...) + (norm_final_attn): LayerNorm(...) + ) + """ + super().__init__(depth, embedding_dim, num_heads, mlp_dim, activation, attention_downsample_rate) + self.layers = nn.ModuleList() + for i in range(depth): + self.layers.append( + SAM2TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + +class RoPEAttention(Attention): + """ + Implements rotary position encoding for attention mechanisms in transformer architectures. + + This class extends the base Attention class by incorporating Rotary Position Encoding (RoPE) to enhance + the positional awareness of the attention mechanism. + + Attributes: + compute_cis (Callable): Function to compute axial complex numbers for rotary encoding. + freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding. + rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories. + + Methods: + forward: Applies rotary position encoding and computes attention between query, key, and value tensors. + + Examples: + >>> rope_attn = RoPEAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32)) + >>> q = torch.randn(1, 1024, 256) + >>> k = torch.randn(1, 1024, 256) + >>> v = torch.randn(1, 1024, 256) + >>> output = rope_attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 1024, 256]) + """ + + def __init__( + self, + *args, + rope_theta=10000.0, + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + """Initializes RoPEAttention with rotary position encoding for enhanced positional awareness.""" + super().__init__(*args, **kwargs) + + self.compute_cis = partial(compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat # repeat q rope to match k length, needed for cross-attention to memories + + def forward(self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0) -> Tensor: + """Applies rotary position encoding and computes attention between query, key, and value tensors.""" + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + """Applies pooling and optional normalization to a tensor, handling spatial dimension permutations.""" + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + """ + Implements multiscale self-attention with optional query pooling for efficient feature extraction. + + This class provides a flexible implementation of multiscale attention, allowing for optional + downsampling of query features through pooling. It's designed to enhance the model's ability to + capture multiscale information in visual tasks. + + Attributes: + dim (int): Input dimension of the feature map. + dim_out (int): Output dimension of the attention module. + num_heads (int): Number of attention heads. + scale (float): Scaling factor for dot-product attention. + q_pool (nn.Module | None): Optional pooling module for query features. + qkv (nn.Linear): Linear projection for query, key, and value. + proj (nn.Linear): Output projection. + + Methods: + forward: Applies multiscale attention to the input tensor. + + Examples: + >>> import torch + >>> from torch import nn + >>> x = torch.randn(1, 64, 64, 256) + >>> msa = MultiScaleAttention(dim=256, dim_out=256, num_heads=8) + >>> output = msa(x) + >>> print(output.shape) + torch.Size([1, 64, 64, 256]) + """ + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + """Initializes multiscale attention with optional query pooling for efficient feature extraction.""" + super().__init__() + + self.dim = dim + self.dim_out = dim_out + + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multiscale attention with optional query pooling to extract multiscale features.""" + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + """ + A multiscale attention block with window partitioning and query pooling for efficient vision transformers. + + This class implements a multiscale attention mechanism with optional window partitioning and downsampling, + designed for use in vision transformer architectures. + + Attributes: + dim (int): Input dimension of the block. + dim_out (int): Output dimension of the block. + norm1 (nn.Module): First normalization layer. + window_size (int): Size of the window for partitioning. + pool (nn.Module | None): Pooling layer for query downsampling. + q_stride (Tuple[int, int] | None): Stride for query pooling. + attn (MultiScaleAttention): Multi-scale attention module. + drop_path (nn.Module): Drop path layer for regularization. + norm2 (nn.Module): Second normalization layer. + mlp (MLP): Multi-layer perceptron module. + proj (nn.Linear | None): Projection layer for dimension mismatch. + + Methods: + forward: Processes input tensor through the multiscale block. + + Examples: + >>> block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 28, 28, 512]) + """ + + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + """Initializes a multiscale attention block with window partitioning and optional query pooling.""" + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d(kernel_size=q_stride, stride=q_stride, ceil_mode=False) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + act=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through multiscale attention and MLP, with optional windowing and downsampling.""" + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PositionEmbeddingSine(nn.Module): + """ + A module for generating sinusoidal positional embeddings for 2D inputs like images. + + This class implements sinusoidal position encoding for 2D spatial positions, which can be used in + transformer-based models for computer vision tasks. + + Attributes: + num_pos_feats (int): Number of positional features (half of the embedding dimension). + temperature (int): Temperature parameter for the sinusoidal functions. + normalize (bool): Whether to normalize the positional embeddings. + scale (float): Scaling factor for the embeddings when normalize is True. + cache (Dict): Cache for storing precomputed embeddings. + + Methods: + _encode_xy: Encodes 2D positions using sine and cosine functions. + encode_boxes: Encodes box coordinates and dimensions into positional embeddings. + encode_points: Encodes 2D point coordinates with sinusoidal positional embeddings. + forward: Generates sinusoidal position embeddings for 2D inputs. + + Examples: + >>> pos_emb = PositionEmbeddingSine(num_pos_feats=128) + >>> x = torch.randn(1, 3, 224, 224) + >>> embeddings = pos_emb(x) + >>> print(embeddings.shape) + torch.Size([1, 256, 224, 224]) + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + """Initializes sinusoidal position embeddings for 2D image inputs.""" + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and not normalize: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + """Encodes 2D positions using sine/cosine functions for transformer positional embeddings.""" + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + """Encodes box coordinates and dimensions into positional embeddings for detection.""" + pos_x, pos_y = self._encode_xy(x, y) + return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + """Encodes 2D points with sinusoidal embeddings and appends labels.""" + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + + @torch.no_grad() + def forward(self, x: torch.Tensor): + """Generates sinusoidal position embeddings for 2D inputs like images.""" + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + + This class generates positional embeddings for input coordinates using random spatial frequencies. It is + particularly useful for transformer-based models that require position information. + + Attributes: + positional_encoding_gaussian_matrix (torch.Tensor): A buffer containing random values for encoding. + + Methods: + _pe_encoding: Positionally encodes points that are normalized to [0,1]. + forward: Generates positional encoding for a grid of the specified size. + forward_with_coords: Positionally encodes points that are not normalized to [0,1]. + + Examples: + >>> pe = PositionEmbeddingRandom(num_pos_feats=64) + >>> size = (32, 32) + >>> encoding = pe(size) + >>> print(encoding.shape) + torch.Size([128, 32, 32]) + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + """Initializes random spatial frequency position embedding for transformers.""" + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((2, num_pos_feats))) + + # Set non-deterministic for forward() error 'cumsum_cuda_kernel does not have a deterministic implementation' + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Encodes normalized [0,1] coordinates using random spatial frequencies.""" + # Assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # Outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generates positional encoding for a grid using random spatial frequencies.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """Positionally encodes input coordinates, normalizing them to [0,1] based on the given image size.""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class Block(nn.Module): + """ + Transformer block with support for window attention and residual propagation. + + This class implements a transformer block that can use either global or windowed self-attention, + followed by a feed-forward network. It supports relative positional embeddings and is designed + for use in vision transformer architectures. + + Attributes: + norm1 (nn.Module): First normalization layer. + attn (REAttention): Self-attention layer with optional relative positional encoding. + norm2 (nn.Module): Second normalization layer. + mlp (MLPBlock): Multi-layer perceptron block. + window_size (int): Size of attention window. If 0, global attention is used. + + Methods: + forward: Processes input through the transformer block. + + Examples: + >>> import torch + >>> block = Block(dim=256, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 56, 56, 256]) + """ + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initializes a transformer block with optional window attention and relative positional embeddings. + + This constructor sets up a transformer block that can use either global or windowed self-attention, + followed by a feed-forward network. It supports relative positional embeddings and is designed + for use in vision transformer architectures. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in the self-attention layer. + mlp_ratio (float): Ratio of mlp hidden dimension to embedding dimension. + qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. + norm_layer (Type[nn.Module]): Type of normalization layer to use. + act_layer (Type[nn.Module]): Type of activation function to use in the MLP block. + use_rel_pos (bool): If True, uses relative positional embeddings in attention. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. + window_size (int): Size of attention window. If 0, uses global attention. + input_size (Optional[Tuple[int, int]]): Input resolution for calculating relative positional parameter size. + + Examples: + >>> block = Block(dim=256, num_heads=8, window_size=7) + >>> x = torch.randn(1, 56, 56, 256) + >>> output = block(x) + >>> print(output.shape) + torch.Size([1, 56, 56, 256]) + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = REAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through transformer block with optional windowed self-attention and residual connection.""" + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + return x + self.mlp(self.norm2(x)) + + +class REAttention(nn.Module): + """ + Rotary Embedding Attention module for efficient self-attention in transformer architectures. + + This class implements a multi-head attention mechanism with rotary positional embeddings, designed + for use in vision transformer models. It supports optional query pooling and window partitioning + for efficient processing of large inputs. + + Attributes: + compute_cis (Callable): Function to compute axial complex numbers for rotary encoding. + freqs_cis (Tensor): Precomputed frequency tensor for rotary encoding. + rope_k_repeat (bool): Flag to repeat query RoPE to match key length for cross-attention to memories. + q_proj (nn.Linear): Linear projection for query. + k_proj (nn.Linear): Linear projection for key. + v_proj (nn.Linear): Linear projection for value. + out_proj (nn.Linear): Output projection. + num_heads (int): Number of attention heads. + internal_dim (int): Internal dimension for attention computation. + + Methods: + forward: Applies rotary position encoding and computes attention between query, key, and value tensors. + + Examples: + >>> rope_attn = REAttention(embedding_dim=256, num_heads=8, rope_theta=10000.0, feat_sizes=(32, 32)) + >>> q = torch.randn(1, 1024, 256) + >>> k = torch.randn(1, 1024, 256) + >>> v = torch.randn(1, 1024, 256) + >>> output = rope_attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 1024, 256]) + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initializes a Relative Position Attention module for transformer-based architectures. + + This module implements multi-head attention with optional relative positional encodings, designed + specifically for vision tasks in transformer models. + + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. Default is 8. + qkv_bias (bool): If True, adds a learnable bias to query, key, value projections. Default is True. + use_rel_pos (bool): If True, uses relative positional encodings. Default is False. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. Default is True. + input_size (Tuple[int, int] | None): Input resolution for calculating relative positional parameter size. + Required if use_rel_pos is True. Default is None. + + Examples: + >>> attention = REAttention(dim=256, num_heads=8, input_size=(32, 32)) + >>> x = torch.randn(1, 32, 32, 256) + >>> output = attention(x) + >>> print(output.shape) + torch.Size([1, 32, 32, 256]) + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert input_size is not None, "Input size must be provided if using relative positional encoding." + # Initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies multi-head attention with optional relative positional encoding to input tensor.""" + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + return self.proj(x) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding module for vision transformer architectures. + + This module converts an input image into a sequence of patch embeddings using a convolutional layer. + It is commonly used as the first layer in vision transformer architectures to transform image data + into a suitable format for subsequent transformer blocks. + + Attributes: + proj (nn.Conv2d): Convolutional layer for projecting image patches to embeddings. + + Methods: + forward: Applies patch embedding to the input tensor. + + Examples: + >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + torch.Size([1, 768, 14, 14]) + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Initializes the PatchEmbed module for converting image patches to embeddings. + + This module is typically used as the first layer in vision transformer architectures to transform + image data into a suitable format for subsequent transformer blocks. + + Args: + kernel_size (Tuple[int, int]): Size of the convolutional kernel for patch extraction. + stride (Tuple[int, int]): Stride of the convolutional operation. + padding (Tuple[int, int]): Padding applied to the input before convolution. + in_chans (int): Number of input image channels. + embed_dim (int): Dimensionality of the output patch embeddings. + + Examples: + >>> patch_embed = PatchEmbed(kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + torch.Size([1, 768, 14, 14]) + """ + super().__init__() + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes patch embedding by applying convolution and transposing resulting tensor.""" + return self.proj(x).permute(0, 2, 3, 1) # B C H W -> B H W C diff --git a/ultralytics/models/sam/modules/decoders.py b/ultralytics/models/sam/modules/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9497f6c67db655b4e3877325fae1307aa5f470 --- /dev/null +++ b/ultralytics/models/sam/modules/decoders.py @@ -0,0 +1,518 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from ultralytics.nn.modules import MLP, LayerNorm2d + + +class MaskDecoder(nn.Module): + """ + Decoder module for generating masks and their associated quality scores using a transformer architecture. + + This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and + generate mask predictions along with their quality scores. + + Attributes: + transformer_dim (int): Channel dimension for the transformer module. + transformer (nn.Module): Transformer module used for mask prediction. + num_multimask_outputs (int): Number of masks to predict for disambiguating masks. + iou_token (nn.Embedding): Embedding for the IoU token. + num_mask_tokens (int): Number of mask tokens. + mask_tokens (nn.Embedding): Embedding for the mask tokens. + output_upscaling (nn.Sequential): Neural network sequence for upscaling the output. + output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks. + iou_prediction_head (nn.Module): MLP for predicting mask quality. + + Methods: + forward: Predicts masks given image and prompt embeddings. + predict_masks: Internal method for mask prediction. + + Examples: + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module) + >>> masks, iou_pred = decoder( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True + ... ) + >>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}") + """ + + def __init__( + self, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Initializes the MaskDecoder module for generating masks and their quality scores. + + Args: + transformer_dim (int): Channel dimension for the transformer module. + transformer (nn.Module): Transformer module used for mask prediction. + num_multimask_outputs (int): Number of masks to predict for disambiguating masks. + activation (Type[nn.Module]): Type of activation to use when upscaling masks. + iou_head_depth (int): Depth of the MLP used to predict mask quality. + iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. + + Examples: + >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6) + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer) + >>> print(decoder) + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) + + self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks given image and prompt embeddings. + + Args: + image_embeddings (torch.Tensor): Embeddings from the image encoder. + image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings. + sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes. + dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs. + multimask_output (bool): Whether to return multiple masks or a single mask. + + Returns: + (Tuple[torch.Tensor, torch.Tensor]): A tuple containing: + - masks (torch.Tensor): Batched predicted masks. + - iou_pred (torch.Tensor): Batched predictions of mask quality. + + Examples: + >>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module) + >>> image_emb = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_emb = torch.rand(1, 2, 256) + >>> dense_emb = torch.rand(1, 256, 64, 64) + >>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True) + >>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}") + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + mask_slice = slice(1, None) if multimask_output else slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks and quality scores using image and prompt embeddings via transformer architecture.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [ + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens) + ] + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +class SAM2MaskDecoder(nn.Module): + """ + Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings. + + This class extends the functionality of the MaskDecoder, incorporating additional features such as + high-resolution feature processing, dynamic multimask output, and object score prediction. + + Attributes: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + iou_token (nn.Embedding): Embedding for IOU token. + num_mask_tokens (int): Total number of mask tokens. + mask_tokens (nn.Embedding): Embedding for mask tokens. + pred_obj_scores (bool): Whether to predict object scores. + obj_score_token (nn.Embedding): Embedding for object score token. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + output_upscaling (nn.Sequential): Upscaling layers for output. + use_high_res_features (bool): Whether to use high-resolution features. + conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0). + conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1). + output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks. + iou_prediction_head (MLP): MLP for IOU prediction. + pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. + + Methods: + forward: Predicts masks given image and prompt embeddings. + predict_masks: Predicts instance segmentation masks from image and prompt embeddings. + _get_stability_scores: Computes mask stability scores based on IoU between thresholds. + _dynamic_multimask_via_stability: Dynamically selects the most stable mask output. + + Examples: + >>> image_embeddings = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) + >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) + >>> decoder = SAM2MaskDecoder(256, transformer) + >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False + ... ) + """ + + def __init__( + self, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Initializes the SAM2MaskDecoder module for predicting instance segmentation masks. + + This decoder extends the functionality of MaskDecoder, incorporating additional features such as + high-resolution feature processing, dynamic multimask output, and object score prediction. + + Args: + transformer_dim (int): Channel dimension of the transformer. + transformer (nn.Module): Transformer used to predict masks. + num_multimask_outputs (int): Number of masks to predict when disambiguating masks. + activation (Type[nn.Module]): Type of activation to use when upscaling masks. + iou_head_depth (int): Depth of the MLP used to predict mask quality. + iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality. + use_high_res_features (bool): Whether to use high-resolution features. + iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction. + dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability. + dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability. + dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability. + pred_obj_scores (bool): Whether to predict object scores. + pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer. + + Examples: + >>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6) + >>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer) + >>> print(decoder) + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1) + self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1) + + self.output_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks given image and prompt embeddings. + + Args: + image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W). + image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W). + sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C). + dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W). + multimask_output (bool): Whether to return multiple masks or a single mask. + repeat_image (bool): Flag to repeat the image embeddings. + high_res_features (List[torch.Tensor] | None): Optional high-resolution features. + + Returns: + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): A tuple containing: + - masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W). + - iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N). + - sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C). + - object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1). + + Examples: + >>> image_embeddings = torch.rand(1, 256, 64, 64) + >>> image_pe = torch.rand(1, 256, 64, 64) + >>> sparse_prompt_embeddings = torch.rand(1, 2, 256) + >>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64) + >>> decoder = SAM2MaskDecoder(256, transformer) + >>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward( + ... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False + ... ) + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts instance segmentation masks from image and prompt embeddings using a transformer.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert image_pe.size(0) == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [ + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens) + ] + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """Computes mask stability scores based on IoU between upper and lower thresholds.""" + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + return torch.where(area_u > 0, area_i / area_u, 1.0) + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + Dynamically selects the most stable mask output based on stability scores and IoU predictions. + + This method is used when outputting a single mask. If the stability score from the current single-mask + output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs + (based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask + for both clicking and tracking scenarios. + + Args: + all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is + batch size, N is number of masks (typically 4), and H, W are mask dimensions. + all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N). + + Returns: + (Tuple[torch.Tensor, torch.Tensor]): + - mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W). + - iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1). + + Examples: + >>> decoder = SAM2MaskDecoder(...) + >>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each + >>> all_iou_scores = torch.rand(2, 4) + >>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores) + >>> print(mask_logits.shape, iou_scores.shape) + torch.Size([2, 1, 256, 256]) torch.Size([2, 1]) + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange(multimask_iou_scores.size(0), device=all_iou_scores.device) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/ultralytics/models/sam/modules/encoders.py b/ultralytics/models/sam/modules/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e9fae887a074134ea9173ad8bd09a43b4f1e2e --- /dev/null +++ b/ultralytics/models/sam/modules/encoders.py @@ -0,0 +1,794 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import List, Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.nn.modules import LayerNorm2d + +from .blocks import ( + Block, + CXBlock, + Fuser, + MaskDownSampler, + MultiScaleBlock, + PatchEmbed, + PositionEmbeddingRandom, + PositionEmbeddingSine, +) + + +class ImageEncoderViT(nn.Module): + """ + An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space. + + This class processes images by splitting them into patches, applying transformer blocks, and generating a final + encoded representation through a neck module. + + Attributes: + img_size (int): Dimension of input images, assumed to be square. + patch_embed (PatchEmbed): Module for patch embedding. + pos_embed (nn.Parameter | None): Absolute positional embedding for patches. + blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings. + neck (nn.Sequential): Neck module to further process the output. + + Methods: + forward: Processes input through patch embedding, positional embedding, blocks, and neck. + + Examples: + >>> import torch + >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12) + >>> input_image = torch.randn(1, 3, 224, 224) + >>> output = encoder(input_image) + >>> print(output.shape) + """ + + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Initializes an ImageEncoderViT instance for encoding images using Vision Transformer architecture. + + Args: + img_size (int): Input image size, assumed to be square. + patch_size (int): Size of image patches. + in_chans (int): Number of input image channels. + embed_dim (int): Dimension of patch embeddings. + depth (int): Number of transformer blocks. + num_heads (int): Number of attention heads in each block. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + out_chans (int): Number of output channels from the neck module. + qkv_bias (bool): If True, adds learnable bias to query, key, value projections. + norm_layer (Type[nn.Module]): Type of normalization layer to use. + act_layer (Type[nn.Module]): Type of activation layer to use. + use_abs_pos (bool): If True, uses absolute positional embeddings. + use_rel_pos (bool): If True, adds relative positional embeddings to attention maps. + rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero. + window_size (int): Size of attention window for windowed attention blocks. + global_attn_indexes (Tuple[int, ...]): Indices of blocks that use global attention. + + Attributes: + img_size (int): Dimension of input images. + patch_embed (PatchEmbed): Module for patch embedding. + pos_embed (nn.Parameter | None): Absolute positional embedding for patches. + blocks (nn.ModuleList): List of transformer blocks. + neck (nn.Sequential): Neck module for final processing. + + Examples: + >>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12) + >>> input_image = torch.randn(1, 3, 224, 224) + >>> output = encoder(input_image) + >>> print(output.shape) + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Processes input through patch embedding, positional embedding, transformer blocks, and neck module.""" + x = self.patch_embed(x) + if self.pos_embed is not None: + pos_embed = ( + F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1) + if self.img_size != 1024 + else self.pos_embed + ) + x = x + pos_embed + for blk in self.blocks: + x = blk(x) + return self.neck(x.permute(0, 3, 1, 2)) + + +class PromptEncoder(nn.Module): + """ + Encodes different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings. + + Attributes: + embed_dim (int): Dimension of the embeddings. + input_image_size (Tuple[int, int]): Size of the input image as (H, W). + image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W). + pe_layer (PositionEmbeddingRandom): Module for random position embedding. + num_point_embeddings (int): Number of point embeddings for different types of points. + point_embeddings (nn.ModuleList): List of point embeddings. + not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label. + mask_input_size (Tuple[int, int]): Size of the input mask. + mask_downscaling (nn.Sequential): Neural network for downscaling the mask. + no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided. + + Methods: + get_dense_pe: Returns the positional encoding used to encode point prompts. + forward: Embeds different types of prompts, returning both sparse and dense embeddings. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks) + >>> print(sparse_embeddings.shape, dense_embeddings.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) + """ + + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Initializes the PromptEncoder module for encoding various types of prompts. + + This module encodes different types of prompts (points, boxes, masks) for input to SAM's mask decoder, + producing both sparse and dense embeddings. + + Args: + embed_dim (int): The dimension of the embeddings. + image_embedding_size (Tuple[int, int]): The spatial size of the image embedding as (H, W). + input_image_size (Tuple[int, int]): The padded size of the input image as (H, W). + mask_in_chans (int): The number of hidden channels used for encoding input masks. + activation (Type[nn.Module]): The activation function to use when encoding input masks. + + Attributes: + embed_dim (int): Dimension of the embeddings. + input_image_size (Tuple[int, int]): Size of the input image as (H, W). + image_embedding_size (Tuple[int, int]): Spatial size of the image embedding as (H, W). + pe_layer (PositionEmbeddingRandom): Module for random position embedding. + num_point_embeddings (int): Number of point embeddings for different types of points. + point_embeddings (nn.ModuleList): List of point embeddings. + not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label. + mask_input_size (Tuple[int, int]): Size of the input mask. + mask_downscaling (nn.Sequential): Neural network for downscaling the mask. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks) + >>> print(sparse_embeddings.shape, dense_embeddings.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the dense positional encoding used for encoding point prompts. + + This method generates a positional encoding for a dense set of points matching the shape of the image + encoding. The encoding is used to provide spatial information to the model when processing point prompts. + + Returns: + (torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the + height and width of the image embedding size, respectively. + + Examples: + >>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> dense_pe = prompt_encoder.get_dense_pe() + >>> print(dense_pe.shape) + torch.Size([1, 256, 64, 64]) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor: + """Embeds point prompts by applying positional encoding and label-specific embeddings.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts by applying positional encoding and adding corner embeddings.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs by downscaling and processing through convolutional layers.""" + return self.mask_downscaling(masks) + + @staticmethod + def _get_batch_size( + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """Gets the batch size of the output given the batch size of the input prompts.""" + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + """Returns the device of the first point embedding's weight tensor.""" + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense embeddings. + + Args: + points (Tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first + tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with + shape (B, N). + boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes. + masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W). + + Returns: + (Tuple[torch.Tensor, torch.Tensor]): A tuple containing: + - sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim). + - dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W). + + Examples: + >>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16) + >>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5))) + >>> boxes = torch.rand(1, 2, 2, 2) + >>> masks = torch.rand(1, 1, 256, 256) + >>> sparse_emb, dense_emb = encoder(points, boxes, masks) + >>> print(sparse_emb.shape, dense_emb.shape) + torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64]) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class MemoryEncoder(nn.Module): + """ + Encodes pixel features and masks into a memory representation for efficient image segmentation. + + This class processes pixel-level features and masks, fusing them to generate encoded memory representations + suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model). + + Attributes: + mask_downsampler (MaskDownSampler): Module for downsampling input masks. + pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features. + fuser (Fuser): Module for fusing pixel features and masks. + position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features. + out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d. + + Methods: + forward: Processes input pixel features and masks to generate encoded memory representations. + + Examples: + >>> import torch + >>> encoder = MemoryEncoder(out_dim=256, in_dim=256) + >>> pix_feat = torch.randn(1, 256, 64, 64) + >>> masks = torch.randn(1, 1, 64, 64) + >>> encoded_feat, pos = encoder(pix_feat, masks) + >>> print(encoded_feat.shape, pos.shape) + torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64]) + """ + + def __init__( + self, + out_dim, + in_dim=256, # in_dim of pix_feats + ): + """Initializes the MemoryEncoder for encoding pixel features and masks into memory representations.""" + super().__init__() + + self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1) + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = Fuser(CXBlock(dim=256), num_layers=2) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=64) + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Processes pixel features and masks to generate encoded memory representations for segmentation.""" + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks) + + # Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} + + +class ImageEncoder(nn.Module): + """ + Encodes images using a trunk-neck architecture, producing multiscale features and positional encodings. + + This class combines a trunk network for feature extraction with a neck network for feature refinement + and positional encoding generation. It can optionally discard the lowest resolution features. + + Attributes: + trunk (nn.Module): The trunk network for initial feature extraction. + neck (nn.Module): The neck network for feature refinement and positional encoding generation. + scalp (int): Number of lowest resolution feature levels to discard. + + Methods: + forward: Processes the input image through the trunk and neck networks. + + Examples: + >>> trunk = SomeTrunkNetwork() + >>> neck = SomeNeckNetwork() + >>> encoder = ImageEncoder(trunk, neck, scalp=1) + >>> image = torch.randn(1, 3, 224, 224) + >>> output = encoder(image) + >>> print(output.keys()) + dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn']) + """ + + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + """Initializes the ImageEncoder with trunk and neck networks for feature extraction and refinement.""" + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert self.trunk.channel_list == self.neck.backbone_channel_list, ( + f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match." + ) + + def forward(self, sample: torch.Tensor): + """Encodes input through patch embedding, positional embedding, transformer blocks, and neck module.""" + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + return { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + + +class FpnNeck(nn.Module): + """ + A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models. + + This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, + similar to ViT positional embedding interpolation. + + Attributes: + position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module. + convs (nn.ModuleList): List of convolutional layers for each backbone level. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. + fpn_top_down_levels (List[int]): Levels to have top-down features in outputs. + + Methods: + forward: Performs forward pass through the FPN neck. + + Examples: + >>> backbone_channels = [64, 128, 256, 512] + >>> fpn_neck = FpnNeck(256, backbone_channels) + >>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels] + >>> outputs, positions = fpn_neck(inputs) + >>> print(len(outputs), len(positions)) + 4 4 + """ + + def __init__( + self, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """ + Initializes a modified Feature Pyramid Network (FPN) neck. + + This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing, + similar to ViT positional embedding interpolation. + + Args: + d_model (int): Dimension of the model. + backbone_channel_list (List[int]): List of channel dimensions from the backbone. + kernel_size (int): Kernel size for the convolutional layers. + stride (int): Stride for the convolutional layers. + padding (int): Padding for the convolutional layers. + fpn_interp_model (str): Interpolation mode for FPN feature resizing. + fuse_type (str): Type of feature fusion, either 'sum' or 'avg'. + fpn_top_down_levels (Optional[List[int]]): Levels to have top-down features in outputs. + + Examples: + >>> backbone_channels = [64, 128, 256, 512] + >>> fpn_neck = FpnNeck(256, backbone_channels) + >>> print(fpn_neck) + """ + super().__init__() + self.position_encoding = PositionEmbeddingSine(num_pos_feats=256) + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in {"sum", "avg"} + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + """ + Performs forward pass through the Feature Pyramid Network (FPN) neck. + + This method processes a list of input tensors from the backbone through the FPN, applying lateral connections + and top-down feature fusion. It generates output feature maps and corresponding positional encodings. + + Args: + xs (List[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W). + + Returns: + (Tuple[List[torch.Tensor], List[torch.Tensor]]): A tuple containing: + - out (List[torch.Tensor]): List of output feature maps after FPN processing, each with shape + (B, d_model, H, W). + - pos (List[torch.Tensor]): List of positional encodings corresponding to each output feature map. + + Examples: + >>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512]) + >>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]] + >>> outputs, positions = fpn_neck(inputs) + >>> print(len(outputs), len(positions)) + 4 4 + """ + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=(None if self.fpn_interp_model == "nearest" else False), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos + + +class Hiera(nn.Module): + """ + Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks. + + This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for + efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages, + with optional pooling and global attention mechanisms. + + Attributes: + window_spec (Tuple[int, ...]): Window sizes for each stage. + q_stride (Tuple[int, int]): Downsampling stride between stages. + stage_ends (List[int]): Indices of the last block in each stage. + q_pool_blocks (List[int]): Indices of blocks where pooling is applied. + return_interm_layers (bool): Whether to return intermediate layer outputs. + patch_embed (PatchEmbed): Module for patch embedding. + global_att_blocks (Tuple[int, ...]): Indices of blocks with global attention. + window_pos_embed_bkg_spatial_size (Tuple[int, int]): Spatial size for window positional embedding background. + pos_embed (nn.Parameter): Positional embedding for the background. + pos_embed_window (nn.Parameter): Positional embedding for the window. + blocks (nn.ModuleList): List of MultiScaleBlock modules. + channel_list (List[int]): List of output channel dimensions for each stage. + + Methods: + _get_pos_embed: Generates positional embeddings by interpolating and combining window and background embeddings. + forward: Performs the forward pass through the Hiera model. + + Examples: + >>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3)) + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> output_features = model(input_tensor) + >>> for feat in output_features: + ... print(feat.shape) + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + """Initializes the Hiera model, configuring its hierarchical vision transformer architecture.""" + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + kernel_size=(7, 7), + stride=(4, 4), + padding=(3, 3), + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)) + self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + """Generates positional embeddings by interpolating and combining window and background embeddings.""" + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)]) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + """Performs forward pass through Hiera model, extracting multiscale features from input images.""" + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs diff --git a/ultralytics/models/sam/modules/memory_attention.py b/ultralytics/models/sam/modules/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..14998f37a923f236b0a4620403bf4a0d229d4d06 --- /dev/null +++ b/ultralytics/models/sam/modules/memory_attention.py @@ -0,0 +1,237 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import copy +from typing import Optional + +import torch +from torch import Tensor, nn + +from .blocks import RoPEAttention + + +class MemoryAttentionLayer(nn.Module): + """ + Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks. + + This class combines self-attention, cross-attention, and feedforward components to process input tensors and + generate memory-based attention outputs. + + Attributes: + d_model (int): Dimensionality of the model. + dim_feedforward (int): Dimensionality of the feedforward network. + dropout_value (float): Dropout rate for regularization. + self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding). + cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing. + linear1 (nn.Linear): First linear layer of the feedforward network. + linear2 (nn.Linear): Second linear layer of the feedforward network. + norm1 (nn.LayerNorm): Layer normalization for self-attention output. + norm2 (nn.LayerNorm): Layer normalization for cross-attention output. + norm3 (nn.LayerNorm): Layer normalization for feedforward network output. + dropout1 (nn.Dropout): Dropout layer after self-attention. + dropout2 (nn.Dropout): Dropout layer after cross-attention. + dropout3 (nn.Dropout): Dropout layer after feedforward network. + activation (nn.ReLU): Activation function for the feedforward network. + pos_enc_at_attn (bool): Flag to add positional encoding at attention. + pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries. + pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys. + + Methods: + forward: Performs the full memory attention operation on input tensors. + _forward_sa: Performs self-attention on input tensor. + _forward_ca: Performs cross-attention between target and memory tensors. + + Examples: + >>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1) + >>> tgt = torch.randn(1, 100, 256) + >>> memory = torch.randn(1, 100, 64) + >>> pos = torch.randn(1, 100, 256) + >>> query_pos = torch.randn(1, 100, 256) + >>> output = layer(tgt, memory, pos, query_pos) + >>> print(output.shape) + torch.Size([1, 100, 256]) + """ + + def __init__( + self, + d_model: int = 256, + dim_feedforward: int = 2048, + dropout: float = 0.1, + pos_enc_at_attn: bool = False, + pos_enc_at_cross_attn_keys: bool = True, + pos_enc_at_cross_attn_queries: bool = False, + ): + """Initializes a memory attention layer with self-attention, cross-attention, and feedforward components.""" + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1) + self.cross_attn_image = RoPEAttention( + rope_k_repeat=True, + embedding_dim=256, + num_heads=1, + downsample_rate=1, + kv_in_dim=64, + ) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = nn.ReLU() + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + """Performs self-attention on input tensor using positional encoding and RoPE attention mechanism.""" + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + """Performs cross-attention between target and memory tensors using RoPEAttention mechanism.""" + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, + v=memory, + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + """Processes input tensors using self-attention, cross-attention, and MLP for memory-based attention.""" + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + """ + Memory attention module for processing sequential data with self and cross-attention mechanisms. + + This class implements a multi-layer attention mechanism that combines self-attention and cross-attention + for processing sequential data, particularly useful in transformer-like architectures. + + Attributes: + d_model (int): The dimension of the model's hidden state. + layers (nn.ModuleList): A list of MemoryAttentionLayer modules. + num_layers (int): The number of attention layers. + norm (nn.LayerNorm): Layer normalization applied to the output. + pos_enc_at_input (bool): Whether to apply positional encoding at the input. + batch_first (bool): Whether the input tensors are in batch-first format. + + Methods: + forward: Processes input tensors through the attention layers. + + Examples: + >>> d_model = 256 + >>> layer = MemoryAttentionLayer(d_model) + >>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3) + >>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model) + >>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model) + >>> curr_pos = torch.randn(10, 32, d_model) + >>> memory_pos = torch.randn(20, 32, d_model) + >>> output = attention(curr, memory, curr_pos, memory_pos) + >>> print(output.shape) + torch.Size([10, 32, 256]) + """ + + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + """Initializes MemoryAttention module with layers and normalization for attention processing.""" + super().__init__() + self.d_model = d_model + self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)]) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + """Processes input tensors through multiple attention layers, applying self and cross-attention mechanisms.""" + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..420a4c3b0d62bedece384921604cb8900328980e --- /dev/null +++ b/ultralytics/models/sam/modules/sam.py @@ -0,0 +1,1013 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import trunc_normal_ + +from ultralytics.nn.modules import MLP + +from .blocks import SAM2TwoWayTransformer +from .decoders import MaskDecoder, SAM2MaskDecoder +from .encoders import ImageEncoderViT, PromptEncoder +from .utils import get_1d_sine_pe, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAMModel(nn.Module): + """ + Segment Anything Model (SAM) for object segmentation tasks. + + This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images + and input prompts. + + Attributes: + mask_threshold (float): Threshold value for mask prediction. + image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings. + prompt_encoder (PromptEncoder): Encoder for various types of input prompts. + mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings. + + Methods: + __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> prompt_encoder = PromptEncoder(...) + >>> mask_decoder = MaskDecoder(...) + >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) + >>> # Further usage depends on SAMPredictor class + + Notes: + All forward() operations are implemented in the SAMPredictor class. + """ + + mask_threshold: float = 0.0 + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = (123.675, 116.28, 103.53), + pixel_std: List[float] = (58.395, 57.12, 57.375), + ) -> None: + """ + Initialize the SAMModel class to predict object masks from an image and input prompts. + + Args: + image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. + pixel_mean (List[float]): Mean values for normalizing pixels in the input image. + pixel_std (List[float]): Std values for normalizing pixels in the input image. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> prompt_encoder = PromptEncoder(...) + >>> mask_decoder = MaskDecoder(...) + >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) + >>> # Further usage depends on SAMPredictor class + + Notes: + All forward() operations moved to SAMPredictor. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + def set_imgsz(self, imgsz): + """ + Set image size to make model compatible with different image sizes. + + Args: + imgsz (Tuple[int, int]): The size of the input image. + """ + if hasattr(self.image_encoder, "set_imgsz"): + self.image_encoder.set_imgsz(imgsz) + self.prompt_encoder.input_image_size = imgsz + self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model + self.image_encoder.img_size = imgsz[0] + + +class SAM2Model(torch.nn.Module): + """ + SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities. + + This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms + for temporal consistency and efficient tracking of objects across frames. + + Attributes: + mask_threshold (float): Threshold value for mask prediction. + image_encoder (ImageEncoderViT): Visual encoder for extracting image features. + memory_attention (nn.Module): Module for attending to memory features. + memory_encoder (nn.Module): Encoder for generating memory representations. + num_maskmem (int): Number of accessible memory frames. + image_size (int): Size of input images. + backbone_stride (int): Stride of the backbone network output. + sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings. + sam_image_embedding_size (int): Size of SAM image embeddings. + sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts. + sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks. + obj_ptr_proj (nn.Module): Projection layer for object pointers. + obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers. + + Methods: + forward_image: Processes image batch through encoder to extract multi-level features. + track_step: Performs a single tracking step, updating object masks and memory features. + + Examples: + >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) + >>> image_batch = torch.rand(1, 3, 512, 512) + >>> features = model.forward_image(image_batch) + >>> track_results = model.track_step(0, True, features, None, None, None, {}) + """ + + mask_threshold: float = 0.0 + + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, + image_size=512, + backbone_stride=16, + sigmoid_scale_for_mem_enc=1.0, + sigmoid_bias_for_mem_enc=0.0, + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, + max_cond_frames_in_attn=-1, + directly_add_no_mem_embed=False, + use_high_res_features_in_sam=False, + multimask_output_in_sam=False, + multimask_min_pt_num=1, + multimask_max_pt_num=1, + multimask_output_for_tracking=False, + use_multimask_token_for_obj_ptr: bool = False, + iou_prediction_use_sigmoid=False, + memory_temporal_stride_for_eval=1, + non_overlap_masks_for_mem_enc=False, + use_obj_ptrs_in_encoder=False, + max_obj_ptrs_in_encoder=16, + add_tpos_enc_to_obj_ptrs=True, + proj_tpos_enc_in_obj_ptrs=False, + use_signed_tpos_enc_to_obj_ptrs=False, + only_obj_ptrs_in_the_past_for_eval=False, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + fixed_no_obj_ptr: bool = False, + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + no_obj_embed_spatial: bool = False, + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + """ + Initializes the SAM2Model for video object segmentation with memory-based tracking. + + Args: + image_encoder (nn.Module): Visual encoder for extracting image features. + memory_attention (nn.Module): Module for attending to memory features. + memory_encoder (nn.Module): Encoder for generating memory representations. + num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames). + image_size (int): Size of input images. + backbone_stride (int): Stride of the image backbone output. + sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability. + sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability. + binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames + with clicks during evaluation. + use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM + prompt encoder and mask decoder on frames with mask input. + max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention. + -1 means no limit. + directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the + first frame. + use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder. + multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial + conditioning frames. + multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM. + multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM. + multimask_output_for_tracking (bool): Whether to use multimask output for tracking. + use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. + iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. + memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. + non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in + memory encoder during evaluation. + use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. + max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder + cross-attention. + add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in + the encoder. + proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional + encoding in object pointers. + use_signed_tpos_enc_to_obj_ptrs (bool): whether to use signed distance (instead of unsigned absolute distance) + in the temporal positional encoding in the object pointers, only relevant when both `use_obj_ptrs_in_encoder=True` + and `add_tpos_enc_to_obj_ptrs=True`. + only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past + during evaluation. + pred_obj_scores (bool): Whether to predict if there is an object in the frame. + pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores. + fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. + soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. + use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. + no_obj_embed_spatial (bool): Whether add no obj embedding to spatial frames. + sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. + compile_image_encoder (bool): Whether to compile the image encoder for faster inference. + + Examples: + >>> image_encoder = ImageEncoderViT(...) + >>> memory_attention = SAM2TwoWayTransformer(...) + >>> memory_encoder = nn.Sequential(...) + >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) + >>> image_batch = torch.rand(1, 3, 512, 512) + >>> features = model.forward_image(image_batch) + >>> track_results = model.track_step(0, True, features, None, None, None, {}) + """ + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.use_signed_tpos_enc_to_obj_ptrs = use_signed_tpos_enc_to_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + self.no_obj_embed_spatial = None + if no_obj_embed_spatial: + self.no_obj_embed_spatial = torch.nn.Parameter(torch.zeros(1, self.mem_dim)) + trunc_normal_(self.no_obj_embed_spatial, std=0.02) + + self._build_sam_heads() + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print("Image encoder compilation is enabled. First forward pass will be slow.") + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + """Returns the device on which the model's parameters are stored.""" + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + """Processes image and prompt inputs to generate object masks and scores in video sequences.""" + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # Build PromptEncoder and MaskDecoder from SAM (hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = SAM2MaskDecoder( + num_multimask_outputs=3, + transformer=SAM2TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward pass through SAM prompt encoders and mask heads. + + This method processes image features and optional point/mask inputs to generate object masks and scores. + + Args: + backbone_features (torch.Tensor): Image features with shape (B, C, H, W). + point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. + 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute + pixel-unit coordinates in (x, y) format for P input points. + 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, + 0 means negative clicks, and -1 means padding. + mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the + same spatial size as the image. + high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes + (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps + for SAM decoder. + multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, + output only 1 mask and its IoU estimate. + + Returns: + (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): + low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. + high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. + ious: Tensor of shape (B, M) with estimated IoU for each output mask. + low_res_masks: Tensor of shape (B, 1, H*4, W*4) with the best low-resolution mask. + high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask. + obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. + object_score_logits: Tensor of shape (B) with object score logits. + + Where M is 3 if multimask_output=True, and 1 if multimask_output=False. + + Examples: + >>> backbone_features = torch.rand(1, 256, 32, 32) + >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} + >>> mask_inputs = torch.rand(1, 1, 512, 512) + >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) + >>> ( + ... low_res_multimasks, + ... high_res_multimasks, + ... ious, + ... low_res_masks, + ... high_res_masks, + ... obj_ptr, + ... object_score_logits, + ... ) = results + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs.float(), + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + low_res_multimasks, ious, sam_output_tokens, object_score_logits = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Spatial memory mask is a *hard* choice between obj and no obj, consistent with actual mask prediction + low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks.float() + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """Processes mask inputs directly as output, bypassing SAM encoder/decoder.""" + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs_float), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing.float() + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Processes image batch through encoder to extract multi-level features for SAM model.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepares and flattens visual features from the image backbone output for further processing.""" + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Prepares memory-conditioned features by fusing current frame's visual features with previous memories.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + num_obj_ptr_tokens = 0 + tpos_sign_mul = -1 if track_in_reverse else 1 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frame's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = 1 if self.training else self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel + elif not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to inference device (it's a no-op if it's already on device). + feats = prev["maskmem_features"].to(device=device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device=device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + ( + ( + (frame_idx - t) * tpos_sign_mul + if self.use_signed_tpos_enc_to_obj_ptrs + else abs(frame_idx - t) + ), + out["obj_ptr"], + ) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if pos_and_ptrs: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + object_score_logits, + is_mask_from_pts, + ): + """Encodes frame features and masks into a new memory representation for video segmentation.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0).float() + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder(pix_feat, mask_for_mem, skip_mask_sigmoid=True) # sigmoid already applied + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + # add a no-object embedding to the spatial memory to indicate that the frame + # is predicted to be occluded (i.e. no object is appearing in the frame) + if self.no_obj_embed_spatial is not None: + is_obj_appearing = (object_score_logits > 0).float() + maskmem_features += (1 - is_obj_appearing[..., None, None]) * self.no_obj_embed_spatial[ + ..., None, None + ].expand(*maskmem_features.shape) + + return maskmem_features, maskmem_pos_enc + + def _track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ): + """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + return current_out, sam_outputs, high_res_features, pix_feat + + def _encode_memory_in_output( + self, + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ): + """Finally run the memory encoder on the predicted mask to encode, it into a new memory feature (that can be + used in future frames). + """ + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" + current_out, sam_outputs, _, _ = self._track_step( + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse, + prev_sam_mask_logits, + ) + _, _, _, low_res_masks, high_res_masks, obj_ptr, object_score_logits = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + if not self.training: + # Only add this in inference (to avoid unused param in activation checkpointing; + # it's mainly used in the demo to encode spatial memories w/ consolidated masks) + current_out["object_score_logits"] = object_score_logits + + # Run memory encoder on the predicted mask to encode it into a new memory feature (for use in future frames) + self._encode_memory_in_output( + current_vision_feats, + feat_sizes, + point_inputs, + run_mem_encoder, + high_res_masks, + object_score_logits, + current_out, + ) + + return current_out + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + return ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + + @staticmethod + def _apply_non_overlapping_constraints(pred_masks): + """Applies non-overlapping constraints to masks, keeping the highest scoring object per location.""" + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks + + def set_binarize(self, binarize=False): + """Set binarize for VideoPredictor.""" + self.binarize_mask_from_pts_for_mem_enc = binarize + + def set_imgsz(self, imgsz): + """ + Set image size to make model compatible with different image sizes. + + Args: + imgsz (Tuple[int, int]): The size of the input image. + """ + self.image_size = imgsz[0] + self.sam_prompt_encoder.input_image_size = imgsz + self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16 diff --git a/ultralytics/models/sam/modules/tiny_encoder.py b/ultralytics/models/sam/modules/tiny_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1b181f7a06b2324bba5c9c310f862354feab9fbb --- /dev/null +++ b/ultralytics/models/sam/modules/tiny_encoder.py @@ -0,0 +1,1013 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from ultralytics.nn.modules import LayerNorm2d +from ultralytics.utils.instance import to_2tuple + + +class Conv2d_BN(torch.nn.Sequential): + """ + A sequential container that performs 2D convolution followed by batch normalization. + + Attributes: + c (torch.nn.Conv2d): 2D convolution layer. + 1 (torch.nn.BatchNorm2d): Batch normalization layer. + + Methods: + __init__: Initializes the Conv2d_BN with specified parameters. + + Args: + a (int): Number of input channels. + b (int): Number of output channels. + ks (int): Kernel size for the convolution. Defaults to 1. + stride (int): Stride for the convolution. Defaults to 1. + pad (int): Padding for the convolution. Defaults to 0. + dilation (int): Dilation factor for the convolution. Defaults to 1. + groups (int): Number of groups for the convolution. Defaults to 1. + bn_weight_init (float): Initial value for batch normalization weight. Defaults to 1. + + Examples: + >>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1) + >>> input_tensor = torch.randn(1, 3, 224, 224) + >>> output = conv_bn(input_tensor) + >>> print(output.shape) + """ + + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): + """Initializes a sequential container with 2D convolution followed by batch normalization.""" + super().__init__() + self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module("bn", bn) + + +class PatchEmbed(nn.Module): + """ + Embeds images into patches and projects them into a specified embedding dimension. + + Attributes: + patches_resolution (Tuple[int, int]): Resolution of the patches after embedding. + num_patches (int): Total number of patches. + in_chans (int): Number of input channels. + embed_dim (int): Dimension of the embedding. + seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding. + + Methods: + forward: Processes the input tensor through the patch embedding sequence. + + Examples: + >>> import torch + >>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = patch_embed(x) + >>> print(output.shape) + """ + + def __init__(self, in_chans, embed_dim, resolution, activation): + """Initializes patch embedding with convolutional layers for image-to-patch conversion and projection.""" + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + """Processes input tensor through patch embedding sequence, converting images to patch embeddings.""" + return self.seq(x) + + +class MBConv(nn.Module): + """ + Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture. + + Attributes: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + out_chans (int): Number of output channels. + conv1 (Conv2d_BN): First convolutional layer. + act1 (nn.Module): First activation function. + conv2 (Conv2d_BN): Depthwise convolutional layer. + act2 (nn.Module): Second activation function. + conv3 (Conv2d_BN): Final convolutional layer. + act3 (nn.Module): Third activation function. + drop_path (nn.Module): Drop path layer (Identity for inference). + + Methods: + forward: Performs the forward pass through the MBConv layer. + + Examples: + >>> in_chans, out_chans = 32, 64 + >>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1) + >>> x = torch.randn(1, in_chans, 56, 56) + >>> output = mbconv(x) + >>> print(output.shape) + torch.Size([1, 64, 56, 56]) + """ + + def __init__(self, in_chans, out_chans, expand_ratio, activation, drop_path): + """Initializes the MBConv layer with specified input/output channels, expansion ratio, and activation.""" + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + # NOTE: `DropPath` is needed only for training. + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + + def forward(self, x): + """Implements the forward pass of MBConv, applying convolutions and skip connection.""" + shortcut = x + x = self.conv1(x) + x = self.act1(x) + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop_path(x) + x += shortcut + return self.act3(x) + + +class PatchMerging(nn.Module): + """ + Merges neighboring patches in the feature map and projects to a new dimension. + + This class implements a patch merging operation that combines spatial information and adjusts the feature + dimension. It uses a series of convolutional layers with batch normalization to achieve this. + + Attributes: + input_resolution (Tuple[int, int]): The input resolution (height, width) of the feature map. + dim (int): The input dimension of the feature map. + out_dim (int): The output dimension after merging and projection. + act (nn.Module): The activation function used between convolutions. + conv1 (Conv2d_BN): The first convolutional layer for dimension projection. + conv2 (Conv2d_BN): The second convolutional layer for spatial merging. + conv3 (Conv2d_BN): The third convolutional layer for final projection. + + Methods: + forward: Applies the patch merging operation to the input tensor. + + Examples: + >>> input_resolution = (56, 56) + >>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU) + >>> x = torch.randn(4, 64, 56, 56) + >>> output = patch_merging(x) + >>> print(output.shape) + """ + + def __init__(self, input_resolution, dim, out_dim, activation): + """Initializes the PatchMerging module for merging and projecting neighboring patches in feature maps.""" + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c = 1 if out_dim in {320, 448, 576} else 2 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + """Applies patch merging and dimension projection to the input feature map.""" + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + return x.flatten(2).transpose(1, 2) + + +class ConvLayer(nn.Module): + """ + Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv). + + This layer optionally applies downsample operations to the output and supports gradient checkpointing. + + Attributes: + dim (int): Dimensionality of the input and output. + input_resolution (Tuple[int, int]): Resolution of the input image. + depth (int): Number of MBConv layers in the block. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + blocks (nn.ModuleList): List of MBConv layers. + downsample (Optional[Callable]): Function for downsampling the output. + + Methods: + forward: Processes the input through the convolutional layers. + + Examples: + >>> input_tensor = torch.randn(1, 64, 56, 56) + >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU) + >>> output = conv_layer(input_tensor) + >>> print(output.shape) + """ + + def __init__( + self, + dim, + input_resolution, + depth, + activation, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4.0, + ): + """ + Initializes the ConvLayer with the given dimensions and settings. + + This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and + optionally applies downsampling to the output. + + Args: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): The resolution of the input image. + depth (int): The number of MBConv layers in the block. + activation (Callable): Activation function applied after each convolution. + drop_path (float | List[float]): Drop path rate. Single float or a list of floats for each MBConv. + downsample (Optional[Callable]): Function for downsampling the output. None to skip downsampling. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + out_dim (Optional[int]): The dimensionality of the output. None means it will be the same as `dim`. + conv_expand_ratio (float): Expansion ratio for the MBConv layers. + + Examples: + >>> input_tensor = torch.randn(1, 64, 56, 56) + >>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU) + >>> output = conv_layer(input_tensor) + >>> print(output.shape) + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # Build blocks + self.blocks = nn.ModuleList( + [ + MBConv( + dim, + dim, + conv_expand_ratio, + activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth) + ] + ) + + # Patch merging layer + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) + + def forward(self, x): + """Processes input through convolutional layers, applying MBConv blocks and optional downsampling.""" + for blk in self.blocks: + x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) + return x if self.downsample is None else self.downsample(x) + + +class Mlp(nn.Module): + """ + Multi-layer Perceptron (MLP) module for transformer architectures. + + This module applies layer normalization, two fully-connected layers with an activation function in between, + and dropout. It is commonly used in transformer-based architectures. + + Attributes: + norm (nn.LayerNorm): Layer normalization applied to the input. + fc1 (nn.Linear): First fully-connected layer. + fc2 (nn.Linear): Second fully-connected layer. + act (nn.Module): Activation function applied after the first fully-connected layer. + drop (nn.Dropout): Dropout layer applied after the activation function. + + Methods: + forward: Applies the MLP operations on the input tensor. + + Examples: + >>> import torch + >>> from torch import nn + >>> mlp = Mlp(in_features=256, hidden_features=512, out_features=256, act_layer=nn.GELU, drop=0.1) + >>> x = torch.randn(32, 100, 256) + >>> output = mlp(x) + >>> print(output.shape) + torch.Size([32, 100, 256]) + """ + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): + """Initializes a multi-layer perceptron with configurable input, hidden, and output dimensions.""" + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Applies MLP operations: layer norm, FC layers, activation, and dropout to the input tensor.""" + x = self.norm(x) + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return self.drop(x) + + +class Attention(torch.nn.Module): + """ + Multi-head attention module with spatial awareness and trainable attention biases. + + This module implements a multi-head attention mechanism with support for spatial awareness, applying + attention biases based on spatial resolution. It includes trainable attention biases for each unique + offset between spatial positions in the resolution grid. + + Attributes: + num_heads (int): Number of attention heads. + scale (float): Scaling factor for attention scores. + key_dim (int): Dimensionality of the keys and queries. + nh_kd (int): Product of num_heads and key_dim. + d (int): Dimensionality of the value vectors. + dh (int): Product of d and num_heads. + attn_ratio (float): Attention ratio affecting the dimensions of the value vectors. + norm (nn.LayerNorm): Layer normalization applied to input. + qkv (nn.Linear): Linear layer for computing query, key, and value projections. + proj (nn.Linear): Linear layer for final projection. + attention_biases (nn.Parameter): Learnable attention biases. + attention_bias_idxs (Tensor): Indices for attention biases. + ab (Tensor): Cached attention biases for inference, deleted during training. + + Methods: + train: Sets the module in training mode and handles the 'ab' attribute. + forward: Performs the forward pass of the attention mechanism. + + Examples: + >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14)) + >>> x = torch.randn(1, 196, 256) + >>> output = attn(x) + >>> print(output.shape) + torch.Size([1, 196, 256]) + """ + + def __init__( + self, + dim, + key_dim, + num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + """ + Initializes the Attention module for multi-head attention with spatial awareness. + + This module implements a multi-head attention mechanism with support for spatial awareness, applying + attention biases based on spatial resolution. It includes trainable attention biases for each unique + offset between spatial positions in the resolution grid. + + Args: + dim (int): The dimensionality of the input and output. + key_dim (int): The dimensionality of the keys and queries. + num_heads (int): Number of attention heads. Default is 8. + attn_ratio (float): Attention ratio, affecting the dimensions of the value vectors. Default is 4. + resolution (Tuple[int, int]): Spatial resolution of the input feature map. Default is (14, 14). + + Raises: + AssertionError: If 'resolution' is not a tuple of length 2. + + Examples: + >>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14)) + >>> x = torch.randn(1, 196, 256) + >>> output = attn(x) + >>> print(output.shape) + torch.Size([1, 196, 256]) + """ + super().__init__() + + assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2" + self.num_heads = num_heads + self.scale = key_dim**-0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product(range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False) + + @torch.no_grad() + def train(self, mode=True): + """Performs multi-head attention with spatial awareness and trainable attention biases.""" + super().train(mode) + if mode and hasattr(self, "ab"): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x + """Applies multi-head attention with spatial awareness and trainable attention biases.""" + B, N, _ = x.shape # B, N, C + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + self.ab = self.ab.to(self.attention_biases.device) + + attn = (q @ k.transpose(-2, -1)) * self.scale + ( + self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + return self.proj(x) + + +class TinyViTBlock(nn.Module): + """ + TinyViT Block that applies self-attention and a local convolution to the input. + + This block is a key component of the TinyViT architecture, combining self-attention mechanisms with + local convolutions to process input features efficiently. + + Attributes: + dim (int): The dimensionality of the input and output. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + num_heads (int): Number of attention heads. + window_size (int): Size of the attention window. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop_path (nn.Module): Stochastic depth layer, identity function during inference. + attn (Attention): Self-attention module. + mlp (Mlp): Multi-layer perceptron module. + local_conv (Conv2d_BN): Depth-wise local convolution layer. + + Methods: + forward: Processes the input through the TinyViT block. + extra_repr: Returns a string with extra information about the block's parameters. + + Examples: + >>> input_tensor = torch.randn(1, 196, 192) + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3) + >>> output = block(input_tensor) + >>> print(output.shape) + torch.Size([1, 196, 192]) + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + local_conv_size=3, + activation=nn.GELU, + ): + """ + Initializes a TinyViT block with self-attention and local convolution. + + This block is a key component of the TinyViT architecture, combining self-attention mechanisms with + local convolutions to process input features efficiently. + + Args: + dim (int): Dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width). + num_heads (int): Number of attention heads. + window_size (int): Size of the attention window. Must be greater than 0. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop (float): Dropout rate. + drop_path (float): Stochastic depth rate. + local_conv_size (int): Kernel size of the local convolution. + activation (torch.nn.Module): Activation function for MLP. + + Raises: + AssertionError: If window_size is not greater than 0. + AssertionError: If dim is not divisible by num_heads. + + Examples: + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3) + >>> input_tensor = torch.randn(1, 196, 192) + >>> output = block(input_tensor) + >>> print(output.shape) + torch.Size([1, 196, 192]) + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, "window_size must be greater than 0" + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + # NOTE: `DropPath` is needed only for training. + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + + assert dim % num_heads == 0, "dim must be divisible by num_heads" + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + """Applies self-attention, local convolution, and MLP operations to the input tensor.""" + h, w = self.input_resolution + b, hw, c = x.shape # batch, height*width, channels + assert hw == h * w, "input feature has wrong size" + res_x = x + if h == self.window_size and w == self.window_size: + x = self.attn(x) + else: + x = x.view(b, h, w, c) + pad_b = (self.window_size - h % self.window_size) % self.window_size + pad_r = (self.window_size - w % self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = h + pad_b, w + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + + # Window partition + x = ( + x.view(b, nH, self.window_size, nW, self.window_size, c) + .transpose(2, 3) + .reshape(b * nH * nW, self.window_size * self.window_size, c) + ) + x = self.attn(x) + + # Window reverse + x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c) + if padding: + x = x[:, :h, :w].contiguous() + + x = x.view(b, hw, c) + + x = res_x + self.drop_path(x) + x = x.transpose(1, 2).reshape(b, c, h, w) + x = self.local_conv(x) + x = x.view(b, c, hw).transpose(1, 2) + + return x + self.drop_path(self.mlp(x)) + + def extra_repr(self) -> str: + """ + Returns a string representation of the TinyViTBlock's parameters. + + This method provides a formatted string containing key information about the TinyViTBlock, including its + dimension, input resolution, number of attention heads, window size, and MLP ratio. + + Returns: + (str): A formatted string containing the block's parameters. + + Examples: + >>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0) + >>> print(block.extra_repr()) + dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0 + """ + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class BasicLayer(nn.Module): + """ + A basic TinyViT layer for one stage in a TinyViT architecture. + + This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks + and an optional downsampling operation. + + Attributes: + dim (int): The dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map. + depth (int): Number of TinyViT blocks in this layer. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + blocks (nn.ModuleList): List of TinyViT blocks that make up this layer. + downsample (nn.Module | None): Downsample layer at the end of the layer, if specified. + + Methods: + forward: Processes the input through the layer's blocks and optional downsampling. + extra_repr: Returns a string with the layer's parameters for printing. + + Examples: + >>> input_tensor = torch.randn(1, 3136, 192) + >>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7) + >>> output = layer(input_tensor) + >>> print(output.shape) + torch.Size([1, 784, 384]) + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + downsample=None, + use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + """ + Initializes a BasicLayer in the TinyViT architecture. + + This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to + process feature maps at a specific resolution and dimensionality within the TinyViT model. + + Args: + dim (int): Dimensionality of the input and output features. + input_resolution (Tuple[int, int]): Spatial resolution of the input feature map (height, width). + depth (int): Number of TinyViT blocks in this layer. + num_heads (int): Number of attention heads in each TinyViT block. + window_size (int): Size of the local window for attention computation. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + drop (float): Dropout rate. + drop_path (float | List[float]): Stochastic depth rate. Can be a float or a list of floats for each block. + downsample (nn.Module | None): Downsampling layer at the end of the layer. None to skip downsampling. + use_checkpoint (bool): Whether to use gradient checkpointing to save memory. + local_conv_size (int): Kernel size for the local convolution in each TinyViT block. + activation (nn.Module): Activation function used in the MLP. + out_dim (int | None): Output dimension after downsampling. None means it will be the same as `dim`. + + Raises: + ValueError: If `drop_path` is a list and its length doesn't match `depth`. + + Examples: + >>> layer = BasicLayer(dim=96, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7) + >>> x = torch.randn(1, 56 * 56, 96) + >>> output = layer(x) + >>> print(output.shape) + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # Build blocks + self.blocks = nn.ModuleList( + [ + TinyViTBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth) + ] + ) + + # Patch merging layer + self.downsample = ( + None + if downsample is None + else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation) + ) + + def forward(self, x): + """Processes input through TinyViT blocks and optional downsampling.""" + for blk in self.blocks: + x = checkpoint.checkpoint(blk, x) if self.use_checkpoint else blk(x) + return x if self.downsample is None else self.downsample(x) + + def extra_repr(self) -> str: + """Returns a string with the layer's parameters for printing.""" + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +class TinyViT(nn.Module): + """ + TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction. + + This class implements the TinyViT model, which combines elements of vision transformers and convolutional + neural networks for improved efficiency and performance on vision tasks. + + Attributes: + img_size (int): Input image size. + num_classes (int): Number of classification classes. + depths (List[int]): Number of blocks in each stage. + num_layers (int): Total number of layers in the network. + mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension. + patch_embed (PatchEmbed): Module for patch embedding. + patches_resolution (Tuple[int, int]): Resolution of embedded patches. + layers (nn.ModuleList): List of network layers. + norm_head (nn.LayerNorm): Layer normalization for the classifier head. + head (nn.Linear): Linear layer for final classification. + neck (nn.Sequential): Neck module for feature refinement. + + Methods: + set_layer_lr_decay: Sets layer-wise learning rate decay. + _init_weights: Initializes weights for linear and normalization layers. + no_weight_decay_keywords: Returns keywords for parameters that should not use weight decay. + forward_features: Processes input through the feature extraction layers. + forward: Performs a forward pass through the entire network. + + Examples: + >>> model = TinyViT(img_size=224, num_classes=1000) + >>> x = torch.randn(1, 3, 224, 224) + >>> features = model.forward_features(x) + >>> print(features.shape) + torch.Size([1, 256, 64, 64]) + """ + + def __init__( + self, + img_size=224, + in_chans=3, + num_classes=1000, + embed_dims=(96, 192, 384, 768), + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_sizes=(7, 7, 14, 7), + mlp_ratio=4.0, + drop_rate=0.0, + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + """ + Initializes the TinyViT model. + + This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of + attention and convolution blocks, and a classification head. + + Args: + img_size (int): Size of the input image. Default is 224. + in_chans (int): Number of input channels. Default is 3. + num_classes (int): Number of classes for classification. Default is 1000. + embed_dims (Tuple[int, int, int, int]): Embedding dimensions for each stage. + Default is (96, 192, 384, 768). + depths (Tuple[int, int, int, int]): Number of blocks in each stage. Default is (2, 2, 6, 2). + num_heads (Tuple[int, int, int, int]): Number of attention heads in each stage. + Default is (3, 6, 12, 24). + window_sizes (Tuple[int, int, int, int]): Window sizes for each stage. Default is (7, 7, 14, 7). + mlp_ratio (float): Ratio of MLP hidden dim to embedding dim. Default is 4.0. + drop_rate (float): Dropout rate. Default is 0.0. + drop_path_rate (float): Stochastic depth rate. Default is 0.1. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default is False. + mbconv_expand_ratio (float): Expansion ratio for MBConv layer. Default is 4.0. + local_conv_size (int): Kernel size for local convolutions. Default is 3. + layer_lr_decay (float): Layer-wise learning rate decay factor. Default is 1.0. + + Examples: + >>> model = TinyViT(img_size=224, num_classes=1000) + >>> x = torch.randn(1, 3, 224, 224) + >>> output = model(x) + >>> print(output.shape) + torch.Size([1, 1000]) + """ + super().__init__() + self.img_size = img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed( + in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation + ) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # Stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # Build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict( + dim=embed_dims[i_layer], + input_resolution=( + patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)), + ), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs, + ) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # Init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + + def set_layer_lr_decay(self, layer_lr_decay): + """Sets layer-wise learning rate decay for the TinyViT model based on depth.""" + decay_rate = layer_lr_decay + + # Layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + + def _set_lr_scale(m, scale): + """Sets the learning rate scale for each layer in the model based on the layer's depth.""" + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + """Checks if the learning rate scale attribute is present in module's parameters.""" + for p in m.parameters(): + assert hasattr(p, "lr_scale"), p.param_name + + self.apply(_check_lr_scale) + + @staticmethod + def _init_weights(m): + """Initializes weights for linear and normalization layers in the TinyViT model.""" + if isinstance(m, nn.Linear): + # NOTE: This initialization is needed only for training. + # trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + """Returns a set of keywords for parameters that should not use weight decay.""" + return {"attention_biases"} + + def forward_features(self, x): + """Processes input through feature extraction layers, returning spatial features.""" + x = self.patch_embed(x) # x input is (N, C, H, W) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + batch, _, channel = x.shape + x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel) + x = x.permute(0, 3, 1, 2) + return self.neck(x) + + def forward(self, x): + """Performs the forward pass through the TinyViT model, extracting features from the input image.""" + return self.forward_features(x) + + def set_imgsz(self, imgsz=[1024, 1024]): + """ + Set image size to make model compatible with different image sizes. + + Args: + imgsz (Tuple[int, int]): The size of the input image. + """ + imgsz = [s // 4 for s in imgsz] + self.patches_resolution = imgsz + for i, layer in enumerate(self.layers): + input_resolution = ( + imgsz[0] // (2 ** (i - 1 if i == 3 else i)), + imgsz[1] // (2 ** (i - 1 if i == 3 else i)), + ) + layer.input_resolution = input_resolution + if layer.downsample is not None: + layer.downsample.input_resolution = input_resolution + if isinstance(layer, BasicLayer): + for b in layer.blocks: + b.input_resolution = input_resolution diff --git a/ultralytics/models/sam/modules/transformer.py b/ultralytics/models/sam/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9b9c2bf6121cf15190db5183e1366f276654a052 --- /dev/null +++ b/ultralytics/models/sam/modules/transformer.py @@ -0,0 +1,373 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +from typing import Tuple, Type + +import torch +from torch import Tensor, nn + +from ultralytics.nn.modules import MLPBlock + + +class TwoWayTransformer(nn.Module): + """ + A Two-Way Transformer module for simultaneous attention to image and query points. + + This class implements a specialized transformer decoder that attends to an input image using queries with + supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point + cloud processing. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Methods: + forward: Processes image and point embeddings through the transformer. + + Examples: + >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 32, 32) + >>> image_pe = torch.randn(1, 256, 32, 32) + >>> point_embedding = torch.randn(1, 100, 256) + >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding) + >>> print(output_queries.shape, output_image.shape) + """ + + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + Initialize a Two-Way Transformer for simultaneous attention to image and query points. + + Args: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. Must divide embedding_dim. + mlp_dim (int): Internal channel dimension for the MLP block. + activation (Type[nn.Module]): Activation function to use in the MLP block. + attention_downsample_rate (int): Downsampling rate for attention mechanism. + + Attributes: + depth (int): Number of layers in the transformer. + embedding_dim (int): Channel dimension for input embeddings. + num_heads (int): Number of heads for multihead attention. + mlp_dim (int): Internal channel dimension for the MLP block. + layers (nn.ModuleList): List of TwoWayAttentionBlock layers. + final_attn_token_to_image (Attention): Final attention layer from queries to image. + norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries. + + Examples: + >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 32, 32) + >>> image_pe = torch.randn(1, 256, 32, 32) + >>> point_embedding = torch.randn(1, 100, 256) + >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding) + >>> print(output_queries.shape, output_image.shape) + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Processes image and point embeddings through the Two-Way Transformer. + + Args: + image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W). + image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding. + point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim). + + Returns: + (Tuple[torch.Tensor, torch.Tensor]): Processed point_embedding and image_embedding. + + Examples: + >>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048) + >>> image_embedding = torch.randn(1, 256, 32, 32) + >>> image_pe = torch.randn(1, 256, 32, 32) + >>> point_embedding = torch.randn(1, 100, 256) + >>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding) + >>> print(output_queries.shape, output_image.shape) + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + """ + A two-way attention block for simultaneous attention to image and query points. + + This class implements a specialized transformer block with four main layers: self-attention on sparse inputs, + cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense + inputs to sparse inputs. + + Attributes: + self_attn (Attention): Self-attention layer for queries. + norm1 (nn.LayerNorm): Layer normalization after self-attention. + cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys. + norm2 (nn.LayerNorm): Layer normalization after token-to-image attention. + mlp (MLPBlock): MLP block for transforming query embeddings. + norm3 (nn.LayerNorm): Layer normalization after MLP block. + norm4 (nn.LayerNorm): Layer normalization after image-to-token attention. + cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries. + skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer. + + Methods: + forward: Applies self-attention and cross-attention to queries and keys. + + Examples: + >>> embedding_dim, num_heads = 256, 8 + >>> block = TwoWayAttentionBlock(embedding_dim, num_heads) + >>> queries = torch.randn(1, 100, embedding_dim) + >>> keys = torch.randn(1, 1000, embedding_dim) + >>> query_pe = torch.randn(1, 100, embedding_dim) + >>> key_pe = torch.randn(1, 1000, embedding_dim) + >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + Initializes a TwoWayAttentionBlock for simultaneous attention to image and query points. + + This block implements a specialized transformer layer with four main components: self-attention on sparse + inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention + of dense inputs to sparse inputs. + + Args: + embedding_dim (int): Channel dimension of the embeddings. + num_heads (int): Number of attention heads in the attention layers. + mlp_dim (int): Hidden dimension of the MLP block. + activation (Type[nn.Module]): Activation function for the MLP block. + attention_downsample_rate (int): Downsampling rate for the attention mechanism. + skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer. + + Examples: + >>> embedding_dim, num_heads = 256, 8 + >>> block = TwoWayAttentionBlock(embedding_dim, num_heads) + >>> queries = torch.randn(1, 100, embedding_dim) + >>> keys = torch.randn(1, 1000, embedding_dim) + >>> query_pe = torch.randn(1, 100, embedding_dim) + >>> key_pe = torch.randn(1, 1000, embedding_dim) + >>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe) + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: + """Applies two-way attention to process query and key embeddings in a transformer block.""" + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer with downscaling capability for embedding size after projection. + + This class implements a multi-head attention mechanism with the option to downsample the internal + dimension of queries, keys, and values. + + Attributes: + embedding_dim (int): Dimensionality of input embeddings. + kv_in_dim (int): Dimensionality of key and value inputs. + internal_dim (int): Internal dimension after downsampling. + num_heads (int): Number of attention heads. + q_proj (nn.Linear): Linear projection for queries. + k_proj (nn.Linear): Linear projection for keys. + v_proj (nn.Linear): Linear projection for values. + out_proj (nn.Linear): Linear projection for output. + + Methods: + _separate_heads: Separates input tensor into attention heads. + _recombine_heads: Recombines separated attention heads. + forward: Computes attention output for given query, key, and value tensors. + + Examples: + >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2) + >>> q = torch.randn(1, 100, 256) + >>> k = v = torch.randn(1, 50, 256) + >>> output = attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 100, 256]) + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + kv_in_dim: int = None, + ) -> None: + """ + Initializes the Attention module with specified dimensions and settings. + + This class implements a multi-head attention mechanism with optional downsampling of the internal + dimension for queries, keys, and values. + + Args: + embedding_dim (int): Dimensionality of input embeddings. + num_heads (int): Number of attention heads. + downsample_rate (int): Factor by which internal dimensions are downsampled. Defaults to 1. + kv_in_dim (int | None): Dimensionality of key and value inputs. If None, uses embedding_dim. + + Raises: + AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate). + + Examples: + >>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2) + >>> q = torch.randn(1, 100, 256) + >>> k = v = torch.randn(1, 50, 256) + >>> output = attn(q, k, v) + >>> print(output.shape) + torch.Size([1, 100, 256]) + """ + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + @staticmethod + def _separate_heads(x: Tensor, num_heads: int) -> Tensor: + """Separates the input tensor into the specified number of attention heads.""" + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + @staticmethod + def _recombine_heads(x: Tensor) -> Tensor: + """Recombines separated attention heads into a single tensor.""" + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + """Applies multi-head attention to query, key, and value tensors with optional downsampling.""" + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + return self.out_proj(out) diff --git a/ultralytics/models/sam/modules/utils.py b/ultralytics/models/sam/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6751b87da2e2cab7892db1fb4a814ea65a9d01c7 --- /dev/null +++ b/ultralytics/models/sam/modules/utils.py @@ -0,0 +1,293 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import Tuple + +import torch +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Selects the closest conditioning frames to a given frame index. + + Args: + frame_idx (int): Current frame index. + cond_frame_outputs (Dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices. + max_cond_frame_num (int): Maximum number of conditioning frames to select. + + Returns: + (Tuple[Dict[int, Any], Dict[int, Any]]): A tuple containing two dictionaries: + - selected_outputs: Selected items from cond_frame_outputs. + - unselected_outputs: Items not selected from cond_frame_outputs. + + Examples: + >>> frame_idx = 5 + >>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"} + >>> max_cond_frame_num = 2 + >>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num) + >>> print(selected) + {3: 'b', 7: 'c'} + >>> print(unselected) + {1: 'a', 9: 'd'} + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs} + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """Generates 1D sinusoidal positional embeddings for given positions and dimensions.""" + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def init_t_xy(end_x: int, end_y: int): + """Initializes 1D and 2D coordinate tensors for a grid of specified dimensions.""" + t = torch.arange(end_x * end_y, dtype=torch.float32) + t_x = (t % end_x).float() + t_y = torch.div(t, end_x, rounding_mode="floor").float() + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + """Computes axial complex exponential positional encodings for 2D spatial positions in a grid.""" + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + """Reshapes frequency tensor for broadcasting with input tensor, ensuring dimensional compatibility.""" + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + """Applies rotary positional encoding to query and key tensors using complex-valued frequency components.""" + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) + + +def window_partition(x, window_size): + """ + Partitions input tensor into non-overlapping windows with padding if needed. + + Args: + x (torch.Tensor): Input tensor with shape (B, H, W, C). + window_size (int): Size of each window. + + Returns: + (Tuple[torch.Tensor, Tuple[int, int]]): A tuple containing: + - windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C). + - (Hp, Wp) (Tuple[int, int]): Padded height and width before partition. + + Examples: + >>> x = torch.randn(1, 16, 16, 3) + >>> windows, (Hp, Wp) = window_partition(x, window_size=4) + >>> print(windows.shape, Hp, Wp) + torch.Size([16, 4, 4, 3]) 16 16 + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Unpartitions windowed sequences into original sequences and removes padding. + + This function reverses the windowing process, reconstructing the original input from windowed segments + and removing any padding that was added during the windowing process. + + Args: + windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size, + window_size, C), where B is the batch size, num_windows is the number of windows, window_size is + the size of each window, and C is the number of channels. + window_size (int): Size of each window. + pad_hw (Tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing. + hw (Tuple[int, int]): Original height and width (H, W) of the input before padding and windowing. + + Returns: + (torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W + are the original height and width, and C is the number of channels. + + Examples: + >>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels + >>> pad_hw = (16, 16) # Padded height and width + >>> hw = (15, 14) # Original height and width + >>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw) + >>> print(x.shape) + torch.Size([1, 15, 14, 64]) + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Extracts relative positional embeddings based on query and key sizes. + + Args: + q_size (int): Size of the query. + k_size (int): Size of the key. + rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative + distance and C is the embedding dimension. + + Returns: + (torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size, + k_size, C). + + Examples: + >>> q_size, k_size = 8, 16 + >>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1 + >>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos) + >>> print(extracted_pos.shape) + torch.Size([8, 16, 64]) + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Adds decomposed Relative Positional Embeddings to the attention map. + + This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2 + paper. It enhances the attention mechanism by incorporating spatial relationships between query and key + positions. + + Args: + attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w). + q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C). + rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C). + q_size (Tuple[int, int]): Spatial sequence size of query q as (q_h, q_w). + k_size (Tuple[int, int]): Spatial sequence size of key k as (k_h, k_w). + + Returns: + (torch.Tensor): Updated attention map with added relative positional embeddings, shape + (B, q_h * q_w, k_h * k_w). + + Examples: + >>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8 + >>> attn = torch.rand(B, q_h * q_w, k_h * k_w) + >>> q = torch.rand(B, q_h * q_w, C) + >>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C) + >>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C) + >>> q_size, k_size = (q_h, q_w), (k_h, k_w) + >>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size) + >>> print(updated_attn.shape) + torch.Size([1, 64, 64]) + + References: + https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( + B, q_h * q_w, k_h * k_w + ) + + return attn diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..345fc7c98fe45b38c073967ea7b7cf1af3921755 --- /dev/null +++ b/ultralytics/models/sam/predict.py @@ -0,0 +1,1605 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Generate predictions using the Segment Anything Model (SAM). + +SAM is an advanced image segmentation model offering features like promptable segmentation and zero-shot performance. +This module contains the implementation of the prediction logic and auxiliary utilities required to perform segmentation +using SAM. It forms an integral part of the Ultralytics framework and is designed for high-performance, real-time image +segmentation tasks. +""" + +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.data.augment import LetterBox +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops +from ultralytics.utils.torch_utils import select_device, smart_inference_mode + +from .amg import ( + batch_iterator, + batched_mask_to_box, + build_all_layer_point_grids, + calculate_stability_score, + generate_crop_boxes, + is_box_near_crop_edge, + remove_small_regions, + uncrop_boxes_xyxy, + uncrop_masks, +) +from .build import build_sam + + +class Predictor(BasePredictor): + """ + Predictor class for SAM, enabling real-time image segmentation with promptable capabilities. + + This class extends BasePredictor and implements the Segment Anything Model (SAM) for advanced image + segmentation tasks. It supports various input prompts like points, bounding boxes, and masks for + fine-grained control over segmentation results. + + Attributes: + args (SimpleNamespace): Configuration arguments for the predictor. + model (torch.nn.Module): The loaded SAM model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + im (torch.Tensor): The preprocessed input image. + features (torch.Tensor): Extracted image features. + prompts (Dict): Dictionary to store various types of prompts (e.g., bboxes, points, masks). + segment_all (bool): Flag to indicate if full image segmentation should be performed. + mean (torch.Tensor): Mean values for image normalization. + std (torch.Tensor): Standard deviation values for image normalization. + + Methods: + preprocess: Prepares input images for model inference. + pre_transform: Performs initial transformations on the input image. + inference: Performs segmentation inference based on input prompts. + prompt_inference: Internal function for prompt-based segmentation inference. + generate: Generates segmentation masks for an entire image. + setup_model: Initializes the SAM model for inference. + get_model: Builds and returns a SAM model. + postprocess: Post-processes model outputs to generate final results. + setup_source: Sets up the data source for inference. + set_image: Sets and preprocesses a single image for inference. + get_im_features: Extracts image features using the SAM image encoder. + set_prompts: Sets prompts for subsequent inference. + reset_image: Resets the current image and its features. + remove_small_regions: Removes small disconnected regions and holes from masks. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> results = predictor(bboxes=bboxes) + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the Predictor with configuration, overrides, and callbacks. + + Sets up the Predictor object for SAM (Segment Anything Model) and applies any configuration overrides or + callbacks provided. Initializes task-specific settings for SAM, such as retina_masks being set to True + for optimal results. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor_example = Predictor(cfg=DEFAULT_CFG) + >>> predictor_example_with_imgsz = Predictor(overrides={"imgsz": 640}) + >>> predictor_example_with_callback = Predictor(_callbacks={"on_predict_start": custom_callback}) + """ + if overrides is None: + overrides = {} + overrides.update(dict(task="segment", mode="predict", batch=1)) + super().__init__(cfg, overrides, _callbacks) + self.args.retina_masks = True + self.im = None + self.features = None + self.prompts = {} + self.segment_all = False + + def preprocess(self, im): + """ + Preprocess the input image for model inference. + + This method prepares the input image by applying transformations and normalization. It supports both + torch.Tensor and list of np.ndarray as input formats. + + Args: + im (torch.Tensor | List[np.ndarray]): Input image(s) in BCHW tensor format or list of HWC numpy arrays. + + Returns: + im (torch.Tensor): The preprocessed image tensor, normalized and converted to the appropriate dtype. + + Examples: + >>> predictor = Predictor() + >>> image = torch.rand(1, 3, 640, 640) + >>> preprocessed_image = predictor.preprocess(image) + """ + if self.im is not None: + return self.im + not_tensor = not isinstance(im, torch.Tensor) + if not_tensor: + im = np.stack(self.pre_transform(im)) + im = im[..., ::-1].transpose((0, 3, 1, 2)) + im = np.ascontiguousarray(im) + im = torch.from_numpy(im) + + im = im.to(self.device) + im = im.half() if self.model.fp16 else im.float() + if not_tensor: + im = (im - self.mean) / self.std + return im + + def pre_transform(self, im): + """ + Perform initial transformations on the input image for preprocessing. + + This method applies transformations such as resizing to prepare the image for further preprocessing. + Currently, batched inference is not supported; hence the list length should be 1. + + Args: + im (List[np.ndarray]): List containing a single image in HWC numpy array format. + + Returns: + (List[np.ndarray]): List containing the transformed image. + + Raises: + AssertionError: If the input list contains more than one image. + + Examples: + >>> predictor = Predictor() + >>> image = np.random.rand(480, 640, 3) # Single HWC image + >>> transformed = predictor.pre_transform([image]) + >>> print(len(transformed)) + 1 + """ + assert len(im) == 1, "SAM model does not currently support batched inference" + letterbox = LetterBox(self.args.imgsz, auto=False, center=False) + return [letterbox(image=x) for x in im] + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False, *args, **kwargs): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. + + This method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt + encoder, and mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List | None): Labels for point prompts, shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions, shape (N, H, W). For SAM H=W=256. + multimask_output (bool): Flag to return multiple masks. Helpful for ambiguous prompts. + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Returns: + (np.ndarray): The output masks in shape (C, H, W), where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + (np.ndarray): Low-resolution logits of shape (C, H, W) for subsequent inference, where H=W=256. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model_path="sam_model.pt") + >>> predictor.set_image("image.jpg") + >>> results = predictor(bboxes=[[0, 0, 100, 100]]) + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + labels = self.prompts.pop("labels", labels) + + if all(i is None for i in [bboxes, points, masks]): + return self.generate(im, *args, **kwargs) + + return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) + + def prompt_inference(self, im, bboxes=None, points=None, labels=None, masks=None, multimask_output=False): + """ + Performs image segmentation inference based on input cues using SAM's specialized architecture. + + This internal function leverages the Segment Anything Model (SAM) for prompt-based, real-time segmentation. + It processes various input prompts such as bounding boxes, points, and masks to generate segmentation masks. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background. + masks (np.ndarray | None): Low-res masks from previous predictions with shape (N, H, W). For SAM, H=W=256. + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores predicted by the model for each mask, with length C. + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) + >>> bboxes = [[100, 100, 200, 200]] + >>> masks, scores, logits = predictor.prompt_inference(im, bboxes=bboxes) + """ + features = self.get_im_features(im) if self.features is None else self.features + + bboxes, points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder(points=points, boxes=bboxes, masks=masks) + + # Predict masks + pred_masks, pred_scores = self.model.mask_decoder( + image_embeddings=features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed bounding boxes, points, labels, and masks. + """ + src_shape = self.batch[1][0].shape[:2] + r = 1.0 if self.segment_all else min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) + # Transform input prompts + if points is not None: + points = torch.as_tensor(points, dtype=torch.float32, device=self.device) + points = points[None] if points.ndim == 1 else points + # Assuming labels are all positive if users don't pass labels. + if labels is None: + labels = np.ones(points.shape[:-1]) + labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device) + assert points.shape[-2] == labels.shape[-1], ( + f"Number of points {points.shape[-2]} should match number of labels {labels.shape[-1]}." + ) + points *= r + if points.ndim == 2: + # (N, 2) --> (N, 1, 2), (N, ) --> (N, 1) + points, labels = points[:, None, :], labels[:, None] + if bboxes is not None: + bboxes = torch.as_tensor(bboxes, dtype=torch.float32, device=self.device) + bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes + bboxes *= r + if masks is not None: + masks = torch.as_tensor(masks, dtype=torch.float32, device=self.device).unsqueeze(1) + return bboxes, points, labels, masks + + def generate( + self, + im, + crop_n_layers=0, + crop_overlap_ratio=512 / 1500, + crop_downscale_factor=1, + point_grids=None, + points_stride=32, + points_batch_size=64, + conf_thres=0.88, + stability_score_thresh=0.95, + stability_score_offset=0.95, + crop_nms_thresh=0.7, + ): + """ + Perform image segmentation using the Segment Anything Model (SAM). + + This method segments an entire image into constituent parts by leveraging SAM's advanced architecture + and real-time performance capabilities. It can optionally work on image crops for finer segmentation. + + Args: + im (torch.Tensor): Input tensor representing the preprocessed image with shape (N, C, H, W). + crop_n_layers (int): Number of layers for additional mask predictions on image crops. + crop_overlap_ratio (float): Overlap between crops, scaled down in subsequent layers. + crop_downscale_factor (int): Scaling factor for sampled points-per-side in each layer. + point_grids (List[np.ndarray] | None): Custom grids for point sampling normalized to [0,1]. + points_stride (int): Number of points to sample along each side of the image. + points_batch_size (int): Batch size for the number of points processed simultaneously. + conf_thres (float): Confidence threshold [0,1] for filtering based on mask quality prediction. + stability_score_thresh (float): Stability threshold [0,1] for mask filtering based on stability. + stability_score_offset (float): Offset value for calculating stability score. + crop_nms_thresh (float): IoU cutoff for NMS to remove duplicate masks between crops. + + Returns: + pred_masks (torch.Tensor): Segmented masks with shape (N, H, W). + pred_scores (torch.Tensor): Confidence scores for each mask with shape (N,). + pred_bboxes (torch.Tensor): Bounding boxes for each mask with shape (N, 4). + + Examples: + >>> predictor = Predictor() + >>> im = torch.rand(1, 3, 1024, 1024) # Example input image + >>> masks, scores, boxes = predictor.generate(im) + """ + import torchvision # scope for faster 'import ultralytics' + + self.segment_all = True + ih, iw = im.shape[2:] + crop_regions, layer_idxs = generate_crop_boxes((ih, iw), crop_n_layers, crop_overlap_ratio) + if point_grids is None: + point_grids = build_all_layer_point_grids(points_stride, crop_n_layers, crop_downscale_factor) + pred_masks, pred_scores, pred_bboxes, region_areas = [], [], [], [] + for crop_region, layer_idx in zip(crop_regions, layer_idxs): + x1, y1, x2, y2 = crop_region + w, h = x2 - x1, y2 - y1 + area = torch.tensor(w * h, device=im.device) + points_scale = np.array([[w, h]]) # w, h + # Crop image and interpolate to input size + crop_im = F.interpolate(im[..., y1:y2, x1:x2], (ih, iw), mode="bilinear", align_corners=False) + # (num_points, 2) + points_for_image = point_grids[layer_idx] * points_scale + crop_masks, crop_scores, crop_bboxes = [], [], [] + for (points,) in batch_iterator(points_batch_size, points_for_image): + pred_mask, pred_score = self.prompt_inference(crop_im, points=points, multimask_output=True) + # Interpolate predicted masks to input size + pred_mask = F.interpolate(pred_mask[None], (h, w), mode="bilinear", align_corners=False)[0] + idx = pred_score > conf_thres + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + + stability_score = calculate_stability_score( + pred_mask, self.model.mask_threshold, stability_score_offset + ) + idx = stability_score > stability_score_thresh + pred_mask, pred_score = pred_mask[idx], pred_score[idx] + # Bool type is much more memory-efficient. + pred_mask = pred_mask > self.model.mask_threshold + # (N, 4) + pred_bbox = batched_mask_to_box(pred_mask).float() + keep_mask = ~is_box_near_crop_edge(pred_bbox, crop_region, [0, 0, iw, ih]) + if not torch.all(keep_mask): + pred_bbox, pred_mask, pred_score = pred_bbox[keep_mask], pred_mask[keep_mask], pred_score[keep_mask] + + crop_masks.append(pred_mask) + crop_bboxes.append(pred_bbox) + crop_scores.append(pred_score) + + # Do nms within this crop + crop_masks = torch.cat(crop_masks) + crop_bboxes = torch.cat(crop_bboxes) + crop_scores = torch.cat(crop_scores) + keep = torchvision.ops.nms(crop_bboxes, crop_scores, self.args.iou) # NMS + crop_bboxes = uncrop_boxes_xyxy(crop_bboxes[keep], crop_region) + crop_masks = uncrop_masks(crop_masks[keep], crop_region, ih, iw) + crop_scores = crop_scores[keep] + + pred_masks.append(crop_masks) + pred_bboxes.append(crop_bboxes) + pred_scores.append(crop_scores) + region_areas.append(area.expand(len(crop_masks))) + + pred_masks = torch.cat(pred_masks) + pred_bboxes = torch.cat(pred_bboxes) + pred_scores = torch.cat(pred_scores) + region_areas = torch.cat(region_areas) + + # Remove duplicate masks between crops + if len(crop_regions) > 1: + scores = 1 / region_areas + keep = torchvision.ops.nms(pred_bboxes, scores, crop_nms_thresh) + pred_masks, pred_bboxes, pred_scores = pred_masks[keep], pred_bboxes[keep], pred_scores[keep] + + return pred_masks, pred_scores, pred_bboxes + + def setup_model(self, model=None, verbose=True): + """ + Initializes the Segment Anything Model (SAM) for inference. + + This method sets up the SAM model by allocating it to the appropriate device and initializing the necessary + parameters for image normalization and other Ultralytics compatibility settings. + + Args: + model (torch.nn.Module | None): A pretrained SAM model. If None, a new model is built based on config. + verbose (bool): If True, prints selected device information. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_model(model=sam_model, verbose=True) + """ + device = select_device(self.args.device, verbose=verbose) + if model is None: + model = self.get_model() + model.eval() + self.model = model.to(device) + self.device = device + self.mean = torch.tensor([123.675, 116.28, 103.53]).view(-1, 1, 1).to(device) + self.std = torch.tensor([58.395, 57.12, 57.375]).view(-1, 1, 1).to(device) + + # Ultralytics compatibility settings + self.model.pt = False + self.model.triton = False + self.model.stride = 32 + self.model.fp16 = False + self.done_warmup = True + + def get_model(self): + """Retrieves or builds the Segment Anything Model (SAM) for image segmentation tasks.""" + return build_sam(self.args.model) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes SAM's inference outputs to generate object detection masks and bounding boxes. + + This method scales masks and boxes to the original image size and applies a threshold to the mask + predictions. It leverages SAM's advanced architecture for real-time, promptable segmentation tasks. + + Args: + preds (Tuple[torch.Tensor]): The output from SAM model inference, containing: + - pred_masks (torch.Tensor): Predicted masks with shape (N, 1, H, W). + - pred_scores (torch.Tensor): Confidence scores for each mask with shape (N, 1). + - pred_bboxes (torch.Tensor, optional): Predicted bounding boxes if segment_all is True. + img (torch.Tensor): The processed input image tensor with shape (C, H, W). + orig_imgs (List[np.ndarray] | torch.Tensor): The original, unprocessed images. + + Returns: + results (List[Results]): List of Results objects containing detection masks, bounding boxes, and other + metadata for each processed image. + + Examples: + >>> predictor = Predictor() + >>> preds = predictor.inference(img) + >>> results = predictor.postprocess(preds, img, orig_imgs) + """ + # (N, 1, H, W), (N, 1) + pred_masks, pred_scores = preds[:2] + pred_bboxes = preds[2] if self.segment_all else None + names = dict(enumerate(str(i) for i in range(len(pred_masks)))) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for masks, orig_img, img_path in zip([pred_masks], orig_imgs, self.batch[0]): + if len(masks) == 0: + masks, pred_bboxes = None, torch.zeros((0, 6), device=pred_masks.device) + else: + masks = ops.scale_masks(masks[None].float(), orig_img.shape[:2], padding=False)[0] + masks = masks > self.model.mask_threshold # to bool + if pred_bboxes is not None: + pred_bboxes = ops.scale_boxes(img.shape[2:], pred_bboxes.float(), orig_img.shape, padding=False) + else: + pred_bboxes = batched_mask_to_box(masks) + # NOTE: SAM models do not return cls info. This `cls` here is just a placeholder for consistency. + cls = torch.arange(len(pred_masks), dtype=torch.int32, device=pred_masks.device) + pred_bboxes = torch.cat([pred_bboxes, pred_scores[:, None], cls[:, None]], dim=-1) + results.append(Results(orig_img, path=img_path, names=names, masks=masks, boxes=pred_bboxes)) + # Reset segment-all mode. + self.segment_all = False + return results + + def setup_source(self, source): + """ + Sets up the data source for inference. + + This method configures the data source from which images will be fetched for inference. It supports + various input types such as image files, directories, video files, and other compatible data sources. + + Args: + source (str | Path | None): The path or identifier for the image data source. Can be a file path, + directory path, URL, or other supported source types. + + Examples: + >>> predictor = Predictor() + >>> predictor.setup_source("path/to/images") + >>> predictor.setup_source("video.mp4") + >>> predictor.setup_source(None) # Uses default source if available + + Notes: + - If source is None, the method may use a default source if configured. + - The method adapts to different source types and prepares them for subsequent inference steps. + - Supported source types may include local files, directories, URLs, and video streams. + """ + if source is not None: + super().setup_source(source) + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference. + + This method prepares the model for inference on a single image by setting up the model if not already + initialized, configuring the data source, and preprocessing the image for feature extraction. It + ensures that only one image is set at a time and extracts image features for subsequent use. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing + an image read by cv2. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(cv2.imread("path/to/image.jpg")) + + Notes: + - This method should be called before performing inference on a new image. + - The extracted features are stored in the `self.features` attribute for later use. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features using the SAM model's image encoder for subsequent mask prediction.""" + assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], ( + f"SAM models only support square image size, but got {self.imgsz}." + ) + self.model.set_imgsz(self.imgsz) + return self.model.image_encoder(im) + + def set_prompts(self, prompts): + """Sets prompts for subsequent inference operations.""" + self.prompts = prompts + + def reset_image(self): + """Resets the current image and its features, clearing them for subsequent inference.""" + self.im = None + self.features = None + + @staticmethod + def remove_small_regions(masks, min_area=0, nms_thresh=0.7): + """ + Remove small disconnected regions and holes from segmentation masks. + + This function performs post-processing on segmentation masks generated by the Segment Anything Model (SAM). + It removes small disconnected regions and holes from the input masks, and then performs Non-Maximum + Suppression (NMS) to eliminate any newly created duplicate boxes. + + Args: + masks (torch.Tensor): Segmentation masks to be processed, with shape (N, H, W) where N is the number of + masks, H is height, and W is width. + min_area (int): Minimum area threshold for removing disconnected regions and holes. Regions smaller than + this will be removed. + nms_thresh (float): IoU threshold for the NMS algorithm to remove duplicate boxes. + + Returns: + new_masks (torch.Tensor): Processed masks with small regions removed, shape (N, H, W). + keep (List[int]): Indices of remaining masks after NMS, for filtering corresponding boxes. + + Examples: + >>> masks = torch.rand(5, 640, 640) > 0.5 # 5 random binary masks + >>> new_masks, keep = remove_small_regions(masks, min_area=100, nms_thresh=0.7) + >>> print(f"Original masks: {masks.shape}, Processed masks: {new_masks.shape}") + >>> print(f"Indices of kept masks: {keep}") + """ + import torchvision # scope for faster 'import ultralytics' + + if len(masks) == 0: + return masks + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for mask in masks: + mask = mask.cpu().numpy().astype(np.uint8) + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and 1 to unchanged masks so NMS prefers masks not needing postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + new_masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(new_masks) + keep = torchvision.ops.nms(boxes.float(), torch.as_tensor(scores), nms_thresh) + + return new_masks[keep].to(device=masks.device, dtype=masks.dtype), keep + + +class SAM2Predictor(Predictor): + """ + SAM2Predictor class for advanced image segmentation using Segment Anything Model 2 architecture. + + This class extends the base Predictor class to implement SAM2-specific functionality for image + segmentation tasks. It provides methods for model initialization, feature extraction, and + prompt-based inference. + + Attributes: + _bb_feat_sizes (List[Tuple[int, int]]): Feature sizes for different backbone levels. + model (torch.nn.Module): The loaded SAM2 model. + device (torch.device): The device (CPU or GPU) on which the model is loaded. + features (Dict[str, torch.Tensor]): Cached image features for efficient inference. + segment_all (bool): Flag to indicate if all segments should be predicted. + prompts (Dict): Dictionary to store various types of prompts for inference. + + Methods: + get_model: Retrieves and initializes the SAM2 model. + prompt_inference: Performs image segmentation inference based on various prompts. + set_image: Preprocesses and sets a single image for inference. + get_im_features: Extracts and processes image features using SAM2's image encoder. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> predictor.set_image("path/to/image.jpg") + >>> bboxes = [[100, 100, 200, 200]] + >>> result = predictor(bboxes=bboxes)[0] + >>> print(f"Predicted {len(result.masks)} masks with average score {result.boxes.conf.mean():.2f}") + """ + + _bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + def get_model(self): + """Retrieves and initializes the Segment Anything Model 2 (SAM2) for image segmentation tasks.""" + return build_sam(self.args.model) + + def prompt_inference( + self, + im, + bboxes=None, + points=None, + labels=None, + masks=None, + multimask_output=False, + img_idx=-1, + ): + """ + Performs image segmentation inference based on various prompts using SAM2 architecture. + + This method leverages the Segment Anything Model 2 (SAM2) to generate segmentation masks for input images + based on provided prompts such as bounding boxes, points, or existing masks. It supports both single and + multi-object prediction scenarios. + + Args: + im (torch.Tensor): Preprocessed input image tensor with shape (N, C, H, W). + bboxes (np.ndarray | List[List[float]] | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List[List[float]] | None): Object location points with shape (N, 2), in pixels. + labels (np.ndarray | List[int] | None): Point prompt labels with shape (N,). 1 = foreground, 0 = background. + masks (np.ndarray | None): Low-resolution masks from previous predictions with shape (N, H, W). + multimask_output (bool): Flag to return multiple masks for ambiguous prompts. + img_idx (int): Index of the image in the batch to process. + + Returns: + (np.ndarray): Output masks with shape (C, H, W), where C is the number of generated masks. + (np.ndarray): Quality scores for each mask, with length C. + + Examples: + >>> predictor = SAM2Predictor(cfg) + >>> image = torch.rand(1, 3, 640, 640) + >>> bboxes = [[100, 100, 200, 200]] + >>> result = predictor(image, bboxes=bboxes)[0] + >>> print(f"Generated {result.masks.shape[0]} masks with average score {result.boxes.conf.mean():.2f}") + + Notes: + - The method supports batched inference for multiple objects when points or bboxes are provided. + - Input prompts (bboxes, points) are automatically scaled to match the input image dimensions. + - When both bboxes and points are provided, they are merged into a single 'points' input for the model. + + References: + - SAM2 Paper: [Add link to SAM2 paper when available] + """ + features = self.get_im_features(im) if self.features is None else self.features + + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + points = (points, labels) if points is not None else None + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=points, + boxes=None, + masks=masks, + ) + # Predict masks + batched_mode = points is not None and points[0].shape[0] > 1 # multi object prediction + high_res_features = [feat_level[img_idx].unsqueeze(0) for feat_level in features["high_res_feats"]] + pred_masks, pred_scores, _, _ = self.model.sam_mask_decoder( + image_embeddings=features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + # (N, d, H, W) --> (N*d, H, W), (N, d) --> (N*d, ) + # `d` could be 1 or 3 depends on `multimask_output`. + return pred_masks.flatten(0, 1), pred_scores.flatten(0, 1) + + def _prepare_prompts(self, dst_shape, bboxes=None, points=None, labels=None, masks=None): + """ + Prepares and transforms the input prompts for processing based on the destination shape. + + Args: + dst_shape (tuple): The target shape (height, width) for the prompts. + bboxes (np.ndarray | List | None): Bounding boxes in XYXY format with shape (N, 4). + points (np.ndarray | List | None): Points indicating object locations with shape (N, 2) or (N, num_points, 2), in pixels. + labels (np.ndarray | List | None): Point prompt labels with shape (N,) or (N, num_points). 1 for foreground, 0 for background. + masks (List | np.ndarray, Optional): Masks for the objects, where each mask is a 2D array. + + Raises: + AssertionError: If the number of points don't match the number of labels, in case labels were passed. + + Returns: + (tuple): A tuple containing transformed points, labels, and masks. + """ + bboxes, points, labels, masks = super()._prepare_prompts(dst_shape, bboxes, points, labels, masks) + if bboxes is not None: + bboxes = bboxes.view(-1, 2, 2) + bbox_labels = torch.tensor([[2, 3]], dtype=torch.int32, device=bboxes.device).expand(len(bboxes), -1) + # NOTE: merge "boxes" and "points" into a single "points" input + # (where boxes are added at the beginning) to model.sam_prompt_encoder + if points is not None: + points = torch.cat([bboxes, points], dim=1) + labels = torch.cat([bbox_labels, labels], dim=1) + else: + points, labels = bboxes, bbox_labels + return points, labels, masks + + def set_image(self, image): + """ + Preprocesses and sets a single image for inference using the SAM2 model. + + This method initializes the model if not already done, configures the data source to the specified image, + and preprocesses the image for feature extraction. It supports setting only one image at a time. + + Args: + image (str | np.ndarray): Path to the image file as a string, or a numpy array representing the image. + + Raises: + AssertionError: If more than one image is attempted to be set. + + Examples: + >>> predictor = SAM2Predictor() + >>> predictor.set_image("path/to/image.jpg") + >>> predictor.set_image(np.array([...])) # Using a numpy array + + Notes: + - This method must be called before performing any inference on a new image. + - The method caches the extracted features for efficient subsequent inferences on the same image. + - Only one image can be set at a time. To process multiple images, call this method for each new image. + """ + if self.model is None: + self.setup_model(model=None) + self.setup_source(image) + assert len(self.dataset) == 1, "`set_image` only supports setting one image!" + for batch in self.dataset: + im = self.preprocess(batch[1]) + self.features = self.get_im_features(im) + break + + def get_im_features(self, im): + """Extracts image features from the SAM image encoder for subsequent processing.""" + assert isinstance(self.imgsz, (tuple, list)) and self.imgsz[0] == self.imgsz[1], ( + f"SAM 2 models only support square image size, but got {self.imgsz}." + ) + self.model.set_imgsz(self.imgsz) + self._bb_feat_sizes = [[x // (4 * i) for x in self.imgsz] for i in [1, 2, 4]] + + backbone_out = self.model.forward_image(im) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + return {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + + +class SAM2VideoPredictor(SAM2Predictor): + """ + SAM2VideoPredictor to handle user interactions with videos and manage inference states. + + This class extends the functionality of SAM2Predictor to support video processing and maintains + the state of inference operations. It includes configurations for managing non-overlapping masks, + clearing memory for non-conditional inputs, and setting up callbacks for prediction events. + + Attributes: + inference_state (Dict): A dictionary to store the current state of inference operations. + non_overlap_masks (bool): A flag indicating whether masks should be non-overlapping. + clear_non_cond_mem_around_input (bool): A flag to control clearing non-conditional memory around inputs. + clear_non_cond_mem_for_multi_obj (bool): A flag to control clearing non-conditional memory for multi-object scenarios. + callbacks (Dict): A dictionary of callbacks for various prediction lifecycle events. + + Args: + cfg (Dict, Optional): Configuration settings for the predictor. Defaults to DEFAULT_CFG. + overrides (Dict, Optional): Additional configuration overrides. Defaults to None. + _callbacks (List, Optional): Custom callbacks to be added. Defaults to None. + + Note: + The `fill_hole_area` attribute is defined but not used in the current implementation. + """ + + # fill_hole_area = 8 # not used + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """ + Initialize the predictor with configuration and optional overrides. + + This constructor initializes the SAM2VideoPredictor with a given configuration, applies any + specified overrides, and sets up the inference state along with certain flags + that control the behavior of the predictor. + + Args: + cfg (Dict): Configuration dictionary containing default settings. + overrides (Dict | None): Dictionary of values to override default configuration. + _callbacks (Dict | None): Dictionary of callback functions to customize behavior. + + Examples: + >>> predictor = SAM2VideoPredictor(cfg=DEFAULT_CFG) + >>> predictor_example_with_imgsz = SAM2VideoPredictor(overrides={"imgsz": 640}) + >>> predictor_example_with_callback = SAM2VideoPredictor(_callbacks={"on_predict_start": custom_callback}) + """ + super().__init__(cfg, overrides, _callbacks) + self.inference_state = {} + self.non_overlap_masks = True + self.clear_non_cond_mem_around_input = False + self.clear_non_cond_mem_for_multi_obj = False + self.callbacks["on_predict_start"].append(self.init_state) + + def get_model(self): + """ + Retrieves and configures the model with binarization enabled. + + Note: + This method overrides the base class implementation to set the binarize flag to True. + """ + model = super().get_model() + model.set_binarize(True) + return model + + def inference(self, im, bboxes=None, points=None, labels=None, masks=None): + """ + Perform image segmentation inference based on the given input cues, using the currently loaded image. This + method leverages SAM's (Segment Anything Model) architecture consisting of image encoder, prompt encoder, and + mask decoder for real-time and promptable segmentation tasks. + + Args: + im (torch.Tensor): The preprocessed input image in tensor format, with shape (N, C, H, W). + bboxes (np.ndarray | List, optional): Bounding boxes with shape (N, 4), in XYXY format. + points (np.ndarray | List, optional): Points indicating object locations with shape (N, 2), in pixels. + labels (np.ndarray | List, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background. + masks (np.ndarray, optional): Low-resolution masks from previous predictions shape (N,H,W). For SAM H=W=256. + + Returns: + (np.ndarray): The output masks in shape CxHxW, where C is the number of generated masks. + (np.ndarray): An array of length C containing quality scores predicted by the model for each mask. + """ + # Override prompts if any stored in self.prompts + bboxes = self.prompts.pop("bboxes", bboxes) + points = self.prompts.pop("points", points) + masks = self.prompts.pop("masks", masks) + + frame = self.dataset.frame + self.inference_state["im"] = im + output_dict = self.inference_state["output_dict"] + if len(output_dict["cond_frame_outputs"]) == 0: # initialize prompts + points, labels, masks = self._prepare_prompts(im.shape[2:], bboxes, points, labels, masks) + if points is not None: + for i in range(len(points)): + self.add_new_prompts(obj_id=i, points=points[[i]], labels=labels[[i]], frame_idx=frame) + elif masks is not None: + for i in range(len(masks)): + self.add_new_prompts(obj_id=i, masks=masks[[i]], frame_idx=frame) + self.propagate_in_video_preflight() + + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + batch_size = len(self.inference_state["obj_idx_to_id"]) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + + if frame in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame] + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame) + elif frame in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame] + else: + storage_key = "non_cond_frame_outputs" + current_out = self._run_single_frame_inference( + output_dict=output_dict, + frame_idx=frame, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=False, + run_mem_encoder=True, + ) + output_dict[storage_key][frame] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object(frame, current_out, storage_key) + self.inference_state["frames_already_tracked"].append(frame) + pred_masks = current_out["pred_masks"].flatten(0, 1) + pred_masks = pred_masks[(pred_masks > self.model.mask_threshold).sum((1, 2)) > 0] # filter blank masks + + return pred_masks, torch.ones(len(pred_masks), dtype=pred_masks.dtype, device=pred_masks.device) + + def postprocess(self, preds, img, orig_imgs): + """ + Post-processes the predictions to apply non-overlapping constraints if required. + + This method extends the post-processing functionality by applying non-overlapping constraints + to the predicted masks if the `non_overlap_masks` flag is set to True. This ensures that + the masks do not overlap, which can be useful for certain applications. + + Args: + preds (Tuple[torch.Tensor]): The predictions from the model. + img (torch.Tensor): The processed image tensor. + orig_imgs (List[np.ndarray]): The original images before processing. + + Returns: + results (list): The post-processed predictions. + + Note: + If `non_overlap_masks` is True, the method applies constraints to ensure non-overlapping masks. + """ + results = super().postprocess(preds, img, orig_imgs) + if self.non_overlap_masks: + for result in results: + if result.masks is None or len(result.masks) == 0: + continue + result.masks.data = self.model._apply_non_overlapping_constraints(result.masks.data.unsqueeze(0))[0] + return results + + @smart_inference_mode() + def add_new_prompts( + self, + obj_id, + points=None, + labels=None, + masks=None, + frame_idx=0, + ): + """ + Adds new points or masks to a specific frame for a given object ID. + + This method updates the inference state with new prompts (points or masks) for a specified + object and frame index. It ensures that the prompts are either points or masks, but not both, + and updates the internal state accordingly. It also handles the generation of new segmentations + based on the provided prompts and the existing state. + + Args: + obj_id (int): The ID of the object to which the prompts are associated. + points (torch.Tensor, Optional): The coordinates of the points of interest. Defaults to None. + labels (torch.Tensor, Optional): The labels corresponding to the points. Defaults to None. + masks (torch.Tensor, optional): Binary masks for the object. Defaults to None. + frame_idx (int, optional): The index of the frame to which the prompts are applied. Defaults to 0. + + Returns: + (tuple): A tuple containing the flattened predicted masks and a tensor of ones indicating the number of objects. + + Raises: + AssertionError: If both `masks` and `points` are provided, or neither is provided. + + Note: + - Only one type of prompt (either points or masks) can be added per call. + - If the frame is being tracked for the first time, it is treated as an initial conditioning frame. + - The method handles the consolidation of outputs and resizing of masks to the original video resolution. + """ + assert (masks is None) ^ (points is None), "'masks' and 'points' prompts are not compatible with each other." + obj_idx = self._obj_id_to_idx(obj_id) + + point_inputs = None + pop_key = "point_inputs_per_obj" + if points is not None: + point_inputs = {"point_coords": points, "point_labels": labels} + self.inference_state["point_inputs_per_obj"][obj_idx][frame_idx] = point_inputs + pop_key = "mask_inputs_per_obj" + self.inference_state["mask_inputs_per_obj"][obj_idx][frame_idx] = masks + self.inference_state[pop_key][obj_idx].pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in self.inference_state["frames_already_tracked"] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.model.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + if point_inputs is not None: + prev_out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + + if prev_out is not None and prev_out.get("pred_masks") is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device=self.device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits.clamp_(-32.0, 32.0) + current_out = self._run_single_frame_inference( + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=masks, + reverse=False, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + ) + pred_masks = consolidated_out["pred_masks"].flatten(0, 1) + return pred_masks.flatten(0, 1), torch.ones(1, dtype=pred_masks.dtype, device=pred_masks.device) + + @smart_inference_mode() + def propagate_in_video_preflight(self): + """ + Prepare inference_state and consolidate temporary outputs before tracking. + + This method marks the start of tracking, disallowing the addition of new objects until the session is reset. + It consolidates temporary outputs from `temp_output_dict_per_obj` and merges them into `output_dict`. + Additionally, it clears non-conditioning memory around input frames and ensures that the state is consistent + with the provided inputs. + """ + # Tracking has started and we don't allow adding new objects until session is reset. + self.inference_state["tracking_has_started"] = True + batch_size = len(self.inference_state["obj_idx_to_id"]) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = self.inference_state["temp_output_dict_per_obj"] + output_dict = self.inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = self.inference_state["consolidated_frame_inds"] + for is_cond in {False, True}: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object(frame_idx, consolidated_out, storage_key) + if self.clear_non_cond_mem_around_input and (self.clear_non_cond_mem_for_multi_obj or batch_size <= 1): + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in self.inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in self.inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @staticmethod + def init_state(predictor): + """ + Initialize an inference state for the predictor. + + This function sets up the initial state required for performing inference on video data. + It includes initializing various dictionaries and ordered dictionaries that will store + inputs, outputs, and other metadata relevant to the tracking process. + + Args: + predictor (SAM2VideoPredictor): The predictor object for which to initialize the state. + """ + if len(predictor.inference_state) > 0: # means initialized + return + assert predictor.dataset is not None + assert predictor.dataset.mode == "video" + + inference_state = { + "num_frames": predictor.dataset.frames, + "point_inputs_per_obj": {}, # inputs points on each frame + "mask_inputs_per_obj": {}, # inputs mask on each frame + "constants": {}, # values that don't change across frames (so we only need to hold one copy of them) + # mapping between client-side object id and model-side object index + "obj_id_to_idx": OrderedDict(), + "obj_idx_to_id": OrderedDict(), + "obj_ids": [], + # A storage to hold the model's tracking results and states on each frame + "output_dict": { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + }, + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + "output_dict_per_obj": {}, + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + "temp_output_dict_per_obj": {}, + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + "consolidated_frame_inds": { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + }, + # metadata for each tracking frame (e.g. which direction it's tracked) + "tracking_has_started": False, + "frames_already_tracked": [], + } + predictor.inference_state = inference_state + + def get_im_features(self, im, batch=1): + """ + Extracts and processes image features using SAM2's image encoder for subsequent segmentation tasks. + + Args: + im (torch.Tensor): The input image tensor. + batch (int, optional): The batch size for expanding features if there are multiple prompts. Defaults to 1. + + Returns: + vis_feats (torch.Tensor): The visual features extracted from the image. + vis_pos_embed (torch.Tensor): The positional embeddings for the visual features. + feat_sizes (List(Tuple[int])): A list containing the sizes of the extracted features. + + Note: + - If `batch` is greater than 1, the features are expanded to fit the batch size. + - The method leverages the model's `_prepare_backbone_features` method to prepare the backbone features. + """ + backbone_out = self.model.forward_image(im) + if batch > 1: # expand features if there's more than one prompt + for i, feat in enumerate(backbone_out["backbone_fpn"]): + backbone_out["backbone_fpn"][i] = feat.expand(batch, -1, -1, -1) + for i, pos in enumerate(backbone_out["vision_pos_enc"]): + pos = pos.expand(batch, -1, -1, -1) + backbone_out["vision_pos_enc"][i] = pos + _, vis_feats, vis_pos_embed, feat_sizes = self.model._prepare_backbone_features(backbone_out) + return vis_feats, vis_pos_embed, feat_sizes + + def _obj_id_to_idx(self, obj_id): + """ + Map client-side object id to model-side object index. + + Args: + obj_id (int): The unique identifier of the object provided by the client side. + + Returns: + obj_idx (int): The index of the object on the model side. + + Raises: + RuntimeError: If an attempt is made to add a new object after tracking has started. + + Note: + - The method updates or retrieves mappings between object IDs and indices stored in + `inference_state`. + - It ensures that new objects can only be added before tracking commences. + - It maintains two-way mappings between IDs and indices (`obj_id_to_idx` and `obj_idx_to_id`). + - Additional data structures are initialized for the new object to store inputs and outputs. + """ + obj_idx = self.inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not self.inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(self.inference_state["obj_id_to_idx"]) + self.inference_state["obj_id_to_idx"][obj_id] = obj_idx + self.inference_state["obj_idx_to_id"][obj_idx] = obj_id + self.inference_state["obj_ids"] = list(self.inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + self.inference_state["point_inputs_per_obj"][obj_idx] = {} + self.inference_state["mask_inputs_per_obj"][obj_idx] = {} + self.inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + self.inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {self.inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _run_single_frame_inference( + self, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """ + Run tracking on a single frame based on current inputs and previous memory. + + Args: + output_dict (Dict): The dictionary containing the output states of the tracking process. + frame_idx (int): The index of the current frame. + batch_size (int): The batch size for processing the frame. + is_init_cond_frame (bool): Indicates if the current frame is an initial conditioning frame. + point_inputs (Dict, Optional): Input points and their labels. Defaults to None. + mask_inputs (torch.Tensor, Optional): Input binary masks. Defaults to None. + reverse (bool): Indicates if the tracking should be performed in reverse order. + run_mem_encoder (bool): Indicates if the memory encoder should be executed. + prev_sam_mask_logits (torch.Tensor, Optional): Previous mask logits for the current object. Defaults to None. + + Returns: + current_out (dict): A dictionary containing the output of the tracking step, including updated features and predictions. + + Raises: + AssertionError: If both `point_inputs` and `mask_inputs` are provided, or neither is provided. + + Note: + - The method assumes that `point_inputs` and `mask_inputs` are mutually exclusive. + - The method retrieves image features using the `get_im_features` method. + - The `maskmem_pos_enc` is assumed to be constant across frames, hence only one copy is stored. + - The `fill_holes_in_mask_scores` function is commented out and currently unsupported due to CUDA extension requirements. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features( + self.inference_state["im"], batch_size + ) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=self.inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + current_out["maskmem_features"] = maskmem_features.to( + dtype=torch.float16, device=self.device, non_blocking=True + ) + # NOTE: Do not support the `fill_holes_in_mask_scores` function since it needs cuda extensions + # potentially fill holes in the predicted masks + # if self.fill_hole_area > 0: + # pred_masks = current_out["pred_masks"].to(self.device, non_blocking=True) + # pred_masks = fill_holes_in_mask_scores(pred_masks, self.fill_hole_area) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + current_out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(current_out["maskmem_pos_enc"]) + return current_out + + def _get_maskmem_pos_enc(self, out_maskmem_pos_enc): + """ + Caches and manages the positional encoding for mask memory across frames and objects. + + This method optimizes storage by caching the positional encoding (`maskmem_pos_enc`) for + mask memory, which is constant across frames and objects, thus reducing the amount of + redundant information stored during an inference session. It checks if the positional + encoding has already been cached; if not, it caches a slice of the provided encoding. + If the batch size is greater than one, it expands the cached positional encoding to match + the current batch size. + + Args: + out_maskmem_pos_enc (List[torch.Tensor] or None): The positional encoding for mask memory. + Should be a list of tensors or None. + + Returns: + out_maskmem_pos_enc (List[torch.Tensor]): The positional encoding for mask memory, either cached or expanded. + + Note: + - The method assumes that `out_maskmem_pos_enc` is a list of tensors or None. + - Only a single object's slice is cached since the encoding is the same across objects. + - The method checks if the positional encoding has already been cached in the session's constants. + - If the batch size is greater than one, the cached encoding is expanded to fit the batch size. + """ + model_constants = self.inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + if batch_size > 1: + out_maskmem_pos_enc = [x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc] + return out_maskmem_pos_enc + + def _consolidate_temp_output_across_obj( + self, + frame_idx, + is_cond=False, + run_mem_encoder=False, + ): + """ + Consolidates per-object temporary outputs into a single output for all objects. + + This method combines the temporary outputs for each object on a given frame into a unified + output. It fills in any missing objects either from the main output dictionary or leaves + placeholders if they do not exist in the main output. Optionally, it can re-run the memory + encoder after applying non-overlapping constraints to the object scores. + + Args: + frame_idx (int): The index of the frame for which to consolidate outputs. + is_cond (bool, Optional): Indicates if the frame is considered a conditioning frame. + Defaults to False. + run_mem_encoder (bool, Optional): Specifies whether to run the memory encoder after + consolidating the outputs. Defaults to False. + + Returns: + consolidated_out (dict): A consolidated output dictionary containing the combined results for all objects. + + Note: + - The method initializes the consolidated output with placeholder values for missing objects. + - It searches for outputs in both the temporary and main output dictionaries. + - If `run_mem_encoder` is True, it applies non-overlapping constraints and re-runs the memory encoder. + - The `maskmem_features` and `maskmem_pos_enc` are only populated when `run_mem_encoder` is True. + """ + batch_size = len(self.inference_state["obj_idx_to_id"]) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": torch.full( + size=(batch_size, 1, self.imgsz[0] // 4, self.imgsz[1] // 4), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "obj_ptr": torch.full( + size=(batch_size, self.model.hidden_dim), + fill_value=-1024.0, + dtype=torch.float32, + device=self.device, + ), + "object_score_logits": torch.full( + size=(batch_size, 1), + # default to 10.0 for object_score_logits, i.e. assuming the object is + # present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder` + fill_value=10.0, + dtype=torch.float32, + device=self.device, + ), + } + for obj_idx in range(batch_size): + obj_temp_output_dict = self.inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = self.inference_state["output_dict_per_obj"][obj_idx] + out = ( + obj_temp_output_dict[storage_key].get(frame_idx) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + or obj_output_dict["cond_frame_outputs"].get(frame_idx) + or obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + ) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = self._get_empty_mask_ptr(frame_idx) + continue + # Add the temporary object output mask to consolidated output mask + consolidated_out["pred_masks"][obj_idx : obj_idx + 1] = out["pred_masks"] + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores and rerun the memory encoder + if run_mem_encoder: + high_res_masks = F.interpolate( + consolidated_out["pred_masks"], + size=self.imgsz, + mode="bilinear", + align_corners=False, + ) + if self.model.non_overlap_masks_for_mem_enc: + high_res_masks = self.model._apply_non_overlapping_constraints(high_res_masks) + consolidated_out["maskmem_features"], consolidated_out["maskmem_pos_enc"] = self._run_memory_encoder( + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + object_score_logits=consolidated_out["object_score_logits"], + ) + + return consolidated_out + + def _get_empty_mask_ptr(self, frame_idx): + """ + Get a dummy object pointer based on an empty mask on the current frame. + + Args: + frame_idx (int): The index of the current frame for which to generate the dummy object pointer. + + Returns: + (torch.Tensor): A tensor representing the dummy object pointer generated from the empty mask. + """ + # Retrieve correct image features + current_vision_feats, current_vision_pos_embeds, feat_sizes = self.get_im_features(self.inference_state["im"]) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.model.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + # A dummy (empty) mask with a single object + mask_inputs=torch.zeros((1, 1, *self.imgsz), dtype=torch.float32, device=self.device), + output_dict={}, + num_frames=self.inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + def _run_memory_encoder(self, batch_size, high_res_masks, object_score_logits, is_mask_from_pts): + """ + Run the memory encoder on masks. + + This is usually after applying non-overlapping constraints to object scores. Since their scores changed, their + memory also needs to be computed again with the memory encoder. + + Args: + batch_size (int): The batch size for processing the frame. + high_res_masks (torch.Tensor): High-resolution masks for which to compute the memory. + object_score_logits (torch.Tensor): Logits representing the object scores. + is_mask_from_pts (bool): Indicates if the mask is derived from point interactions. + + Returns: + (tuple[torch.Tensor, torch.Tensor]): A tuple containing the encoded mask features and positional encoding. + """ + # Retrieve correct image features + current_vision_feats, _, feat_sizes = self.get_im_features(self.inference_state["im"], batch_size) + maskmem_features, maskmem_pos_enc = self.model._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + object_score_logits=object_score_logits, + ) + + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(maskmem_pos_enc) + return maskmem_features.to(dtype=torch.float16, device=self.device, non_blocking=True), maskmem_pos_enc + + def _add_output_per_object(self, frame_idx, current_out, storage_key): + """ + Split a multi-object output into per-object output slices and add them into Output_Dict_Per_Obj. + + The resulting slices share the same tensor storage. + + Args: + frame_idx (int): The index of the current frame. + current_out (Dict): The current output dictionary containing multi-object outputs. + storage_key (str): The key used to store the output in the per-object output dictionary. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + for obj_idx, obj_output_dict in self.inference_state["output_dict_per_obj"].items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + def _clear_non_cond_mem_around_input(self, frame_idx): + """ + Remove the non-conditioning memory around the input frame. + + When users provide correction clicks, the surrounding frames' non-conditioning memories can still contain outdated + object appearance information and could confuse the model. This method clears those non-conditioning memories + surrounding the interacted frame to avoid giving the model both old and new information about the object. + + Args: + frame_idx (int): The index of the current frame where user interaction occurred. + """ + r = self.model.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.model.num_maskmem + frame_idx_end = frame_idx + r * self.model.num_maskmem + for t in range(frame_idx_begin, frame_idx_end + 1): + self.inference_state["output_dict"]["non_cond_frame_outputs"].pop(t, None) + for obj_output_dict in self.inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/ultralytics/models/utils/__init__.py b/ultralytics/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/ultralytics/models/utils/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/ultralytics/models/utils/loss.py b/ultralytics/models/utils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..42f437439b838e63b281fc34622cd714f2fd1996 --- /dev/null +++ b/ultralytics/models/utils/loss.py @@ -0,0 +1,357 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils.loss import FocalLoss, VarifocalLoss +from ultralytics.utils.metrics import bbox_iou + +from .ops import HungarianMatcher + + +class DETRLoss(nn.Module): + """ + DETR (DEtection TRansformer) Loss class. This class calculates and returns the different loss components for the + DETR object detection model. It computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary + losses. + + Attributes: + nc (int): The number of classes. + loss_gain (dict): Coefficients for different loss components. + aux_loss (bool): Whether to compute auxiliary losses. + use_fl (bool): Use FocalLoss or not. + use_vfl (bool): Use VarifocalLoss or not. + use_uni_match (bool): Whether to use a fixed layer to assign labels for the auxiliary branch. + uni_match_ind (int): The fixed indices of a layer to use if `use_uni_match` is True. + matcher (HungarianMatcher): Object to compute matching cost and indices. + fl (FocalLoss or None): Focal Loss object if `use_fl` is True, otherwise None. + vfl (VarifocalLoss or None): Varifocal Loss object if `use_vfl` is True, otherwise None. + device (torch.device): Device on which tensors are stored. + """ + + def __init__( + self, nc=80, loss_gain=None, aux_loss=True, use_fl=True, use_vfl=False, use_uni_match=False, uni_match_ind=0 + ): + """ + Initialize DETR loss function with customizable components and gains. + + Uses default loss_gain if not provided. Initializes HungarianMatcher with + preset cost gains. Supports auxiliary losses and various loss types. + + Args: + nc (int): Number of classes. + loss_gain (dict): Coefficients for different loss components. + aux_loss (bool): Use auxiliary losses from each decoder layer. + use_fl (bool): Use FocalLoss. + use_vfl (bool): Use VarifocalLoss. + use_uni_match (bool): Use fixed layer for auxiliary branch label assignment. + uni_match_ind (int): Index of fixed layer for uni_match. + """ + super().__init__() + + if loss_gain is None: + loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1} + self.nc = nc + self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2}) + self.loss_gain = loss_gain + self.aux_loss = aux_loss + self.fl = FocalLoss() if use_fl else None + self.vfl = VarifocalLoss() if use_vfl else None + + self.use_uni_match = use_uni_match + self.uni_match_ind = uni_match_ind + self.device = None + + def _get_loss_class(self, pred_scores, targets, gt_scores, num_gts, postfix=""): + """Computes the classification loss based on predictions, target values, and ground truth scores.""" + # Logits: [b, query, num_classes], gt_class: list[[n, 1]] + name_class = f"loss_class{postfix}" + bs, nq = pred_scores.shape[:2] + # one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes) + one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device) + one_hot.scatter_(2, targets.unsqueeze(-1), 1) + one_hot = one_hot[..., :-1] + gt_scores = gt_scores.view(bs, nq, 1) * one_hot + + if self.fl: + if num_gts and self.vfl: + loss_cls = self.vfl(pred_scores, gt_scores, one_hot) + else: + loss_cls = self.fl(pred_scores, one_hot.float()) + loss_cls /= max(num_gts, 1) / nq + else: + loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss + + return {name_class: loss_cls.squeeze() * self.loss_gain["class"]} + + def _get_loss_bbox(self, pred_bboxes, gt_bboxes, postfix=""): + """Computes bounding box and GIoU losses for predicted and ground truth bounding boxes.""" + # Boxes: [b, query, 4], gt_bbox: list[[n, 4]] + name_bbox = f"loss_bbox{postfix}" + name_giou = f"loss_giou{postfix}" + + loss = {} + if len(gt_bboxes) == 0: + loss[name_bbox] = torch.tensor(0.0, device=self.device) + loss[name_giou] = torch.tensor(0.0, device=self.device) + return loss + + loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes) + loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True) + loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes) + loss[name_giou] = self.loss_gain["giou"] * loss[name_giou] + return {k: v.squeeze() for k, v in loss.items()} + + # This function is for future RT-DETR Segment models + # def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''): + # # masks: [b, query, h, w], gt_mask: list[[n, H, W]] + # name_mask = f'loss_mask{postfix}' + # name_dice = f'loss_dice{postfix}' + # + # loss = {} + # if sum(len(a) for a in gt_mask) == 0: + # loss[name_mask] = torch.tensor(0., device=self.device) + # loss[name_dice] = torch.tensor(0., device=self.device) + # return loss + # + # num_gts = len(gt_mask) + # src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices) + # src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0] + # # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now. + # loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks, + # torch.tensor([num_gts], dtype=torch.float32)) + # loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts) + # return loss + + # This function is for future RT-DETR Segment models + # @staticmethod + # def _dice_loss(inputs, targets, num_gts): + # inputs = F.sigmoid(inputs).flatten(1) + # targets = targets.flatten(1) + # numerator = 2 * (inputs * targets).sum(1) + # denominator = inputs.sum(-1) + targets.sum(-1) + # loss = 1 - (numerator + 1) / (denominator + 1) + # return loss.sum() / num_gts + + def _get_loss_aux( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + match_indices=None, + postfix="", + masks=None, + gt_mask=None, + ): + """Get auxiliary losses.""" + # NOTE: loss class, bbox, giou, mask, dice + loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device) + if match_indices is None and self.use_uni_match: + match_indices = self.matcher( + pred_bboxes[self.uni_match_ind], + pred_scores[self.uni_match_ind], + gt_bboxes, + gt_cls, + gt_groups, + masks=masks[self.uni_match_ind] if masks is not None else None, + gt_mask=gt_mask, + ) + for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)): + aux_masks = masks[i] if masks is not None else None + loss_ = self._get_loss( + aux_bboxes, + aux_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=aux_masks, + gt_mask=gt_mask, + postfix=postfix, + match_indices=match_indices, + ) + loss[0] += loss_[f"loss_class{postfix}"] + loss[1] += loss_[f"loss_bbox{postfix}"] + loss[2] += loss_[f"loss_giou{postfix}"] + # if masks is not None and gt_mask is not None: + # loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix) + # loss[3] += loss_[f'loss_mask{postfix}'] + # loss[4] += loss_[f'loss_dice{postfix}'] + + loss = { + f"loss_class_aux{postfix}": loss[0], + f"loss_bbox_aux{postfix}": loss[1], + f"loss_giou_aux{postfix}": loss[2], + } + # if masks is not None and gt_mask is not None: + # loss[f'loss_mask_aux{postfix}'] = loss[3] + # loss[f'loss_dice_aux{postfix}'] = loss[4] + return loss + + @staticmethod + def _get_index(match_indices): + """Returns batch indices, source indices, and destination indices from provided match indices.""" + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)]) + src_idx = torch.cat([src for (src, _) in match_indices]) + dst_idx = torch.cat([dst for (_, dst) in match_indices]) + return (batch_idx, src_idx), dst_idx + + def _get_assigned_bboxes(self, pred_bboxes, gt_bboxes, match_indices): + """Assigns predicted bounding boxes to ground truth bounding boxes based on the match indices.""" + pred_assigned = torch.cat( + [ + t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (i, _) in zip(pred_bboxes, match_indices) + ] + ) + gt_assigned = torch.cat( + [ + t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device) + for t, (_, j) in zip(gt_bboxes, match_indices) + ] + ) + return pred_assigned, gt_assigned + + def _get_loss( + self, + pred_bboxes, + pred_scores, + gt_bboxes, + gt_cls, + gt_groups, + masks=None, + gt_mask=None, + postfix="", + match_indices=None, + ): + """Get losses.""" + if match_indices is None: + match_indices = self.matcher( + pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask + ) + + idx, gt_idx = self._get_index(match_indices) + pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx] + + bs, nq = pred_scores.shape[:2] + targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype) + targets[idx] = gt_cls[gt_idx] + + gt_scores = torch.zeros([bs, nq], device=pred_scores.device) + if len(gt_bboxes): + gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1) + + return { + **self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix), + **self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix), + # **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {}) + } + + def forward(self, pred_bboxes, pred_scores, batch, postfix="", **kwargs): + """ + Calculate loss for predicted bounding boxes and scores. + + Args: + pred_bboxes (torch.Tensor): Predicted bounding boxes, shape [l, b, query, 4]. + pred_scores (torch.Tensor): Predicted class scores, shape [l, b, query, num_classes]. + batch (dict): Batch information containing: + cls (torch.Tensor): Ground truth classes, shape [num_gts]. + bboxes (torch.Tensor): Ground truth bounding boxes, shape [num_gts, 4]. + gt_groups (List[int]): Number of ground truths for each image in the batch. + postfix (str): Postfix for loss names. + **kwargs (Any): Additional arguments, may include 'match_indices'. + + Returns: + (dict): Computed losses, including main and auxiliary (if enabled). + + Note: + Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if + self.aux_loss is True. + """ + self.device = pred_bboxes.device + match_indices = kwargs.get("match_indices", None) + gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"] + + total_loss = self._get_loss( + pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices + ) + + if self.aux_loss: + total_loss.update( + self._get_loss_aux( + pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix + ) + ) + + return total_loss + + +class RTDETRDetectionLoss(DETRLoss): + """ + Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss. + + This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as + an additional denoising training loss when provided with denoising metadata. + """ + + def forward(self, preds, batch, dn_bboxes=None, dn_scores=None, dn_meta=None): + """ + Forward pass to compute the detection loss. + + Args: + preds (tuple): Predicted bounding boxes and scores. + batch (dict): Batch data containing ground truth information. + dn_bboxes (torch.Tensor, optional): Denoising bounding boxes. Default is None. + dn_scores (torch.Tensor, optional): Denoising scores. Default is None. + dn_meta (dict, optional): Metadata for denoising. Default is None. + + Returns: + (dict): Dictionary containing the total loss and, if applicable, the denoising loss. + """ + pred_bboxes, pred_scores = preds + total_loss = super().forward(pred_bboxes, pred_scores, batch) + + # Check for denoising metadata to compute denoising training loss + if dn_meta is not None: + dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"] + assert len(batch["gt_groups"]) == len(dn_pos_idx) + + # Get the match indices for denoising + match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"]) + + # Compute the denoising training loss + dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices) + total_loss.update(dn_loss) + else: + # If no denoising metadata is provided, set denoising loss to zero + total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()}) + + return total_loss + + @staticmethod + def get_dn_match_indices(dn_pos_idx, dn_num_group, gt_groups): + """ + Get the match indices for denoising. + + Args: + dn_pos_idx (List[torch.Tensor]): List of tensors containing positive indices for denoising. + dn_num_group (int): Number of denoising groups. + gt_groups (List[int]): List of integers representing the number of ground truths for each image. + + Returns: + (List[tuple]): List of tuples containing matched indices for denoising. + """ + dn_match_indices = [] + idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) + for i, num_gt in enumerate(gt_groups): + if num_gt > 0: + gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i] + gt_idx = gt_idx.repeat(dn_num_group) + assert len(dn_pos_idx[i]) == len(gt_idx), "Expected the same length, " + f"but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively." + dn_match_indices.append((dn_pos_idx[i], gt_idx)) + else: + dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long))) + return dn_match_indices diff --git a/ultralytics/models/utils/ops.py b/ultralytics/models/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..eafc744078ca2c9d8f3080c6f4964a1824500b58 --- /dev/null +++ b/ultralytics/models/utils/ops.py @@ -0,0 +1,259 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from ultralytics.utils.metrics import bbox_iou +from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh + + +class HungarianMatcher(nn.Module): + """ + A module implementing the HungarianMatcher, which is a differentiable module to solve the assignment problem in an + end-to-end fashion. + + HungarianMatcher performs optimal assignment over the predicted and ground truth bounding boxes using a cost + function that considers classification scores, bounding box coordinates, and optionally, mask predictions. + + Attributes: + cost_gain (dict): Dictionary of cost coefficients: 'class', 'bbox', 'giou', 'mask', and 'dice'. + use_fl (bool): Indicates whether to use Focal Loss for the classification cost calculation. + with_mask (bool): Indicates whether the model makes mask predictions. + num_sample_points (int): The number of sample points used in mask cost calculation. + alpha (float): The alpha factor in Focal Loss calculation. + gamma (float): The gamma factor in Focal Loss calculation. + + Methods: + forward(pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): Computes the + assignment between predictions and ground truths for a batch. + _cost_mask(bs, num_gts, masks=None, gt_mask=None): Computes the mask cost and dice cost if masks are predicted. + """ + + def __init__(self, cost_gain=None, use_fl=True, with_mask=False, num_sample_points=12544, alpha=0.25, gamma=2.0): + """Initializes a HungarianMatcher module for optimal assignment of predicted and ground truth bounding boxes.""" + super().__init__() + if cost_gain is None: + cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1} + self.cost_gain = cost_gain + self.use_fl = use_fl + self.with_mask = with_mask + self.num_sample_points = num_sample_points + self.alpha = alpha + self.gamma = gamma + + def forward(self, pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=None, gt_mask=None): + """ + Forward pass for HungarianMatcher. This function computes costs based on prediction and ground truth + (classification cost, L1 cost between boxes and GIoU cost between boxes) and finds the optimal matching between + predictions and ground truth based on these costs. + + Args: + pred_bboxes (Tensor): Predicted bounding boxes with shape [batch_size, num_queries, 4]. + pred_scores (Tensor): Predicted scores with shape [batch_size, num_queries, num_classes]. + gt_cls (torch.Tensor): Ground truth classes with shape [num_gts, ]. + gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape [num_gts, 4]. + gt_groups (List[int]): List of length equal to batch size, containing the number of ground truths for + each image. + masks (Tensor, optional): Predicted masks with shape [batch_size, num_queries, height, width]. + Defaults to None. + gt_mask (List[Tensor], optional): List of ground truth masks, each with shape [num_masks, Height, Width]. + Defaults to None. + + Returns: + (List[Tuple[Tensor, Tensor]]): A list of size batch_size, each element is a tuple (index_i, index_j), where: + - index_i is the tensor of indices of the selected predictions (in order) + - index_j is the tensor of indices of the corresponding selected ground truth targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, nq, nc = pred_scores.shape + + if sum(gt_groups) == 0: + return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)] + + # We flatten to compute the cost matrices in a batch + # [batch_size * num_queries, num_classes] + pred_scores = pred_scores.detach().view(-1, nc) + pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1) + # [batch_size * num_queries, 4] + pred_bboxes = pred_bboxes.detach().view(-1, 4) + + # Compute the classification cost + pred_scores = pred_scores[:, gt_cls] + if self.use_fl: + neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log()) + pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log()) + cost_class = pos_cost_class - neg_cost_class + else: + cost_class = -pred_scores + + # Compute the L1 cost between boxes + cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt) + + # Compute the GIoU cost between boxes, (bs*num_queries, num_gt) + cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1) + + # Final cost matrix + C = ( + self.cost_gain["class"] * cost_class + + self.cost_gain["bbox"] * cost_bbox + + self.cost_gain["giou"] * cost_giou + ) + # Compute the mask cost and dice cost + if self.with_mask: + C += self._cost_mask(bs, gt_groups, masks, gt_mask) + + # Set invalid values (NaNs and infinities) to 0 (fixes ValueError: matrix contains invalid numeric entries) + C[C.isnan() | C.isinf()] = 0.0 + + C = C.view(bs, nq, -1).cpu() + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))] + gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt) + return [ + (torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k]) + for k, (i, j) in enumerate(indices) + ] + + # This function is for future RT-DETR Segment models + # def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None): + # assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`' + # # all masks share the same set of points for efficient matching + # sample_points = torch.rand([bs, 1, self.num_sample_points, 2]) + # sample_points = 2.0 * sample_points - 1.0 + # + # out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2) + # out_mask = out_mask.flatten(0, 1) + # + # tgt_mask = torch.cat(gt_mask).unsqueeze(1) + # sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0]) + # tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2]) + # + # with torch.amp.autocast("cuda", enabled=False): + # # binary cross entropy cost + # pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none') + # neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none') + # cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T) + # cost_mask /= self.num_sample_points + # + # # dice cost + # out_mask = F.sigmoid(out_mask) + # numerator = 2 * torch.matmul(out_mask, tgt_mask.T) + # denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0) + # cost_dice = 1 - (numerator + 1) / (denominator + 1) + # + # C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice + # return C + + +def get_cdn_group( + batch, num_classes, num_queries, class_embed, num_dn=100, cls_noise_ratio=0.5, box_noise_scale=1.0, training=False +): + """ + Get contrastive denoising training group. This function creates a contrastive denoising training group with positive + and negative samples from the ground truths (gt). It applies noise to the class labels and bounding box coordinates, + and returns the modified labels, bounding boxes, attention mask and meta information. + + Args: + batch (dict): A dict that includes 'gt_cls' (torch.Tensor with shape [num_gts, ]), 'gt_bboxes' + (torch.Tensor with shape [num_gts, 4]), 'gt_groups' (List(int)) which is a list of batch size length + indicating the number of gts of each image. + num_classes (int): Number of classes. + num_queries (int): Number of queries. + class_embed (torch.Tensor): Embedding weights to map class labels to embedding space. + num_dn (int, optional): Number of denoising. Defaults to 100. + cls_noise_ratio (float, optional): Noise ratio for class labels. Defaults to 0.5. + box_noise_scale (float, optional): Noise scale for bounding box coordinates. Defaults to 1.0. + training (bool, optional): If it's in training mode. Defaults to False. + + Returns: + (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Dict]]): The modified class embeddings, + bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn' + is less than or equal to 0, the function returns None for all elements in the tuple. + """ + if (not training) or num_dn <= 0: + return None, None, None, None + gt_groups = batch["gt_groups"] + total_num = sum(gt_groups) + max_nums = max(gt_groups) + if max_nums == 0: + return None, None, None, None + + num_group = num_dn // max_nums + num_group = 1 if num_group == 0 else num_group + # Pad gt to max_num of a batch + bs = len(gt_groups) + gt_cls = batch["cls"] # (bs*num, ) + gt_bbox = batch["bboxes"] # bs*num, 4 + b_idx = batch["batch_idx"] + + # Each group has positive and negative queries. + dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, ) + dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4 + dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, ) + + # Positive and negative mask + # (bs*num*num_group, ), the second total_num*num_group part as negative samples + neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num + + if cls_noise_ratio > 0: + # Half of bbox prob + mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5) + idx = torch.nonzero(mask).squeeze(-1) + # Randomly put a new one here + new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device) + dn_cls[idx] = new_label + + if box_noise_scale > 0: + known_bbox = xywh2xyxy(dn_bbox) + + diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4 + + rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0 + rand_part = torch.rand_like(dn_bbox) + rand_part[neg_idx] += 1.0 + rand_part *= rand_sign + known_bbox += rand_part * diff + known_bbox.clip_(min=0.0, max=1.0) + dn_bbox = xyxy2xywh(known_bbox) + dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid + + num_dn = int(max_nums * 2 * num_group) # total denoising queries + # class_embed = torch.cat([class_embed, torch.zeros([1, class_embed.shape[-1]], device=class_embed.device)]) + dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256 + padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device) + padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device) + + map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups]) + pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0) + + map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)]) + padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed + padding_bbox[(dn_b_idx, map_indices)] = dn_bbox + + tgt_size = num_dn + num_queries + attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool) + # Match query cannot see the reconstruct + attn_mask[num_dn:, :num_dn] = True + # Reconstruct cannot see each other + for i in range(num_group): + if i == 0: + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True + if i == num_group - 1: + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True + else: + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True + attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True + dn_meta = { + "dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)], + "dn_num_group": num_group, + "dn_num_split": [num_dn, num_queries], + } + + return ( + padding_cls.to(class_embed.device), + padding_bbox.to(class_embed.device), + attn_mask.to(class_embed.device), + dn_meta, + ) diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..95006d437ccc1f3d6bcb3822b94a08834498b828 --- /dev/null +++ b/ultralytics/models/yolo/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.models.yolo import classify, detect, obb, pose, segment, world + +from .model import YOLO, YOLOWorld + +__all__ = "classify", "segment", "detect", "pose", "obb", "world", "YOLO", "YOLOWorld" diff --git a/ultralytics/models/yolo/classify/__init__.py b/ultralytics/models/yolo/classify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a10629229f8a8e5769480d004bbdbe42a633e79 --- /dev/null +++ b/ultralytics/models/yolo/classify/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.models.yolo.classify.predict import ClassificationPredictor +from ultralytics.models.yolo.classify.train import ClassificationTrainer +from ultralytics.models.yolo.classify.val import ClassificationValidator + +__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator" diff --git a/ultralytics/models/yolo/classify/predict.py b/ultralytics/models/yolo/classify/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..35526266b5a65e2b5e73feac290fe49b761f57e6 --- /dev/null +++ b/ultralytics/models/yolo/classify/predict.py @@ -0,0 +1,60 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import cv2 +import torch +from PIL import Image + +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import DEFAULT_CFG, ops + + +class ClassificationPredictor(BasePredictor): + """ + A class extending the BasePredictor class for prediction based on a classification model. + + Notes: + - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.yolo.classify import ClassificationPredictor + + args = dict(model="yolov8n-cls.pt", source=ASSETS) + predictor = ClassificationPredictor(overrides=args) + predictor.predict_cli() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes ClassificationPredictor setting the task to 'classify'.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "classify" + self._legacy_transform_name = "ultralytics.yolo.data.augment.ToTensor" + + def preprocess(self, img): + """Converts input image to model-compatible data type.""" + if not isinstance(img, torch.Tensor): + is_legacy_transform = any( + self._legacy_transform_name in str(transform) for transform in self.transforms.transforms + ) + if is_legacy_transform: # to handle legacy transforms + img = torch.stack([self.transforms(im) for im in img], dim=0) + else: + img = torch.stack( + [self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0 + ) + img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) + return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions to return Results objects.""" + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + preds = preds[0] if isinstance(preds, (list, tuple)) else preds + return [ + Results(orig_img, path=img_path, names=self.model.names, probs=pred) + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]) + ] diff --git a/ultralytics/models/yolo/classify/train.py b/ultralytics/models/yolo/classify/train.py new file mode 100644 index 0000000000000000000000000000000000000000..fba69664f872e242d1974bb2d02616b203acaa90 --- /dev/null +++ b/ultralytics/models/yolo/classify/train.py @@ -0,0 +1,153 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +import torch + +from ultralytics.data import ClassificationDataset, build_dataloader +from ultralytics.engine.trainer import BaseTrainer +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel +from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK +from ultralytics.utils.plotting import plot_images, plot_results +from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first + + +class ClassificationTrainer(BaseTrainer): + """ + A class extending the BaseTrainer class for training based on a classification model. + + Notes: + - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + + Example: + ```python + from ultralytics.models.yolo.classify import ClassificationTrainer + + args = dict(model="yolov8n-cls.pt", data="imagenet10", epochs=3) + trainer = ClassificationTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a ClassificationTrainer object with optional configuration overrides and callbacks.""" + if overrides is None: + overrides = {} + overrides["task"] = "classify" + if overrides.get("imgsz") is None: + overrides["imgsz"] = 224 + super().__init__(cfg, overrides, _callbacks) + + def set_model_attributes(self): + """Set the YOLO model's class names from the loaded dataset.""" + self.model.names = self.data["names"] + + def get_model(self, cfg=None, weights=None, verbose=True): + """Returns a modified PyTorch model configured for training YOLO.""" + model = ClassificationModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + for m in model.modules(): + if not self.args.pretrained and hasattr(m, "reset_parameters"): + m.reset_parameters() + if isinstance(m, torch.nn.Dropout) and self.args.dropout: + m.p = self.args.dropout # set dropout + for p in model.parameters(): + p.requires_grad = True # for training + return model + + def setup_model(self): + """Load, create or download model for any task.""" + import torchvision # scope for faster 'import ultralytics' + + if str(self.model) in torchvision.models.__dict__: + self.model = torchvision.models.__dict__[self.model]( + weights="IMAGENET1K_V1" if self.args.pretrained else None + ) + ckpt = None + else: + ckpt = super().setup_model() + ClassificationModel.reshape_outputs(self.model, self.data["nc"]) + return ckpt + + def build_dataset(self, img_path, mode="train", batch=None): + """Creates a ClassificationDataset instance given an image path, and mode (train/test etc.).""" + return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode) + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Returns PyTorch DataLoader with transforms to preprocess images for inference.""" + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode) + + loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank) + # Attach inference transforms + if mode != "train": + if is_parallel(self.model): + self.model.module.transforms = loader.dataset.torch_transforms + else: + self.model.transforms = loader.dataset.torch_transforms + return loader + + def preprocess_batch(self, batch): + """Preprocesses a batch of images and classes.""" + batch["img"] = batch["img"].to(self.device) + batch["cls"] = batch["cls"].to(self.device) + return batch + + def progress_string(self): + """Returns a formatted string showing training progress.""" + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) + + def get_validator(self): + """Returns an instance of ClassificationValidator for validation.""" + self.loss_names = ["loss"] + return yolo.classify.ClassificationValidator( + self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Returns a loss dict with labelled training loss items tensor. + + Not needed for classification but necessary for segmentation & detection + """ + keys = [f"{prefix}/{x}" for x in self.loss_names] + if loss_items is None: + return keys + loss_items = [round(float(loss_items), 5)] + return dict(zip(keys, loss_items)) + + def plot_metrics(self): + """Plots metrics from a CSV file.""" + plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png + + def final_eval(self): + """Evaluate trained model and save validation results.""" + for f in self.last, self.best: + if f.exists(): + strip_optimizer(f) # strip optimizers + if f is self.best: + LOGGER.info(f"\nValidating {f}...") + self.validator.args.data = self.args.data + self.validator.args.plots = self.args.plots + self.metrics = self.validator(model=f) + self.metrics.pop("fitness", None) + self.run_callbacks("on_fit_epoch_end") + + def plot_training_samples(self, batch, ni): + """Plots training samples with their annotations.""" + plot_images( + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) diff --git a/ultralytics/models/yolo/classify/val.py b/ultralytics/models/yolo/classify/val.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae9a012c19aa36f8c64eea61235809e6b3c0cfe --- /dev/null +++ b/ultralytics/models/yolo/classify/val.py @@ -0,0 +1,117 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.data import ClassificationDataset, build_dataloader +from ultralytics.engine.validator import BaseValidator +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix +from ultralytics.utils.plotting import plot_images + + +class ClassificationValidator(BaseValidator): + """ + A class extending the BaseValidator class for validation based on a classification model. + + Notes: + - Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'. + + Example: + ```python + from ultralytics.models.yolo.classify import ClassificationValidator + + args = dict(model="yolov8n-cls.pt", data="imagenet10") + validator = ClassificationValidator(args=args) + validator() + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.targets = None + self.pred = None + self.args.task = "classify" + self.metrics = ClassifyMetrics() + + def get_desc(self): + """Returns a formatted string summarizing classification metrics.""" + return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc") + + def init_metrics(self, model): + """Initialize confusion matrix, class names, and top-1 and top-5 accuracy.""" + self.names = model.names + self.nc = len(model.names) + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task="classify") + self.pred = [] + self.targets = [] + + def preprocess(self, batch): + """Preprocesses input batch and returns it.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = batch["img"].half() if self.args.half else batch["img"].float() + batch["cls"] = batch["cls"].to(self.device) + return batch + + def update_metrics(self, preds, batch): + """Updates running metrics with model predictions and batch targets.""" + n5 = min(len(self.names), 5) + self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu()) + self.targets.append(batch["cls"].type(torch.int32).cpu()) + + def finalize_metrics(self, *args, **kwargs): + """Finalizes metrics of the model such as confusion_matrix and speed.""" + self.confusion_matrix.process_cls_preds(self.pred, self.targets) + if self.args.plots: + for normalize in True, False: + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + self.metrics.save_dir = self.save_dir + + def postprocess(self, preds): + """Preprocesses the classification predictions.""" + return preds[0] if isinstance(preds, (list, tuple)) else preds + + def get_stats(self): + """Returns a dictionary of metrics obtained by processing targets and predictions.""" + self.metrics.process(self.targets, self.pred) + return self.metrics.results_dict + + def build_dataset(self, img_path): + """Creates and returns a ClassificationDataset instance using given image path and preprocessing parameters.""" + return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split) + + def get_dataloader(self, dataset_path, batch_size): + """Builds and returns a data loader for classification tasks with given parameters.""" + dataset = self.build_dataset(dataset_path) + return build_dataloader(dataset, batch_size, self.args.workers, rank=-1) + + def print_results(self): + """Prints evaluation metrics for YOLO object detection model.""" + pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5)) + + def plot_val_samples(self, batch, ni): + """Plot validation image samples.""" + plot_images( + images=batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=batch["cls"].view(-1), # warning: use .view(), not .squeeze() for Classify models + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + plot_images( + batch["img"], + batch_idx=torch.arange(len(batch["img"])), + cls=torch.argmax(preds, dim=1), + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred diff --git a/ultralytics/models/yolo/detect/__init__.py b/ultralytics/models/yolo/detect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..caece94ae0c06d51e8a4a4c5a083d74a6c731c90 --- /dev/null +++ b/ultralytics/models/yolo/detect/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import DetectionPredictor +from .train import DetectionTrainer +from .val import DetectionValidator + +__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator" diff --git a/ultralytics/models/yolo/detect/predict.py b/ultralytics/models/yolo/detect/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9da8966eb8fdfb023db46e0f4d9679a576966e --- /dev/null +++ b/ultralytics/models/yolo/detect/predict.py @@ -0,0 +1,41 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.engine.predictor import BasePredictor +from ultralytics.engine.results import Results +from ultralytics.utils import ops + + +class DetectionPredictor(BasePredictor): + """ + A class extending the BasePredictor class for prediction based on a detection model. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.yolo.detect import DetectionPredictor + + args = dict(model="yolo11n.pt", source=ASSETS) + predictor = DetectionPredictor(overrides=args) + predictor.predict_cli() + ``` + """ + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions and returns a list of Results objects.""" + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred)) + return results diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py new file mode 100644 index 0000000000000000000000000000000000000000..eea16e73af15924415c1b47aaacd8a8948739f27 --- /dev/null +++ b/ultralytics/models/yolo/detect/train.py @@ -0,0 +1,150 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +import random +from copy import copy + +import numpy as np +import torch.nn as nn + +from ultralytics.data import build_dataloader, build_yolo_dataset +from ultralytics.engine.trainer import BaseTrainer +from ultralytics.models import yolo +from ultralytics.nn.tasks import DetectionModel +from ultralytics.utils import LOGGER, RANK +from ultralytics.utils.plotting import plot_images, plot_labels, plot_results +from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first + + +class DetectionTrainer(BaseTrainer): + """ + A class extending the BaseTrainer class for training based on a detection model. + + Example: + ```python + from ultralytics.models.yolo.detect import DetectionTrainer + + args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3) + trainer = DetectionTrainer(overrides=args) + trainer.train() + ``` + """ + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) + + def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"): + """Construct and return dataloader.""" + assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}." + with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP + dataset = self.build_dataset(dataset_path, mode, batch_size) + shuffle = mode == "train" + if getattr(dataset, "rect", False) and shuffle: + LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False") + shuffle = False + workers = self.args.workers if mode == "train" else self.args.workers * 2 + return build_dataloader(dataset, batch_size, workers, shuffle, rank) # return dataloader + + def preprocess_batch(self, batch): + """Preprocesses a batch of images by scaling and converting to float.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255 + if self.args.multi_scale: + imgs = batch["img"] + sz = ( + random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride)) + // self.stride + * self.stride + ) # size + sf = sz / max(imgs.shape[2:]) # scale factor + if sf != 1: + ns = [ + math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:] + ] # new shape (stretched to gs-multiple) + imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) + batch["img"] = imgs + return batch + + def set_model_attributes(self): + """Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps).""" + # self.args.box *= 3 / nl # scale to layers + # self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers + # self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers + self.model.nc = self.data["nc"] # attach number of classes to model + self.model.names = self.data["names"] # attach class names to model + self.model.args = self.args # attach hyperparameters to model + # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return a YOLO detection model.""" + model = DetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + return model + + def get_validator(self): + """Returns a DetectionValidator for YOLO model validation.""" + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.detect.DetectionValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def label_loss_items(self, loss_items=None, prefix="train"): + """ + Returns a loss dict with labelled training loss items tensor. + + Not needed for classification but necessary for segmentation & detection + """ + keys = [f"{prefix}/{x}" for x in self.loss_names] + if loss_items is not None: + loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats + return dict(zip(keys, loss_items)) + else: + return keys + + def progress_string(self): + """Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size.""" + return ("\n" + "%11s" * (4 + len(self.loss_names))) % ( + "Epoch", + "GPU_mem", + *self.loss_names, + "Instances", + "Size", + ) + + def plot_training_samples(self, batch, ni): + """Plots training samples with their annotations.""" + plot_images( + images=batch["img"], + batch_idx=batch["batch_idx"], + cls=batch["cls"].squeeze(-1), + bboxes=batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plots metrics from a CSV file.""" + plot_results(file=self.csv, on_plot=self.on_plot) # save results.png + + def plot_training_labels(self): + """Create a labeled training plot of the YOLO model.""" + boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0) + cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0) + plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot) + + def auto_batch(self): + """Get batch size by calculating memory occupation of model.""" + train_dataset = self.build_dataset(self.trainset, mode="train", batch=16) + # 4 for mosaic augmentation + max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 + return super().auto_batch(max_num_obj) diff --git a/ultralytics/models/yolo/detect/val.py b/ultralytics/models/yolo/detect/val.py new file mode 100644 index 0000000000000000000000000000000000000000..d5fcbfe5bcc94ea5f0f59076a2e39543063a9a6a --- /dev/null +++ b/ultralytics/models/yolo/detect/val.py @@ -0,0 +1,337 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.data import build_dataloader, build_yolo_dataset, converter +from ultralytics.engine.validator import BaseValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou +from ultralytics.utils.plotting import output_to_target, plot_images + + +class DetectionValidator(BaseValidator): + """ + A class extending the BaseValidator class for validation based on a detection model. + + Example: + ```python + from ultralytics.models.yolo.detect import DetectionValidator + + args = dict(model="yolo11n.pt", data="coco8.yaml") + validator = DetectionValidator(args=args) + validator() + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize detection model with necessary variables and settings.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.nt_per_class = None + self.nt_per_image = None + self.is_coco = False + self.is_lvis = False + self.class_map = None + self.args.task = "detect" + self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) + self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95 + self.niou = self.iouv.numel() + self.lb = [] # for autolabelling + if self.args.save_hybrid: + LOGGER.warning( + "WARNING ⚠️ 'save_hybrid=True' will append ground truth to predictions for autolabelling.\n" + "WARNING ⚠️ 'save_hybrid=True' will cause incorrect mAP.\n" + ) + + def preprocess(self, batch): + """Preprocesses batch of images for YOLO training.""" + batch["img"] = batch["img"].to(self.device, non_blocking=True) + batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255 + for k in ["batch_idx", "cls", "bboxes"]: + batch[k] = batch[k].to(self.device) + + if self.args.save_hybrid: + height, width = batch["img"].shape[2:] + nb = len(batch["img"]) + bboxes = batch["bboxes"] * torch.tensor((width, height, width, height), device=self.device) + self.lb = [ + torch.cat([batch["cls"][batch["batch_idx"] == i], bboxes[batch["batch_idx"] == i]], dim=-1) + for i in range(nb) + ] + + return batch + + def init_metrics(self, model): + """Initialize evaluation metrics for YOLO.""" + val = self.data.get(self.args.split, "") # validation path + self.is_coco = ( + isinstance(val, str) + and "coco" in val + and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt")) + ) # is COCO + self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS + self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1)) + self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val + self.names = model.names + self.nc = len(model.names) + self.metrics.names = self.names + self.metrics.plot = self.args.plots + self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf) + self.seen = 0 + self.jdict = [] + self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) + + def get_desc(self): + """Return a formatted string summarizing class metrics of YOLO model.""" + return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)") + + def postprocess(self, preds): + """Apply Non-maximum suppression to prediction outputs.""" + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls or self.args.agnostic_nms, + max_det=self.args.max_det, + ) + + def _prepare_batch(self, si, batch): + """Prepares a batch of images and annotations for validation.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes + ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad) # native-space labels + return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad} + + def _prepare_pred(self, pred, pbatch): + """Prepares a batch of images and annotations for validation.""" + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"] + ) # native-space pred + return predn + + def update_metrics(self, preds, batch): + """Metrics.""" + for si, pred in enumerate(preds): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + stat["target_img"] = cls.unique() + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) + continue + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + if self.args.plots: + self.confusion_matrix.process_batch(predn, bbox, cls) + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + # Save + if self.args.save_json: + self.pred_to_json(predn, batch["im_file"][si]) + if self.args.save_txt: + self.save_one_txt( + predn, + self.args.save_conf, + pbatch["ori_shape"], + self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt", + ) + + def finalize_metrics(self, *args, **kwargs): + """Set final values for metrics speed and confusion matrix.""" + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + + def get_stats(self): + """Returns metrics statistics and results dictionary.""" + stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy + self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc) + self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc) + stats.pop("target_img", None) + if len(stats) and stats["tp"].any(): + self.metrics.process(**stats) + return self.metrics.results_dict + + def print_results(self): + """Prints training/validation set metrics per class.""" + pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format + LOGGER.info(pf % ("all", self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) + if self.nt_per_class.sum() == 0: + LOGGER.warning(f"WARNING ⚠️ no labels found in {self.args.task} set, can not compute metrics without labels") + + # Print results per class + if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): + for i, c in enumerate(self.metrics.ap_class_index): + LOGGER.info( + pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i)) + ) + + if self.args.plots: + for normalize in True, False: + self.confusion_matrix.plot( + save_dir=self.save_dir, names=self.names.values(), normalize=normalize, on_plot=self.on_plot + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls): + """ + Return correct prediction matrix. + + Args: + detections (torch.Tensor): Tensor of shape (N, 6) representing detections where each detection is + (x1, y1, x2, y2, conf, class). + gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground-truth bounding box coordinates. Each + bounding box is of the format: (x1, y1, x2, y2). + gt_cls (torch.Tensor): Tensor of shape (M,) representing target class indices. + + Returns: + (torch.Tensor): Correct prediction matrix of shape (N, 10) for 10 IoU levels. + + Note: + The function does not return any value directly usable for metrics calculation. Instead, it provides an + intermediate representation used for evaluating predictions against ground truth. + """ + iou = box_iou(gt_bboxes, detections[:, :4]) + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def build_dataset(self, img_path, mode="val", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride) + + def get_dataloader(self, dataset_path, batch_size): + """Construct and return dataloader.""" + dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val") + return build_dataloader(dataset, batch_size, self.args.workers, shuffle=False, rank=-1) # return dataloader + + def plot_val_samples(self, batch, ni): + """Plot validation image samples.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def save_one_txt(self, predn, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + from ultralytics.engine.results import Results + + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + boxes=predn[:, :6], + ).save_txt(file, save_conf=save_conf) + + def pred_to_json(self, predn, filename): + """Serialize YOLO predictions to COCO json format.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(predn.tolist(), box.tolist()): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + } + ) + + def eval_json(self, stats): + """Evaluates YOLO output in JSON format and returns performance statistics.""" + if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict): + pred_json = self.save_dir / "predictions.json" # predictions + anno_json = ( + self.data["path"] + / "annotations" + / ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json") + ) # annotations + pkg = "pycocotools" if self.is_coco else "lvis" + LOGGER.info(f"\nEvaluating {pkg} mAP using {pred_json} and {anno_json}...") + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + for x in pred_json, anno_json: + assert x.is_file(), f"{x} file not found" + check_requirements("pycocotools>=2.0.6" if self.is_coco else "lvis>=0.5.3") + if self.is_coco: + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + val = COCOeval(anno, pred, "bbox") + else: + from lvis import LVIS, LVISEval + + anno = LVIS(str(anno_json)) # init annotations api + pred = anno._load_json(str(pred_json)) # init predictions api (must pass string, not Path) + val = LVISEval(anno, pred, "bbox") + val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval + val.evaluate() + val.accumulate() + val.summarize() + if self.is_lvis: + val.print_results() # explicitly call print_results + # update mAP50-95 and mAP50 + stats[self.metrics.keys[-1]], stats[self.metrics.keys[-2]] = ( + val.stats[:2] if self.is_coco else [val.results["AP50"], val.results["AP"]] + ) + except Exception as e: + LOGGER.warning(f"{pkg} unable to run: {e}") + return stats diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py new file mode 100644 index 0000000000000000000000000000000000000000..03bf15355928e7f1bf01903055e1524bfe8556d5 --- /dev/null +++ b/ultralytics/models/yolo/model.py @@ -0,0 +1,111 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +from ultralytics.engine.model import Model +from ultralytics.models import yolo +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel +from ultralytics.utils import ROOT, yaml_load + + +class YOLO(Model): + """YOLO (You Only Look Once) object detection model.""" + + def __init__(self, model="yolo11n.pt", task=None, verbose=False): + """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" + path = Path(model) + if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model + new_instance = YOLOWorld(path, verbose=verbose) + self.__class__ = type(new_instance) + self.__dict__ = new_instance.__dict__ + else: + # Continue with default YOLO initialization + super().__init__(model=model, task=task, verbose=verbose) + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes.""" + return { + "classify": { + "model": ClassificationModel, + "trainer": yolo.classify.ClassificationTrainer, + "validator": yolo.classify.ClassificationValidator, + "predictor": yolo.classify.ClassificationPredictor, + }, + "detect": { + "model": DetectionModel, + "trainer": yolo.detect.DetectionTrainer, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + }, + "segment": { + "model": SegmentationModel, + "trainer": yolo.segment.SegmentationTrainer, + "validator": yolo.segment.SegmentationValidator, + "predictor": yolo.segment.SegmentationPredictor, + }, + "pose": { + "model": PoseModel, + "trainer": yolo.pose.PoseTrainer, + "validator": yolo.pose.PoseValidator, + "predictor": yolo.pose.PosePredictor, + }, + "obb": { + "model": OBBModel, + "trainer": yolo.obb.OBBTrainer, + "validator": yolo.obb.OBBValidator, + "predictor": yolo.obb.OBBPredictor, + }, + } + + +class YOLOWorld(Model): + """YOLO-World object detection model.""" + + def __init__(self, model="yolov8s-world.pt", verbose=False) -> None: + """ + Initialize YOLOv8-World model with a pre-trained model file. + + Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default + COCO class names. + + Args: + model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats. + verbose (bool): If True, prints additional information during initialization. + """ + super().__init__(model=model, task="detect", verbose=verbose) + + # Assign default COCO class names when there are no custom names + if not hasattr(self.model, "names"): + self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") + + @property + def task_map(self): + """Map head to model, validator, and predictor classes.""" + return { + "detect": { + "model": WorldModel, + "validator": yolo.detect.DetectionValidator, + "predictor": yolo.detect.DetectionPredictor, + "trainer": yolo.world.WorldTrainer, + } + } + + def set_classes(self, classes): + """ + Set classes. + + Args: + classes (List(str)): A list of categories i.e. ["person"]. + """ + self.model.set_classes(classes) + # Remove background if it's given + background = " " + if background in classes: + classes.remove(background) + self.model.names = classes + + # Reset method class names + # self.predictor = None # reset predictor otherwise old names remain + if self.predictor: + self.predictor.model.names = classes diff --git a/ultralytics/models/yolo/obb/__init__.py b/ultralytics/models/yolo/obb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61e3e3c6a82b9addf0206bfe0bab63fa34c26108 --- /dev/null +++ b/ultralytics/models/yolo/obb/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import OBBPredictor +from .train import OBBTrainer +from .val import OBBValidator + +__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator" diff --git a/ultralytics/models/yolo/obb/predict.py b/ultralytics/models/yolo/obb/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..1608b355adf7405356c0c3ab26a2821bd25a2230 --- /dev/null +++ b/ultralytics/models/yolo/obb/predict.py @@ -0,0 +1,53 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch + +from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect.predict import DetectionPredictor +from ultralytics.utils import DEFAULT_CFG, ops + + +class OBBPredictor(DetectionPredictor): + """ + A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.yolo.obb import OBBPredictor + + args = dict(model="yolov8n-obb.pt", source=ASSETS) + predictor = OBBPredictor(overrides=args) + predictor.predict_cli() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes OBBPredictor with optional model and data configuration overrides.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "obb" + + def postprocess(self, preds, img, orig_imgs): + """Post-processes predictions and returns a list of Results objects.""" + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes, + rotated=True, + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): + rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1)) + rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True) + # xywh, r, conf, cls + obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1) + results.append(Results(orig_img, path=img_path, names=self.model.names, obb=obb)) + return results diff --git a/ultralytics/models/yolo/obb/train.py b/ultralytics/models/yolo/obb/train.py new file mode 100644 index 0000000000000000000000000000000000000000..41b7478b0b38a8230e4fc3491acf48dde9f658e6 --- /dev/null +++ b/ultralytics/models/yolo/obb/train.py @@ -0,0 +1,44 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import OBBModel +from ultralytics.utils import DEFAULT_CFG, RANK + + +class OBBTrainer(yolo.detect.DetectionTrainer): + """ + A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model. + + Example: + ```python + from ultralytics.models.yolo.obb import OBBTrainer + + args = dict(model="yolov8n-obb.pt", data="dota8.yaml", epochs=3) + trainer = OBBTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a OBBTrainer object with given arguments.""" + if overrides is None: + overrides = {} + overrides["task"] = "obb" + super().__init__(cfg, overrides, _callbacks) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return OBBModel initialized with specified config and weights.""" + model = OBBModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + return model + + def get_validator(self): + """Return an instance of OBBValidator for validation of YOLO model.""" + self.loss_names = "box_loss", "cls_loss", "dfl_loss" + return yolo.obb.OBBValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) diff --git a/ultralytics/models/yolo/obb/val.py b/ultralytics/models/yolo/obb/val.py new file mode 100644 index 0000000000000000000000000000000000000000..d75f16295978da19a3beadc6fc4cc8a1209963c2 --- /dev/null +++ b/ultralytics/models/yolo/obb/val.py @@ -0,0 +1,203 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +import torch + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.metrics import OBBMetrics, batch_probiou +from ultralytics.utils.plotting import output_to_rotated_target, plot_images + + +class OBBValidator(DetectionValidator): + """ + A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model. + + Example: + ```python + from ultralytics.models.yolo.obb import OBBValidator + + args = dict(model="yolov8n-obb.pt", data="dota8.yaml") + validator = OBBValidator(args=args) + validator(model=args["model"]) + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.args.task = "obb" + self.metrics = OBBMetrics(save_dir=self.save_dir, plot=True, on_plot=self.on_plot) + + def init_metrics(self, model): + """Initialize evaluation metrics for YOLO.""" + super().init_metrics(model) + val = self.data.get(self.args.split, "") # validation path + self.is_dota = isinstance(val, str) and "DOTA" in val # is COCO + + def postprocess(self, preds): + """Apply Non-maximum suppression to prediction outputs.""" + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + nc=self.nc, + multi_label=True, + agnostic=self.args.single_cls or self.args.agnostic_nms, + max_det=self.args.max_det, + rotated=True, + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls): + """ + Perform computation of the correct prediction matrix for a batch of detections and ground truth bounding boxes. + + Args: + detections (torch.Tensor): A tensor of shape (N, 7) representing the detected bounding boxes and associated + data. Each detection is represented as (x1, y1, x2, y2, conf, class, angle). + gt_bboxes (torch.Tensor): A tensor of shape (M, 5) representing the ground truth bounding boxes. Each box is + represented as (x1, y1, x2, y2, angle). + gt_cls (torch.Tensor): A tensor of shape (M,) representing class labels for the ground truth bounding boxes. + + Returns: + (torch.Tensor): The correct prediction matrix with shape (N, 10), which includes 10 IoU (Intersection over + Union) levels for each detection, indicating the accuracy of predictions compared to the ground truth. + + Example: + ```python + detections = torch.rand(100, 7) # 100 sample detections + gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes + gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels + correct_matrix = OBBValidator._process_batch(detections, gt_bboxes, gt_cls) + ``` + + Note: + This method relies on `batch_probiou` to calculate IoU between detections and ground truth bounding boxes. + """ + iou = batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1)) + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def _prepare_batch(self, si, batch): + """Prepares and returns a batch for OBB validation.""" + idx = batch["batch_idx"] == si + cls = batch["cls"][idx].squeeze(-1) + bbox = batch["bboxes"][idx] + ori_shape = batch["ori_shape"][si] + imgsz = batch["img"].shape[2:] + ratio_pad = batch["ratio_pad"][si] + if len(cls): + bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes + ops.scale_boxes(imgsz, bbox, ori_shape, ratio_pad=ratio_pad, xywh=True) # native-space labels + return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad} + + def _prepare_pred(self, pred, pbatch): + """Prepares and returns a batch for OBB validation with scaled and padded bounding boxes.""" + predn = pred.clone() + ops.scale_boxes( + pbatch["imgsz"], predn[:, :4], pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True + ) # native-space pred + return predn + + def plot_predictions(self, batch, preds, ni): + """Plots predicted bounding boxes on input images and saves the result.""" + plot_images( + batch["img"], + *output_to_rotated_target(preds, max_det=self.args.max_det), + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def pred_to_json(self, predn, filename): + """Serialize YOLO predictions to COCO json format.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + rbox = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1) + poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8) + for i, (r, b) in enumerate(zip(rbox.tolist(), poly.tolist())): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(predn[i, 5].item())], + "score": round(predn[i, 4].item(), 5), + "rbox": [round(x, 3) for x in r], + "poly": [round(x, 3) for x in b], + } + ) + + def save_one_txt(self, predn, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + import numpy as np + + from ultralytics.engine.results import Results + + rboxes = torch.cat([predn[:, :4], predn[:, -1:]], dim=-1) + # xywh, r, conf, cls + obb = torch.cat([rboxes, predn[:, 4:6]], dim=-1) + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + obb=obb, + ).save_txt(file, save_conf=save_conf) + + def eval_json(self, stats): + """Evaluates YOLO output in JSON format and returns performance statistics.""" + if self.args.save_json and self.is_dota and len(self.jdict): + import json + import re + from collections import defaultdict + + pred_json = self.save_dir / "predictions.json" # predictions + pred_txt = self.save_dir / "predictions_txt" # predictions + pred_txt.mkdir(parents=True, exist_ok=True) + data = json.load(open(pred_json)) + # Save split results + LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...") + for d in data: + image_id = d["image_id"] + score = d["score"] + classname = self.names[d["category_id"] - 1].replace(" ", "-") + p = d["poly"] + + with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a") as f: + f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n") + # Save merged results, this could result slightly lower map than using official merging script, + # because of the probiou calculation. + pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions + pred_merged_txt.mkdir(parents=True, exist_ok=True) + merged_results = defaultdict(list) + LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...") + for d in data: + image_id = d["image_id"].split("__")[0] + pattern = re.compile(r"\d+___\d+") + x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___")) + bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1 + bbox[0] += x + bbox[1] += y + bbox.extend([score, cls]) + merged_results[image_id].append(bbox) + for image_id, bbox in merged_results.items(): + bbox = torch.tensor(bbox) + max_wh = torch.max(bbox[:, :2]).item() * 2 + c = bbox[:, 6:7] * max_wh # classes + scores = bbox[:, 5] # scores + b = bbox[:, :5].clone() + b[:, :2] += c + # 0.3 could get results close to the ones from official merging script, even slightly better. + i = ops.nms_rotated(b, scores, 0.3) + bbox = bbox[i] + + b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8) + for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist(): + classname = self.names[int(x[-1])].replace(" ", "-") + p = [round(i, 3) for i in x[:-2]] # poly + score = round(x[-2], 3) + + with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a") as f: + f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n") + + return stats diff --git a/ultralytics/models/yolo/pose/__init__.py b/ultralytics/models/yolo/pose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..396167b08f88632230296306abdec8eb508f8b78 --- /dev/null +++ b/ultralytics/models/yolo/pose/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import PosePredictor +from .train import PoseTrainer +from .val import PoseValidator + +__all__ = "PoseTrainer", "PoseValidator", "PosePredictor" diff --git a/ultralytics/models/yolo/pose/predict.py b/ultralytics/models/yolo/pose/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..a382d1388c71c02a8587ca4b41be8010f74961c8 --- /dev/null +++ b/ultralytics/models/yolo/pose/predict.py @@ -0,0 +1,56 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect.predict import DetectionPredictor +from ultralytics.utils import DEFAULT_CFG, LOGGER, ops + + +class PosePredictor(DetectionPredictor): + """ + A class extending the DetectionPredictor class for prediction based on a pose model. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.yolo.pose import PosePredictor + + args = dict(model="yolov8n-pose.pt", source=ASSETS) + predictor = PosePredictor(overrides=args) + predictor.predict_cli() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes PosePredictor, sets task to 'pose' and logs a warning for using 'mps' as device.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "pose" + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) + + def postprocess(self, preds, img, orig_imgs): + """Return detection results for a given input image or list of images.""" + preds = ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + classes=self.args.classes, + nc=len(self.model.names), + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]): + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape).round() + pred_kpts = pred[:, 6:].view(len(pred), *self.model.kpt_shape) if len(pred) else pred[:, 6:] + pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape) + results.append( + Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], keypoints=pred_kpts) + ) + return results diff --git a/ultralytics/models/yolo/pose/train.py b/ultralytics/models/yolo/pose/train.py new file mode 100644 index 0000000000000000000000000000000000000000..5360f2f7d7037658e4f46806779772f49f23f95b --- /dev/null +++ b/ultralytics/models/yolo/pose/train.py @@ -0,0 +1,79 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import PoseModel +from ultralytics.utils import DEFAULT_CFG, LOGGER +from ultralytics.utils.plotting import plot_images, plot_results + + +class PoseTrainer(yolo.detect.DetectionTrainer): + """ + A class extending the DetectionTrainer class for training based on a pose model. + + Example: + ```python + from ultralytics.models.yolo.pose import PoseTrainer + + args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml", epochs=3) + trainer = PoseTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a PoseTrainer object with specified configurations and overrides.""" + if overrides is None: + overrides = {} + overrides["task"] = "pose" + super().__init__(cfg, overrides, _callbacks) + + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Get pose estimation model with specified configuration and weights.""" + model = PoseModel(cfg, ch=3, nc=self.data["nc"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose) + if weights: + model.load(weights) + + return model + + def set_model_attributes(self): + """Sets keypoints shape attribute of PoseModel.""" + super().set_model_attributes() + self.model.kpt_shape = self.data["kpt_shape"] + + def get_validator(self): + """Returns an instance of the PoseValidator class for validation.""" + self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss" + return yolo.pose.PoseValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def plot_training_samples(self, batch, ni): + """Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints.""" + images = batch["img"] + kpts = batch["keypoints"] + cls = batch["cls"].squeeze(-1) + bboxes = batch["bboxes"] + paths = batch["im_file"] + batch_idx = batch["batch_idx"] + plot_images( + images, + batch_idx, + cls, + bboxes, + kpts=kpts, + paths=paths, + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plots training/val metrics.""" + plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py new file mode 100644 index 0000000000000000000000000000000000000000..67805af427aed99b2b45c377ae35e45c4702e260 --- /dev/null +++ b/ultralytics/models/yolo/pose/val.py @@ -0,0 +1,282 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from pathlib import Path + +import numpy as np +import torch + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import LOGGER, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou +from ultralytics.utils.plotting import output_to_target, plot_images + + +class PoseValidator(DetectionValidator): + """ + A class extending the DetectionValidator class for validation based on a pose model. + + Example: + ```python + from ultralytics.models.yolo.pose import PoseValidator + + args = dict(model="yolov8n-pose.pt", data="coco8-pose.yaml") + validator = PoseValidator(args=args) + validator() + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize a 'PoseValidator' object with custom parameters and assigned attributes.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.sigma = None + self.kpt_shape = None + self.args.task = "pose" + self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot) + if isinstance(self.args.device, str) and self.args.device.lower() == "mps": + LOGGER.warning( + "WARNING ⚠️ Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. " + "See https://github.com/ultralytics/ultralytics/issues/4031." + ) + + def preprocess(self, batch): + """Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device.""" + batch = super().preprocess(batch) + batch["keypoints"] = batch["keypoints"].to(self.device).float() + return batch + + def get_desc(self): + """Returns description of evaluation metrics in string format.""" + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Pose(P", + "R", + "mAP50", + "mAP50-95)", + ) + + def postprocess(self, preds): + """Apply non-maximum suppression and return detections with high confidence scores.""" + return ops.non_max_suppression( + preds, + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls or self.args.agnostic_nms, + max_det=self.args.max_det, + nc=self.nc, + ) + + def init_metrics(self, model): + """Initiate pose estimation metrics for YOLO model.""" + super().init_metrics(model) + self.kpt_shape = self.data["kpt_shape"] + is_pose = self.kpt_shape == [17, 3] + nkpt = self.kpt_shape[0] + self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt + self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) + + def _prepare_batch(self, si, batch): + """Prepares a batch for processing by converting keypoints to float and moving to device.""" + pbatch = super()._prepare_batch(si, batch) + kpts = batch["keypoints"][batch["batch_idx"] == si] + h, w = pbatch["imgsz"] + kpts = kpts.clone() + kpts[..., 0] *= w + kpts[..., 1] *= h + kpts = ops.scale_coords(pbatch["imgsz"], kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) + pbatch["kpts"] = kpts + return pbatch + + def _prepare_pred(self, pred, pbatch): + """Prepares and scales keypoints in a batch for pose processing.""" + predn = super()._prepare_pred(pred, pbatch) + nk = pbatch["kpts"].shape[1] + pred_kpts = predn[:, 6:].view(len(predn), nk, -1) + ops.scale_coords(pbatch["imgsz"], pred_kpts, pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"]) + return predn, pred_kpts + + def update_metrics(self, preds, batch): + """Metrics.""" + for si, pred in enumerate(preds): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_p=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + stat["target_img"] = cls.unique() + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) + continue + + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn, pred_kpts = self._prepare_pred(pred, pbatch) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_p"] = self._process_batch(predn, bbox, cls, pred_kpts, pbatch["kpts"]) + if self.args.plots: + self.confusion_matrix.process_batch(predn, bbox, cls) + + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + # Save + if self.args.save_json: + self.pred_to_json(predn, batch["im_file"][si]) + if self.args.save_txt: + self.save_one_txt( + predn, + pred_kpts, + self.args.save_conf, + pbatch["ori_shape"], + self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt", + ) + + def _process_batch(self, detections, gt_bboxes, gt_cls, pred_kpts=None, gt_kpts=None): + """ + Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth. + + Args: + detections (torch.Tensor): Tensor with shape (N, 6) representing detection boxes and scores, where each + detection is of the format (x1, y1, x2, y2, conf, class). + gt_bboxes (torch.Tensor): Tensor with shape (M, 4) representing ground truth bounding boxes, where each + box is of the format (x1, y1, x2, y2). + gt_cls (torch.Tensor): Tensor with shape (M,) representing ground truth class indices. + pred_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing predicted keypoints, where + 51 corresponds to 17 keypoints each having 3 values. + gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints. + + Returns: + torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels, + where N is the number of detections. + + Example: + ```python + detections = torch.rand(100, 6) # 100 predictions: (x1, y1, x2, y2, conf, class) + gt_bboxes = torch.rand(50, 4) # 50 ground truth boxes: (x1, y1, x2, y2) + gt_cls = torch.randint(0, 2, (50,)) # 50 ground truth class indices + pred_kpts = torch.rand(100, 51) # 100 predicted keypoints + gt_kpts = torch.rand(50, 51) # 50 ground truth keypoints + correct_preds = _process_batch(detections, gt_bboxes, gt_cls, pred_kpts, gt_kpts) + ``` + + Note: + `0.53` scale factor used in area computation is referenced from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384. + """ + if pred_kpts is not None and gt_kpts is not None: + # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384 + area = ops.xyxy2xywh(gt_bboxes)[:, 2:].prod(1) * 0.53 + iou = kpt_iou(gt_kpts, pred_kpts, sigma=self.sigma, area=area) + else: # boxes + iou = box_iou(gt_bboxes, detections[:, :4]) + + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def plot_val_samples(self, batch, ni): + """Plots and saves validation set samples with predicted bounding boxes and keypoints.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + kpts=batch["keypoints"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plots predictions for YOLO model.""" + pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape) for p in preds], 0) + plot_images( + batch["img"], + *output_to_target(preds, max_det=self.args.max_det), + kpts=pred_kpts, + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + + def save_one_txt(self, predn, pred_kpts, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + from ultralytics.engine.results import Results + + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + boxes=predn[:, :6], + keypoints=pred_kpts, + ).save_txt(file, save_conf=save_conf) + + def pred_to_json(self, predn, filename): + """Converts YOLO predictions to COCO JSON format.""" + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + for p, b in zip(predn.tolist(), box.tolist()): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "keypoints": p[6:], + "score": round(p[4], 5), + } + ) + + def eval_json(self, stats): + """Evaluates object detection model using COCO JSON format.""" + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + check_requirements("pycocotools>=2.0.6") + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f"{x} file not found" + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "keypoints")]): + if self.is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + idx = i * 4 + 2 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 + except Exception as e: + LOGGER.warning(f"pycocotools unable to run: {e}") + return stats diff --git a/ultralytics/models/yolo/segment/__init__.py b/ultralytics/models/yolo/segment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36a921a9a36a0f0a0bf1bf03be9014c6886f6c6e --- /dev/null +++ b/ultralytics/models/yolo/segment/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .predict import SegmentationPredictor +from .train import SegmentationTrainer +from .val import SegmentationValidator + +__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator" diff --git a/ultralytics/models/yolo/segment/predict.py b/ultralytics/models/yolo/segment/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9dcfd5f3b341ab305a33c06e34a88322695489 --- /dev/null +++ b/ultralytics/models/yolo/segment/predict.py @@ -0,0 +1,55 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect.predict import DetectionPredictor +from ultralytics.utils import DEFAULT_CFG, ops + + +class SegmentationPredictor(DetectionPredictor): + """ + A class extending the DetectionPredictor class for prediction based on a segmentation model. + + Example: + ```python + from ultralytics.utils import ASSETS + from ultralytics.models.yolo.segment import SegmentationPredictor + + args = dict(model="yolov8n-seg.pt", source=ASSETS) + predictor = SegmentationPredictor(overrides=args) + predictor.predict_cli() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initializes the SegmentationPredictor with the provided configuration, overrides, and callbacks.""" + super().__init__(cfg, overrides, _callbacks) + self.args.task = "segment" + + def postprocess(self, preds, img, orig_imgs): + """Applies non-max suppression and processes detections for each image in an input batch.""" + p = ops.non_max_suppression( + preds[0], + self.args.conf, + self.args.iou, + agnostic=self.args.agnostic_nms, + max_det=self.args.max_det, + nc=len(self.model.names), + classes=self.args.classes, + ) + + if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list + orig_imgs = ops.convert_torch2numpy_batch(orig_imgs) + + results = [] + proto = preds[1][-1] if isinstance(preds[1], tuple) else preds[1] # tuple if PyTorch model or array if exported + for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])): + if not len(pred): # save empty boxes + masks = None + elif self.args.retina_masks: + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC + else: + masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC + pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) + results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) + return results diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4737f6d6f74217c53f1b2a6c5b011823bb84f8d3 --- /dev/null +++ b/ultralytics/models/yolo/segment/train.py @@ -0,0 +1,62 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from copy import copy + +from ultralytics.models import yolo +from ultralytics.nn.tasks import SegmentationModel +from ultralytics.utils import DEFAULT_CFG, RANK +from ultralytics.utils.plotting import plot_images, plot_results + + +class SegmentationTrainer(yolo.detect.DetectionTrainer): + """ + A class extending the DetectionTrainer class for training based on a segmentation model. + + Example: + ```python + from ultralytics.models.yolo.segment import SegmentationTrainer + + args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml", epochs=3) + trainer = SegmentationTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a SegmentationTrainer object with given arguments.""" + if overrides is None: + overrides = {} + overrides["task"] = "segment" + super().__init__(cfg, overrides, _callbacks) + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return SegmentationModel initialized with specified config and weights.""" + model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose and RANK == -1) + if weights: + model.load(weights) + + return model + + def get_validator(self): + """Return an instance of SegmentationValidator for validation of YOLO model.""" + self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" + return yolo.segment.SegmentationValidator( + self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks + ) + + def plot_training_samples(self, batch, ni): + """Creates a plot of training sample images with labels and box coordinates.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"train_batch{ni}.jpg", + on_plot=self.on_plot, + ) + + def plot_metrics(self): + """Plots training/val metrics.""" + plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py new file mode 100644 index 0000000000000000000000000000000000000000..bd77ac7a888fc9ecc1c993f4f7c629c2129bcfa0 --- /dev/null +++ b/ultralytics/models/yolo/segment/val.py @@ -0,0 +1,318 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from multiprocessing.pool import ThreadPool +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.models.yolo.detect import DetectionValidator +from ultralytics.utils import LOGGER, NUM_THREADS, ops +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou +from ultralytics.utils.plotting import output_to_target, plot_images + + +class SegmentationValidator(DetectionValidator): + """ + A class extending the DetectionValidator class for validation based on a segmentation model. + + Example: + ```python + from ultralytics.models.yolo.segment import SegmentationValidator + + args = dict(model="yolov8n-seg.pt", data="coco8-seg.yaml") + validator = SegmentationValidator(args=args) + validator() + ``` + """ + + def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): + """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" + super().__init__(dataloader, save_dir, pbar, args, _callbacks) + self.plot_masks = None + self.process = None + self.args.task = "segment" + self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) + + def preprocess(self, batch): + """Preprocesses batch by converting masks to float and sending to device.""" + batch = super().preprocess(batch) + batch["masks"] = batch["masks"].to(self.device).float() + return batch + + def init_metrics(self, model): + """Initialize metrics and select mask processing function based on save_json flag.""" + super().init_metrics(model) + self.plot_masks = [] + if self.args.save_json: + check_requirements("pycocotools>=2.0.6") + # more accurate vs faster + self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask + self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) + + def get_desc(self): + """Return a formatted description of evaluation metrics.""" + return ("%22s" + "%11s" * 10) % ( + "Class", + "Images", + "Instances", + "Box(P", + "R", + "mAP50", + "mAP50-95)", + "Mask(P", + "R", + "mAP50", + "mAP50-95)", + ) + + def postprocess(self, preds): + """Post-processes YOLO predictions and returns output detections with proto.""" + p = ops.non_max_suppression( + preds[0], + self.args.conf, + self.args.iou, + labels=self.lb, + multi_label=True, + agnostic=self.args.single_cls or self.args.agnostic_nms, + max_det=self.args.max_det, + nc=self.nc, + ) + proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported + return p, proto + + def _prepare_batch(self, si, batch): + """Prepares a batch for training or inference by processing images and targets.""" + prepared_batch = super()._prepare_batch(si, batch) + midx = [si] if self.args.overlap_mask else batch["batch_idx"] == si + prepared_batch["masks"] = batch["masks"][midx] + return prepared_batch + + def _prepare_pred(self, pred, pbatch, proto): + """Prepares a batch for training or inference by processing images and targets.""" + predn = super()._prepare_pred(pred, pbatch) + pred_masks = self.process(proto, pred[:, 6:], pred[:, :4], shape=pbatch["imgsz"]) + return predn, pred_masks + + def update_metrics(self, preds, batch): + """Metrics.""" + for si, (pred, proto) in enumerate(zip(preds[0], preds[1])): + self.seen += 1 + npr = len(pred) + stat = dict( + conf=torch.zeros(0, device=self.device), + pred_cls=torch.zeros(0, device=self.device), + tp=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + tp_m=torch.zeros(npr, self.niou, dtype=torch.bool, device=self.device), + ) + pbatch = self._prepare_batch(si, batch) + cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") + nl = len(cls) + stat["target_cls"] = cls + stat["target_img"] = cls.unique() + if npr == 0: + if nl: + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + if self.args.plots: + self.confusion_matrix.process_batch(detections=None, gt_bboxes=bbox, gt_cls=cls) + continue + + # Masks + gt_masks = pbatch.pop("masks") + # Predictions + if self.args.single_cls: + pred[:, 5] = 0 + predn, pred_masks = self._prepare_pred(pred, pbatch, proto) + stat["conf"] = predn[:, 4] + stat["pred_cls"] = predn[:, 5] + + # Evaluate + if nl: + stat["tp"] = self._process_batch(predn, bbox, cls) + stat["tp_m"] = self._process_batch( + predn, bbox, cls, pred_masks, gt_masks, self.args.overlap_mask, masks=True + ) + if self.args.plots: + self.confusion_matrix.process_batch(predn, bbox, cls) + + for k in self.stats.keys(): + self.stats[k].append(stat[k]) + + pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) + if self.args.plots and self.batch_i < 3: + self.plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot + + # Save + if self.args.save_json: + self.pred_to_json( + predn, + batch["im_file"][si], + ops.scale_image( + pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), + pbatch["ori_shape"], + ratio_pad=batch["ratio_pad"][si], + ), + ) + if self.args.save_txt: + self.save_one_txt( + predn, + pred_masks, + self.args.save_conf, + pbatch["ori_shape"], + self.save_dir / "labels" / f"{Path(batch['im_file'][si]).stem}.txt", + ) + + def finalize_metrics(self, *args, **kwargs): + """Sets speed and confusion matrix for evaluation metrics.""" + self.metrics.speed = self.speed + self.metrics.confusion_matrix = self.confusion_matrix + + def _process_batch(self, detections, gt_bboxes, gt_cls, pred_masks=None, gt_masks=None, overlap=False, masks=False): + """ + Compute correct prediction matrix for a batch based on bounding boxes and optional masks. + + Args: + detections (torch.Tensor): Tensor of shape (N, 6) representing detected bounding boxes and + associated confidence scores and class indices. Each row is of the format [x1, y1, x2, y2, conf, class]. + gt_bboxes (torch.Tensor): Tensor of shape (M, 4) representing ground truth bounding box coordinates. + Each row is of the format [x1, y1, x2, y2]. + gt_cls (torch.Tensor): Tensor of shape (M,) representing ground truth class indices. + pred_masks (torch.Tensor | None): Tensor representing predicted masks, if available. The shape should + match the ground truth masks. + gt_masks (torch.Tensor | None): Tensor of shape (M, H, W) representing ground truth masks, if available. + overlap (bool): Flag indicating if overlapping masks should be considered. + masks (bool): Flag indicating if the batch contains mask data. + + Returns: + (torch.Tensor): A correct prediction matrix of shape (N, 10), where 10 represents different IoU levels. + + Note: + - If `masks` is True, the function computes IoU between predicted and ground truth masks. + - If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU. + + Example: + ```python + detections = torch.tensor([[25, 30, 200, 300, 0.8, 1], [50, 60, 180, 290, 0.75, 0]]) + gt_bboxes = torch.tensor([[24, 29, 199, 299], [55, 65, 185, 295]]) + gt_cls = torch.tensor([1, 0]) + correct_preds = validator._process_batch(detections, gt_bboxes, gt_cls) + ``` + """ + if masks: + if overlap: + nl = len(gt_cls) + index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1 + gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640) + gt_masks = torch.where(gt_masks == index, 1.0, 0.0) + if gt_masks.shape[1:] != pred_masks.shape[1:]: + gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0] + gt_masks = gt_masks.gt_(0.5) + iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1)) + else: # boxes + iou = box_iou(gt_bboxes, detections[:, :4]) + + return self.match_predictions(detections[:, 5], gt_cls, iou) + + def plot_val_samples(self, batch, ni): + """Plots validation samples with bounding box labels.""" + plot_images( + batch["img"], + batch["batch_idx"], + batch["cls"].squeeze(-1), + batch["bboxes"], + masks=batch["masks"], + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_labels.jpg", + names=self.names, + on_plot=self.on_plot, + ) + + def plot_predictions(self, batch, preds, ni): + """Plots batch predictions with masks and bounding boxes.""" + plot_images( + batch["img"], + *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed + torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, + paths=batch["im_file"], + fname=self.save_dir / f"val_batch{ni}_pred.jpg", + names=self.names, + on_plot=self.on_plot, + ) # pred + self.plot_masks.clear() + + def save_one_txt(self, predn, pred_masks, save_conf, shape, file): + """Save YOLO detections to a txt file in normalized coordinates in a specific format.""" + from ultralytics.engine.results import Results + + Results( + np.zeros((shape[0], shape[1]), dtype=np.uint8), + path=None, + names=self.names, + boxes=predn[:, :6], + masks=pred_masks, + ).save_txt(file, save_conf=save_conf) + + def pred_to_json(self, predn, filename, pred_masks): + """ + Save one JSON result. + + Examples: + >>> result = {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} + """ + from pycocotools.mask import encode # noqa + + def single_encode(x): + """Encode predicted masks as RLE and append results to jdict.""" + rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0] + rle["counts"] = rle["counts"].decode("utf-8") + return rle + + stem = Path(filename).stem + image_id = int(stem) if stem.isnumeric() else stem + box = ops.xyxy2xywh(predn[:, :4]) # xywh + box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner + pred_masks = np.transpose(pred_masks, (2, 0, 1)) + with ThreadPool(NUM_THREADS) as pool: + rles = pool.map(single_encode, pred_masks) + for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())): + self.jdict.append( + { + "image_id": image_id, + "category_id": self.class_map[int(p[5])], + "bbox": [round(x, 3) for x in b], + "score": round(p[4], 5), + "segmentation": rles[i], + } + ) + + def eval_json(self, stats): + """Return COCO-style object detection evaluation metrics.""" + if self.args.save_json and self.is_coco and len(self.jdict): + anno_json = self.data["path"] / "annotations/instances_val2017.json" # annotations + pred_json = self.save_dir / "predictions.json" # predictions + LOGGER.info(f"\nEvaluating pycocotools mAP using {pred_json} and {anno_json}...") + try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb + check_requirements("pycocotools>=2.0.6") + from pycocotools.coco import COCO # noqa + from pycocotools.cocoeval import COCOeval # noqa + + for x in anno_json, pred_json: + assert x.is_file(), f"{x} file not found" + anno = COCO(str(anno_json)) # init annotations api + pred = anno.loadRes(str(pred_json)) # init predictions api (must pass string, not Path) + for i, eval in enumerate([COCOeval(anno, pred, "bbox"), COCOeval(anno, pred, "segm")]): + if self.is_coco: + eval.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # im to eval + eval.evaluate() + eval.accumulate() + eval.summarize() + idx = i * 4 + 2 + stats[self.metrics.keys[idx + 1]], stats[self.metrics.keys[idx]] = eval.stats[ + :2 + ] # update mAP50-95 and mAP50 + except Exception as e: + LOGGER.warning(f"pycocotools unable to run: {e}") + return stats diff --git a/ultralytics/models/yolo/world/__init__.py b/ultralytics/models/yolo/world/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4380d244602c1db758195f87f7fb2c6aa8141536 --- /dev/null +++ b/ultralytics/models/yolo/world/__init__.py @@ -0,0 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .train import WorldTrainer + +__all__ = ["WorldTrainer"] diff --git a/ultralytics/models/yolo/world/train.py b/ultralytics/models/yolo/world/train.py new file mode 100644 index 0000000000000000000000000000000000000000..1a16a2d1d1eabb9e9a20fa48db89e694f8511c2f --- /dev/null +++ b/ultralytics/models/yolo/world/train.py @@ -0,0 +1,92 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import itertools + +from ultralytics.data import build_yolo_dataset +from ultralytics.models import yolo +from ultralytics.nn.tasks import WorldModel +from ultralytics.utils import DEFAULT_CFG, RANK, checks +from ultralytics.utils.torch_utils import de_parallel + + +def on_pretrain_routine_end(trainer): + """Callback.""" + if RANK in {-1, 0}: + # NOTE: for evaluation + names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())] + de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False) + device = next(trainer.model.parameters()).device + trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device) + for p in trainer.text_model.parameters(): + p.requires_grad_(False) + + +class WorldTrainer(yolo.detect.DetectionTrainer): + """ + A class to fine-tune a world model on a close-set dataset. + + Example: + ```python + from ultralytics.models.yolo.world import WorldModel + + args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3) + trainer = WorldTrainer(overrides=args) + trainer.train() + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a WorldTrainer object with given arguments.""" + if overrides is None: + overrides = {} + super().__init__(cfg, overrides, _callbacks) + + # Import and assign clip + try: + import clip + except ImportError: + checks.check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + self.clip = clip + + def get_model(self, cfg=None, weights=None, verbose=True): + """Return WorldModel initialized with specified config and weights.""" + # NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`. + # NOTE: Following the official config, nc hard-coded to 80 for now. + model = WorldModel( + cfg["yaml_file"] if isinstance(cfg, dict) else cfg, + ch=3, + nc=min(self.data["nc"], 80), + verbose=verbose and RANK == -1, + ) + if weights: + model.load(weights) + self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end) + + return model + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + return build_yolo_dataset( + self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train" + ) + + def preprocess_batch(self, batch): + """Preprocesses a batch of images for YOLOWorld training, adjusting formatting and dimensions as needed.""" + batch = super().preprocess_batch(batch) + + # NOTE: add text features + texts = list(itertools.chain(*batch["texts"])) + text_token = self.clip.tokenize(texts).to(batch["img"].device) + txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32 + txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) + batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1]) + return batch diff --git a/ultralytics/models/yolo/world/train_world.py b/ultralytics/models/yolo/world/train_world.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbdb2a4e77343603adff4b54a1d19a4d009360e --- /dev/null +++ b/ultralytics/models/yolo/world/train_world.py @@ -0,0 +1,109 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset +from ultralytics.data.utils import check_det_dataset +from ultralytics.models.yolo.world import WorldTrainer +from ultralytics.utils import DEFAULT_CFG +from ultralytics.utils.torch_utils import de_parallel + + +class WorldTrainerFromScratch(WorldTrainer): + """ + A class extending the WorldTrainer class for training a world model from scratch on open-set dataset. + + Example: + ```python + from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch + from ultralytics import YOLOWorld + + data = dict( + train=dict( + yolo_data=["Objects365.yaml"], + grounding_data=[ + dict( + img_path="../datasets/flickr30k/images", + json_file="../datasets/flickr30k/final_flickr_separateGT_train.json", + ), + dict( + img_path="../datasets/GQA/images", + json_file="../datasets/GQA/final_mixed_train_no_coco.json", + ), + ], + ), + val=dict(yolo_data=["lvis.yaml"]), + ) + + model = YOLOWorld("yolov8s-worldv2.yaml") + model.train(data=data, trainer=WorldTrainerFromScratch) + ``` + """ + + def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): + """Initialize a WorldTrainer object with given arguments.""" + if overrides is None: + overrides = {} + super().__init__(cfg, overrides, _callbacks) + + def build_dataset(self, img_path, mode="train", batch=None): + """ + Build YOLO Dataset. + + Args: + img_path (List[str] | str): Path to the folder containing images. + mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode. + batch (int, optional): Size of batches, this is for `rect`. Defaults to None. + """ + gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32) + if mode != "train": + return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs) + dataset = [ + build_yolo_dataset(self.args, im_path, batch, self.data, stride=gs, multi_modal=True) + if isinstance(im_path, str) + else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs) + for im_path in img_path + ] + return YOLOConcatDataset(dataset) if len(dataset) > 1 else dataset[0] + + def get_dataset(self): + """ + Get train, val path from data dict if it exists. + + Returns None if data format is not recognized. + """ + final_data = {} + data_yaml = self.args.data + assert data_yaml.get("train", False), "train dataset not found" # object365.yaml + assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml + data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()} + assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}." + val_split = "minival" if "lvis" in data["val"][0]["val"] else "val" + for d in data["val"]: + if d.get("minival") is None: # for lvis dataset + continue + d["minival"] = str(d["path"] / d["minival"]) + for s in ["train", "val"]: + final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]] + # save grounding data if there's one + grounding_data = data_yaml[s].get("grounding_data") + if grounding_data is None: + continue + grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data] + for g in grounding_data: + assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}" + final_data[s] += grounding_data + # NOTE: to make training work properly, set `nc` and `names` + final_data["nc"] = data["val"][0]["nc"] + final_data["names"] = data["val"][0]["names"] + self.data = final_data + return final_data["train"], final_data["val"][0] + + def plot_training_labels(self): + """DO NOT plot labels.""" + pass + + def final_eval(self): + """Performs final evaluation and validation for object detection YOLO-World model.""" + val = self.args.data["val"]["yolo_data"][0] + self.validator.args.data = val + self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val" + return super().final_eval() diff --git a/ultralytics/nn/__init__.py b/ultralytics/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a54d51cb76e4b6356ef22ec574b6031b4640ea --- /dev/null +++ b/ultralytics/nn/__init__.py @@ -0,0 +1,29 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .tasks import ( + BaseModel, + ClassificationModel, + DetectionModel, + SegmentationModel, + attempt_load_one_weight, + attempt_load_weights, + guess_model_scale, + guess_model_task, + parse_model, + torch_safe_load, + yaml_model_load, +) + +__all__ = ( + "attempt_load_one_weight", + "attempt_load_weights", + "parse_model", + "yaml_model_load", + "guess_model_task", + "guess_model_scale", + "torch_safe_load", + "DetectionModel", + "SegmentationModel", + "ClassificationModel", + "BaseModel", +) diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7af68b1a091a5e7183225dae5637d5a642f10e --- /dev/null +++ b/ultralytics/nn/autobackend.py @@ -0,0 +1,763 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import ast +import json +import platform +import zipfile +from collections import OrderedDict, namedtuple +from pathlib import Path + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from ultralytics.utils import ARM64, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, ROOT, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_version, check_yaml +from ultralytics.utils.downloads import attempt_download_asset, is_url + + +def check_class_names(names): + """ + Check class names. + + Map imagenet class codes to human-readable names if required. Convert lists to dicts. + """ + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names, dict): + # Convert 1) string keys to int, i.e. '0' to 0, and non-string values to strings, i.e. True to 'True' + names = {int(k): str(v) for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError( + f"{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices " + f"{min(names.keys())}-{max(names.keys())} defined in your dataset YAML." + ) + if isinstance(names[0], str) and names[0].startswith("n0"): # imagenet class codes, i.e. 'n01440764' + names_map = yaml_load(ROOT / "cfg/datasets/ImageNet.yaml")["map"] # human-readable names + names = {k: names_map[v] for k, v in names.items()} + return names + + +def default_class_names(data=None): + """Applies default class names to an input YAML file or returns numerical class names.""" + if data: + try: + return yaml_load(check_yaml(data))["names"] + except Exception: + pass + return {i: f"class{i}" for i in range(999)} # return default if above errors + + +class AutoBackend(nn.Module): + """ + Handles dynamic backend selection for running inference using Ultralytics YOLO models. + + The AutoBackend class is designed to provide an abstraction layer for various inference engines. It supports a wide + range of formats, each with specific naming conventions as outlined below: + + Supported Formats and Naming Conventions: + | Format | File Suffix | + |-----------------------|-------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True) | + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model/ | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model/ | + | MNN | *.mnn | + | NCNN | *_ncnn_model/ | + + This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy + models across various platforms. + """ + + @torch.no_grad() + def __init__( + self, + weights="yolo11n.pt", + device=torch.device("cpu"), + dnn=False, + data=None, + fp16=False, + batch=1, + fuse=True, + verbose=True, + ): + """ + Initialize the AutoBackend for inference. + + Args: + weights (str | torch.nn.Module): Path to the model weights file or a module instance. Defaults to 'yolo11n.pt'. + device (torch.device): Device to run the model on. Defaults to CPU. + dnn (bool): Use OpenCV DNN module for ONNX inference. Defaults to False. + data (str | Path | optional): Path to the additional data.yaml file containing class names. Optional. + fp16 (bool): Enable half-precision inference. Supported only on specific backends. Defaults to False. + batch (int): Batch-size to assume for inference. + fuse (bool): Fuse Conv2D + BatchNorm layers for optimization. Defaults to True. + verbose (bool): Enable verbose logging. Defaults to True. + """ + super().__init__() + w = str(weights[0] if isinstance(weights, list) else weights) + nn_module = isinstance(weights, torch.nn.Module) + ( + pt, + jit, + onnx, + xml, + engine, + coreml, + saved_model, + pb, + tflite, + edgetpu, + tfjs, + paddle, + mnn, + ncnn, + imx, + triton, + ) = self._model_type(w) + fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16 + nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH) + stride = 32 # default stride + model, metadata, task = None, None, None + + # Set device + cuda = torch.cuda.is_available() and device.type != "cpu" # use CUDA + if cuda and not any([nn_module, pt, jit, engine, onnx, paddle]): # GPU dataloader formats + device = torch.device("cpu") + cuda = False + + # Download if not local + if not (pt or triton or nn_module): + w = attempt_download_asset(w) + + # In-memory PyTorch model + if nn_module: + model = weights.to(device) + if fuse: + model = model.fuse(verbose=verbose) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + pt = True + + # PyTorch + elif pt: + from ultralytics.nn.tasks import attempt_load_weights + + model = attempt_load_weights( + weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse + ) + if hasattr(model, "kpt_shape"): + kpt_shape = model.kpt_shape # pose-only + stride = max(int(model.stride.max()), 32) # model stride + names = model.module.names if hasattr(model, "module") else model.names # get class names + model.half() if fp16 else model.float() + self.model = model # explicitly assign for to(), cpu(), cuda(), half() + + # TorchScript + elif jit: + LOGGER.info(f"Loading {w} for TorchScript inference...") + extra_files = {"config.txt": ""} # model metadata + model = torch.jit.load(w, _extra_files=extra_files, map_location=device) + model.half() if fp16 else model.float() + if extra_files["config.txt"]: # load metadata dict + metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) + + # ONNX OpenCV DNN + elif dnn: + LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") + check_requirements("opencv-python>=4.5.4") + net = cv2.dnn.readNetFromONNX(w) + + # ONNX Runtime and IMX + elif onnx or imx: + LOGGER.info(f"Loading {w} for ONNX Runtime inference...") + check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) + if IS_RASPBERRYPI or IS_JETSON: + # Fix 'numpy.linalg._umath_linalg' has no attribute '_ilp64' for TF SavedModel on RPi and Jetson + check_requirements("numpy==1.23.5") + import onnxruntime + + providers = ["CPUExecutionProvider"] + if cuda and "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + providers.insert(0, "CUDAExecutionProvider") + elif cuda: # Only log warning if CUDA was requested but unavailable + LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime with CUDA. Using CPU...") + device = torch.device("cpu") + cuda = False + LOGGER.info(f"Using ONNX Runtime {providers[0]}") + if onnx: + session = onnxruntime.InferenceSession(w, providers=providers) + else: + check_requirements( + ["model-compression-toolkit==2.1.1", "sony-custom-layers[torch]==0.2.0", "onnxruntime-extensions"] + ) + w = next(Path(w).glob("*.onnx")) + LOGGER.info(f"Loading {w} for ONNX IMX inference...") + import mct_quantizers as mctq + from sony_custom_layers.pytorch.object_detection import nms_ort # noqa + + session = onnxruntime.InferenceSession( + w, mctq.get_ort_session_options(), providers=["CPUExecutionProvider"] + ) + task = "detect" + + output_names = [x.name for x in session.get_outputs()] + metadata = session.get_modelmeta().custom_metadata_map + dynamic = isinstance(session.get_outputs()[0].shape[0], str) + if not dynamic: + io = session.io_binding() + bindings = [] + for output in session.get_outputs(): + y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device) + io.bind_output( + name=output.name, + device_type=device.type, + device_id=device.index if cuda else 0, + element_type=np.float16 if fp16 else np.float32, + shape=tuple(y_tensor.shape), + buffer_ptr=y_tensor.data_ptr(), + ) + bindings.append(y_tensor) + + # OpenVINO + elif xml: + LOGGER.info(f"Loading {w} for OpenVINO inference...") + check_requirements("openvino>=2024.0.0") + import openvino as ov + + core = ov.Core() + w = Path(w) + if not w.is_file(): # if not *.xml + w = next(w.glob("*.xml")) # get *.xml file from *_openvino_model dir + ov_model = core.read_model(model=str(w), weights=w.with_suffix(".bin")) + if ov_model.get_parameters()[0].get_layout().empty: + ov_model.get_parameters()[0].set_layout(ov.Layout("NCHW")) + + # OpenVINO inference modes are 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' + inference_mode = "CUMULATIVE_THROUGHPUT" if batch > 1 else "LATENCY" + LOGGER.info(f"Using OpenVINO {inference_mode} mode for batch={batch} inference...") + ov_compiled_model = core.compile_model( + ov_model, + device_name="AUTO", # AUTO selects best available device, do not modify + config={"PERFORMANCE_HINT": inference_mode}, + ) + input_name = ov_compiled_model.input().get_any_name() + metadata = w.parent / "metadata.yaml" + + # TensorRT + elif engine: + LOGGER.info(f"Loading {w} for TensorRT inference...") + try: + import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download + except ImportError: + if LINUX: + check_requirements("tensorrt>7.0.0,!=10.1.0") + import tensorrt as trt # noqa + check_version(trt.__version__, ">=7.0.0", hard=True) + check_version(trt.__version__, "!=10.1.0", msg="https://github.com/ultralytics/ultralytics/pull/14239") + if device.type == "cpu": + device = torch.device("cuda:0") + Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr")) + logger = trt.Logger(trt.Logger.INFO) + # Read file + with open(w, "rb") as f, trt.Runtime(logger) as runtime: + try: + meta_len = int.from_bytes(f.read(4), byteorder="little") # read metadata length + metadata = json.loads(f.read(meta_len).decode("utf-8")) # read metadata + except UnicodeDecodeError: + f.seek(0) # engine file may lack embedded Ultralytics metadata + model = runtime.deserialize_cuda_engine(f.read()) # read engine + + # Model context + try: + context = model.create_execution_context() + except Exception as e: # model is None + LOGGER.error(f"ERROR: TensorRT model exported with a different version than {trt.__version__}\n") + raise e + + bindings = OrderedDict() + output_names = [] + fp16 = False # default updated below + dynamic = False + is_trt10 = not hasattr(model, "num_bindings") + num = range(model.num_io_tensors) if is_trt10 else range(model.num_bindings) + for i in num: + if is_trt10: + name = model.get_tensor_name(i) + dtype = trt.nptype(model.get_tensor_dtype(name)) + is_input = model.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input: + if -1 in tuple(model.get_tensor_shape(name)): + dynamic = True + context.set_input_shape(name, tuple(model.get_tensor_profile_shape(name, 0)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_tensor_shape(name)) + else: # TensorRT < 10.0 + name = model.get_binding_name(i) + dtype = trt.nptype(model.get_binding_dtype(i)) + is_input = model.binding_is_input(i) + if model.binding_is_input(i): + if -1 in tuple(model.get_binding_shape(i)): # dynamic + dynamic = True + context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[1])) + if dtype == np.float16: + fp16 = True + else: + output_names.append(name) + shape = tuple(context.get_binding_shape(i)) + im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device) + bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) + binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) + batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size + + # CoreML + elif coreml: + LOGGER.info(f"Loading {w} for CoreML inference...") + import coremltools as ct + + model = ct.models.MLModel(w) + metadata = dict(model.user_defined_metadata) + + # TF SavedModel + elif saved_model: + LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") + import tensorflow as tf + + keras = False # assume TF1 saved_model + model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) + metadata = Path(w) / "metadata.yaml" + + # TF GraphDef + elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt + LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") + import tensorflow as tf + + from ultralytics.engine.exporter import gd_outputs + + def wrap_frozen_graph(gd, inputs, outputs): + """Wrap frozen graphs for deployment.""" + x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped + ge = x.graph.as_graph_element + return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) + + gd = tf.Graph().as_graph_def() # TF GraphDef + with open(w, "rb") as f: + gd.ParseFromString(f.read()) + frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) + try: # find metadata in SavedModel alongside GraphDef + metadata = next(Path(w).resolve().parent.rglob(f"{Path(w).stem}_saved_model*/metadata.yaml")) + except StopIteration: + pass + + # TFLite or TFLite Edge TPU + elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python + try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu + from tflite_runtime.interpreter import Interpreter, load_delegate + except ImportError: + import tensorflow as tf + + Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate + if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime + device = device[3:] if str(device).startswith("tpu") else ":0" + LOGGER.info(f"Loading {w} on device {device[1:]} for TensorFlow Lite Edge TPU inference...") + delegate = {"Linux": "libedgetpu.so.1", "Darwin": "libedgetpu.1.dylib", "Windows": "edgetpu.dll"}[ + platform.system() + ] + interpreter = Interpreter( + model_path=w, + experimental_delegates=[load_delegate(delegate, options={"device": device})], + ) + device = "cpu" # Required, otherwise PyTorch will try to use the wrong device + else: # TFLite + LOGGER.info(f"Loading {w} for TensorFlow Lite inference...") + interpreter = Interpreter(model_path=w) # load TFLite model + interpreter.allocate_tensors() # allocate + input_details = interpreter.get_input_details() # inputs + output_details = interpreter.get_output_details() # outputs + # Load metadata + try: + with zipfile.ZipFile(w, "r") as model: + meta_file = model.namelist()[0] + metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) + except zipfile.BadZipFile: + pass + + # TF.js + elif tfjs: + raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") + + # PaddlePaddle + elif paddle: + LOGGER.info(f"Loading {w} for PaddlePaddle inference...") + check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") + import paddle.inference as pdi # noqa + + w = Path(w) + if not w.is_file(): # if not *.pdmodel + w = next(w.rglob("*.pdmodel")) # get *.pdmodel file from *_paddle_model dir + config = pdi.Config(str(w), str(w.with_suffix(".pdiparams"))) + if cuda: + config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0) + predictor = pdi.create_predictor(config) + input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) + output_names = predictor.get_output_names() + metadata = w.parents[1] / "metadata.yaml" + + # MNN + elif mnn: + LOGGER.info(f"Loading {w} for MNN inference...") + check_requirements("MNN") # requires MNN + import os + + import MNN + + config = {"precision": "low", "backend": "CPU", "numThread": (os.cpu_count() + 1) // 2} + rt = MNN.nn.create_runtime_manager((config,)) + net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True) + + def torch_to_mnn(x): + return MNN.expr.const(x.data_ptr(), x.shape) + + metadata = json.loads(net.get_info()["bizCode"]) + + # NCNN + elif ncnn: + LOGGER.info(f"Loading {w} for NCNN inference...") + check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN + import ncnn as pyncnn + + net = pyncnn.Net() + net.opt.use_vulkan_compute = cuda + w = Path(w) + if not w.is_file(): # if not *.param + w = next(w.glob("*.param")) # get *.param file from *_ncnn_model dir + net.load_param(str(w)) + net.load_model(str(w.with_suffix(".bin"))) + metadata = w.parent / "metadata.yaml" + + # NVIDIA Triton Inference Server + elif triton: + check_requirements("tritonclient[all]") + from ultralytics.utils.triton import TritonRemoteModel + + model = TritonRemoteModel(w) + metadata = model.metadata + + # Any other format (unsupported) + else: + from ultralytics.engine.exporter import export_formats + + raise TypeError( + f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n" + f"See https://docs.ultralytics.com/modes/predict for help." + ) + + # Load external metadata YAML + if isinstance(metadata, (str, Path)) and Path(metadata).exists(): + metadata = yaml_load(metadata) + if metadata and isinstance(metadata, dict): + for k, v in metadata.items(): + if k in {"stride", "batch"}: + metadata[k] = int(v) + elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str): + metadata[k] = eval(v) + stride = metadata["stride"] + task = metadata["task"] + batch = metadata["batch"] + imgsz = metadata["imgsz"] + names = metadata["names"] + kpt_shape = metadata.get("kpt_shape") + elif not (pt or triton or nn_module): + LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'") + + # Check names + if "names" not in locals(): # names missing + names = default_class_names(data) + names = check_class_names(names) + + # Disable gradients + if pt: + for p in model.parameters(): + p.requires_grad = False + + self.__dict__.update(locals()) # assign all variables to self + + def forward(self, im, augment=False, visualize=False, embed=None): + """ + Runs inference on the YOLOv8 MultiBackend model. + + Args: + im (torch.Tensor): The image tensor to perform inference on. + augment (bool): whether to perform data augmentation during inference, defaults to False + visualize (bool): whether to visualize the output predictions, defaults to False + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (tuple): Tuple containing the raw output tensor, and processed output for visualization (if visualize=True) + """ + b, ch, h, w = im.shape # batch, channel, height, width + if self.fp16 and im.dtype != torch.float16: + im = im.half() # to FP16 + if self.nhwc: + im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) + + # PyTorch + if self.pt or self.nn_module: + y = self.model(im, augment=augment, visualize=visualize, embed=embed) + + # TorchScript + elif self.jit: + y = self.model(im) + + # ONNX OpenCV DNN + elif self.dnn: + im = im.cpu().numpy() # torch to numpy + self.net.setInput(im) + y = self.net.forward() + + # ONNX Runtime + elif self.onnx or self.imx: + if self.dynamic: + im = im.cpu().numpy() # torch to numpy + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + else: + if not self.cuda: + im = im.cpu() + self.io.bind_input( + name="images", + device_type=im.device.type, + device_id=im.device.index if im.device.type == "cuda" else 0, + element_type=np.float16 if self.fp16 else np.float32, + shape=tuple(im.shape), + buffer_ptr=im.data_ptr(), + ) + self.session.run_with_iobinding(self.io) + y = self.bindings + if self.imx: + # boxes, conf, cls + y = np.concatenate([y[0], y[1][:, :, None], y[2][:, :, None]], axis=-1) + + # OpenVINO + elif self.xml: + im = im.cpu().numpy() # FP32 + + if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes + n = im.shape[0] # number of images in batch + results = [None] * n # preallocate list with None to match the number of images + + def callback(request, userdata): + """Places result in preallocated list using userdata index.""" + results[userdata] = request.results + + # Create AsyncInferQueue, set the callback and start asynchronous inference for each input image + async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) + async_queue.set_callback(callback) + for i in range(n): + # Start async inference with userdata=i to specify the position in results list + async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW + async_queue.wait_all() # wait for all inference requests to complete + y = np.concatenate([list(r.values())[0] for r in results]) + + else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 + y = list(self.ov_compiled_model(im).values()) + + # TensorRT + elif self.engine: + if self.dynamic and im.shape != self.bindings["images"].shape: + if self.is_trt10: + self.context.set_input_shape("images", im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + self.bindings[name].data.resize_(tuple(self.context.get_tensor_shape(name))) + else: + i = self.model.get_binding_index("images") + self.context.set_binding_shape(i, im.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im.shape) + for name in self.output_names: + i = self.model.get_binding_index(name) + self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i))) + + s = self.bindings["images"].shape + assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}" + self.binding_addrs["images"] = int(im.data_ptr()) + self.context.execute_v2(list(self.binding_addrs.values())) + y = [self.bindings[x].data for x in sorted(self.output_names)] + + # CoreML + elif self.coreml: + im = im[0].cpu().numpy() + im_pil = Image.fromarray((im * 255).astype("uint8")) + # im = im.resize((192, 320), Image.BILINEAR) + y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized + if "confidence" in y: + raise TypeError( + "Ultralytics only supports inference of non-pipelined CoreML models exported with " + f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export." + ) + # TODO: CoreML NMS inference handling + # from ultralytics.utils.ops import xywh2xyxy + # box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels + # conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32) + # y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1) + y = list(y.values()) + if len(y) == 2 and len(y[1].shape) != 4: # segmentation model + y = list(reversed(y)) # reversed for segmentation models (pred, proto) + + # PaddlePaddle + elif self.paddle: + im = im.cpu().numpy().astype(np.float32) + self.input_handle.copy_from_cpu(im) + self.predictor.run() + y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + + # MNN + elif self.mnn: + input_var = self.torch_to_mnn(im) + output_var = self.net.onForward([input_var]) + y = [x.read() for x in output_var] + + # NCNN + elif self.ncnn: + mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + # WARNING: 'output_names' sorted as a temporary fix for https://github.com/pnnx/pnnx/issues/130 + y = [np.array(ex.extract(x)[1])[None] for x in sorted(self.net.output_names())] + + # NVIDIA Triton Inference Server + elif self.triton: + im = im.cpu().numpy() # torch to numpy + y = self.model(im) + + # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) + else: + im = im.cpu().numpy() + if self.saved_model: # SavedModel + y = self.model(im, training=False) if self.keras else self.model(im) + if not isinstance(y, list): + y = [y] + elif self.pb: # GraphDef + y = self.frozen_func(x=self.tf.constant(im)) + else: # Lite or Edge TPU + details = self.input_details[0] + is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model + if is_int: + scale, zero_point = details["quantization"] + im = (im / scale + zero_point).astype(details["dtype"]) # de-scale + self.interpreter.set_tensor(details["index"], im) + self.interpreter.invoke() + y = [] + for output in self.output_details: + x = self.interpreter.get_tensor(output["index"]) + if is_int: + scale, zero_point = output["quantization"] + x = (x.astype(np.float32) - zero_point) * scale # re-scale + if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well + # Denormalize xywh by image size. See https://github.com/ultralytics/ultralytics/pull/1695 + # xywh are normalized in TFLite/EdgeTPU to mitigate quantization error of integer models + if x.shape[-1] == 6: # end-to-end model + x[:, :, [0, 2]] *= w + x[:, :, [1, 3]] *= h + else: + x[:, [0, 2]] *= w + x[:, [1, 3]] *= h + if self.task == "pose": + x[:, 5::3] *= w + x[:, 6::3] *= h + y.append(x) + # TF segment fixes: export is reversed vs ONNX export and protos are transposed + if len(y) == 2: # segment with (det, proto) output order reversed + if len(y[1].shape) != 4: + y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32) + if y[1].shape[-1] == 6: # end-to-end model + y = [y[1]] + else: + y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160) + y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y] + + # for x in y: + # print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes + if isinstance(y, (list, tuple)): + if len(self.names) == 999 and (self.task == "segment" or len(y) == 2): # segments and names not defined + nc = y[0].shape[1] - y[1].shape[1] - 4 # y = (1, 32, 160, 160), (1, 116, 8400) + self.names = {i: f"class{i}" for i in range(nc)} + return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y] + else: + return self.from_numpy(y) + + def from_numpy(self, x): + """ + Convert a numpy array to a tensor. + + Args: + x (np.ndarray): The array to be converted. + + Returns: + (torch.Tensor): The converted tensor + """ + return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x + + def warmup(self, imgsz=(1, 3, 640, 640)): + """ + Warm up the model by running one forward pass with a dummy input. + + Args: + imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width) + """ + import torchvision # noqa (import here so torchvision import time not recorded in postprocess time) + + warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module + if any(warmup_types) and (self.device.type != "cpu" or self.triton): + im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input + for _ in range(2 if self.jit else 1): + self.forward(im) # warmup + + @staticmethod + def _model_type(p="path/to/model.pt"): + """ + Takes a path to a model file and returns the model type. Possibles types are pt, jit, onnx, xml, engine, coreml, + saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle. + + Args: + p: path to the model file. Defaults to path/to/model.pt + + Examples: + >>> model = AutoBackend(weights="path/to/model.onnx") + >>> model_type = model._model_type() # returns "onnx" + """ + from ultralytics.engine.exporter import export_formats + + sf = export_formats()["Suffix"] # export suffixes + if not is_url(p) and not isinstance(p, str): + check_suffix(p, sf) # checks + name = Path(p).name + types = [s in name for s in sf] + types[5] |= name.endswith(".mlmodel") # retain support for older Apple CoreML *.mlmodel formats + types[8] &= not types[9] # tflite &= not edgetpu + if any(types): + triton = False + else: + from urllib.parse import urlsplit + + url = urlsplit(p) + triton = bool(url.netloc) and bool(url.path) and url.scheme in {"http", "grpc"} + + return types + [triton] diff --git a/ultralytics/nn/modules/__init__.py b/ultralytics/nn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a168b3b085bbd184a53aecf1506c20c9b0f3dff --- /dev/null +++ b/ultralytics/nn/modules/__init__.py @@ -0,0 +1,165 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Ultralytics modules. + +Example: + Visualize a module with Netron. + ```python + from ultralytics.nn.modules import * + import torch + import os + + x = torch.ones(1, 128, 40, 40) + m = Conv(128, 128) + f = f"{m._get_name()}.onnx" + torch.onnx.export(m, x, f) + os.system(f"onnxslim {f} {f} && open {f}") # pip install onnxslim + ``` +""" + +from .block import ( + C1, + C2, + C2PSA, + C3, + C3TR, + CIB, + DFL, + ELAN1, + PSA, + SPP, + SPPELAN, + SPPF, + AConv, + ADown, + Attention, + BNContrastiveHead, + Bottleneck, + BottleneckCSP, + C2f, + C2fAttn, + C2fCIB, + C2fPSA, + C3Ghost, + C3k2, + C3x, + CBFuse, + CBLinear, + ContrastiveHead, + GhostBottleneck, + HGBlock, + HGStem, + ImagePoolingAttn, + Proto, + RepC3, + RepNCSPELAN4, + RepVGGDW, + ResNetLayer, + SCDown, + TorchVision, + A2C2f, +) +from .conv import ( + CBAM, + ChannelAttention, + Concat, + Conv, + Conv2, + ConvTranspose, + DWConv, + DWConvTranspose2d, + Focus, + GhostConv, + Index, + LightConv, + RepConv, + SpatialAttention, +) +from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, v10Detect +from .transformer import ( + AIFI, + MLP, + DeformableTransformerDecoder, + DeformableTransformerDecoderLayer, + LayerNorm2d, + MLPBlock, + MSDeformAttn, + TransformerBlock, + TransformerEncoderLayer, + TransformerLayer, +) + +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "RepConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C3k2", + "SCDown", + "C2fPSA", + "C2PSA", + "C2fAttn", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "Detect", + "Segment", + "Pose", + "Classify", + "TransformerEncoderLayer", + "RepC3", + "RTDETRDecoder", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", + "ResNetLayer", + "OBB", + "WorldDetect", + "v10Detect", + "ImagePoolingAttn", + "ContrastiveHead", + "BNContrastiveHead", + "RepNCSPELAN4", + "ADown", + "SPPELAN", + "CBFuse", + "CBLinear", + "AConv", + "ELAN1", + "RepVGGDW", + "CIB", + "C2fCIB", + "Attention", + "PSA", + "TorchVision", + "Index", + "A2C2f" +) diff --git a/ultralytics/nn/modules/activation.py b/ultralytics/nn/modules/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6b44b47b6f96cde8a9a09aad001d38f8689406 --- /dev/null +++ b/ultralytics/nn/modules/activation.py @@ -0,0 +1,21 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Activation modules.""" + +import torch +import torch.nn as nn + + +class AGLU(nn.Module): + """Unified activation function module from https://github.com/kostas1515/AGLU.""" + + def __init__(self, device=None, dtype=None) -> None: + """Initialize the Unified activation function.""" + super().__init__() + self.act = nn.Softplus(beta=-1.0) + self.lambd = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # lambda parameter + self.kappa = nn.Parameter(nn.init.uniform_(torch.empty(1, device=device, dtype=dtype))) # kappa parameter + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute the forward pass of the Unified activation function.""" + lam = torch.clamp(self.lambd, min=0.0001) + return torch.exp((1 / lam) * self.act((self.kappa * x) - torch.log(lam))) diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py new file mode 100644 index 0000000000000000000000000000000000000000..c39b3d429613e5ff8db7ce3dbb05bb0007549f13 --- /dev/null +++ b/ultralytics/nn/modules/block.py @@ -0,0 +1,1355 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Block modules.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils.torch_utils import fuse_conv_and_bn + +from .conv import Conv, DWConv, GhostConv, LightConv, RepConv, autopad +from .transformer import TransformerBlock + +__all__ = ( + "DFL", + "HGBlock", + "HGStem", + "SPP", + "SPPF", + "C1", + "C2", + "C3", + "C2f", + "C2fAttn", + "ImagePoolingAttn", + "ContrastiveHead", + "BNContrastiveHead", + "C3x", + "C3TR", + "C3Ghost", + "GhostBottleneck", + "Bottleneck", + "BottleneckCSP", + "Proto", + "RepC3", + "ResNetLayer", + "RepNCSPELAN4", + "ELAN1", + "ADown", + "AConv", + "SPPELAN", + "CBFuse", + "CBLinear", + "C3k2", + "C2fPSA", + "C2PSA", + "RepVGGDW", + "CIB", + "C2fCIB", + "Attention", + "PSA", + "SCDown", + "TorchVision", +) + + +class DFL(nn.Module): + """ + Integral module of Distribution Focal Loss (DFL). + + Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 + """ + + def __init__(self, c1=16): + """Initialize a convolutional layer with a given number of input channels.""" + super().__init__() + self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) + x = torch.arange(c1, dtype=torch.float) + self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) + self.c1 = c1 + + def forward(self, x): + """Applies a transformer layer on input tensor 'x' and returns a tensor.""" + b, _, a = x.shape # batch, channels, anchors + return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) + # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a) + + +class Proto(nn.Module): + """YOLOv8 mask Proto module for segmentation models.""" + + def __init__(self, c1, c_=256, c2=32): + """ + Initializes the YOLOv8 mask Proto module with specified number of protos and masks. + + Input arguments are ch_in, number of protos, number of masks. + """ + super().__init__() + self.cv1 = Conv(c1, c_, k=3) + self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest') + self.cv2 = Conv(c_, c_, k=3) + self.cv3 = Conv(c_, c2) + + def forward(self, x): + """Performs a forward pass through layers using an upsampled input image.""" + return self.cv3(self.cv2(self.upsample(self.cv1(x)))) + + +class HGStem(nn.Module): + """ + StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py + """ + + def __init__(self, c1, cm, c2): + """Initialize the SPP layer with input/output channels and specified kernel sizes for max pooling.""" + super().__init__() + self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU()) + self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU()) + self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU()) + self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU()) + self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU()) + self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True) + + def forward(self, x): + """Forward pass of a PPHGNetV2 backbone layer.""" + x = self.stem1(x) + x = F.pad(x, [0, 1, 0, 1]) + x2 = self.stem2a(x) + x2 = F.pad(x2, [0, 1, 0, 1]) + x2 = self.stem2b(x2) + x1 = self.pool(x) + x = torch.cat([x1, x2], dim=1) + x = self.stem3(x) + x = self.stem4(x) + return x + + +class HGBlock(nn.Module): + """ + HG_Block of PPHGNetV2 with 2 convolutions and LightConv. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py + """ + + def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()): + """Initializes a CSP Bottleneck with 1 convolution using specified input and output channels.""" + super().__init__() + block = LightConv if lightconv else Conv + self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n)) + self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv + self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv + self.add = shortcut and c1 == c2 + + def forward(self, x): + """Forward pass of a PPHGNetV2 backbone layer.""" + y = [x] + y.extend(m(y[-1]) for m in self.m) + y = self.ec(self.sc(torch.cat(y, 1))) + return y + x if self.add else y + + +class SPP(nn.Module): + """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.""" + + def __init__(self, c1, c2, k=(5, 9, 13)): + """Initialize the SPP layer with input/output channels and pooling kernel sizes.""" + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + """Forward pass of the SPP layer, performing spatial pyramid pooling.""" + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class SPPF(nn.Module): + """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.""" + + def __init__(self, c1, c2, k=5): + """ + Initializes the SPPF layer with given input/output channels and kernel size. + + This module is equivalent to SPP(k=(5, 9, 13)). + """ + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * 4, c2, 1, 1) + self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + + def forward(self, x): + """Forward pass through Ghost Convolution block.""" + y = [self.cv1(x)] + y.extend(self.m(y[-1]) for _ in range(3)) + return self.cv2(torch.cat(y, 1)) + + +class C1(nn.Module): + """CSP Bottleneck with 1 convolution.""" + + def __init__(self, c1, c2, n=1): + """Initializes the CSP Bottleneck with configurations for 1 convolution with arguments ch_in, ch_out, number.""" + super().__init__() + self.cv1 = Conv(c1, c2, 1, 1) + self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) + + def forward(self, x): + """Applies cross-convolutions to input in the C3 module.""" + y = self.cv1(x) + return self.m(y) + y + + +class C2(nn.Module): + """CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes a CSP Bottleneck with 2 convolutions and optional shortcut connection.""" + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2) + # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention() + self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))) + + def forward(self, x): + """Forward pass through the CSP bottleneck with 2 convolutions.""" + a, b = self.cv1(x).chunk(2, 1) + return self.cv2(torch.cat((self.m(a), b), 1)) + + +class C2f(nn.Module): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): + """Initializes a CSP bottleneck with 2 convolutions and n Bottleneck blocks for faster processing.""" + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + + def forward(self, x): + """Forward pass through C2f layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend(m(y[-1]) for m in self.m) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x): + """Forward pass using split() instead of chunk().""" + y = self.cv1(x).split((self.c, self.c), 1) + y = [y[0], y[1]] + y.extend(m(y[-1]) for m in self.m) + return self.cv2(torch.cat(y, 1)) + + +class C3(nn.Module): + """CSP Bottleneck with 3 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initialize the CSP Bottleneck with given channels, number, shortcut, groups, and expansion values.""" + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) + + def forward(self, x): + """Forward pass through the CSP bottleneck with 2 convolutions.""" + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) + + +class C3x(C3): + """C3 module with cross-convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initialize C3TR instance and set default parameters.""" + super().__init__(c1, c2, n, shortcut, g, e) + self.c_ = int(c2 * e) + self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n))) + + +class RepC3(nn.Module): + """Rep C3.""" + + def __init__(self, c1, c2, n=3, e=1.0): + """Initialize CSP Bottleneck with a single convolution using input channels, output channels, and number.""" + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)]) + self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity() + + def forward(self, x): + """Forward pass of RT-DETR neck layer.""" + return self.cv3(self.m(self.cv1(x)) + self.cv2(x)) + + +class C3TR(C3): + """C3 module with TransformerBlock().""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initialize C3Ghost module with GhostBottleneck().""" + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) + self.m = TransformerBlock(c_, c_, 4, n) + + +class C3Ghost(C3): + """C3 module with GhostBottleneck().""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling.""" + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) + + +class GhostBottleneck(nn.Module): + """Ghost Bottleneck https://github.com/huawei-noah/ghostnet.""" + + def __init__(self, c1, c2, k=3, s=1): + """Initializes GhostBottleneck module with arguments ch_in, ch_out, kernel, stride.""" + super().__init__() + c_ = c2 // 2 + self.conv = nn.Sequential( + GhostConv(c1, c_, 1, 1), # pw + DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw + GhostConv(c_, c2, 1, 1, act=False), # pw-linear + ) + self.shortcut = ( + nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() + ) + + def forward(self, x): + """Applies skip connection and concatenation to input tensor.""" + return self.conv(x) + self.shortcut(x) + + +class Bottleneck(nn.Module): + """Standard bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): + """Initializes a standard bottleneck module with optional shortcut connection and configurable parameters.""" + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, k[0], 1) + self.cv2 = Conv(c_, c2, k[1], 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + """Applies the YOLO FPN to input data.""" + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes the CSP Bottleneck given arguments for ch_in, ch_out, number, shortcut, groups, expansion.""" + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.SiLU() + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + """Applies a CSP bottleneck with 3 convolutions.""" + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) + + +class ResNetBlock(nn.Module): + """ResNet block with standard convolution layers.""" + + def __init__(self, c1, c2, s=1, e=4): + """Initialize convolution with given parameters.""" + super().__init__() + c3 = e * c2 + self.cv1 = Conv(c1, c2, k=1, s=1, act=True) + self.cv2 = Conv(c2, c2, k=3, s=s, p=1, act=True) + self.cv3 = Conv(c2, c3, k=1, act=False) + self.shortcut = nn.Sequential(Conv(c1, c3, k=1, s=s, act=False)) if s != 1 or c1 != c3 else nn.Identity() + + def forward(self, x): + """Forward pass through the ResNet block.""" + return F.relu(self.cv3(self.cv2(self.cv1(x))) + self.shortcut(x)) + + +class ResNetLayer(nn.Module): + """ResNet layer with multiple ResNet blocks.""" + + def __init__(self, c1, c2, s=1, is_first=False, n=1, e=4): + """Initializes the ResNetLayer given arguments.""" + super().__init__() + self.is_first = is_first + + if self.is_first: + self.layer = nn.Sequential( + Conv(c1, c2, k=7, s=2, p=3, act=True), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + ) + else: + blocks = [ResNetBlock(c1, c2, s, e=e)] + blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)]) + self.layer = nn.Sequential(*blocks) + + def forward(self, x): + """Forward pass through the ResNet layer.""" + return self.layer(x) + + +class MaxSigmoidAttnBlock(nn.Module): + """Max Sigmoid attention block.""" + + def __init__(self, c1, c2, nh=1, ec=128, gc=512, scale=False): + """Initializes MaxSigmoidAttnBlock with specified arguments.""" + super().__init__() + self.nh = nh + self.hc = c2 // nh + self.ec = Conv(c1, ec, k=1, act=False) if c1 != ec else None + self.gl = nn.Linear(gc, ec) + self.bias = nn.Parameter(torch.zeros(nh)) + self.proj_conv = Conv(c1, c2, k=3, s=1, act=False) + self.scale = nn.Parameter(torch.ones(1, nh, 1, 1)) if scale else 1.0 + + def forward(self, x, guide): + """Forward process.""" + bs, _, h, w = x.shape + + guide = self.gl(guide) + guide = guide.view(bs, -1, self.nh, self.hc) + embed = self.ec(x) if self.ec is not None else x + embed = embed.view(bs, self.nh, self.hc, h, w) + + aw = torch.einsum("bmchw,bnmc->bmhwn", embed, guide) + aw = aw.max(dim=-1)[0] + aw = aw / (self.hc**0.5) + aw = aw + self.bias[None, :, None, None] + aw = aw.sigmoid() * self.scale + + x = self.proj_conv(x) + x = x.view(bs, self.nh, -1, h, w) + x = x * aw.unsqueeze(2) + return x.view(bs, -1, h, w) + + +class C2fAttn(nn.Module): + """C2f module with an additional attn module.""" + + def __init__(self, c1, c2, n=1, ec=128, nh=1, gc=512, shortcut=False, g=1, e=0.5): + """Initializes C2f module with attention mechanism for enhanced feature extraction and processing.""" + super().__init__() + self.c = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv((3 + n) * self.c, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + self.attn = MaxSigmoidAttnBlock(self.c, self.c, gc=gc, ec=ec, nh=nh) + + def forward(self, x, guide): + """Forward pass through C2f layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend(m(y[-1]) for m in self.m) + y.append(self.attn(y[-1], guide)) + return self.cv2(torch.cat(y, 1)) + + def forward_split(self, x, guide): + """Forward pass using split() instead of chunk().""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in self.m) + y.append(self.attn(y[-1], guide)) + return self.cv2(torch.cat(y, 1)) + + +class ImagePoolingAttn(nn.Module): + """ImagePoolingAttn: Enhance the text embeddings with image-aware information.""" + + def __init__(self, ec=256, ch=(), ct=512, nh=8, k=3, scale=False): + """Initializes ImagePoolingAttn with specified arguments.""" + super().__init__() + + nf = len(ch) + self.query = nn.Sequential(nn.LayerNorm(ct), nn.Linear(ct, ec)) + self.key = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec)) + self.value = nn.Sequential(nn.LayerNorm(ec), nn.Linear(ec, ec)) + self.proj = nn.Linear(ec, ct) + self.scale = nn.Parameter(torch.tensor([0.0]), requires_grad=True) if scale else 1.0 + self.projections = nn.ModuleList([nn.Conv2d(in_channels, ec, kernel_size=1) for in_channels in ch]) + self.im_pools = nn.ModuleList([nn.AdaptiveMaxPool2d((k, k)) for _ in range(nf)]) + self.ec = ec + self.nh = nh + self.nf = nf + self.hc = ec // nh + self.k = k + + def forward(self, x, text): + """Executes attention mechanism on input tensor x and guide tensor.""" + bs = x[0].shape[0] + assert len(x) == self.nf + num_patches = self.k**2 + x = [pool(proj(x)).view(bs, -1, num_patches) for (x, proj, pool) in zip(x, self.projections, self.im_pools)] + x = torch.cat(x, dim=-1).transpose(1, 2) + q = self.query(text) + k = self.key(x) + v = self.value(x) + + # q = q.reshape(1, text.shape[1], self.nh, self.hc).repeat(bs, 1, 1, 1) + q = q.reshape(bs, -1, self.nh, self.hc) + k = k.reshape(bs, -1, self.nh, self.hc) + v = v.reshape(bs, -1, self.nh, self.hc) + + aw = torch.einsum("bnmc,bkmc->bmnk", q, k) + aw = aw / (self.hc**0.5) + aw = F.softmax(aw, dim=-1) + + x = torch.einsum("bmnk,bkmc->bnmc", aw, v) + x = self.proj(x.reshape(bs, -1, self.ec)) + return x * self.scale + text + + +class ContrastiveHead(nn.Module): + """Implements contrastive learning head for region-text similarity in vision-language models.""" + + def __init__(self): + """Initializes ContrastiveHead with specified region-text similarity parameters.""" + super().__init__() + # NOTE: use -10.0 to keep the init cls loss consistency with other losses + self.bias = nn.Parameter(torch.tensor([-10.0])) + self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log()) + + def forward(self, x, w): + """Forward function of contrastive learning.""" + x = F.normalize(x, dim=1, p=2) + w = F.normalize(w, dim=-1, p=2) + x = torch.einsum("bchw,bkc->bkhw", x, w) + return x * self.logit_scale.exp() + self.bias + + +class BNContrastiveHead(nn.Module): + """ + Batch Norm Contrastive Head for YOLO-World using batch norm instead of l2-normalization. + + Args: + embed_dims (int): Embed dimensions of text and image features. + """ + + def __init__(self, embed_dims: int): + """Initialize ContrastiveHead with region-text similarity parameters.""" + super().__init__() + self.norm = nn.BatchNorm2d(embed_dims) + # NOTE: use -10.0 to keep the init cls loss consistency with other losses + self.bias = nn.Parameter(torch.tensor([-10.0])) + # use -1.0 is more stable + self.logit_scale = nn.Parameter(-1.0 * torch.ones([])) + + def forward(self, x, w): + """Forward function of contrastive learning.""" + x = self.norm(x) + w = F.normalize(w, dim=-1, p=2) + x = torch.einsum("bchw,bkc->bkhw", x, w) + return x * self.logit_scale.exp() + self.bias + + +class RepBottleneck(Bottleneck): + """Rep bottleneck.""" + + def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): + """Initializes a RepBottleneck module with customizable in/out channels, shortcuts, groups and expansion.""" + super().__init__(c1, c2, shortcut, g, k, e) + c_ = int(c2 * e) # hidden channels + self.cv1 = RepConv(c1, c_, k[0], 1) + + +class RepCSP(C3): + """Repeatable Cross Stage Partial Network (RepCSP) module for efficient feature extraction.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): + """Initializes RepCSP layer with given channels, repetitions, shortcut, groups and expansion ratio.""" + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + +class RepNCSPELAN4(nn.Module): + """CSP-ELAN.""" + + def __init__(self, c1, c2, c3, c4, n=1): + """Initializes CSP-ELAN layer with specified channel sizes, repetitions, and convolutions.""" + super().__init__() + self.c = c3 // 2 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.Sequential(RepCSP(c3 // 2, c4, n), Conv(c4, c4, 3, 1)) + self.cv3 = nn.Sequential(RepCSP(c4, c4, n), Conv(c4, c4, 3, 1)) + self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1) + + def forward(self, x): + """Forward pass through RepNCSPELAN4 layer.""" + y = list(self.cv1(x).chunk(2, 1)) + y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + def forward_split(self, x): + """Forward pass using split() instead of chunk().""" + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + +class ELAN1(RepNCSPELAN4): + """ELAN1 module with 4 convolutions.""" + + def __init__(self, c1, c2, c3, c4): + """Initializes ELAN1 layer with specified channel sizes.""" + super().__init__(c1, c2, c3, c4) + self.c = c3 // 2 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = Conv(c3 // 2, c4, 3, 1) + self.cv3 = Conv(c4, c4, 3, 1) + self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1) + + +class AConv(nn.Module): + """AConv.""" + + def __init__(self, c1, c2): + """Initializes AConv module with convolution layers.""" + super().__init__() + self.cv1 = Conv(c1, c2, 3, 2, 1) + + def forward(self, x): + """Forward pass through AConv layer.""" + x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True) + return self.cv1(x) + + +class ADown(nn.Module): + """ADown.""" + + def __init__(self, c1, c2): + """Initializes ADown module with convolution layers to downsample input from channels c1 to c2.""" + super().__init__() + self.c = c2 // 2 + self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1) + self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0) + + def forward(self, x): + """Forward pass through ADown layer.""" + x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True) + x1, x2 = x.chunk(2, 1) + x1 = self.cv1(x1) + x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1) + x2 = self.cv2(x2) + return torch.cat((x1, x2), 1) + + +class SPPELAN(nn.Module): + """SPP-ELAN.""" + + def __init__(self, c1, c2, c3, k=5): + """Initializes SPP-ELAN block with convolution and max pooling layers for spatial pyramid pooling.""" + super().__init__() + self.c = c3 + self.cv1 = Conv(c1, c3, 1, 1) + self.cv2 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv3 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv4 = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) + self.cv5 = Conv(4 * c3, c2, 1, 1) + + def forward(self, x): + """Forward pass through SPPELAN layer.""" + y = [self.cv1(x)] + y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4]) + return self.cv5(torch.cat(y, 1)) + + +class CBLinear(nn.Module): + """CBLinear.""" + + def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): + """Initializes the CBLinear module, passing inputs unchanged.""" + super().__init__() + self.c2s = c2s + self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True) + + def forward(self, x): + """Forward pass through CBLinear layer.""" + return self.conv(x).split(self.c2s, dim=1) + + +class CBFuse(nn.Module): + """CBFuse.""" + + def __init__(self, idx): + """Initializes CBFuse module with layer index for selective feature fusion.""" + super().__init__() + self.idx = idx + + def forward(self, xs): + """Forward pass through CBFuse layer.""" + target_size = xs[-1].shape[2:] + res = [F.interpolate(x[self.idx[i]], size=target_size, mode="nearest") for i, x in enumerate(xs[:-1])] + return torch.sum(torch.stack(res + xs[-1:]), dim=0) + + +class C3f(nn.Module): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): + """Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups, + expansion. + """ + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2) + self.m = nn.ModuleList(Bottleneck(c_, c_, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) + + def forward(self, x): + """Forward pass through C2f layer.""" + y = [self.cv2(x), self.cv1(x)] + y.extend(m(y[-1]) for m in self.m) + return self.cv3(torch.cat(y, 1)) + + +class C3k2(C2f): + """Faster Implementation of CSP Bottleneck with 2 convolutions.""" + + def __init__(self, c1, c2, n=1, c3k=False, e=0.5, g=1, shortcut=True): + """Initializes the C3k2 module, a faster CSP Bottleneck with 2 convolutions and optional C3k blocks.""" + super().__init__(c1, c2, n, shortcut, g, e) + self.m = nn.ModuleList( + C3k(self.c, self.c, 2, shortcut, g) if c3k else Bottleneck(self.c, self.c, shortcut, g) for _ in range(n) + ) + + +class C3k(C3): + """C3k is a CSP bottleneck module with customizable kernel sizes for feature extraction in neural networks.""" + + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3): + """Initializes the C3k module with specified channels, number of layers, and configurations.""" + super().__init__(c1, c2, n, shortcut, g, e) + c_ = int(c2 * e) # hidden channels + # self.m = nn.Sequential(*(RepBottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n))) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n))) + + +class RepVGGDW(torch.nn.Module): + """RepVGGDW is a class that represents a depth wise separable convolutional block in RepVGG architecture.""" + + def __init__(self, ed) -> None: + """Initializes RepVGGDW with depthwise separable convolutional layers for efficient processing.""" + super().__init__() + self.conv = Conv(ed, ed, 7, 1, 3, g=ed, act=False) + self.conv1 = Conv(ed, ed, 3, 1, 1, g=ed, act=False) + self.dim = ed + self.act = nn.SiLU() + + def forward(self, x): + """ + Performs a forward pass of the RepVGGDW block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after applying the depth wise separable convolution. + """ + return self.act(self.conv(x) + self.conv1(x)) + + def forward_fuse(self, x): + """ + Performs a forward pass of the RepVGGDW block without fusing the convolutions. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor after applying the depth wise separable convolution. + """ + return self.act(self.conv(x)) + + @torch.no_grad() + def fuse(self): + """ + Fuses the convolutional layers in the RepVGGDW block. + + This method fuses the convolutional layers and updates the weights and biases accordingly. + """ + conv = fuse_conv_and_bn(self.conv.conv, self.conv.bn) + conv1 = fuse_conv_and_bn(self.conv1.conv, self.conv1.bn) + + conv_w = conv.weight + conv_b = conv.bias + conv1_w = conv1.weight + conv1_b = conv1.bias + + conv1_w = torch.nn.functional.pad(conv1_w, [2, 2, 2, 2]) + + final_conv_w = conv_w + conv1_w + final_conv_b = conv_b + conv1_b + + conv.weight.data.copy_(final_conv_w) + conv.bias.data.copy_(final_conv_b) + + self.conv = conv + del self.conv1 + + +class CIB(nn.Module): + """ + Conditional Identity Block (CIB) module. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + shortcut (bool, optional): Whether to add a shortcut connection. Defaults to True. + e (float, optional): Scaling factor for the hidden channels. Defaults to 0.5. + lk (bool, optional): Whether to use RepVGGDW for the third convolutional layer. Defaults to False. + """ + + def __init__(self, c1, c2, shortcut=True, e=0.5, lk=False): + """Initializes the custom model with optional shortcut, scaling factor, and RepVGGDW layer.""" + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = nn.Sequential( + Conv(c1, c1, 3, g=c1), + Conv(c1, 2 * c_, 1), + RepVGGDW(2 * c_) if lk else Conv(2 * c_, 2 * c_, 3, g=2 * c_), + Conv(2 * c_, c2, 1), + Conv(c2, c2, 3, g=c2), + ) + + self.add = shortcut and c1 == c2 + + def forward(self, x): + """ + Forward pass of the CIB module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + (torch.Tensor): Output tensor. + """ + return x + self.cv1(x) if self.add else self.cv1(x) + + +class C2fCIB(C2f): + """ + C2fCIB class represents a convolutional block with C2f and CIB modules. + + Args: + c1 (int): Number of input channels. + c2 (int): Number of output channels. + n (int, optional): Number of CIB modules to stack. Defaults to 1. + shortcut (bool, optional): Whether to use shortcut connection. Defaults to False. + lk (bool, optional): Whether to use local key connection. Defaults to False. + g (int, optional): Number of groups for grouped convolution. Defaults to 1. + e (float, optional): Expansion ratio for CIB modules. Defaults to 0.5. + """ + + def __init__(self, c1, c2, n=1, shortcut=False, lk=False, g=1, e=0.5): + """Initializes the module with specified parameters for channel, shortcut, local key, groups, and expansion.""" + super().__init__(c1, c2, n, shortcut, g, e) + self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n)) + + +class Attention(nn.Module): + """ + Attention module that performs self-attention on the input tensor. + + Args: + dim (int): The input tensor dimension. + num_heads (int): The number of attention heads. + attn_ratio (float): The ratio of the attention key dimension to the head dimension. + + Attributes: + num_heads (int): The number of attention heads. + head_dim (int): The dimension of each attention head. + key_dim (int): The dimension of the attention key. + scale (float): The scaling factor for the attention scores. + qkv (Conv): Convolutional layer for computing the query, key, and value. + proj (Conv): Convolutional layer for projecting the attended values. + pe (Conv): Convolutional layer for positional encoding. + """ + + def __init__(self, dim, num_heads=8, attn_ratio=0.5): + """Initializes multi-head attention module with query, key, and value convolutions and positional encoding.""" + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.key_dim = int(self.head_dim * attn_ratio) + self.scale = self.key_dim**-0.5 + nh_kd = self.key_dim * num_heads + h = dim + nh_kd * 2 + self.qkv = Conv(dim, h, 1, act=False) + self.proj = Conv(dim, dim, 1, act=False) + self.pe = Conv(dim, dim, 3, 1, g=dim, act=False) + + def forward(self, x): + """ + Forward pass of the Attention module. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + (torch.Tensor): The output tensor after self-attention. + """ + B, C, H, W = x.shape + N = H * W + qkv = self.qkv(x) + q, k, v = qkv.view(B, self.num_heads, self.key_dim * 2 + self.head_dim, N).split( + [self.key_dim, self.key_dim, self.head_dim], dim=2 + ) + + attn = (q.transpose(-2, -1) @ k) * self.scale + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W)) + x = self.proj(x) + return x + + +class PSABlock(nn.Module): + """ + PSABlock class implementing a Position-Sensitive Attention block for neural networks. + + This class encapsulates the functionality for applying multi-head attention and feed-forward neural network layers + with optional shortcut connections. + + Attributes: + attn (Attention): Multi-head attention module. + ffn (nn.Sequential): Feed-forward neural network module. + add (bool): Flag indicating whether to add shortcut connections. + + Methods: + forward: Performs a forward pass through the PSABlock, applying attention and feed-forward layers. + + Examples: + Create a PSABlock and perform a forward pass + >>> psablock = PSABlock(c=128, attn_ratio=0.5, num_heads=4, shortcut=True) + >>> input_tensor = torch.randn(1, 128, 32, 32) + >>> output_tensor = psablock(input_tensor) + """ + + def __init__(self, c, attn_ratio=0.5, num_heads=4, shortcut=True) -> None: + """Initializes the PSABlock with attention and feed-forward layers for enhanced feature extraction.""" + super().__init__() + + self.attn = Attention(c, attn_ratio=attn_ratio, num_heads=num_heads) + self.ffn = nn.Sequential(Conv(c, c * 2, 1), Conv(c * 2, c, 1, act=False)) + self.add = shortcut + + def forward(self, x): + """Executes a forward pass through PSABlock, applying attention and feed-forward layers to the input tensor.""" + x = x + self.attn(x) if self.add else self.attn(x) + x = x + self.ffn(x) if self.add else self.ffn(x) + return x + + +class PSA(nn.Module): + """ + PSA class for implementing Position-Sensitive Attention in neural networks. + + This class encapsulates the functionality for applying position-sensitive attention and feed-forward networks to + input tensors, enhancing feature extraction and processing capabilities. + + Attributes: + c (int): Number of hidden channels after applying the initial convolution. + cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c. + cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c. + attn (Attention): Attention module for position-sensitive attention. + ffn (nn.Sequential): Feed-forward network for further processing. + + Methods: + forward: Applies position-sensitive attention and feed-forward network to the input tensor. + + Examples: + Create a PSA module and apply it to an input tensor + >>> psa = PSA(c1=128, c2=128, e=0.5) + >>> input_tensor = torch.randn(1, 128, 64, 64) + >>> output_tensor = psa.forward(input_tensor) + """ + + def __init__(self, c1, c2, e=0.5): + """Initializes the PSA module with input/output channels and attention mechanism for feature extraction.""" + super().__init__() + assert c1 == c2 + self.c = int(c1 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c1, 1) + + self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64) + self.ffn = nn.Sequential(Conv(self.c, self.c * 2, 1), Conv(self.c * 2, self.c, 1, act=False)) + + def forward(self, x): + """Executes forward pass in PSA module, applying attention and feed-forward layers to the input tensor.""" + a, b = self.cv1(x).split((self.c, self.c), dim=1) + b = b + self.attn(b) + b = b + self.ffn(b) + return self.cv2(torch.cat((a, b), 1)) + + +class C2PSA(nn.Module): + """ + C2PSA module with attention mechanism for enhanced feature extraction and processing. + + This module implements a convolutional block with attention mechanisms to enhance feature extraction and processing + capabilities. It includes a series of PSABlock modules for self-attention and feed-forward operations. + + Attributes: + c (int): Number of hidden channels. + cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c. + cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c. + m (nn.Sequential): Sequential container of PSABlock modules for attention and feed-forward operations. + + Methods: + forward: Performs a forward pass through the C2PSA module, applying attention and feed-forward operations. + + Notes: + This module essentially is the same as PSA module, but refactored to allow stacking more PSABlock modules. + + Examples: + >>> c2psa = C2PSA(c1=256, c2=256, n=3, e=0.5) + >>> input_tensor = torch.randn(1, 256, 64, 64) + >>> output_tensor = c2psa(input_tensor) + """ + + def __init__(self, c1, c2, n=1, e=0.5): + """Initializes the C2PSA module with specified input/output channels, number of layers, and expansion ratio.""" + super().__init__() + assert c1 == c2 + self.c = int(c1 * e) + self.cv1 = Conv(c1, 2 * self.c, 1, 1) + self.cv2 = Conv(2 * self.c, c1, 1) + + self.m = nn.Sequential(*(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n))) + + def forward(self, x): + """Processes the input tensor 'x' through a series of PSA blocks and returns the transformed tensor.""" + a, b = self.cv1(x).split((self.c, self.c), dim=1) + b = self.m(b) + return self.cv2(torch.cat((a, b), 1)) + + +class C2fPSA(C2f): + """ + C2fPSA module with enhanced feature extraction using PSA blocks. + + This class extends the C2f module by incorporating PSA blocks for improved attention mechanisms and feature extraction. + + Attributes: + c (int): Number of hidden channels. + cv1 (Conv): 1x1 convolution layer to reduce the number of input channels to 2*c. + cv2 (Conv): 1x1 convolution layer to reduce the number of output channels to c. + m (nn.ModuleList): List of PSA blocks for feature extraction. + + Methods: + forward: Performs a forward pass through the C2fPSA module. + forward_split: Performs a forward pass using split() instead of chunk(). + + Examples: + >>> import torch + >>> from ultralytics.models.common import C2fPSA + >>> model = C2fPSA(c1=64, c2=64, n=3, e=0.5) + >>> x = torch.randn(1, 64, 128, 128) + >>> output = model(x) + >>> print(output.shape) + """ + + def __init__(self, c1, c2, n=1, e=0.5): + """Initializes the C2fPSA module, a variant of C2f with PSA blocks for enhanced feature extraction.""" + assert c1 == c2 + super().__init__(c1, c2, n=n, e=e) + self.m = nn.ModuleList(PSABlock(self.c, attn_ratio=0.5, num_heads=self.c // 64) for _ in range(n)) + + +class SCDown(nn.Module): + """ + SCDown module for downsampling with separable convolutions. + + This module performs downsampling using a combination of pointwise and depthwise convolutions, which helps in + efficiently reducing the spatial dimensions of the input tensor while maintaining the channel information. + + Attributes: + cv1 (Conv): Pointwise convolution layer that reduces the number of channels. + cv2 (Conv): Depthwise convolution layer that performs spatial downsampling. + + Methods: + forward: Applies the SCDown module to the input tensor. + + Examples: + >>> import torch + >>> from ultralytics import SCDown + >>> model = SCDown(c1=64, c2=128, k=3, s=2) + >>> x = torch.randn(1, 64, 128, 128) + >>> y = model(x) + >>> print(y.shape) + torch.Size([1, 128, 64, 64]) + """ + + def __init__(self, c1, c2, k, s): + """Initializes the SCDown module with specified input/output channels, kernel size, and stride.""" + super().__init__() + self.cv1 = Conv(c1, c2, 1, 1) + self.cv2 = Conv(c2, c2, k=k, s=s, g=c2, act=False) + + def forward(self, x): + """Applies convolution and downsampling to the input tensor in the SCDown module.""" + return self.cv2(self.cv1(x)) + + +class TorchVision(nn.Module): + """ + TorchVision module to allow loading any torchvision model. + + This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers. + + Attributes: + m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped. + + Args: + c1 (int): Input channels. + c2 (): Output channels. + model (str): Name of the torchvision model to load. + weights (str, optional): Pre-trained weights to load. Default is "DEFAULT". + unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True. + truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2. + split (bool, optional): Returns output from intermediate child modules as list. Default is False. + """ + + def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False): + """Load the model and weights from torchvision.""" + import torchvision # scope for faster 'import ultralytics' + + super().__init__() + if hasattr(torchvision.models, "get_model"): + self.m = torchvision.models.get_model(model, weights=weights) + else: + self.m = torchvision.models.__dict__[model](pretrained=bool(weights)) + if unwrap: + layers = list(self.m.children())[:-truncate] + if isinstance(layers[0], nn.Sequential): # Second-level for some models like EfficientNet, Swin + layers = [*list(layers[0].children()), *layers[1:]] + self.m = nn.Sequential(*layers) + self.split = split + else: + self.split = False + self.m.head = self.m.heads = nn.Identity() + + def forward(self, x): + """Forward pass through the model.""" + if self.split: + y = [x] + y.extend(m(y[-1]) for m in self.m) + else: + y = self.m(x) + return y + +try: + from flash_attn.flash_attn_interface import flash_attn_func +except Exception: + # assert False, "import FlashAttention error! Please install FlashAttention first." + pass +from timm.models.layers import trunc_normal_ + +class AAttn(nn.Module): + """ + Area-attention module with the requirement of flash attention. + + Attributes: + dim (int): Number of hidden channels; + num_heads (int): Number of heads into which the attention mechanism is divided; + area (int, optional): Number of areas the feature map is divided. Defaults to 1. + + Methods: + forward: Performs a forward process of input tensor and outputs a tensor after the execution of the area attention mechanism. + + Examples: + >>> import torch + >>> from ultralytics.nn.modules import AAttn + >>> model = AAttn(dim=64, num_heads=2, area=4) + >>> x = torch.randn(2, 64, 128, 128) + >>> output = model(x) + >>> print(output.shape) + + Notes: + recommend that dim//num_heads be a multiple of 32 or 64. + + """ + + def __init__(self, dim, num_heads, area=1): + """Initializes the area-attention module, a simple yet efficient attention module for YOLO.""" + super().__init__() + self.area = area + + self.num_heads = num_heads + self.head_dim = head_dim = dim // num_heads + all_head_dim = head_dim * self.num_heads + + self.qkv = Conv(dim, all_head_dim * 3, 1, act=False) + self.proj = Conv(all_head_dim, dim, 1, act=False) + self.pe = Conv(all_head_dim, dim, 7, 1, 3, g=dim, act=False) + + + def forward(self, x): + """Processes the input tensor 'x' through the area-attention""" + B, C, H, W = x.shape + N = H * W + + qkv = self.qkv(x).flatten(2).transpose(1, 2) + if self.area > 1: + qkv = qkv.reshape(B * self.area, N // self.area, C * 3) + B, N, _ = qkv.shape + q, k, v = qkv.view(B, N, self.num_heads, self.head_dim * 3).split( + [self.head_dim, self.head_dim, self.head_dim], dim=3 + ) + + if x.is_cuda: + x = flash_attn_func( + q.contiguous().half(), + k.contiguous().half(), + v.contiguous().half() + ).to(q.dtype) + else: + q = q.permute(0, 2, 3, 1) + k = k.permute(0, 2, 3, 1) + v = v.permute(0, 2, 3, 1) + attn = (q.transpose(-2, -1) @ k) * (self.head_dim ** -0.5) + max_attn = attn.max(dim=-1, keepdim=True).values + exp_attn = torch.exp(attn - max_attn) + attn = exp_attn / exp_attn.sum(dim=-1, keepdim=True) + x = (v @ attn.transpose(-2, -1)) + x = x.permute(0, 3, 1, 2) + v = v.permute(0, 3, 1, 2) + + if self.area > 1: + x = x.reshape(B // self.area, N * self.area, C) + v = v.reshape(B // self.area, N * self.area, C) + B, N, _ = x.shape + + x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) + v = v.reshape(B, H, W, C).permute(0, 3, 1, 2) + + x = x + self.pe(v) + x = self.proj(x) + return x + + +class ABlock(nn.Module): + """ + ABlock class implementing a Area-Attention block with effective feature extraction. + + This class encapsulates the functionality for applying multi-head attention with feature map are dividing into areas + and feed-forward neural network layers. + + Attributes: + dim (int): Number of hidden channels; + num_heads (int): Number of heads into which the attention mechanism is divided; + mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2; + area (int, optional): Number of areas the feature map is divided. Defaults to 1. + + Methods: + forward: Performs a forward pass through the ABlock, applying area-attention and feed-forward layers. + + Examples: + Create a ABlock and perform a forward pass + >>> model = ABlock(dim=64, num_heads=2, mlp_ratio=1.2, area=4) + >>> x = torch.randn(2, 64, 128, 128) + >>> output = model(x) + >>> print(output.shape) + + Notes: + recommend that dim//num_heads be a multiple of 32 or 64. + """ + + def __init__(self, dim, num_heads, mlp_ratio=1.2, area=1): + """Initializes the ABlock with area-attention and feed-forward layers for faster feature extraction.""" + super().__init__() + + self.attn = AAttn(dim, num_heads=num_heads, area=area) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = nn.Sequential(Conv(dim, mlp_hidden_dim, 1), Conv(mlp_hidden_dim, dim, 1, act=False)) + + self.apply(self._init_weights) + + def _init_weights(self, m): + """Initialize weights using a truncated normal distribution.""" + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Conv2d) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + """Executes a forward pass through ABlock, applying area-attention and feed-forward layers to the input tensor.""" + x = x + self.attn(x) + x = x + self.mlp(x) + return x + + +class A2C2f(nn.Module): + """ + A2C2f module with residual enhanced feature extraction using ABlock blocks with area-attention. Also known as R-ELAN + + This class extends the C2f module by incorporating ABlock blocks for fast attention mechanisms and feature extraction. + + Attributes: + c1 (int): Number of input channels; + c2 (int): Number of output channels; + n (int, optional): Number of 2xABlock modules to stack. Defaults to 1; + a2 (bool, optional): Whether use area-attention. Defaults to True; + area (int, optional): Number of areas the feature map is divided. Defaults to 1; + residual (bool, optional): Whether use the residual (with layer scale). Defaults to False; + mlp_ratio (float, optional): MLP expansion ratio (or MLP hidden dimension ratio). Defaults to 1.2; + e (float, optional): Expansion ratio for R-ELAN modules. Defaults to 0.5. + g (int, optional): Number of groups for grouped convolution. Defaults to 1; + shortcut (bool, optional): Whether to use shortcut connection. Defaults to True; + + Methods: + forward: Performs a forward pass through the A2C2f module. + + Examples: + >>> import torch + >>> from ultralytics.nn.modules import A2C2f + >>> model = A2C2f(c1=64, c2=64, n=2, a2=True, area=4, residual=True, e=0.5) + >>> x = torch.randn(2, 64, 128, 128) + >>> output = model(x) + >>> print(output.shape) + """ + + def __init__(self, c1, c2, n=1, a2=True, area=1, residual=False, mlp_ratio=2.0, e=0.5, g=1, shortcut=True): + super().__init__() + c_ = int(c2 * e) # hidden channels + assert c_ % 32 == 0, "Dimension of ABlock be a multiple of 32." + + # num_heads = c_ // 64 if c_ // 64 >= 2 else c_ // 32 + num_heads = c_ // 32 + + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv((1 + n) * c_, c2, 1) # optional act=FReLU(c2) + + init_values = 0.01 # or smaller + self.gamma = nn.Parameter(init_values * torch.ones((c2)), requires_grad=True) if a2 and residual else None + + self.m = nn.ModuleList( + nn.Sequential(*(ABlock(c_, num_heads, mlp_ratio, area) for _ in range(2))) if a2 else C3k(c_, c_, 2, shortcut, g) for _ in range(n) + ) + + def forward(self, x): + """Forward pass through R-ELAN layer.""" + y = [self.cv1(x)] + y.extend(m(y[-1]) for m in self.m) + if self.gamma is not None: + return x + (self.gamma * self.cv2(torch.cat(y, 1)).permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return self.cv2(torch.cat(y, 1)) diff --git a/ultralytics/nn/modules/conv.py b/ultralytics/nn/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..cc62c2950d9969f95ce2b82e19ee15cbfccf302c --- /dev/null +++ b/ultralytics/nn/modules/conv.py @@ -0,0 +1,350 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Convolution modules.""" + +import math + +import numpy as np +import torch +import torch.nn as nn + +__all__ = ( + "Conv", + "Conv2", + "LightConv", + "DWConv", + "DWConvTranspose2d", + "ConvTranspose", + "Focus", + "GhostConv", + "ChannelAttention", + "SpatialAttention", + "CBAM", + "Concat", + "RepConv", + "Index", +) + + +def autopad(k, p=None, d=1): # kernel, padding, dilation + """Pad to 'same' shape outputs.""" + if d > 1: + k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +class Conv(nn.Module): + """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation).""" + + default_act = nn.SiLU() # default activation + + def __init__(self, c1, c2, k=1, s=1, bias=False, p=None, g=1, d=1, act=True): + """Initialize Conv layer with given arguments including activation.""" + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=bias) + self.bn = nn.BatchNorm2d(c2) + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + def forward(self, x): + """Apply convolution, batch normalization and activation to input tensor.""" + return self.act(self.bn(self.conv(x))) + + def forward_fuse(self, x): + """Apply convolution and activation without batch normalization.""" + return self.act(self.conv(x)) + + +class Conv2(Conv): + """Simplified RepConv module with Conv fusing.""" + + def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True): + """Initialize Conv layer with given arguments including activation.""" + super().__init__(c1, c2, k, s, p, g=g, d=d, act=act) + self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv + + def forward(self, x): + """Apply convolution, batch normalization and activation to input tensor.""" + return self.act(self.bn(self.conv(x) + self.cv2(x))) + + def forward_fuse(self, x): + """Apply fused convolution, batch normalization and activation to input tensor.""" + return self.act(self.bn(self.conv(x))) + + def fuse_convs(self): + """Fuse parallel convolutions.""" + w = torch.zeros_like(self.conv.weight.data) + i = [x // 2 for x in w.shape[2:]] + w[:, :, i[0] : i[0] + 1, i[1] : i[1] + 1] = self.cv2.weight.data.clone() + self.conv.weight.data += w + self.__delattr__("cv2") + self.forward = self.forward_fuse + + +class LightConv(nn.Module): + """ + Light convolution with args(ch_in, ch_out, kernel). + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py + """ + + def __init__(self, c1, c2, k=1, act=nn.ReLU()): + """Initialize Conv layer with given arguments including activation.""" + super().__init__() + self.conv1 = Conv(c1, c2, 1, act=False) + self.conv2 = DWConv(c2, c2, k, act=act) + + def forward(self, x): + """Apply 2 convolutions to input tensor.""" + return self.conv2(self.conv1(x)) + + +class DWConv(Conv): + """Depth-wise convolution.""" + + def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation + """Initialize Depth-wise convolution with given parameters.""" + super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act) + + +class DWConvTranspose2d(nn.ConvTranspose2d): + """Depth-wise transpose convolution.""" + + def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out + """Initialize DWConvTranspose2d class with given parameters.""" + super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) + + +class ConvTranspose(nn.Module): + """Convolution transpose 2d layer.""" + + default_act = nn.SiLU() # default activation + + def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): + """Initialize ConvTranspose2d layer with batch normalization and activation function.""" + super().__init__() + self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn) + self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity() + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + def forward(self, x): + """Applies transposed convolutions, batch normalization and activation to input.""" + return self.act(self.bn(self.conv_transpose(x))) + + def forward_fuse(self, x): + """Applies activation and convolution transpose operation to input.""" + return self.act(self.conv_transpose(x)) + + +class Focus(nn.Module): + """Focus wh information into c-space.""" + + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): + """Initializes Focus object with user defined channel, convolution, padding, group and activation values.""" + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act) + # self.contract = Contract(gain=2) + + def forward(self, x): + """ + Applies convolution to concatenated tensor and returns the output. + + Input shape is (b,c,w,h) and output shape is (b,4c,w/2,h/2). + """ + return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1)) + # return self.conv(self.contract(x)) + + +class GhostConv(nn.Module): + """Ghost Convolution https://github.com/huawei-noah/ghostnet.""" + + def __init__(self, c1, c2, k=1, s=1, g=1, act=True): + """Initializes Ghost Convolution module with primary and cheap operations for efficient feature learning.""" + super().__init__() + c_ = c2 // 2 # hidden channels + self.cv1 = Conv(c1, c_, k, s, None, g, act=act) + self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act) + + def forward(self, x): + """Forward propagation through a Ghost Bottleneck layer with skip connection.""" + y = self.cv1(x) + return torch.cat((y, self.cv2(y)), 1) + + +class RepConv(nn.Module): + """ + RepConv is a basic rep-style block, including training and deploy status. + + This module is used in RT-DETR. + Based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py + """ + + default_act = nn.SiLU() # default activation + + def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False): + """Initializes Light Convolution layer with inputs, outputs & optional activation function.""" + super().__init__() + assert k == 3 and p == 1 + self.g = g + self.c1 = c1 + self.c2 = c2 + self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() + + self.bn = nn.BatchNorm2d(num_features=c1) if bn and c2 == c1 and s == 1 else None + self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False) + self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False) + + def forward_fuse(self, x): + """Forward process.""" + return self.act(self.conv(x)) + + def forward(self, x): + """Forward process.""" + id_out = 0 if self.bn is None else self.bn(x) + return self.act(self.conv1(x) + self.conv2(x) + id_out) + + def get_equivalent_kernel_bias(self): + """Returns equivalent kernel and bias by adding 3x3 kernel, 1x1 kernel and identity kernel with their biases.""" + kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1) + kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2) + kernelid, biasid = self._fuse_bn_tensor(self.bn) + return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid + + @staticmethod + def _pad_1x1_to_3x3_tensor(kernel1x1): + """Pads a 1x1 tensor to a 3x3 tensor.""" + if kernel1x1 is None: + return 0 + else: + return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1]) + + def _fuse_bn_tensor(self, branch): + """Generates appropriate kernels and biases for convolution by fusing branches of the neural network.""" + if branch is None: + return 0, 0 + if isinstance(branch, Conv): + kernel = branch.conv.weight + running_mean = branch.bn.running_mean + running_var = branch.bn.running_var + gamma = branch.bn.weight + beta = branch.bn.bias + eps = branch.bn.eps + elif isinstance(branch, nn.BatchNorm2d): + if not hasattr(self, "id_tensor"): + input_dim = self.c1 // self.g + kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32) + for i in range(self.c1): + kernel_value[i, i % input_dim, 1, 1] = 1 + self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device) + kernel = self.id_tensor + running_mean = branch.running_mean + running_var = branch.running_var + gamma = branch.weight + beta = branch.bias + eps = branch.eps + std = (running_var + eps).sqrt() + t = (gamma / std).reshape(-1, 1, 1, 1) + return kernel * t, beta - running_mean * gamma / std + + def fuse_convs(self): + """Combines two convolution layers into a single layer and removes unused attributes from the class.""" + if hasattr(self, "conv"): + return + kernel, bias = self.get_equivalent_kernel_bias() + self.conv = nn.Conv2d( + in_channels=self.conv1.conv.in_channels, + out_channels=self.conv1.conv.out_channels, + kernel_size=self.conv1.conv.kernel_size, + stride=self.conv1.conv.stride, + padding=self.conv1.conv.padding, + dilation=self.conv1.conv.dilation, + groups=self.conv1.conv.groups, + bias=True, + ).requires_grad_(False) + self.conv.weight.data = kernel + self.conv.bias.data = bias + for para in self.parameters(): + para.detach_() + self.__delattr__("conv1") + self.__delattr__("conv2") + if hasattr(self, "nm"): + self.__delattr__("nm") + if hasattr(self, "bn"): + self.__delattr__("bn") + if hasattr(self, "id_tensor"): + self.__delattr__("id_tensor") + + +class ChannelAttention(nn.Module): + """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet.""" + + def __init__(self, channels: int) -> None: + """Initializes the class and sets the basic configurations and instance variables required.""" + super().__init__() + self.pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True) + self.act = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Applies forward pass using activation on convolutions of the input, optionally using batch normalization.""" + return x * self.act(self.fc(self.pool(x))) + + +class SpatialAttention(nn.Module): + """Spatial-attention module.""" + + def __init__(self, kernel_size=7): + """Initialize Spatial-attention module with kernel size argument.""" + super().__init__() + assert kernel_size in {3, 7}, "kernel size must be 3 or 7" + padding = 3 if kernel_size == 7 else 1 + self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) + self.act = nn.Sigmoid() + + def forward(self, x): + """Apply channel and spatial attention on input for feature recalibration.""" + return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1))) + + +class CBAM(nn.Module): + """Convolutional Block Attention Module.""" + + def __init__(self, c1, kernel_size=7): + """Initialize CBAM with given input channel (c1) and kernel size.""" + super().__init__() + self.channel_attention = ChannelAttention(c1) + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x): + """Applies the forward pass through C1 module.""" + return self.spatial_attention(self.channel_attention(x)) + + +class Concat(nn.Module): + """Concatenate a list of tensors along dimension.""" + + def __init__(self, dimension=1): + """Concatenates a list of tensors along a specified dimension.""" + super().__init__() + self.d = dimension + + def forward(self, x): + """Forward pass for the YOLOv8 mask Proto module.""" + return torch.cat(x, self.d) + + +class Index(nn.Module): + """Returns a particular index of the input.""" + + def __init__(self, c1, c2, index=0): + """Returns a particular index of the input.""" + super().__init__() + self.index = index + + def forward(self, x): + """ + Forward pass. + + Expects a list of tensors as input. + """ + return x[self.index] diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d0502466cfdb6944bc2711b66ba7bf72ac531a --- /dev/null +++ b/ultralytics/nn/modules/head.py @@ -0,0 +1,625 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Model head modules.""" + +import copy +import math + +import torch +import torch.nn as nn +from torch.nn.init import constant_, xavier_uniform_ + +from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors + +from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto +from .conv import Conv, DWConv +from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer +from .utils import bias_init_with_prob, linear_init + +__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect" + + +class Detect(nn.Module): + """YOLO Detect head for detection models.""" + + dynamic = False # force grid reconstruction + export = False # export mode + format = None # export format + end2end = False # end2end + max_det = 300 # max_det + shape = None + anchors = torch.empty(0) # init + strides = torch.empty(0) # init + legacy = False # backward compatibility for v3/v5/v8/v9 models + + def __init__(self, nc=80, ch=()): + """Initializes the YOLO detection layer with specified number of classes and channels.""" + super().__init__() + self.nc = nc # number of classes + self.nl = len(ch) # number of detection layers + self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + self.no = nc + self.reg_max * 4 # number of outputs per anchor + self.stride = torch.zeros(self.nl) # strides computed during build + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels + self.cv2 = nn.ModuleList( + nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch + ) + self.cv3 = ( + nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) + if self.legacy + else nn.ModuleList( + nn.Sequential( + nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), + nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), + nn.Conv2d(c3, self.nc, 1), + ) + for x in ch + ) + ) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() + + if self.end2end: + self.one2one_cv2 = copy.deepcopy(self.cv2) + self.one2one_cv3 = copy.deepcopy(self.cv3) + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + if self.end2end: + return self.forward_end2end(x) + + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + if self.training: # Training path + return x + y = self._inference(x) + return y if self.export else (y, x) + + def forward_end2end(self, x): + """ + Performs forward pass of the v10Detect module. + + Args: + x (tensor): Input tensor. + + Returns: + (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections. + If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately. + """ + x_detach = [xi.detach() for xi in x] + one2one = [ + torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl) + ] + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + if self.training: # Training path + return {"one2many": x, "one2one": one2one} + + y = self._inference(one2one) + y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc) + return y if self.export else (y, {"one2many": x, "one2one": one2one}) + + def _inference(self, x): + """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.""" + # Inference path + shape = x[0].shape # BCHW + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + if self.format != "imx" and (self.dynamic or self.shape != shape): + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] + else: + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + + if self.export and self.format in {"tflite", "edgetpu"}: + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + elif self.export and self.format == "imx": + dbox = self.decode_bboxes( + self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False + ) + return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1) + else: + dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides + + return torch.cat((dbox, cls.sigmoid()), 1) + + def bias_init(self): + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + if self.end2end: + for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + + def decode_bboxes(self, bboxes, anchors, xywh=True): + """Decode bounding boxes.""" + return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1) + + @staticmethod + def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): + """ + Post-processes YOLO model predictions. + + Args: + preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension + format [x, y, w, h, class_probs]. + max_det (int): Maximum detections per image. + nc (int, optional): Number of classes. Default: 80. + + Returns: + (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last + dimension format [x, y, w, h, max_class_prob, class_index]. + """ + batch_size, anchors, _ = preds.shape # i.e. shape(16,8400,84) + boxes, scores = preds.split([4, nc], dim=-1) + index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1) + boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4)) + scores = scores.gather(dim=1, index=index.repeat(1, 1, nc)) + scores, index = scores.flatten(1).topk(min(max_det, anchors)) + i = torch.arange(batch_size)[..., None] # batch indices + return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1) + + +class Segment(Detect): + """YOLO Segment head for segmentation models.""" + + def __init__(self, nc=80, nm=32, npr=256, ch=()): + """Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers.""" + super().__init__(nc, ch) + self.nm = nm # number of masks + self.npr = npr # number of protos + self.proto = Proto(ch[0], self.npr, self.nm) # protos + + c4 = max(ch[0] // 4, self.nm) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch) + + def forward(self, x): + """Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients.""" + p = self.proto(x[0]) # mask protos + bs = p.shape[0] # batch size + + mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients + x = Detect.forward(self, x) + if self.training: + return x, mc, p + return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p)) + + +class OBB(Detect): + """YOLO OBB detection head for detection with rotation models.""" + + def __init__(self, nc=80, ne=1, ch=()): + """Initialize OBB with number of classes `nc` and layer channels `ch`.""" + super().__init__(nc, ch) + self.ne = ne # number of extra parameters + + c4 = max(ch[0] // 4, self.ne) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.ne, 1)) for x in ch) + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + bs = x[0].shape[0] # batch size + angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits + # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it. + angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4] + # angle = angle.sigmoid() * math.pi / 2 # [0, pi/2] + if not self.training: + self.angle = angle + x = Detect.forward(self, x) + if self.training: + return x, angle + return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle)) + + def decode_bboxes(self, bboxes, anchors): + """Decode rotated bounding boxes.""" + return dist2rbox(bboxes, self.angle, anchors, dim=1) + + +class Pose(Detect): + """YOLO Pose head for keypoints models.""" + + def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): + """Initialize YOLO network with default parameters and Convolutional Layers.""" + super().__init__(nc, ch) + self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) + self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total + + c4 = max(ch[0] // 4, self.nk) + self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch) + + def forward(self, x): + """Perform forward pass through YOLO model and return predictions.""" + bs = x[0].shape[0] # batch size + kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) + x = Detect.forward(self, x) + if self.training: + return x, kpt + pred_kpt = self.kpts_decode(bs, kpt) + return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt)) + + def kpts_decode(self, bs, kpts): + """Decodes keypoints.""" + ndim = self.kpt_shape[1] + if self.export: + if self.format in { + "tflite", + "edgetpu", + }: # required for TFLite export to avoid 'PLACEHOLDER_FOR_GREATER_OP_CODES' bug + # Precompute normalization factor to increase numerical stability + y = kpts.view(bs, *self.kpt_shape, -1) + grid_h, grid_w = self.shape[2], self.shape[3] + grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1) + norm = self.strides / (self.stride[0] * grid_size) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm + else: + # NCNN fix + y = kpts.view(bs, *self.kpt_shape, -1) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides + if ndim == 3: + a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) + return a.view(bs, self.nk, -1) + else: + y = kpts.clone() + if ndim == 3: + y[:, 2::3] = y[:, 2::3].sigmoid() # sigmoid (WARNING: inplace .sigmoid_() Apple MPS bug) + y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides + y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides + return y + + +class Classify(nn.Module): + """YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2).""" + + export = False # export mode + + def __init__(self, c1, c2, k=1, s=1, p=None, g=1): + """Initializes YOLO classification head to transform input tensor from (b,c1,20,20) to (b,c2) shape.""" + super().__init__() + c_ = 1280 # efficientnet_b0 size + self.conv = Conv(c1, c_, k, s, p, g) + self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1) + self.drop = nn.Dropout(p=0.0, inplace=True) + self.linear = nn.Linear(c_, c2) # to x(b,c2) + + def forward(self, x): + """Performs a forward pass of the YOLO model on input image data.""" + if isinstance(x, list): + x = torch.cat(x, 1) + x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) + if self.training: + return x + y = x.softmax(1) # get final output + return y if self.export else (y, x) + + +class WorldDetect(Detect): + """Head for integrating YOLO detection models with semantic understanding from text embeddings.""" + + def __init__(self, nc=80, embed=512, with_bn=False, ch=()): + """Initialize YOLO detection layer with nc classes and layer channels ch.""" + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) + self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, embed, 1)) for x in ch) + self.cv4 = nn.ModuleList(BNContrastiveHead(embed) if with_bn else ContrastiveHead() for _ in ch) + + def forward(self, x, text): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv4[i](self.cv3[i](x[i]), text)), 1) + if self.training: + return x + + # Inference path + shape = x[0].shape # BCHW + x_cat = torch.cat([xi.view(shape[0], self.nc + self.reg_max * 4, -1) for xi in x], 2) + if self.dynamic or self.shape != shape: + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape + + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] + else: + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + + if self.export and self.format in {"tflite", "edgetpu"}: + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + else: + dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides + + y = torch.cat((dbox, cls.sigmoid()), 1) + return y if self.export else (y, x) + + def bias_init(self): + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + # b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + + +class RTDETRDecoder(nn.Module): + """ + Real-Time Deformable Transformer Decoder (RTDETRDecoder) module for object detection. + + This decoder module utilizes Transformer architecture along with deformable convolutions to predict bounding boxes + and class labels for objects in an image. It integrates features from multiple layers and runs through a series of + Transformer decoder layers to output the final predictions. + """ + + export = False # export mode + + def __init__( + self, + nc=80, + ch=(512, 1024, 2048), + hd=256, # hidden dim + nq=300, # num queries + ndp=4, # num decoder points + nh=8, # num head + ndl=6, # num decoder layers + d_ffn=1024, # dim of feedforward + dropout=0.0, + act=nn.ReLU(), + eval_idx=-1, + # Training args + nd=100, # num denoising + label_noise_ratio=0.5, + box_noise_scale=1.0, + learnt_init_query=False, + ): + """ + Initializes the RTDETRDecoder module with the given parameters. + + Args: + nc (int): Number of classes. Default is 80. + ch (tuple): Channels in the backbone feature maps. Default is (512, 1024, 2048). + hd (int): Dimension of hidden layers. Default is 256. + nq (int): Number of query points. Default is 300. + ndp (int): Number of decoder points. Default is 4. + nh (int): Number of heads in multi-head attention. Default is 8. + ndl (int): Number of decoder layers. Default is 6. + d_ffn (int): Dimension of the feed-forward networks. Default is 1024. + dropout (float): Dropout rate. Default is 0. + act (nn.Module): Activation function. Default is nn.ReLU. + eval_idx (int): Evaluation index. Default is -1. + nd (int): Number of denoising. Default is 100. + label_noise_ratio (float): Label noise ratio. Default is 0.5. + box_noise_scale (float): Box noise scale. Default is 1.0. + learnt_init_query (bool): Whether to learn initial query embeddings. Default is False. + """ + super().__init__() + self.hidden_dim = hd + self.nhead = nh + self.nl = len(ch) # num level + self.nc = nc + self.num_queries = nq + self.num_decoder_layers = ndl + + # Backbone feature projection + self.input_proj = nn.ModuleList(nn.Sequential(nn.Conv2d(x, hd, 1, bias=False), nn.BatchNorm2d(hd)) for x in ch) + # NOTE: simplified version but it's not consistent with .pt weights. + # self.input_proj = nn.ModuleList(Conv(x, hd, act=False) for x in ch) + + # Transformer module + decoder_layer = DeformableTransformerDecoderLayer(hd, nh, d_ffn, dropout, act, self.nl, ndp) + self.decoder = DeformableTransformerDecoder(hd, decoder_layer, ndl, eval_idx) + + # Denoising part + self.denoising_class_embed = nn.Embedding(nc, hd) + self.num_denoising = nd + self.label_noise_ratio = label_noise_ratio + self.box_noise_scale = box_noise_scale + + # Decoder embedding + self.learnt_init_query = learnt_init_query + if learnt_init_query: + self.tgt_embed = nn.Embedding(nq, hd) + self.query_pos_head = MLP(4, 2 * hd, hd, num_layers=2) + + # Encoder head + self.enc_output = nn.Sequential(nn.Linear(hd, hd), nn.LayerNorm(hd)) + self.enc_score_head = nn.Linear(hd, nc) + self.enc_bbox_head = MLP(hd, hd, 4, num_layers=3) + + # Decoder head + self.dec_score_head = nn.ModuleList([nn.Linear(hd, nc) for _ in range(ndl)]) + self.dec_bbox_head = nn.ModuleList([MLP(hd, hd, 4, num_layers=3) for _ in range(ndl)]) + + self._reset_parameters() + + def forward(self, x, batch=None): + """Runs the forward pass of the module, returning bounding box and classification scores for the input.""" + from ultralytics.models.utils.ops import get_cdn_group + + # Input projection and embedding + feats, shapes = self._get_encoder_input(x) + + # Prepare denoising training + dn_embed, dn_bbox, attn_mask, dn_meta = get_cdn_group( + batch, + self.nc, + self.num_queries, + self.denoising_class_embed.weight, + self.num_denoising, + self.label_noise_ratio, + self.box_noise_scale, + self.training, + ) + + embed, refer_bbox, enc_bboxes, enc_scores = self._get_decoder_input(feats, shapes, dn_embed, dn_bbox) + + # Decoder + dec_bboxes, dec_scores = self.decoder( + embed, + refer_bbox, + feats, + shapes, + self.dec_bbox_head, + self.dec_score_head, + self.query_pos_head, + attn_mask=attn_mask, + ) + x = dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta + if self.training: + return x + # (bs, 300, 4+nc) + y = torch.cat((dec_bboxes.squeeze(0), dec_scores.squeeze(0).sigmoid()), -1) + return y if self.export else (y, x) + + def _generate_anchors(self, shapes, grid_size=0.05, dtype=torch.float32, device="cpu", eps=1e-2): + """Generates anchor bounding boxes for given shapes with specific grid size and validates them.""" + anchors = [] + for i, (h, w) in enumerate(shapes): + sy = torch.arange(end=h, dtype=dtype, device=device) + sx = torch.arange(end=w, dtype=dtype, device=device) + grid_y, grid_x = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) + grid_xy = torch.stack([grid_x, grid_y], -1) # (h, w, 2) + + valid_WH = torch.tensor([w, h], dtype=dtype, device=device) + grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH # (1, h, w, 2) + wh = torch.ones_like(grid_xy, dtype=dtype, device=device) * grid_size * (2.0**i) + anchors.append(torch.cat([grid_xy, wh], -1).view(-1, h * w, 4)) # (1, h*w, 4) + + anchors = torch.cat(anchors, 1) # (1, h*w*nl, 4) + valid_mask = ((anchors > eps) & (anchors < 1 - eps)).all(-1, keepdim=True) # 1, h*w*nl, 1 + anchors = torch.log(anchors / (1 - anchors)) + anchors = anchors.masked_fill(~valid_mask, float("inf")) + return anchors, valid_mask + + def _get_encoder_input(self, x): + """Processes and returns encoder inputs by getting projection features from input and concatenating them.""" + # Get projection features + x = [self.input_proj[i](feat) for i, feat in enumerate(x)] + # Get encoder inputs + feats = [] + shapes = [] + for feat in x: + h, w = feat.shape[2:] + # [b, c, h, w] -> [b, h*w, c] + feats.append(feat.flatten(2).permute(0, 2, 1)) + # [nl, 2] + shapes.append([h, w]) + + # [b, h*w, c] + feats = torch.cat(feats, 1) + return feats, shapes + + def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None): + """Generates and prepares the input required for the decoder from the provided features and shapes.""" + bs = feats.shape[0] + # Prepare input for decoder + anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device) + features = self.enc_output(valid_mask * feats) # bs, h*w, 256 + + enc_outputs_scores = self.enc_score_head(features) # (bs, h*w, nc) + + # Query selection + # (bs, num_queries) + topk_ind = torch.topk(enc_outputs_scores.max(-1).values, self.num_queries, dim=1).indices.view(-1) + # (bs, num_queries) + batch_ind = torch.arange(end=bs, dtype=topk_ind.dtype).unsqueeze(-1).repeat(1, self.num_queries).view(-1) + + # (bs, num_queries, 256) + top_k_features = features[batch_ind, topk_ind].view(bs, self.num_queries, -1) + # (bs, num_queries, 4) + top_k_anchors = anchors[:, topk_ind].view(bs, self.num_queries, -1) + + # Dynamic anchors + static content + refer_bbox = self.enc_bbox_head(top_k_features) + top_k_anchors + + enc_bboxes = refer_bbox.sigmoid() + if dn_bbox is not None: + refer_bbox = torch.cat([dn_bbox, refer_bbox], 1) + enc_scores = enc_outputs_scores[batch_ind, topk_ind].view(bs, self.num_queries, -1) + + embeddings = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) if self.learnt_init_query else top_k_features + if self.training: + refer_bbox = refer_bbox.detach() + if not self.learnt_init_query: + embeddings = embeddings.detach() + if dn_embed is not None: + embeddings = torch.cat([dn_embed, embeddings], 1) + + return embeddings, refer_bbox, enc_bboxes, enc_scores + + # TODO + def _reset_parameters(self): + """Initializes or resets the parameters of the model's various components with predefined weights and biases.""" + # Class and bbox head init + bias_cls = bias_init_with_prob(0.01) / 80 * self.nc + # NOTE: the weight initialization in `linear_init` would cause NaN when training with custom datasets. + # linear_init(self.enc_score_head) + constant_(self.enc_score_head.bias, bias_cls) + constant_(self.enc_bbox_head.layers[-1].weight, 0.0) + constant_(self.enc_bbox_head.layers[-1].bias, 0.0) + for cls_, reg_ in zip(self.dec_score_head, self.dec_bbox_head): + # linear_init(cls_) + constant_(cls_.bias, bias_cls) + constant_(reg_.layers[-1].weight, 0.0) + constant_(reg_.layers[-1].bias, 0.0) + + linear_init(self.enc_output[0]) + xavier_uniform_(self.enc_output[0].weight) + if self.learnt_init_query: + xavier_uniform_(self.tgt_embed.weight) + xavier_uniform_(self.query_pos_head.layers[0].weight) + xavier_uniform_(self.query_pos_head.layers[1].weight) + for layer in self.input_proj: + xavier_uniform_(layer[0].weight) + + +class v10Detect(Detect): + """ + v10 Detection head from https://arxiv.org/pdf/2405.14458. + + Args: + nc (int): Number of classes. + ch (tuple): Tuple of channel sizes. + + Attributes: + max_det (int): Maximum number of detections. + + Methods: + __init__(self, nc=80, ch=()): Initializes the v10Detect object. + forward(self, x): Performs forward pass of the v10Detect module. + bias_init(self): Initializes biases of the Detect module. + + """ + + end2end = True + + def __init__(self, nc=80, ch=()): + """Initializes the v10Detect object with the specified number of classes and input channels.""" + super().__init__(nc, ch) + c3 = max(ch[0], min(self.nc, 100)) # channels + # Light cls head + self.cv3 = nn.ModuleList( + nn.Sequential( + nn.Sequential(Conv(x, x, 3, g=x), Conv(x, c3, 1)), + nn.Sequential(Conv(c3, c3, 3, g=c3), Conv(c3, c3, 1)), + nn.Conv2d(c3, self.nc, 1), + ) + for x in ch + ) + self.one2one_cv3 = copy.deepcopy(self.cv3) diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c198736908eb8c85aeceb096c20c0b61136d9dd8 --- /dev/null +++ b/ultralytics/nn/modules/transformer.py @@ -0,0 +1,427 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Transformer modules.""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import constant_, xavier_uniform_ + +from .conv import Conv +from .utils import _get_clones, inverse_sigmoid, multi_scale_deformable_attn_pytorch + +__all__ = ( + "TransformerEncoderLayer", + "TransformerLayer", + "TransformerBlock", + "MLPBlock", + "LayerNorm2d", + "AIFI", + "DeformableTransformerDecoder", + "DeformableTransformerDecoderLayer", + "MSDeformAttn", + "MLP", +) + + +class TransformerEncoderLayer(nn.Module): + """Defines a single layer of the transformer encoder.""" + + def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False): + """Initialize the TransformerEncoderLayer with specified parameters.""" + super().__init__() + from ...utils.torch_utils import TORCH_1_9 + + if not TORCH_1_9: + raise ModuleNotFoundError( + "TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)." + ) + self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True) + # Implementation of Feedforward model + self.fc1 = nn.Linear(c1, cm) + self.fc2 = nn.Linear(cm, c1) + + self.norm1 = nn.LayerNorm(c1) + self.norm2 = nn.LayerNorm(c1) + self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.act = act + self.normalize_before = normalize_before + + @staticmethod + def with_pos_embed(tensor, pos=None): + """Add position embeddings to the tensor if provided.""" + return tensor if pos is None else tensor + pos + + def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """Performs forward pass with post-normalization.""" + q = k = self.with_pos_embed(src, pos) + src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.fc2(self.dropout(self.act(self.fc1(src)))) + src = src + self.dropout2(src2) + return self.norm2(src) + + def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """Performs forward pass with pre-normalization.""" + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.fc2(self.dropout(self.act(self.fc1(src2)))) + return src + self.dropout2(src2) + + def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None): + """Forward propagates the input through the encoder module.""" + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class AIFI(TransformerEncoderLayer): + """Defines the AIFI transformer layer.""" + + def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False): + """Initialize the AIFI instance with specified parameters.""" + super().__init__(c1, cm, num_heads, dropout, act, normalize_before) + + def forward(self, x): + """Forward pass for the AIFI transformer layer.""" + c, h, w = x.shape[1:] + pos_embed = self.build_2d_sincos_position_embedding(w, h, c) + # Flatten [B, C, H, W] to [B, HxW, C] + x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype)) + return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous() + + @staticmethod + def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0): + """Builds 2D sine-cosine position embedding.""" + assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" + grid_w = torch.arange(w, dtype=torch.float32) + grid_h = torch.arange(h, dtype=torch.float32) + grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij") + pos_dim = embed_dim // 4 + omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim + omega = 1.0 / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @ omega[None] + out_h = grid_h.flatten()[..., None] @ omega[None] + + return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None] + + +class TransformerLayer(nn.Module): + """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).""" + + def __init__(self, c, num_heads): + """Initializes a self-attention mechanism using linear transformations and multi-head attention.""" + super().__init__() + self.q = nn.Linear(c, c, bias=False) + self.k = nn.Linear(c, c, bias=False) + self.v = nn.Linear(c, c, bias=False) + self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads) + self.fc1 = nn.Linear(c, c, bias=False) + self.fc2 = nn.Linear(c, c, bias=False) + + def forward(self, x): + """Apply a transformer block to the input x and return the output.""" + x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x + return self.fc2(self.fc1(x)) + x + + +class TransformerBlock(nn.Module): + """Vision Transformer https://arxiv.org/abs/2010.11929.""" + + def __init__(self, c1, c2, num_heads, num_layers): + """Initialize a Transformer module with position embedding and specified number of heads and layers.""" + super().__init__() + self.conv = None + if c1 != c2: + self.conv = Conv(c1, c2) + self.linear = nn.Linear(c2, c2) # learnable position embedding + self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers))) + self.c2 = c2 + + def forward(self, x): + """Forward propagates the input through the bottleneck module.""" + if self.conv is not None: + x = self.conv(x) + b, _, w, h = x.shape + p = x.flatten(2).permute(2, 0, 1) + return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h) + + +class MLPBlock(nn.Module): + """Implements a single block of a multi-layer perceptron.""" + + def __init__(self, embedding_dim, mlp_dim, act=nn.GELU): + """Initialize the MLPBlock with specified embedding dimension, MLP dimension, and activation function.""" + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass for the MLPBlock.""" + return self.lin2(self.act(self.lin1(x))) + + +class MLP(nn.Module): + """Implements a simple multi-layer perceptron (also called FFN).""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act=nn.ReLU, sigmoid=False): + """Initialize the MLP with specified input, hidden, output dimensions and number of layers.""" + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid = sigmoid + self.act = act() + + def forward(self, x): + """Forward pass for the entire MLP.""" + for i, layer in enumerate(self.layers): + x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.sigmoid() if getattr(self, "sigmoid", False) else x + + +class LayerNorm2d(nn.Module): + """ + 2D Layer Normalization module inspired by Detectron2 and ConvNeXt implementations. + + Original implementations in + https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py + and + https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py. + """ + + def __init__(self, num_channels, eps=1e-6): + """Initialize LayerNorm2d with the given parameters.""" + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x): + """Perform forward pass for 2D layer normalization.""" + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + return self.weight[:, None, None] * x + self.bias[:, None, None] + + +class MSDeformAttn(nn.Module): + """ + Multiscale Deformable Attention Module based on Deformable-DETR and PaddleDetection implementations. + + https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py + """ + + def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + """Initialize MSDeformAttn with the given parameters.""" + super().__init__() + if d_model % n_heads != 0: + raise ValueError(f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}") + _d_per_head = d_model // n_heads + # Better to set _d_per_head to a power of 2 which is more efficient in a CUDA implementation + assert _d_per_head * n_heads == d_model, "`d_model` must be divisible by `n_heads`" + + self.im2col_step = 64 + + self.d_model = d_model + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_model) + self.output_proj = nn.Linear(d_model, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + """Reset module parameters.""" + constant_(self.sampling_offsets.weight.data, 0.0) + thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = ( + (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) + .view(self.n_heads, 1, 1, 2) + .repeat(1, self.n_levels, self.n_points, 1) + ) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.0) + constant_(self.attention_weights.bias.data, 0.0) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.0) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.0) + + def forward(self, query, refer_bbox, value, value_shapes, value_mask=None): + """ + Perform forward pass for multiscale deformable attention. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py + + Args: + query (torch.Tensor): [bs, query_length, C] + refer_bbox (torch.Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (torch.Tensor): [bs, value_length, C] + value_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, len_q = query.shape[:2] + len_v = value.shape[1] + assert sum(s[0] * s[1] for s in value_shapes) == len_v + + value = self.value_proj(value) + if value_mask is not None: + value = value.masked_fill(value_mask[..., None], float(0)) + value = value.view(bs, len_v, self.n_heads, self.d_model // self.n_heads) + sampling_offsets = self.sampling_offsets(query).view(bs, len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view(bs, len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, -1).view(bs, len_q, self.n_heads, self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + num_points = refer_bbox.shape[-1] + if num_points == 2: + offset_normalizer = torch.as_tensor(value_shapes, dtype=query.dtype, device=query.device).flip(-1) + add = sampling_offsets / offset_normalizer[None, None, None, :, None, :] + sampling_locations = refer_bbox[:, :, None, :, None, :] + add + elif num_points == 4: + add = sampling_offsets / self.n_points * refer_bbox[:, :, None, :, None, 2:] * 0.5 + sampling_locations = refer_bbox[:, :, None, :, None, :2] + add + else: + raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {num_points}.") + output = multi_scale_deformable_attn_pytorch(value, value_shapes, sampling_locations, attention_weights) + return self.output_proj(output) + + +class DeformableTransformerDecoderLayer(nn.Module): + """ + Deformable Transformer Decoder Layer inspired by PaddleDetection and Deformable-DETR implementations. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py + https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/deformable_transformer.py + """ + + def __init__(self, d_model=256, n_heads=8, d_ffn=1024, dropout=0.0, act=nn.ReLU(), n_levels=4, n_points=4): + """Initialize the DeformableTransformerDecoderLayer with the given parameters.""" + super().__init__() + + # Self attention + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # Cross attention + self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # FFN + self.linear1 = nn.Linear(d_model, d_ffn) + self.act = act + self.dropout3 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + """Add positional embeddings to the input tensor, if provided.""" + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + """Perform forward pass through the Feed-Forward Network part of the layer.""" + tgt2 = self.linear2(self.dropout3(self.act(self.linear1(tgt)))) + tgt = tgt + self.dropout4(tgt2) + return self.norm3(tgt) + + def forward(self, embed, refer_bbox, feats, shapes, padding_mask=None, attn_mask=None, query_pos=None): + """Perform the forward pass through the entire decoder layer.""" + # Self attention + q = k = self.with_pos_embed(embed, query_pos) + tgt = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), embed.transpose(0, 1), attn_mask=attn_mask)[ + 0 + ].transpose(0, 1) + embed = embed + self.dropout1(tgt) + embed = self.norm1(embed) + + # Cross attention + tgt = self.cross_attn( + self.with_pos_embed(embed, query_pos), refer_bbox.unsqueeze(2), feats, shapes, padding_mask + ) + embed = embed + self.dropout2(tgt) + embed = self.norm2(embed) + + # FFN + return self.forward_ffn(embed) + + +class DeformableTransformerDecoder(nn.Module): + """ + Implementation of Deformable Transformer Decoder based on PaddleDetection. + + https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/transformers/deformable_transformer.py + """ + + def __init__(self, hidden_dim, decoder_layer, num_layers, eval_idx=-1): + """Initialize the DeformableTransformerDecoder with the given parameters.""" + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx + + def forward( + self, + embed, # decoder embeddings + refer_bbox, # anchor + feats, # image features + shapes, # feature shapes + bbox_head, + score_head, + pos_mlp, + attn_mask=None, + padding_mask=None, + ): + """Perform the forward pass through the entire decoder.""" + output = embed + dec_bboxes = [] + dec_cls = [] + last_refined_bbox = None + refer_bbox = refer_bbox.sigmoid() + for i, layer in enumerate(self.layers): + output = layer(output, refer_bbox, feats, shapes, padding_mask, attn_mask, pos_mlp(refer_bbox)) + + bbox = bbox_head[i](output) + refined_bbox = torch.sigmoid(bbox + inverse_sigmoid(refer_bbox)) + + if self.training: + dec_cls.append(score_head[i](output)) + if i == 0: + dec_bboxes.append(refined_bbox) + else: + dec_bboxes.append(torch.sigmoid(bbox + inverse_sigmoid(last_refined_bbox))) + elif i == self.eval_idx: + dec_cls.append(score_head[i](output)) + dec_bboxes.append(refined_bbox) + break + + last_refined_bbox = refined_bbox + refer_bbox = refined_bbox.detach() if self.training else refined_bbox + + return torch.stack(dec_bboxes), torch.stack(dec_cls) diff --git a/ultralytics/nn/modules/utils.py b/ultralytics/nn/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7837ebe6c6c8b0de8639191116555ab51d0e575 --- /dev/null +++ b/ultralytics/nn/modules/utils.py @@ -0,0 +1,84 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Module utils.""" + +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import uniform_ + +__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid" + + +def _get_clones(module, n): + """Create a list of cloned modules from the given module.""" + return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) + + +def bias_init_with_prob(prior_prob=0.01): + """Initialize conv/fc bias value according to a given probability value.""" + return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init + + +def linear_init(module): + """Initialize the weights and biases of a linear module.""" + bound = 1 / math.sqrt(module.weight.shape[0]) + uniform_(module.weight, -bound, bound) + if hasattr(module, "bias") and module.bias is not None: + uniform_(module.bias, -bound, bound) + + +def inverse_sigmoid(x, eps=1e-5): + """Calculate the inverse sigmoid function for a tensor.""" + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1 / x2) + + +def multi_scale_deformable_attn_pytorch( + value: torch.Tensor, + value_spatial_shapes: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, +) -> torch.Tensor: + """ + Multiscale deformable attention. + + https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py + """ + bs, _, num_heads, embed_dims = value.shape + _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape + value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) + sampling_grids = 2 * sampling_locations - 1 + sampling_value_list = [] + for level, (H_, W_) in enumerate(value_spatial_shapes): + # bs, H_*W_, num_heads, embed_dims -> + # bs, H_*W_, num_heads*embed_dims -> + # bs, num_heads*embed_dims, H_*W_ -> + # bs*num_heads, embed_dims, H_, W_ + value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_) + # bs, num_queries, num_heads, num_points, 2 -> + # bs, num_heads, num_queries, num_points, 2 -> + # bs*num_heads, num_queries, num_points, 2 + sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1) + # bs*num_heads, embed_dims, num_queries, num_points + sampling_value_l_ = F.grid_sample( + value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False + ) + sampling_value_list.append(sampling_value_l_) + # (bs, num_queries, num_heads, num_levels, num_points) -> + # (bs, num_heads, num_queries, num_levels, num_points) -> + # (bs, num_heads, 1, num_queries, num_levels*num_points) + attention_weights = attention_weights.transpose(1, 2).reshape( + bs * num_heads, 1, num_queries, num_levels * num_points + ) + output = ( + (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) + .sum(-1) + .view(bs, num_heads * embed_dims, num_queries) + ) + return output.transpose(1, 2).contiguous() diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..44589c23cae667fec550585ac2780f255aea6f70 --- /dev/null +++ b/ultralytics/nn/tasks.py @@ -0,0 +1,1189 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import pickle +import re +import types +from copy import deepcopy +from pathlib import Path + +import thop +import torch +import torch.nn as nn + +from ultralytics.nn.modules import ( + AIFI, + C1, + C2, + C2PSA, + C3, + C3TR, + ELAN1, + OBB, + PSA, + SPP, + SPPELAN, + SPPF, + AConv, + ADown, + Bottleneck, + BottleneckCSP, + C2f, + C2fAttn, + C2fCIB, + C2fPSA, + C3Ghost, + C3k2, + C3x, + CBFuse, + CBLinear, + Classify, + Concat, + Conv, + Conv2, + ConvTranspose, + Detect, + DWConv, + DWConvTranspose2d, + Focus, + GhostBottleneck, + GhostConv, + HGBlock, + HGStem, + ImagePoolingAttn, + Index, + Pose, + RepC3, + RepConv, + RepNCSPELAN4, + RepVGGDW, + ResNetLayer, + RTDETRDecoder, + SCDown, + Segment, + TorchVision, + WorldDetect, + v10Detect, + A2C2f, +) +from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load +from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml +from ultralytics.utils.loss import ( + E2EDetectLoss, + v8ClassificationLoss, + v8DetectionLoss, + v8OBBLoss, + v8PoseLoss, + v8SegmentationLoss, +) +from ultralytics.utils.ops import make_divisible +from ultralytics.utils.plotting import feature_visualization +from ultralytics.utils.torch_utils import ( + fuse_conv_and_bn, + fuse_deconv_and_bn, + initialize_weights, + intersect_dicts, + model_info, + scale_img, + time_sync, +) + + +class BaseModel(nn.Module): + """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" + + def forward(self, x, *args, **kwargs): + """ + Perform forward pass of the model for either training or inference. + + If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference. + + Args: + x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training. + *args (Any): Variable length argument list. + **kwargs (Any): Arbitrary keyword arguments. + + Returns: + (torch.Tensor): Loss if x is a dict (training), or network predictions (inference). + """ + if isinstance(x, dict): # for cases of training and validating while training. + return self.loss(x, *args, **kwargs) + return self.predict(x, *args, **kwargs) + + def predict(self, x, profile=False, visualize=False, augment=False, embed=None): + """ + Perform a forward pass through the network. + + Args: + x (torch.Tensor): The input tensor to the model. + profile (bool): Print the computation time of each layer if True, defaults to False. + visualize (bool): Save the feature maps of the model if True, defaults to False. + augment (bool): Augment image during prediction, defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): The last output of the model. + """ + if augment: + return self._predict_augment(x) + return self._predict_once(x, profile, visualize, embed) + + def _predict_once(self, x, profile=False, visualize=False, embed=None): + """ + Perform a forward pass through the network. + + Args: + x (torch.Tensor): The input tensor to the model. + profile (bool): Print the computation time of each layer if True, defaults to False. + visualize (bool): Save the feature maps of the model if True, defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): The last output of the model. + """ + y, dt, embeddings = [], [], [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) + return x + + def _predict_augment(self, x): + """Perform augmentations on input image x and return augmented inference.""" + LOGGER.warning( + f"WARNING ⚠️ {self.__class__.__name__} does not support 'augment=True' prediction. " + f"Reverting to single-scale prediction." + ) + return self._predict_once(x) + + def _profile_one_layer(self, m, x, dt): + """ + Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to + the provided list. + + Args: + m (nn.Module): The layer to be profiled. + x (torch.Tensor): The input data to the layer. + dt (list): A list to store the computation time of the layer. + + Returns: + None + """ + c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix + flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs + t = time_sync() + for _ in range(10): + m(x.copy() if c else x) + dt.append((time_sync() - t) * 100) + if m == self.model[0]: + LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module") + LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}") + if c: + LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") + + def fuse(self, verbose=True): + """ + Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the + computation efficiency. + + Returns: + (nn.Module): The fused model is returned. + """ + if not self.is_fused(): + for m in self.model.modules(): + if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"): + if isinstance(m, Conv2): + m.fuse_convs() + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, "bn") # remove batchnorm + m.forward = m.forward_fuse # update forward + if isinstance(m, ConvTranspose) and hasattr(m, "bn"): + m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) + delattr(m, "bn") # remove batchnorm + m.forward = m.forward_fuse # update forward + if isinstance(m, RepConv): + m.fuse_convs() + m.forward = m.forward_fuse # update forward + if isinstance(m, RepVGGDW): + m.fuse() + m.forward = m.forward_fuse + self.info(verbose=verbose) + + return self + + def is_fused(self, thresh=10): + """ + Check if the model has less than a certain threshold of BatchNorm layers. + + Args: + thresh (int, optional): The threshold number of BatchNorm layers. Default is 10. + + Returns: + (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. + """ + bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() + return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model + + def info(self, detailed=False, verbose=True, imgsz=640): + """ + Prints model information. + + Args: + detailed (bool): if True, prints out detailed information about the model. Defaults to False + verbose (bool): if True, prints out the model information. Defaults to False + imgsz (int): the size of the image that the model will be trained on. Defaults to 640 + """ + return model_info(self, detailed=detailed, verbose=verbose, imgsz=imgsz) + + def _apply(self, fn): + """ + Applies a function to all the tensors in the model that are not parameters or registered buffers. + + Args: + fn (function): the function to apply to the model + + Returns: + (BaseModel): An updated BaseModel object. + """ + self = super()._apply(fn) + m = self.model[-1] # Detect() + if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect + m.stride = fn(m.stride) + m.anchors = fn(m.anchors) + m.strides = fn(m.strides) + return self + + def load(self, weights, verbose=True): + """ + Load the weights into the model. + + Args: + weights (dict | torch.nn.Module): The pre-trained weights to be loaded. + verbose (bool, optional): Whether to log the transfer progress. Defaults to True. + """ + model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts + csd = model.float().state_dict() # checkpoint state_dict as FP32 + csd = intersect_dicts(csd, self.state_dict()) # intersect + self.load_state_dict(csd, strict=False) # load + if verbose: + LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights") + + def loss(self, batch, preds=None): + """ + Compute loss. + + Args: + batch (dict): Batch to compute loss on + preds (torch.Tensor | List[torch.Tensor]): Predictions. + """ + if getattr(self, "criterion", None) is None: + self.criterion = self.init_criterion() + + preds = self.forward(batch["img"]) if preds is None else preds + return self.criterion(preds, batch) + + def init_criterion(self): + """Initialize the loss criterion for the BaseModel.""" + raise NotImplementedError("compute_loss() needs to be implemented by task heads") + + +class DetectionModel(BaseModel): + """YOLOv8 detection model.""" + + def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes + """Initialize the YOLOv8 detection model with the given config and parameters.""" + super().__init__() + self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict + if self.yaml["backbone"][0][2] == "Silence": + LOGGER.warning( + "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. " + "Please delete local *.pt file and re-download the latest model checkpoint." + ) + self.yaml["backbone"][0][2] = "nn.Identity" + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") + self.yaml["nc"] = nc # override YAML value + self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + self.inplace = self.yaml.get("inplace", True) + self.end2end = getattr(self.model[-1], "end2end", False) + + # Build strides + m = self.model[-1] # Detect() + if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect + s = 256 # 2x min stride + m.inplace = self.inplace + + def _forward(x): + """Performs a forward pass through the model, handling different Detect subclass types accordingly.""" + if self.end2end: + return self.forward(x)["one2many"] + return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x) + + m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward + self.stride = m.stride + m.bias_init() # only run once + else: + self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR + + # Init weights, biases + initialize_weights(self) + if verbose: + self.info() + LOGGER.info("") + + def _predict_augment(self, x): + """Perform augmentations on input image x and return augmented inference and train outputs.""" + if getattr(self, "end2end", False) or self.__class__.__name__ != "DetectionModel": + LOGGER.warning("WARNING ⚠️ Model does not support 'augment=True', reverting to single-scale prediction.") + return self._predict_once(x) + img_size = x.shape[-2:] # height, width + s = [1, 0.83, 0.67] # scales + f = [None, 3, None] # flips (2-ud, 3-lr) + y = [] # outputs + for si, fi in zip(s, f): + xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) + yi = super().predict(xi)[0] # forward + yi = self._descale_pred(yi, fi, si, img_size) + y.append(yi) + y = self._clip_augmented(y) # clip augmented tails + return torch.cat(y, -1), None # augmented inference, train + + @staticmethod + def _descale_pred(p, flips, scale, img_size, dim=1): + """De-scale predictions following augmented inference (inverse operation).""" + p[:, :4] /= scale # de-scale + x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim) + if flips == 2: + y = img_size[0] - y # de-flip ud + elif flips == 3: + x = img_size[1] - x # de-flip lr + return torch.cat((x, y, wh, cls), dim) + + def _clip_augmented(self, y): + """Clip YOLO augmented inference tails.""" + nl = self.model[-1].nl # number of detection layers (P3-P5) + g = sum(4**x for x in range(nl)) # grid points + e = 1 # exclude layer count + i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices + y[0] = y[0][..., :-i] # large + i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices + y[-1] = y[-1][..., i:] # small + return y + + def init_criterion(self): + """Initialize the loss criterion for the DetectionModel.""" + return E2EDetectLoss(self) if getattr(self, "end2end", False) else v8DetectionLoss(self) + + +class OBBModel(DetectionModel): + """YOLOv8 Oriented Bounding Box (OBB) model.""" + + def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True): + """Initialize YOLOv8 OBB model with given config and parameters.""" + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the model.""" + return v8OBBLoss(self) + + +class SegmentationModel(DetectionModel): + """YOLOv8 segmentation model.""" + + def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True): + """Initialize YOLOv8 segmentation model with given config and parameters.""" + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the SegmentationModel.""" + return v8SegmentationLoss(self) + + +class PoseModel(DetectionModel): + """YOLOv8 pose model.""" + + def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): + """Initialize YOLOv8 Pose model.""" + if not isinstance(cfg, dict): + cfg = yaml_model_load(cfg) # load model YAML + if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]): + LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}") + cfg["kpt_shape"] = data_kpt_shape + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the PoseModel.""" + return v8PoseLoss(self) + + +class ClassificationModel(BaseModel): + """YOLOv8 classification model.""" + + def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True): + """Init ClassificationModel with YAML, channels, number of classes, verbose flag.""" + super().__init__() + self._from_yaml(cfg, ch, nc, verbose) + + def _from_yaml(self, cfg, ch, nc, verbose): + """Set YOLOv8 model configurations and define the model architecture.""" + self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}") + self.yaml["nc"] = nc # override YAML value + elif not nc and not self.yaml.get("nc", None): + raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.") + self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist + self.stride = torch.Tensor([1]) # no stride constraints + self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict + self.info() + + @staticmethod + def reshape_outputs(model, nc): + """Update a TorchVision classification model to class count 'n' if required.""" + name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module + if isinstance(m, Classify): # YOLO Classify() head + if m.linear.out_features != nc: + m.linear = nn.Linear(m.linear.in_features, nc) + elif isinstance(m, nn.Linear): # ResNet, EfficientNet + if m.out_features != nc: + setattr(model, name, nn.Linear(m.in_features, nc)) + elif isinstance(m, nn.Sequential): + types = [type(x) for x in m] + if nn.Linear in types: + i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index + if m[i].out_features != nc: + m[i] = nn.Linear(m[i].in_features, nc) + elif nn.Conv2d in types: + i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index + if m[i].out_channels != nc: + m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) + + def init_criterion(self): + """Initialize the loss criterion for the ClassificationModel.""" + return v8ClassificationLoss() + + +class RTDETRDetectionModel(DetectionModel): + """ + RTDETR (Real-time DEtection and Tracking using Transformers) Detection Model class. + + This class is responsible for constructing the RTDETR architecture, defining loss functions, and facilitating both + the training and inference processes. RTDETR is an object detection and tracking model that extends from the + DetectionModel base class. + + Attributes: + cfg (str): The configuration file path or preset string. Default is 'rtdetr-l.yaml'. + ch (int): Number of input channels. Default is 3 (RGB). + nc (int, optional): Number of classes for object detection. Default is None. + verbose (bool): Specifies if summary statistics are shown during initialization. Default is True. + + Methods: + init_criterion: Initializes the criterion used for loss calculation. + loss: Computes and returns the loss during training. + predict: Performs a forward pass through the network and returns the output. + """ + + def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True): + """ + Initialize the RTDETRDetectionModel. + + Args: + cfg (str): Configuration file name or path. + ch (int): Number of input channels. + nc (int, optional): Number of classes. Defaults to None. + verbose (bool, optional): Print additional information during initialization. Defaults to True. + """ + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def init_criterion(self): + """Initialize the loss criterion for the RTDETRDetectionModel.""" + from ultralytics.models.utils.loss import RTDETRDetectionLoss + + return RTDETRDetectionLoss(nc=self.nc, use_vfl=True) + + def loss(self, batch, preds=None): + """ + Compute the loss for the given batch of data. + + Args: + batch (dict): Dictionary containing image and label data. + preds (torch.Tensor, optional): Precomputed model predictions. Defaults to None. + + Returns: + (tuple): A tuple containing the total loss and main three losses in a tensor. + """ + if not hasattr(self, "criterion"): + self.criterion = self.init_criterion() + + img = batch["img"] + # NOTE: preprocess gt_bbox and gt_labels to list. + bs = len(img) + batch_idx = batch["batch_idx"] + gt_groups = [(batch_idx == i).sum().item() for i in range(bs)] + targets = { + "cls": batch["cls"].to(img.device, dtype=torch.long).view(-1), + "bboxes": batch["bboxes"].to(device=img.device), + "batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1), + "gt_groups": gt_groups, + } + + preds = self.predict(img, batch=targets) if preds is None else preds + dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1] + if dn_meta is None: + dn_bboxes, dn_scores = None, None + else: + dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2) + dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2) + + dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4) + dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores]) + + loss = self.criterion( + (dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta + ) + # NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses. + return sum(loss.values()), torch.as_tensor( + [loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device + ) + + def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None): + """ + Perform a forward pass through the model. + + Args: + x (torch.Tensor): The input tensor. + profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. + visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. + batch (dict, optional): Ground truth data for evaluation. Defaults to None. + augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): Model's output tensor. + """ + y, dt, embeddings = [], [], [] # outputs + for m in self.model[:-1]: # except the head part + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) + head = self.model[-1] + x = head([y[j] for j in head.f], batch) # head inference + return x + + +class WorldModel(DetectionModel): + """YOLOv8 World Model.""" + + def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True): + """Initialize YOLOv8 world model with given config and parameters.""" + self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder + self.clip_model = None # CLIP model placeholder + super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) + + def set_classes(self, text, batch=80, cache_clip_model=True): + """Set classes in advance so that model could do offline-inference without clip model.""" + try: + import clip + except ImportError: + check_requirements("git+https://github.com/ultralytics/CLIP.git") + import clip + + if ( + not getattr(self, "clip_model", None) and cache_clip_model + ): # for backwards compatibility of models lacking clip_model attribute + self.clip_model = clip.load("ViT-B/32")[0] + model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0] + device = next(model.parameters()).device + text_token = clip.tokenize(text).to(device) + txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)] + txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0) + txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) + self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]) + self.model[-1].nc = len(text) + + def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None): + """ + Perform a forward pass through the model. + + Args: + x (torch.Tensor): The input tensor. + profile (bool, optional): If True, profile the computation time for each layer. Defaults to False. + visualize (bool, optional): If True, save feature maps for visualization. Defaults to False. + txt_feats (torch.Tensor): The text features, use it if it's given. Defaults to None. + augment (bool, optional): If True, perform data augmentation during inference. Defaults to False. + embed (list, optional): A list of feature vectors/embeddings to return. + + Returns: + (torch.Tensor): Model's output tensor. + """ + txt_feats = (self.txt_feats if txt_feats is None else txt_feats).to(device=x.device, dtype=x.dtype) + if len(txt_feats) != len(x): + txt_feats = txt_feats.repeat(len(x), 1, 1) + ori_txt_feats = txt_feats.clone() + y, dt, embeddings = [], [], [] # outputs + for m in self.model: # except the head part + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + if profile: + self._profile_one_layer(m, x, dt) + if isinstance(m, C2fAttn): + x = m(x, txt_feats) + elif isinstance(m, WorldDetect): + x = m(x, ori_txt_feats) + elif isinstance(m, ImagePoolingAttn): + txt_feats = m(x, txt_feats) + else: + x = m(x) # run + + y.append(x if m.i in self.save else None) # save output + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) + if embed and m.i in embed: + embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten + if m.i == max(embed): + return torch.unbind(torch.cat(embeddings, 1), dim=0) + return x + + def loss(self, batch, preds=None): + """ + Compute loss. + + Args: + batch (dict): Batch to compute loss on. + preds (torch.Tensor | List[torch.Tensor]): Predictions. + """ + if not hasattr(self, "criterion"): + self.criterion = self.init_criterion() + + if preds is None: + preds = self.forward(batch["img"], txt_feats=batch["txt_feats"]) + return self.criterion(preds, batch) + + +class Ensemble(nn.ModuleList): + """Ensemble of models.""" + + def __init__(self): + """Initialize an ensemble of models.""" + super().__init__() + + def forward(self, x, augment=False, profile=False, visualize=False): + """Function generates the YOLO network's final layer.""" + y = [module(x, augment, profile, visualize)[0] for module in self] + # y = torch.stack(y).max(0)[0] # max ensemble + # y = torch.stack(y).mean(0) # mean ensemble + y = torch.cat(y, 2) # nms ensemble, y shape(B, HW, C) + return y, None # inference, train output + + +# Functions ------------------------------------------------------------------------------------------------------------ + + +@contextlib.contextmanager +def temporary_modules(modules=None, attributes=None): + """ + Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`). + + This function can be used to change the module paths during runtime. It's useful when refactoring code, + where you've moved a module from one location to another, but you still want to support the old import + paths for backwards compatibility. + + Args: + modules (dict, optional): A dictionary mapping old module paths to new module paths. + attributes (dict, optional): A dictionary mapping old module attributes to new module attributes. + + Example: + ```python + with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}): + import old.module # this will now import new.module + from old.module import attribute # this will now import new.module.attribute + ``` + + Note: + The changes are only in effect inside the context manager and are undone once the context manager exits. + Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger + applications or libraries. Use this function with caution. + """ + if modules is None: + modules = {} + if attributes is None: + attributes = {} + import sys + from importlib import import_module + + try: + # Set attributes in sys.modules under their old name + for old, new in attributes.items(): + old_module, old_attr = old.rsplit(".", 1) + new_module, new_attr = new.rsplit(".", 1) + setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr)) + + # Set modules in sys.modules under their old name + for old, new in modules.items(): + sys.modules[old] = import_module(new) + + yield + finally: + # Remove the temporary module paths + for old in modules: + if old in sys.modules: + del sys.modules[old] + + +class SafeClass: + """A placeholder class to replace unknown classes during unpickling.""" + + def __init__(self, *args, **kwargs): + """Initialize SafeClass instance, ignoring all arguments.""" + pass + + def __call__(self, *args, **kwargs): + """Run SafeClass instance, ignoring all arguments.""" + pass + + +class SafeUnpickler(pickle.Unpickler): + """Custom Unpickler that replaces unknown classes with SafeClass.""" + + def find_class(self, module, name): + """Attempt to find a class, returning SafeClass if not among safe modules.""" + safe_modules = ( + "torch", + "collections", + "collections.abc", + "builtins", + "math", + "numpy", + # Add other modules considered safe + ) + if module in safe_modules: + return super().find_class(module, name) + else: + return SafeClass + + +def torch_safe_load(weight, safe_only=False): + """ + Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the + error, logs a warning message, and attempts to install the missing module via the check_requirements() function. + After installation, the function again attempts to load the model using torch.load(). + + Args: + weight (str): The file path of the PyTorch model. + safe_only (bool): If True, replace unknown classes with SafeClass during loading. + + Example: + ```python + from ultralytics.nn.tasks import torch_safe_load + + ckpt, file = torch_safe_load("path/to/best.pt", safe_only=True) + ``` + + Returns: + ckpt (dict): The loaded model checkpoint. + file (str): The loaded filename + """ + from ultralytics.utils.downloads import attempt_download_asset + + check_suffix(file=weight, suffix=".pt") + file = attempt_download_asset(weight) # search online if missing locally + try: + with temporary_modules( + modules={ + "ultralytics.yolo.utils": "ultralytics.utils", + "ultralytics.yolo.v8": "ultralytics.models.yolo", + "ultralytics.yolo.data": "ultralytics.data", + }, + attributes={ + "ultralytics.nn.modules.block.Silence": "torch.nn.Identity", # YOLOv9e + "ultralytics.nn.tasks.YOLOv10DetectionModel": "ultralytics.nn.tasks.DetectionModel", # YOLOv10 + "ultralytics.utils.loss.v10DetectLoss": "ultralytics.utils.loss.E2EDetectLoss", # YOLOv10 + }, + ): + if safe_only: + # Load via custom pickle module + safe_pickle = types.ModuleType("safe_pickle") + safe_pickle.Unpickler = SafeUnpickler + safe_pickle.load = lambda file_obj: SafeUnpickler(file_obj).load() + with open(file, "rb") as f: + ckpt = torch.load(f, pickle_module=safe_pickle) + else: + ckpt = torch.load(file, map_location="cpu") + + except ModuleNotFoundError as e: # e.name is missing module name + if e.name == "models": + raise TypeError( + emojis( + f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained " + f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with " + f"YOLOv8 at https://github.com/ultralytics/ultralytics." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'" + ) + ) from e + LOGGER.warning( + f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in Ultralytics requirements." + f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future." + f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to " + f"run a command with an official Ultralytics model, i.e. 'yolo predict model=yolov8n.pt'" + ) + check_requirements(e.name) # install missing module + ckpt = torch.load(file, map_location="cpu") + + if not isinstance(ckpt, dict): + # File is likely a YOLO instance saved with i.e. torch.save(model, "saved_model.pt") + LOGGER.warning( + f"WARNING ⚠️ The file '{weight}' appears to be improperly saved or formatted. " + f"For optimal results, use model.save('filename.pt') to correctly save YOLO models." + ) + ckpt = {"model": ckpt.model} + + return ckpt, file + + +def attempt_load_weights(weights, device=None, inplace=True, fuse=False): + """Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a.""" + ensemble = Ensemble() + for w in weights if isinstance(weights, list) else [weights]: + ckpt, w = torch_safe_load(w) # load ckpt + args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model + + # Model compatibility updates + model.args = args # attach args to model + model.pt_path = w # attach *.pt file path to model + model.task = guess_model_task(model) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) + + # Append + ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode + + # Module updates + for m in ensemble.modules(): + if hasattr(m, "inplace"): + m.inplace = inplace + elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): + m.recompute_scale_factor = None # torch 1.11.0 compatibility + + # Return model + if len(ensemble) == 1: + return ensemble[-1] + + # Return ensemble + LOGGER.info(f"Ensemble created with {weights}\n") + for k in "names", "nc", "yaml": + setattr(ensemble, k, getattr(ensemble[0], k)) + ensemble.stride = ensemble[int(torch.argmax(torch.tensor([m.stride.max() for m in ensemble])))].stride + assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}" + return ensemble + + +def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): + """Loads a single model weights.""" + ckpt, weight = torch_safe_load(weight) # load ckpt + args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args + model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model + + # Model compatibility updates + model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model + model.pt_path = weight # attach *.pt file path to model + model.task = guess_model_task(model) + if not hasattr(model, "stride"): + model.stride = torch.tensor([32.0]) + + model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode + + # Module updates + for m in model.modules(): + if hasattr(m, "inplace"): + m.inplace = inplace + elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): + m.recompute_scale_factor = None # torch 1.11.0 compatibility + + # Return model and ckpt + return model, ckpt + + +def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) + """Parse a YOLO model.yaml dictionary into a PyTorch model.""" + import ast + + # Args + legacy = True # backward compatibility for v3/v5/v8/v9 models + max_channels = float("inf") + nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) + depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) + if scales: + scale = d.get("scale") + if not scale: + scale = tuple(scales.keys())[0] + LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.") + depth, width, max_channels = scales[scale] + + if act: + Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU() + if verbose: + LOGGER.info(f"{colorstr('activation:')} {act}") # print + + if verbose: + LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}") + ch = [ch] + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module + for j, a in enumerate(args): + if isinstance(a, str): + with contextlib.suppress(ValueError): + args[j] = locals()[a] if a in locals() else ast.literal_eval(a) + n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain + if m in { + Classify, + Conv, + ConvTranspose, + GhostConv, + Bottleneck, + GhostBottleneck, + SPP, + SPPF, + C2fPSA, + C2PSA, + DWConv, + Focus, + BottleneckCSP, + C1, + C2, + C2f, + C3k2, + RepNCSPELAN4, + ELAN1, + ADown, + AConv, + SPPELAN, + C2fAttn, + C3, + C3TR, + C3Ghost, + nn.ConvTranspose2d, + DWConvTranspose2d, + C3x, + RepC3, + PSA, + SCDown, + C2fCIB, + A2C2f, + }: + c1, c2 = ch[f], args[0] + if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) + c2 = make_divisible(min(c2, max_channels) * width, 8) + if m is C2fAttn: + args[1] = make_divisible(min(args[1], max_channels // 2) * width, 8) # embed channels + args[2] = int( + max(round(min(args[2], max_channels // 2 // 32)) * width, 1) if args[2] > 1 else args[2] + ) # num heads + + args = [c1, c2, *args[1:]] + if m in { + BottleneckCSP, + C1, + C2, + C2f, + C3k2, + C2fAttn, + C3, + C3TR, + C3Ghost, + C3x, + RepC3, + C2fPSA, + C2fCIB, + C2PSA, + A2C2f, + }: + args.insert(2, n) # number of repeats + n = 1 + if m is C3k2: # for M/L/X sizes + legacy = False + if scale in "mlx": + args[3] = True + if m is A2C2f: + legacy = False + if scale in "lx": # for L/X sizes + args.append(True) + args.append(1.2) + elif m is AIFI: + args = [ch[f], *args] + elif m in {HGStem, HGBlock}: + c1, cm, c2 = ch[f], args[0], args[1] + args = [c1, cm, c2, *args[2:]] + if m is HGBlock: + args.insert(4, n) # number of repeats + n = 1 + elif m is ResNetLayer: + c2 = args[1] if args[3] else args[1] * 4 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[x] for x in f) + elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn, v10Detect}: + args.append([ch[x] for x in f]) + if m is Segment: + args[2] = make_divisible(min(args[2], max_channels) * width, 8) + if m in {Detect, Segment, Pose, OBB}: + m.legacy = legacy + elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 + args.insert(1, [ch[x] for x in f]) + elif m in {CBLinear, TorchVision, Index}: + c2 = args[0] + c1 = ch[f] + args = [c1, c2, *args[1:]] + elif m is CBFuse: + c2 = ch[f[-1]] + else: + c2 = ch[f] + + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace("__main__.", "") # module type + m_.np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type + if verbose: + LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m_.np:10.0f} {t:<45}{str(args):<30}") # print + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + if i == 0: + ch = [] + ch.append(c2) + return nn.Sequential(*layers), sorted(save) + + +def yaml_model_load(path): + """Load a YOLOv8 model from a YAML file.""" + path = Path(path) + if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)): + new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem) + LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.") + path = path.with_name(new_stem + path.suffix) + + unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml + yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path) + d = yaml_load(yaml_file) # model dict + d["scale"] = guess_model_scale(path) + d["yaml_file"] = str(path) + return d + + +def guess_model_scale(model_path): + """ + Takes a path to a YOLO model's YAML file as input and extracts the size character of the model's scale. The function + uses regular expression matching to find the pattern of the model scale in the YAML file name, which is denoted by + n, s, m, l, or x. The function returns the size character of the model scale as a string. + + Args: + model_path (str | Path): The path to the YOLO model's YAML file. + + Returns: + (str): The size character of the model's scale, which can be n, s, m, l, or x. + """ + try: + return re.search(r"yolo[v]?\d+([nslmx])", Path(model_path).stem).group(1) # noqa, returns n, s, m, l, or x + except AttributeError: + return "" + + +def guess_model_task(model): + """ + Guess the task of a PyTorch model from its architecture or configuration. + + Args: + model (nn.Module | dict): PyTorch model or model configuration in YAML format. + + Returns: + (str): Task of the model ('detect', 'segment', 'classify', 'pose'). + + Raises: + SyntaxError: If the task of the model could not be determined. + """ + + def cfg2task(cfg): + """Guess from YAML dictionary.""" + m = cfg["head"][-1][-2].lower() # output module name + if m in {"classify", "classifier", "cls", "fc"}: + return "classify" + if "detect" in m: + return "detect" + if m == "segment": + return "segment" + if m == "pose": + return "pose" + if m == "obb": + return "obb" + + # Guess from model cfg + if isinstance(model, dict): + with contextlib.suppress(Exception): + return cfg2task(model) + # Guess from PyTorch model + if isinstance(model, nn.Module): # PyTorch model + for x in "model.args", "model.model.args", "model.model.model.args": + with contextlib.suppress(Exception): + return eval(x)["task"] + for x in "model.yaml", "model.model.yaml", "model.model.model.yaml": + with contextlib.suppress(Exception): + return cfg2task(eval(x)) + for m in model.modules(): + if isinstance(m, Segment): + return "segment" + elif isinstance(m, Classify): + return "classify" + elif isinstance(m, Pose): + return "pose" + elif isinstance(m, OBB): + return "obb" + elif isinstance(m, (Detect, WorldDetect, v10Detect)): + return "detect" + + # Guess from model filename + if isinstance(model, (str, Path)): + model = Path(model) + if "-seg" in model.stem or "segment" in model.parts: + return "segment" + elif "-cls" in model.stem or "classify" in model.parts: + return "classify" + elif "-pose" in model.stem or "pose" in model.parts: + return "pose" + elif "-obb" in model.stem or "obb" in model.parts: + return "obb" + elif "detect" in model.parts: + return "detect" + + # Unable to determine task from model + LOGGER.warning( + "WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. " + "Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'." + ) + return "detect" # assume detect diff --git a/ultralytics/solutions/__init__.py b/ultralytics/solutions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..635cb3ad7e2996b11fc08a330072d23c382c1351 --- /dev/null +++ b/ultralytics/solutions/__init__.py @@ -0,0 +1,30 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .ai_gym import AIGym +from .analytics import Analytics +from .distance_calculation import DistanceCalculation +from .heatmap import Heatmap +from .object_counter import ObjectCounter +from .parking_management import ParkingManagement, ParkingPtsSelection +from .queue_management import QueueManager +from .region_counter import RegionCounter +from .security_alarm import SecurityAlarm +from .speed_estimation import SpeedEstimator +from .streamlit_inference import Inference +from .trackzone import TrackZone + +__all__ = ( + "AIGym", + "DistanceCalculation", + "Heatmap", + "ObjectCounter", + "ParkingManagement", + "ParkingPtsSelection", + "QueueManager", + "SpeedEstimator", + "Analytics", + "Inference", + "RegionCounter", + "TrackZone", + "SecurityAlarm", +) diff --git a/ultralytics/solutions/ai_gym.py b/ultralytics/solutions/ai_gym.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf48544d717f4ef4c82f42fbac880408ce99e62 --- /dev/null +++ b/ultralytics/solutions/ai_gym.py @@ -0,0 +1,111 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator + + +class AIGym(BaseSolution): + """ + A class to manage gym steps of people in a real-time video stream based on their poses. + + This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts + repetitions of exercises based on predefined angle thresholds for up and down positions. + + Attributes: + count (List[int]): Repetition counts for each detected person. + angle (List[float]): Current angle of the tracked body part for each person. + stage (List[str]): Current exercise stage ('up', 'down', or '-') for each person. + initial_stage (str | None): Initial stage of the exercise. + up_angle (float): Angle threshold for considering the 'up' position of an exercise. + down_angle (float): Angle threshold for considering the 'down' position of an exercise. + kpts (List[int]): Indices of keypoints used for angle calculation. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + monitor: Processes a frame to detect poses, calculate angles, and count repetitions. + + Examples: + >>> gym = AIGym(model="yolov8n-pose.pt") + >>> image = cv2.imread("gym_scene.jpg") + >>> processed_image = gym.monitor(image) + >>> cv2.imshow("Processed Image", processed_image) + >>> cv2.waitKey(0) + """ + + def __init__(self, **kwargs): + """Initializes AIGym for workout monitoring using pose estimation and predefined angles.""" + # Check if the model name ends with '-pose' + if "model" in kwargs and "-pose" not in kwargs["model"]: + kwargs["model"] = "yolo11n-pose.pt" + elif "model" not in kwargs: + kwargs["model"] = "yolo11n-pose.pt" + + super().__init__(**kwargs) + self.count = [] # List for counts, necessary where there are multiple objects in frame + self.angle = [] # List for angle, necessary where there are multiple objects in frame + self.stage = [] # List for stage, necessary where there are multiple objects in frame + + # Extract details from CFG single time for usage later + self.initial_stage = None + self.up_angle = float(self.CFG["up_angle"]) # Pose up predefined angle to consider up pose + self.down_angle = float(self.CFG["down_angle"]) # Pose down predefined angle to consider down pose + self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage + + def monitor(self, im0): + """ + Monitors workouts using Ultralytics YOLO Pose Model. + + This function processes an input image to track and analyze human poses for workout monitoring. It uses + the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined + angle thresholds. + + Args: + im0 (ndarray): Input image for processing. + + Returns: + (ndarray): Processed image with annotations for workout monitoring. + + Examples: + >>> gym = AIGym() + >>> image = cv2.imread("workout.jpg") + >>> processed_image = gym.monitor(image) + """ + # Extract tracks + tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args)[0] + + if tracks.boxes.id is not None: + # Extract and check keypoints + if len(tracks) > len(self.count): + new_human = len(tracks) - len(self.count) + self.angle += [0] * new_human + self.count += [0] * new_human + self.stage += ["-"] * new_human + + # Initialize annotator + self.annotator = Annotator(im0, line_width=self.line_width) + + # Enumerate over keypoints + for ind, k in enumerate(reversed(tracks.keypoints.data)): + # Get keypoints and estimate the angle + kpts = [k[int(self.kpts[i])].cpu() for i in range(3)] + self.angle[ind] = self.annotator.estimate_pose_angle(*kpts) + im0 = self.annotator.draw_specific_points(k, self.kpts, radius=self.line_width * 3) + + # Determine stage and count logic based on angle thresholds + if self.angle[ind] < self.down_angle: + if self.stage[ind] == "up": + self.count[ind] += 1 + self.stage[ind] = "down" + elif self.angle[ind] > self.up_angle: + self.stage[ind] = "up" + + # Display angle, count, and stage text + self.annotator.plot_angle_and_count_and_stage( + angle_text=self.angle[ind], # angle text for display + count_text=self.count[ind], # count text for workouts + stage_text=self.stage[ind], # stage position text + center_kpt=k[int(self.kpts[1])], # center keypoint for display + ) + + self.display_output(im0) # Display output image, if environment support display + return im0 # return an image for writing or further usage diff --git a/ultralytics/solutions/analytics.py b/ultralytics/solutions/analytics.py new file mode 100644 index 0000000000000000000000000000000000000000..3a62e8c2e68d036985a2bf5996c9193c238d573f --- /dev/null +++ b/ultralytics/solutions/analytics.py @@ -0,0 +1,247 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from itertools import cycle + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas +from matplotlib.figure import Figure + +from ultralytics.solutions.solutions import BaseSolution # Import a parent class + + +class Analytics(BaseSolution): + """ + A class for creating and updating various types of charts for visual analytics. + + This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts + based on object detection and tracking data. + + Attributes: + type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area'). + x_label (str): Label for the x-axis. + y_label (str): Label for the y-axis. + bg_color (str): Background color of the chart frame. + fg_color (str): Foreground color of the chart frame. + title (str): Title of the chart window. + max_points (int): Maximum number of data points to display on the chart. + fontsize (int): Font size for text display. + color_cycle (cycle): Cyclic iterator for chart colors. + total_counts (int): Total count of detected objects (used for line charts). + clswise_count (Dict[str, int]): Dictionary for class-wise object counts. + fig (Figure): Matplotlib figure object for the chart. + ax (Axes): Matplotlib axes object for the chart. + canvas (FigureCanvas): Canvas for rendering the chart. + + Methods: + process_data: Processes image data and updates the chart. + update_graph: Updates the chart with new data points. + + Examples: + >>> analytics = Analytics(analytics_type="line") + >>> frame = cv2.imread("image.jpg") + >>> processed_frame = analytics.process_data(frame, frame_number=1) + >>> cv2.imshow("Analytics", processed_frame) + """ + + def __init__(self, **kwargs): + """Initialize Analytics class with various chart types for visual data representation.""" + super().__init__(**kwargs) + + self.type = self.CFG["analytics_type"] # extract type of analytics + self.x_label = "Classes" if self.type in {"bar", "pie"} else "Frame#" + self.y_label = "Total Counts" + + # Predefined data + self.bg_color = "#F3F3F3" # background color of frame + self.fg_color = "#111E68" # foreground color of frame + self.title = "Ultralytics Solutions" # window name + self.max_points = 45 # maximum points to be drawn on window + self.fontsize = 25 # text font size for display + figsize = (19.2, 10.8) # Set output image size 1920 * 1080 + self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"]) + + self.total_counts = 0 # count variable for storing total counts i.e. for line + self.clswise_count = {} # dictionary for class-wise counts + + # Ensure line and area chart + if self.type in {"line", "area"}: + self.lines = {} + self.fig = Figure(facecolor=self.bg_color, figsize=figsize) + self.canvas = FigureCanvas(self.fig) # Set common axis properties + self.ax = self.fig.add_subplot(111, facecolor=self.bg_color) + if self.type == "line": + (self.line,) = self.ax.plot([], [], color="cyan", linewidth=self.line_width) + elif self.type in {"bar", "pie"}: + # Initialize bar or pie plot + self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color) + self.canvas = FigureCanvas(self.fig) # Set common axis properties + self.ax.set_facecolor(self.bg_color) + self.color_mapping = {} + + if self.type == "pie": # Ensure pie chart is circular + self.ax.axis("equal") + + def process_data(self, im0, frame_number): + """ + Processes image data and runs object tracking to update analytics charts. + + Args: + im0 (np.ndarray): Input image for processing. + frame_number (int): Video frame number for plotting the data. + + Returns: + (np.ndarray): Processed image with updated analytics chart. + + Raises: + ModuleNotFoundError: If an unsupported chart type is specified. + + Examples: + >>> analytics = Analytics(analytics_type="line") + >>> frame = np.zeros((480, 640, 3), dtype=np.uint8) + >>> processed_frame = analytics.process_data(frame, frame_number=1) + """ + self.extract_tracks(im0) # Extract tracks + + if self.type == "line": + for _ in self.boxes: + self.total_counts += 1 + im0 = self.update_graph(frame_number=frame_number) + self.total_counts = 0 + elif self.type in {"pie", "bar", "area"}: + self.clswise_count = {} + for box, cls in zip(self.boxes, self.clss): + if self.names[int(cls)] in self.clswise_count: + self.clswise_count[self.names[int(cls)]] += 1 + else: + self.clswise_count[self.names[int(cls)]] = 1 + im0 = self.update_graph(frame_number=frame_number, count_dict=self.clswise_count, plot=self.type) + else: + raise ModuleNotFoundError(f"{self.type} chart is not supported ❌") + return im0 + + def update_graph(self, frame_number, count_dict=None, plot="line"): + """ + Updates the graph with new data for single or multiple classes. + + Args: + frame_number (int): The current frame number. + count_dict (Dict[str, int] | None): Dictionary with class names as keys and counts as values for multiple + classes. If None, updates a single line graph. + plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'. + + Returns: + (np.ndarray): Updated image containing the graph. + + Examples: + >>> analytics = Analytics() + >>> frame_number = 10 + >>> count_dict = {"person": 5, "car": 3} + >>> updated_image = analytics.update_graph(frame_number, count_dict, plot="bar") + """ + if count_dict is None: + # Single line update + x_data = np.append(self.line.get_xdata(), float(frame_number)) + y_data = np.append(self.line.get_ydata(), float(self.total_counts)) + + if len(x_data) > self.max_points: + x_data, y_data = x_data[-self.max_points :], y_data[-self.max_points :] + + self.line.set_data(x_data, y_data) + self.line.set_label("Counts") + self.line.set_color("#7b0068") # Pink color + self.line.set_marker("*") + self.line.set_markersize(self.line_width * 5) + else: + labels = list(count_dict.keys()) + counts = list(count_dict.values()) + if plot == "area": + color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"]) + # Multiple lines or area update + x_data = self.ax.lines[0].get_xdata() if self.ax.lines else np.array([]) + y_data_dict = {key: np.array([]) for key in count_dict.keys()} + if self.ax.lines: + for line, key in zip(self.ax.lines, count_dict.keys()): + y_data_dict[key] = line.get_ydata() + + x_data = np.append(x_data, float(frame_number)) + max_length = len(x_data) + for key in count_dict.keys(): + y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key])) + if len(y_data_dict[key]) < max_length: + y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key]))) + if len(x_data) > self.max_points: + x_data = x_data[1:] + for key in count_dict.keys(): + y_data_dict[key] = y_data_dict[key][1:] + + self.ax.clear() + for key, y_data in y_data_dict.items(): + color = next(color_cycle) + self.ax.fill_between(x_data, y_data, color=color, alpha=0.7) + self.ax.plot( + x_data, + y_data, + color=color, + linewidth=self.line_width, + marker="o", + markersize=self.line_width * 5, + label=f"{key} Data Points", + ) + if plot == "bar": + self.ax.clear() # clear bar data + for label in labels: # Map labels to colors + if label not in self.color_mapping: + self.color_mapping[label] = next(self.color_cycle) + colors = [self.color_mapping[label] for label in labels] + bars = self.ax.bar(labels, counts, color=colors) + for bar, count in zip(bars, counts): + self.ax.text( + bar.get_x() + bar.get_width() / 2, + bar.get_height(), + str(count), + ha="center", + va="bottom", + color=self.fg_color, + ) + # Create the legend using labels from the bars + for bar, label in zip(bars, labels): + bar.set_label(label) # Assign label to each bar + self.ax.legend(loc="upper left", fontsize=13, facecolor=self.fg_color, edgecolor=self.fg_color) + if plot == "pie": + total = sum(counts) + percentages = [size / total * 100 for size in counts] + start_angle = 90 + self.ax.clear() + + # Create pie chart and create legend labels with percentages + wedges, autotexts = self.ax.pie( + counts, labels=labels, startangle=start_angle, textprops={"color": self.fg_color}, autopct=None + ) + legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)] + + # Assign the legend using the wedges and manually created labels + self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1)) + self.fig.subplots_adjust(left=0.1, right=0.75) # Adjust layout to fit the legend + + # Common plot settings + self.ax.set_facecolor("#f0f0f0") # Set to light gray or any other color you like + self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize) + self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3) + self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3) + + # Add and format legend + legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.bg_color) + for text in legend.get_texts(): + text.set_color(self.fg_color) + + # Redraw graph, update view, capture, and display the updated plot + self.ax.relim() + self.ax.autoscale_view() + self.canvas.draw() + im0 = np.array(self.canvas.renderer.buffer_rgba()) + im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR) + self.display_output(im0) + + return im0 # Return the image diff --git a/ultralytics/solutions/distance_calculation.py b/ultralytics/solutions/distance_calculation.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d8e77b371fa1e557d421b4679cc32c3cb95bb9 --- /dev/null +++ b/ultralytics/solutions/distance_calculation.py @@ -0,0 +1,124 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math + +import cv2 + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class DistanceCalculation(BaseSolution): + """ + A class to calculate distance between two objects in a real-time video stream based on their tracks. + + This class extends BaseSolution to provide functionality for selecting objects and calculating the distance + between them in a video stream using YOLO object detection and tracking. + + Attributes: + left_mouse_count (int): Counter for left mouse button clicks. + selected_boxes (Dict[int, List[float]]): Dictionary to store selected bounding boxes and their track IDs. + annotator (Annotator): An instance of the Annotator class for drawing on the image. + boxes (List[List[float]]): List of bounding boxes for detected objects. + track_ids (List[int]): List of track IDs for detected objects. + clss (List[int]): List of class indices for detected objects. + names (List[str]): List of class names that the model can detect. + centroids (List[List[int]]): List to store centroids of selected bounding boxes. + + Methods: + mouse_event_for_distance: Handles mouse events for selecting objects in the video stream. + calculate: Processes video frames and calculates the distance between selected objects. + + Examples: + >>> distance_calc = DistanceCalculation() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = distance_calc.calculate(frame) + >>> cv2.imshow("Distance Calculation", processed_frame) + >>> cv2.waitKey(0) + """ + + def __init__(self, **kwargs): + """Initializes the DistanceCalculation class for measuring object distances in video streams.""" + super().__init__(**kwargs) + + # Mouse event information + self.left_mouse_count = 0 + self.selected_boxes = {} + + self.centroids = [] # Initialize empty list to store centroids + + def mouse_event_for_distance(self, event, x, y, flags, param): + """ + Handles mouse events to select regions in a real-time video stream for distance calculation. + + Args: + event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN). + x (int): X-coordinate of the mouse pointer. + y (int): Y-coordinate of the mouse pointer. + flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY). + param (Dict): Additional parameters passed to the function. + + Examples: + >>> # Assuming 'dc' is an instance of DistanceCalculation + >>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance) + """ + if event == cv2.EVENT_LBUTTONDOWN: + self.left_mouse_count += 1 + if self.left_mouse_count <= 2: + for box, track_id in zip(self.boxes, self.track_ids): + if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes: + self.selected_boxes[track_id] = box + + elif event == cv2.EVENT_RBUTTONDOWN: + self.selected_boxes = {} + self.left_mouse_count = 0 + + def calculate(self, im0): + """ + Processes a video frame and calculates the distance between two selected bounding boxes. + + This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance + between two user-selected objects if they have been chosen. + + Args: + im0 (numpy.ndarray): The input image frame to process. + + Returns: + (numpy.ndarray): The processed image frame with annotations and distance calculations. + + Examples: + >>> import numpy as np + >>> from ultralytics.solutions import DistanceCalculation + >>> dc = DistanceCalculation() + >>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> processed_frame = dc.calculate(frame) + """ + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + self.annotator.box_label(box, color=colors(int(cls), True), label=self.names[int(cls)]) + + if len(self.selected_boxes) == 2: + for trk_id in self.selected_boxes.keys(): + if trk_id == track_id: + self.selected_boxes[track_id] = box + + if len(self.selected_boxes) == 2: + # Store user selected boxes in centroids list + self.centroids.extend( + [[int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)] for box in self.selected_boxes.values()] + ) + # Calculate pixels distance + pixels_distance = math.sqrt( + (self.centroids[0][0] - self.centroids[1][0]) ** 2 + (self.centroids[0][1] - self.centroids[1][1]) ** 2 + ) + self.annotator.plot_distance_and_line(pixels_distance, self.centroids) + + self.centroids = [] + + self.display_output(im0) # display output with base class function + cv2.setMouseCallback("Ultralytics Solutions", self.mouse_event_for_distance) + + return im0 # return output image for more usage diff --git a/ultralytics/solutions/heatmap.py b/ultralytics/solutions/heatmap.py new file mode 100644 index 0000000000000000000000000000000000000000..98c79d8fc1b2cb69ee5e52dfee0532aeec2f0bcb --- /dev/null +++ b/ultralytics/solutions/heatmap.py @@ -0,0 +1,127 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import cv2 +import numpy as np + +from ultralytics.solutions.object_counter import ObjectCounter +from ultralytics.utils.plotting import Annotator + + +class Heatmap(ObjectCounter): + """ + A class to draw heatmaps in real-time video streams based on object tracks. + + This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video + streams. It uses tracked object positions to create a cumulative heatmap effect over time. + + Attributes: + initialized (bool): Flag indicating whether the heatmap has been initialized. + colormap (int): OpenCV colormap used for heatmap visualization. + heatmap (np.ndarray): Array storing the cumulative heatmap data. + annotator (Annotator): Object for drawing annotations on the image. + + Methods: + heatmap_effect: Calculates and updates the heatmap effect for a given bounding box. + generate_heatmap: Generates and applies the heatmap effect to each frame. + + Examples: + >>> from ultralytics.solutions import Heatmap + >>> heatmap = Heatmap(model="yolov8n.pt", colormap=cv2.COLORMAP_JET) + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = heatmap.generate_heatmap(frame) + """ + + def __init__(self, **kwargs): + """Initializes the Heatmap class for real-time video stream heatmap generation based on object tracks.""" + super().__init__(**kwargs) + + self.initialized = False # bool variable for heatmap initialization + if self.region is not None: # check if user provided the region coordinates + self.initialize_region() + + # store colormap + self.colormap = cv2.COLORMAP_PARULA if self.CFG["colormap"] is None else self.CFG["colormap"] + self.heatmap = None + + def heatmap_effect(self, box): + """ + Efficiently calculates heatmap area and effect location for applying colormap. + + Args: + box (List[float]): Bounding box coordinates [x0, y0, x1, y1]. + + Examples: + >>> heatmap = Heatmap() + >>> box = [100, 100, 200, 200] + >>> heatmap.heatmap_effect(box) + """ + x0, y0, x1, y1 = map(int, box) + radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2 + + # Create a meshgrid with region of interest (ROI) for vectorized distance calculations + xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1)) + + # Calculate squared distances from the center + dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2 + + # Create a mask of points within the radius + within_radius = dist_squared <= radius_squared + + # Update only the values within the bounding box in a single vectorized operation + self.heatmap[y0:y1, x0:x1][within_radius] += 2 + + def generate_heatmap(self, im0): + """ + Generate heatmap for each frame using Ultralytics. + + Args: + im0 (np.ndarray): Input image array for processing. + + Returns: + (np.ndarray): Processed image with heatmap overlay and object counts (if region is specified). + + Examples: + >>> heatmap = Heatmap() + >>> im0 = cv2.imread("image.jpg") + >>> result = heatmap.generate_heatmap(im0) + """ + if not self.initialized: + self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99 + self.initialized = True # Initialize heatmap only once + + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.heatmap_effect(box) + + if self.region is not None: + self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # Store tracking previous position and perform object counting + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + if self.region is not None: + self.display_counts(im0) # Display the counts on the frame + + # Normalize, apply colormap to heatmap and combine with original image + if self.track_data.id is not None: + im0 = cv2.addWeighted( + im0, + 0.5, + cv2.applyColorMap( + cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8), self.colormap + ), + 0.5, + 0, + ) + + self.display_output(im0) # display output with base class function + return im0 # return output image for more usage diff --git a/ultralytics/solutions/object_counter.py b/ultralytics/solutions/object_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..d202ca51f517e70f68f3dd9b67b04d1dffddd89a --- /dev/null +++ b/ultralytics/solutions/object_counter.py @@ -0,0 +1,203 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class ObjectCounter(BaseSolution): + """ + A class to manage the counting of objects in a real-time video stream based on their tracks. + + This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a + specified region in a video stream. It supports both polygonal and linear regions for counting. + + Attributes: + in_count (int): Counter for objects moving inward. + out_count (int): Counter for objects moving outward. + counted_ids (List[int]): List of IDs of objects that have been counted. + classwise_counts (Dict[str, Dict[str, int]]): Dictionary for counts, categorized by object class. + region_initialized (bool): Flag indicating whether the counting region has been initialized. + show_in (bool): Flag to control display of inward count. + show_out (bool): Flag to control display of outward count. + + Methods: + count_objects: Counts objects within a polygonal or linear region. + store_classwise_counts: Initializes class-wise counts if not already present. + display_counts: Displays object counts on the frame. + count: Processes input data (frames or object tracks) and updates counts. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = counter.count(frame) + >>> print(f"Inward count: {counter.in_count}, Outward count: {counter.out_count}") + """ + + def __init__(self, **kwargs): + """Initializes the ObjectCounter class for real-time object counting in video streams.""" + super().__init__(**kwargs) + + self.in_count = 0 # Counter for objects moving inward + self.out_count = 0 # Counter for objects moving outward + self.counted_ids = [] # List of IDs of objects that have been counted + self.classwise_counts = {} # Dictionary for counts, categorized by object class + self.region_initialized = False # Bool variable for region initialization + + self.show_in = self.CFG["show_in"] + self.show_out = self.CFG["show_out"] + + def count_objects(self, current_centroid, track_id, prev_position, cls): + """ + Counts objects within a polygonal or linear region based on their tracks. + + Args: + current_centroid (Tuple[float, float]): Current centroid values in the current frame. + track_id (int): Unique identifier for the tracked object. + prev_position (Tuple[float, float]): Last frame position coordinates (x, y) of the track. + cls (int): Class index for classwise count updates. + + Examples: + >>> counter = ObjectCounter() + >>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]} + >>> box = [130, 230, 150, 250] + >>> track_id = 1 + >>> prev_position = (120, 220) + >>> cls = 0 + >>> counter.count_objects(current_centroid, track_id, prev_position, cls) + """ + if prev_position is None or track_id in self.counted_ids: + return + + if len(self.region) == 2: # Linear region (defined as a line segment) + line = self.LineString(self.region) # Check if the line intersects the trajectory of the object + if line.intersects(self.LineString([prev_position, current_centroid])): + # Determine orientation of the region (vertical or horizontal) + if abs(self.region[0][0] - self.region[1][0]) < abs(self.region[0][1] - self.region[1][1]): + # Vertical region: Compare x-coordinates to determine direction + if current_centroid[0] > prev_position[0]: # Moving right + self.in_count += 1 + self.classwise_counts[self.names[cls]]["IN"] += 1 + else: # Moving left + self.out_count += 1 + self.classwise_counts[self.names[cls]]["OUT"] += 1 + # Horizontal region: Compare y-coordinates to determine direction + elif current_centroid[1] > prev_position[1]: # Moving downward + self.in_count += 1 + self.classwise_counts[self.names[cls]]["IN"] += 1 + else: # Moving upward + self.out_count += 1 + self.classwise_counts[self.names[cls]]["OUT"] += 1 + self.counted_ids.append(track_id) + + elif len(self.region) > 2: # Polygonal region + polygon = self.Polygon(self.region) + if polygon.contains(self.Point(current_centroid)): + # Determine motion direction for vertical or horizontal polygons + region_width = max(p[0] for p in self.region) - min(p[0] for p in self.region) + region_height = max(p[1] for p in self.region) - min(p[1] for p in self.region) + + if ( + region_width < region_height + and current_centroid[0] > prev_position[0] + or region_width >= region_height + and current_centroid[1] > prev_position[1] + ): # Moving right + self.in_count += 1 + self.classwise_counts[self.names[cls]]["IN"] += 1 + else: # Moving left + self.out_count += 1 + self.classwise_counts[self.names[cls]]["OUT"] += 1 + self.counted_ids.append(track_id) + + def store_classwise_counts(self, cls): + """ + Initialize class-wise counts for a specific object class if not already present. + + Args: + cls (int): Class index for classwise count updates. + + This method ensures that the 'classwise_counts' dictionary contains an entry for the specified class, + initializing 'IN' and 'OUT' counts to zero if the class is not already present. + + Examples: + >>> counter = ObjectCounter() + >>> counter.store_classwise_counts(0) # Initialize counts for class index 0 + >>> print(counter.classwise_counts) + {'person': {'IN': 0, 'OUT': 0}} + """ + if self.names[cls] not in self.classwise_counts: + self.classwise_counts[self.names[cls]] = {"IN": 0, "OUT": 0} + + def display_counts(self, im0): + """ + Displays object counts on the input image or frame. + + Args: + im0 (numpy.ndarray): The input image or frame to display counts on. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("image.jpg") + >>> counter.display_counts(frame) + """ + labels_dict = { + str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} " + f"{'OUT ' + str(value['OUT']) if self.show_out else ''}".strip() + for key, value in self.classwise_counts.items() + if value["IN"] != 0 or value["OUT"] != 0 + } + + if labels_dict: + self.annotator.display_analytics(im0, labels_dict, (104, 31, 17), (255, 255, 255), 10) + + def count(self, im0): + """ + Processes input data (frames or object tracks) and updates object counts. + + This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates + object counts, and displays the results on the input image. + + Args: + im0 (numpy.ndarray): The input image or frame to be processed. + + Returns: + (numpy.ndarray): The processed image with annotations and count information. + + Examples: + >>> counter = ObjectCounter() + >>> frame = cv2.imread("path/to/image.jpg") + >>> processed_frame = counter.count(frame) + """ + if not self.region_initialized: + self.initialize_region() + self.region_initialized = True + + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + self.annotator.draw_region( + reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2 + ) # Draw region + + # Iterate over bounding boxes, track ids and classes index + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True)) + self.store_tracking_history(track_id, box) # Store track history + self.store_classwise_counts(cls) # store classwise counts in dict + + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(cls), True), track_thickness=self.line_width + ) + current_centroid = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + # store previous position of track for object counting + prev_position = None + if len(self.track_history[track_id]) > 1: + prev_position = self.track_history[track_id][-2] + self.count_objects(current_centroid, track_id, prev_position, cls) # Perform object counting + + self.display_counts(im0) # Display the counts on the frame + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/ultralytics/solutions/parking_management.py b/ultralytics/solutions/parking_management.py new file mode 100644 index 0000000000000000000000000000000000000000..91f91936a32b77beba4bab7576da1cbcef632711 --- /dev/null +++ b/ultralytics/solutions/parking_management.py @@ -0,0 +1,246 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json + +import cv2 +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils import LOGGER +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.plotting import Annotator + + +class ParkingPtsSelection: + """ + A class for selecting and managing parking zone points on images using a Tkinter-based UI. + + This class provides functionality to upload an image, select points to define parking zones, and save the + selected points to a JSON file. It uses Tkinter for the graphical user interface. + + Attributes: + tk (module): The Tkinter module for GUI operations. + filedialog (module): Tkinter's filedialog module for file selection operations. + messagebox (module): Tkinter's messagebox module for displaying message boxes. + master (tk.Tk): The main Tkinter window. + canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes. + image (PIL.Image.Image): The uploaded image. + canvas_image (ImageTk.PhotoImage): The image displayed on the canvas. + rg_data (List[List[Tuple[int, int]]]): List of bounding boxes, each defined by 4 points. + current_box (List[Tuple[int, int]]): Temporary storage for the points of the current bounding box. + imgw (int): Original width of the uploaded image. + imgh (int): Original height of the uploaded image. + canvas_max_width (int): Maximum width of the canvas. + canvas_max_height (int): Maximum height of the canvas. + + Methods: + initialize_properties: Initializes the necessary properties. + upload_image: Uploads an image, resizes it to fit the canvas, and displays it. + on_canvas_click: Handles mouse clicks to add points for bounding boxes. + draw_box: Draws a bounding box on the canvas. + remove_last_bounding_box: Removes the last bounding box and redraws the canvas. + redraw_canvas: Redraws the canvas with the image and all bounding boxes. + save_to_json: Saves the bounding boxes to a JSON file. + + Examples: + >>> parking_selector = ParkingPtsSelection() + >>> # Use the GUI to upload an image, select parking zones, and save the data + """ + + def __init__(self): + """Initializes the ParkingPtsSelection class, setting up UI and properties for parking zone point selection.""" + check_requirements("tkinter") + import tkinter as tk + from tkinter import filedialog, messagebox + + self.tk, self.filedialog, self.messagebox = tk, filedialog, messagebox + self.master = self.tk.Tk() # Reference to the main application window or parent widget + self.master.title("Ultralytics Parking Zones Points Selector") + self.master.resizable(False, False) + + self.canvas = self.tk.Canvas(self.master, bg="white") # Canvas widget for displaying images or graphics + self.canvas.pack(side=self.tk.BOTTOM) + + self.image = None # Variable to store the loaded image + self.canvas_image = None # Reference to the image displayed on the canvas + self.canvas_max_width = None # Maximum allowed width for the canvas + self.canvas_max_height = None # Maximum allowed height for the canvas + self.rg_data = None # Data related to region or annotation management + self.current_box = None # Stores the currently selected or active bounding box + self.imgh = None # Height of the current image + self.imgw = None # Width of the current image + + # Button frame with buttons + button_frame = self.tk.Frame(self.master) + button_frame.pack(side=self.tk.TOP) + + for text, cmd in [ + ("Upload Image", self.upload_image), + ("Remove Last BBox", self.remove_last_bounding_box), + ("Save", self.save_to_json), + ]: + self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT) + + self.initialize_properties() + self.master.mainloop() + + def initialize_properties(self): + """Initialize properties for image, canvas, bounding boxes, and dimensions.""" + self.image = self.canvas_image = None + self.rg_data, self.current_box = [], [] + self.imgw = self.imgh = 0 + self.canvas_max_width, self.canvas_max_height = 1280, 720 + + def upload_image(self): + """Uploads and displays an image on the canvas, resizing it to fit within specified dimensions.""" + from PIL import Image, ImageTk # scope because ImageTk requires tkinter package + + self.image = Image.open(self.filedialog.askopenfilename(filetypes=[("Image Files", "*.png *.jpg *.jpeg")])) + if not self.image: + return + + self.imgw, self.imgh = self.image.size + aspect_ratio = self.imgw / self.imgh + canvas_width = ( + min(self.canvas_max_width, self.imgw) if aspect_ratio > 1 else int(self.canvas_max_height * aspect_ratio) + ) + canvas_height = ( + min(self.canvas_max_height, self.imgh) if aspect_ratio <= 1 else int(canvas_width / aspect_ratio) + ) + + self.canvas.config(width=canvas_width, height=canvas_height) + self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height))) + self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image) + self.canvas.bind("", self.on_canvas_click) + + self.rg_data.clear(), self.current_box.clear() + + def on_canvas_click(self, event): + """Handles mouse clicks to add points for bounding boxes on the canvas.""" + self.current_box.append((event.x, event.y)) + self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill="red") + if len(self.current_box) == 4: + self.rg_data.append(self.current_box.copy()) + self.draw_box(self.current_box) + self.current_box.clear() + + def draw_box(self, box): + """Draws a bounding box on the canvas using the provided coordinates.""" + for i in range(4): + self.canvas.create_line(box[i], box[(i + 1) % 4], fill="blue", width=2) + + def remove_last_bounding_box(self): + """Removes the last bounding box from the list and redraws the canvas.""" + if not self.rg_data: + self.messagebox.showwarning("Warning", "No bounding boxes to remove.") + return + self.rg_data.pop() + self.redraw_canvas() + + def redraw_canvas(self): + """Redraws the canvas with the image and all bounding boxes.""" + self.canvas.delete("all") + self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image) + for box in self.rg_data: + self.draw_box(box) + + def save_to_json(self): + """Saves the selected parking zone points to a JSON file with scaled coordinates.""" + scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height() + data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data] + + from io import StringIO # Function level import, as it's only required to store coordinates, not every frame + + write_buffer = StringIO() + json.dump(data, write_buffer, indent=4) + with open("bounding_boxes.json", "w", encoding="utf-8") as f: + f.write(write_buffer.getvalue()) + self.messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json") + + +class ParkingManagement(BaseSolution): + """ + Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization. + + This class extends BaseSolution to provide functionality for parking lot management, including detection of + occupied spaces, visualization of parking regions, and display of occupancy statistics. + + Attributes: + json_file (str): Path to the JSON file containing parking region details. + json (List[Dict]): Loaded JSON data containing parking region information. + pr_info (Dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces). + arc (Tuple[int, int, int]): RGB color tuple for available region visualization. + occ (Tuple[int, int, int]): RGB color tuple for occupied region visualization. + dc (Tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects. + + Methods: + process_data: Processes model data for parking lot management and visualization. + + Examples: + >>> from ultralytics.solutions import ParkingManagement + >>> parking_manager = ParkingManagement(model="yolov8n.pt", json_file="parking_regions.json") + >>> print(f"Occupied spaces: {parking_manager.pr_info['Occupancy']}") + >>> print(f"Available spaces: {parking_manager.pr_info['Available']}") + """ + + def __init__(self, **kwargs): + """Initializes the parking management system with a YOLO model and visualization settings.""" + super().__init__(**kwargs) + + self.json_file = self.CFG["json_file"] # Load JSON data + if self.json_file is None: + LOGGER.warning("❌ json_file argument missing. Parking region details required.") + raise ValueError("❌ Json file path can not be empty") + + with open(self.json_file) as f: + self.json = json.load(f) + + self.pr_info = {"Occupancy": 0, "Available": 0} # dictionary for parking information + + self.arc = (0, 0, 255) # available region color + self.occ = (0, 255, 0) # occupied region color + self.dc = (255, 0, 189) # centroid color for each box + + def process_data(self, im0): + """ + Processes the model data for parking lot management. + + This function analyzes the input image, extracts tracks, and determines the occupancy status of parking + regions defined in the JSON file. It annotates the image with occupied and available parking spots, + and updates the parking information. + + Args: + im0 (np.ndarray): The input inference image. + + Examples: + >>> parking_manager = ParkingManagement(json_file="parking_regions.json") + >>> image = cv2.imread("parking_lot.jpg") + >>> parking_manager.process_data(image) + """ + self.extract_tracks(im0) # extract tracks from im0 + es, fs = len(self.json), 0 # empty slots, filled slots + annotator = Annotator(im0, self.line_width) # init annotator + + for region in self.json: + # Convert points to a NumPy array with the correct dtype and reshape properly + pts_array = np.array(region["points"], dtype=np.int32).reshape((-1, 1, 2)) + rg_occupied = False # occupied region initialization + for box, cls in zip(self.boxes, self.clss): + xc, yc = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + dist = cv2.pointPolygonTest(pts_array, (xc, yc), False) + if dist >= 0: + # cv2.circle(im0, (xc, yc), radius=self.line_width * 4, color=self.dc, thickness=-1) + annotator.display_objects_labels( + im0, self.model.names[int(cls)], (104, 31, 17), (255, 255, 255), xc, yc, 10 + ) + rg_occupied = True + break + fs, es = (fs + 1, es - 1) if rg_occupied else (fs, es) + # Plotting regions + cv2.polylines(im0, [pts_array], isClosed=True, color=self.occ if rg_occupied else self.arc, thickness=2) + + self.pr_info["Occupancy"], self.pr_info["Available"] = fs, es + + annotator.display_analytics(im0, self.pr_info, (104, 31, 17), (255, 255, 255), 10) + self.display_output(im0) # display output with base class function + return im0 # return output image for more usage diff --git a/ultralytics/solutions/queue_management.py b/ultralytics/solutions/queue_management.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcf8fa7103c08063b36344f7b77eace7157f4dc --- /dev/null +++ b/ultralytics/solutions/queue_management.py @@ -0,0 +1,112 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class QueueManager(BaseSolution): + """ + Manages queue counting in real-time video streams based on object tracks. + + This class extends BaseSolution to provide functionality for tracking and counting objects within a specified + region in video frames. + + Attributes: + counts (int): The current count of objects in the queue. + rect_color (Tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle. + region_length (int): The number of points defining the queue region. + annotator (Annotator): An instance of the Annotator class for drawing on frames. + track_line (List[Tuple[int, int]]): List of track line coordinates. + track_history (Dict[int, List[Tuple[int, int]]]): Dictionary storing tracking history for each object. + + Methods: + initialize_region: Initializes the queue region. + process_queue: Processes a single frame for queue management. + extract_tracks: Extracts object tracks from the current frame. + store_tracking_history: Stores the tracking history for an object. + display_output: Displays the processed output. + + Examples: + >>> cap = cv2.VideoCapture("Path/to/video/file.mp4") + >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300]) + >>> while cap.isOpened(): + >>> success, im0 = cap.read() + >>> if not success: + >>> break + >>> out = queue.process_queue(im0) + """ + + def __init__(self, **kwargs): + """Initializes the QueueManager with parameters for tracking and counting objects in a video stream.""" + super().__init__(**kwargs) + self.initialize_region() + self.counts = 0 # Queue counts Information + self.rect_color = (255, 255, 255) # Rectangle color + self.region_length = len(self.region) # Store region length for further usage + + def process_queue(self, im0): + """ + Processes the queue management for a single frame of video. + + Args: + im0 (numpy.ndarray): Input image for processing, typically a frame from a video stream. + + Returns: + (numpy.ndarray): Processed image with annotations, bounding boxes, and queue counts. + + This method performs the following steps: + 1. Resets the queue count for the current frame. + 2. Initializes an Annotator object for drawing on the image. + 3. Extracts tracks from the image. + 4. Draws the counting region on the image. + 5. For each detected object: + - Draws bounding boxes and labels. + - Stores tracking history. + - Draws centroids and tracks. + - Checks if the object is inside the counting region and updates the count. + 6. Displays the queue count on the image. + 7. Displays the processed output. + + Examples: + >>> queue_manager = QueueManager() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = queue_manager.process_queue(frame) + """ + self.counts = 0 # Reset counts every frame + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + self.annotator.draw_region( + reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2 + ) # Draw region + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + # Draw bounding box and counting region + self.annotator.box_label(box, label=self.names[cls], color=colors(track_id, True)) + self.store_tracking_history(track_id, box) # Store track history + + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) + + # Cache frequently accessed attributes + track_history = self.track_history.get(track_id, []) + + # store previous position of track and check if the object is inside the counting region + prev_position = None + if len(track_history) > 1: + prev_position = track_history[-2] + if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])): + self.counts += 1 + + # Display queue counts + self.annotator.queue_counts_display( + f"Queue Counts : {str(self.counts)}", + points=self.region, + region_color=self.rect_color, + txt_color=(104, 31, 17), + ) + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/ultralytics/solutions/region_counter.py b/ultralytics/solutions/region_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..5a2953f3c615133211d1d56c620af11a106cb138 --- /dev/null +++ b/ultralytics/solutions/region_counter.py @@ -0,0 +1,116 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils import LOGGER +from ultralytics.utils.plotting import Annotator, colors + + +class RegionCounter(BaseSolution): + """ + A class designed for real-time counting of objects within user-defined regions in a video stream. + + This class inherits from `BaseSolution` and offers functionalities to define polygonal regions in a video + frame, track objects, and count those objects that pass through each defined region. This makes it useful + for applications that require counting in specified areas, such as monitoring zones or segmented sections. + + Attributes: + region_template (dict): A template for creating new counting regions with default attributes including + the name, polygon coordinates, and display colors. + counting_regions (list): A list storing all defined regions, where each entry is based on `region_template` + and includes specific region settings like name, coordinates, and color. + + Methods: + add_region: Adds a new counting region with specified attributes, such as the region's name, polygon points, + region color, and text color. + count: Processes video frames to count objects in each region, drawing regions and displaying counts + on the frame. Handles object detection, region definition, and containment checks. + """ + + def __init__(self, **kwargs): + """Initializes the RegionCounter class for real-time counting in different regions of the video streams.""" + super().__init__(**kwargs) + self.region_template = { + "name": "Default Region", + "polygon": None, + "counts": 0, + "dragging": False, + "region_color": (255, 255, 255), + "text_color": (0, 0, 0), + } + self.counting_regions = [] + + def add_region(self, name, polygon_points, region_color, text_color): + """ + Adds a new region to the counting list based on the provided template with specific attributes. + + Args: + name (str): Name assigned to the new region. + polygon_points (list[tuple]): List of (x, y) coordinates defining the region's polygon. + region_color (tuple): BGR color for region visualization. + text_color (tuple): BGR color for the text within the region. + """ + region = self.region_template.copy() + region.update( + { + "name": name, + "polygon": self.Polygon(polygon_points), + "region_color": region_color, + "text_color": text_color, + } + ) + self.counting_regions.append(region) + + def count(self, im0): + """ + Processes the input frame to detect and count objects within each defined region. + + Args: + im0 (numpy.ndarray): Input image frame where objects and regions are annotated. + + Returns: + im0 (numpy.ndarray): Processed image frame with annotated counting information. + """ + self.annotator = Annotator(im0, line_width=self.line_width) + self.extract_tracks(im0) + + # Region initialization and conversion + if self.region is None: + self.initialize_region() + regions = {"Region#01": self.region} + else: + regions = self.region if isinstance(self.region, dict) else {"Region#01": self.region} + + # Draw regions and process counts for each defined area + for idx, (region_name, reg_pts) in enumerate(regions.items(), start=1): + if not isinstance(reg_pts, list) or not all(isinstance(pt, tuple) for pt in reg_pts): + LOGGER.warning(f"Invalid region points for {region_name}: {reg_pts}") + continue # Skip invalid entries + color = colors(idx, True) + self.annotator.draw_region(reg_pts=reg_pts, color=color, thickness=self.line_width * 2) + self.add_region(region_name, reg_pts, color, self.annotator.get_txt_color()) + + # Prepare regions for containment check + for region in self.counting_regions: + region["prepared_polygon"] = self.prep(region["polygon"]) + + # Process bounding boxes and count objects within each region + for box, cls in zip(self.boxes, self.clss): + self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True)) + bbox_center = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + + for region in self.counting_regions: + if region["prepared_polygon"].contains(self.Point(bbox_center)): + region["counts"] += 1 + + # Display counts in each region + for region in self.counting_regions: + self.annotator.text_label( + region["polygon"].bounds, + label=str(region["counts"]), + color=region["region_color"], + txt_color=region["text_color"], + ) + region["counts"] = 0 # Reset count for next frame + + self.display_output(im0) + return im0 diff --git a/ultralytics/solutions/security_alarm.py b/ultralytics/solutions/security_alarm.py new file mode 100644 index 0000000000000000000000000000000000000000..e07119bc5bdb76e0d92436dc1fec4047be79fe0a --- /dev/null +++ b/ultralytics/solutions/security_alarm.py @@ -0,0 +1,144 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils import LOGGER +from ultralytics.utils.plotting import Annotator, colors + + +class SecurityAlarm(BaseSolution): + """ + A class to manage security alarm functionalities for real-time monitoring. + + This class extends the BaseSolution class and provides features to monitor + objects in a frame, send email notifications when specific thresholds are + exceeded for total detections, and annotate the output frame for visualization. + + Attributes: + email_sent (bool): Flag to track if an email has already been sent for the current event. + records (int): Threshold for the number of detected objects to trigger an alert. + + Methods: + authenticate: Sets up email server authentication for sending alerts. + send_email: Sends an email notification with details and an image attachment. + monitor: Monitors the frame, processes detections, and triggers alerts if thresholds are crossed. + + Examples: + >>> security = SecurityAlarm() + >>> security.authenticate("abc@gmail.com", "1111222233334444", "xyz@gmail.com") + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = security.monitor(frame) + """ + + def __init__(self, **kwargs): + """Initializes the SecurityAlarm class with parameters for real-time object monitoring.""" + super().__init__(**kwargs) + self.email_sent = False + self.records = self.CFG["records"] + self.server = None + self.to_email = "" + self.from_email = "" + + def authenticate(self, from_email, password, to_email): + """ + Authenticates the email server for sending alert notifications. + + Args: + from_email (str): Sender's email address. + password (str): Password for the sender's email account. + to_email (str): Recipient's email address. + + This method initializes a secure connection with the SMTP server + and logs in using the provided credentials. + + Examples: + >>> alarm = SecurityAlarm() + >>> alarm.authenticate("sender@example.com", "password123", "recipient@example.com") + """ + import smtplib + + self.server = smtplib.SMTP("smtp.gmail.com: 587") + self.server.starttls() + self.server.login(from_email, password) + self.to_email = to_email + self.from_email = from_email + + def send_email(self, im0, records=5): + """ + Sends an email notification with an image attachment indicating the number of objects detected. + + Args: + im0 (numpy.ndarray): The input image or frame to be attached to the email. + records (int): The number of detected objects to be included in the email message. + + This method encodes the input image, composes the email message with + details about the detection, and sends it to the specified recipient. + + Examples: + >>> alarm = SecurityAlarm() + >>> frame = cv2.imread("path/to/image.jpg") + >>> alarm.send_email(frame, records=10) + """ + from email.mime.image import MIMEImage + from email.mime.multipart import MIMEMultipart + from email.mime.text import MIMEText + + import cv2 + + img_bytes = cv2.imencode(".jpg", im0)[1].tobytes() # Encode the image as JPEG + + # Create the email + message = MIMEMultipart() + message["From"] = self.from_email + message["To"] = self.to_email + message["Subject"] = "Security Alert" + + # Add the text message body + message_body = f"Ultralytics ALERT!!! {records} objects have been detected!!" + message.attach(MIMEText(message_body)) + + # Attach the image + image_attachment = MIMEImage(img_bytes, name="ultralytics.jpg") + message.attach(image_attachment) + + # Send the email + try: + self.server.send_message(message) + LOGGER.info("✅ Email sent successfully!") + except Exception as e: + print(f"❌ Failed to send email: {e}") + + def monitor(self, im0): + """ + Monitors the frame, processes object detections, and triggers alerts if thresholds are exceeded. + + Args: + im0 (numpy.ndarray): The input image or frame to be processed and annotated. + + This method processes the input frame, extracts detections, annotates the frame + with bounding boxes, and sends an email notification if the number of detected objects + surpasses the specified threshold and an alert has not already been sent. + + Returns: + (numpy.ndarray): The processed frame with annotations. + + Examples: + >>> alarm = SecurityAlarm() + >>> frame = cv2.imread("path/to/image.jpg") + >>> processed_frame = alarm.monitor(frame) + """ + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + # Iterate over bounding boxes, track ids and classes index + for box, cls in zip(self.boxes, self.clss): + # Draw bounding box + self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True)) + + total_det = len(self.clss) + if total_det > self.records and not self.email_sent: # Only send email If not sent before + self.send_email(im0, total_det) + self.email_sent = True + + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py new file mode 100644 index 0000000000000000000000000000000000000000..8b526f4336fdf0140bbc87456e75494f90b039f6 --- /dev/null +++ b/ultralytics/solutions/solutions.py @@ -0,0 +1,178 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from collections import defaultdict + +import cv2 + +from ultralytics import YOLO +from ultralytics.utils import ASSETS_URL, DEFAULT_CFG_DICT, DEFAULT_SOL_DICT, LOGGER +from ultralytics.utils.checks import check_imshow, check_requirements + + +class BaseSolution: + """ + A base class for managing Ultralytics Solutions. + + This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking, + and region initialization. + + Attributes: + LineString (shapely.geometry.LineString): Class for creating line string geometries. + Polygon (shapely.geometry.Polygon): Class for creating polygon geometries. + Point (shapely.geometry.Point): Class for creating point geometries. + CFG (Dict): Configuration dictionary loaded from a YAML file and updated with kwargs. + region (List[Tuple[int, int]]): List of coordinate tuples defining a region of interest. + line_width (int): Width of lines used in visualizations. + model (ultralytics.YOLO): Loaded YOLO model instance. + names (Dict[int, str]): Dictionary mapping class indices to class names. + env_check (bool): Flag indicating whether the environment supports image display. + track_history (collections.defaultdict): Dictionary to store tracking history for each object. + + Methods: + extract_tracks: Apply object tracking and extract tracks from an input image. + store_tracking_history: Store object tracking history for a given track ID and bounding box. + initialize_region: Initialize the counting region and line segment based on configuration. + display_output: Display the results of processing, including showing frames or saving results. + + Examples: + >>> solution = BaseSolution(model="yolov8n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)]) + >>> solution.initialize_region() + >>> image = cv2.imread("image.jpg") + >>> solution.extract_tracks(image) + >>> solution.display_output(image) + """ + + def __init__(self, IS_CLI=False, **kwargs): + """ + Initializes the `BaseSolution` class with configuration settings and the YOLO model for Ultralytics solutions. + + IS_CLI (optional): Enables CLI mode if set. + """ + check_requirements("shapely>=2.0.0") + from shapely.geometry import LineString, Point, Polygon + from shapely.prepared import prep + + self.LineString = LineString + self.Polygon = Polygon + self.Point = Point + self.prep = prep + self.annotator = None # Initialize annotator + self.tracks = None + self.track_data = None + self.boxes = [] + self.clss = [] + self.track_ids = [] + self.track_line = None + self.r_s = None + + # Load config and update with args + DEFAULT_SOL_DICT.update(kwargs) + DEFAULT_CFG_DICT.update(kwargs) + self.CFG = {**DEFAULT_SOL_DICT, **DEFAULT_CFG_DICT} + LOGGER.info(f"Ultralytics Solutions: ✅ {DEFAULT_SOL_DICT}") + + self.region = self.CFG["region"] # Store region data for other classes usage + self.line_width = ( + self.CFG["line_width"] if self.CFG["line_width"] is not None else 2 + ) # Store line_width for usage + + # Load Model and store classes names + if self.CFG["model"] is None: + self.CFG["model"] = "yolo11n.pt" + self.model = YOLO(self.CFG["model"]) + self.names = self.model.names + + self.track_add_args = { # Tracker additional arguments for advance configuration + k: self.CFG[k] for k in ["verbose", "iou", "conf", "device", "max_det", "half", "tracker"] + } + + if IS_CLI and self.CFG["source"] is None: + d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4" + LOGGER.warning(f"⚠️ WARNING: source not provided. using default source {ASSETS_URL}/{d_s}") + from ultralytics.utils.downloads import safe_download + + safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets + self.CFG["source"] = d_s # set default source + + # Initialize environment and region setup + self.env_check = check_imshow(warn=True) + self.track_history = defaultdict(list) + + def extract_tracks(self, im0): + """ + Applies object tracking and extracts tracks from an input image or frame. + + Args: + im0 (ndarray): The input image or frame. + + Examples: + >>> solution = BaseSolution() + >>> frame = cv2.imread("path/to/image.jpg") + >>> solution.extract_tracks(frame) + """ + self.tracks = self.model.track(source=im0, persist=True, classes=self.CFG["classes"], **self.track_add_args) + + # Extract tracks for OBB or object detection + self.track_data = self.tracks[0].obb or self.tracks[0].boxes + + if self.track_data and self.track_data.id is not None: + self.boxes = self.track_data.xyxy.cpu() + self.clss = self.track_data.cls.cpu().tolist() + self.track_ids = self.track_data.id.int().cpu().tolist() + else: + LOGGER.warning("WARNING ⚠️ no tracks found!") + self.boxes, self.clss, self.track_ids = [], [], [] + + def store_tracking_history(self, track_id, box): + """ + Stores the tracking history of an object. + + This method updates the tracking history for a given object by appending the center point of its + bounding box to the track line. It maintains a maximum of 30 points in the tracking history. + + Args: + track_id (int): The unique identifier for the tracked object. + box (List[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2]. + + Examples: + >>> solution = BaseSolution() + >>> solution.store_tracking_history(1, [100, 200, 300, 400]) + """ + # Store tracking history + self.track_line = self.track_history[track_id] + self.track_line.append(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)) + if len(self.track_line) > 30: + self.track_line.pop(0) + + def initialize_region(self): + """Initialize the counting region and line segment based on configuration settings.""" + if self.region is None: + self.region = [(20, 400), (1080, 400), (1080, 360), (20, 360)] + self.r_s = ( + self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region) + ) # region or line + + def display_output(self, im0): + """ + Display the results of the processing, which could involve showing frames, printing counts, or saving results. + + This method is responsible for visualizing the output of the object detection and tracking process. It displays + the processed frame with annotations, and allows for user interaction to close the display. + + Args: + im0 (numpy.ndarray): The input image or frame that has been processed and annotated. + + Examples: + >>> solution = BaseSolution() + >>> frame = cv2.imread("path/to/image.jpg") + >>> solution.display_output(frame) + + Notes: + - This method will only display output if the 'show' configuration is set to True and the environment + supports image display. + - The display can be closed by pressing the 'q' key. + """ + if self.CFG.get("show") and self.env_check: + cv2.imshow("Ultralytics Solutions", im0) + if cv2.waitKey(1) & 0xFF == ord("q"): + return diff --git a/ultralytics/solutions/speed_estimation.py b/ultralytics/solutions/speed_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..43eaceceb2732638347c98ffef0195837a64b824 --- /dev/null +++ b/ultralytics/solutions/speed_estimation.py @@ -0,0 +1,110 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from time import time + +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class SpeedEstimator(BaseSolution): + """ + A class to estimate the speed of objects in a real-time video stream based on their tracks. + + This class extends the BaseSolution class and provides functionality for estimating object speeds using + tracking data in video streams. + + Attributes: + spd (Dict[int, float]): Dictionary storing speed data for tracked objects. + trkd_ids (List[int]): List of tracked object IDs that have already been speed-estimated. + trk_pt (Dict[int, float]): Dictionary storing previous timestamps for tracked objects. + trk_pp (Dict[int, Tuple[float, float]]): Dictionary storing previous positions for tracked objects. + annotator (Annotator): Annotator object for drawing on images. + region (List[Tuple[int, int]]): List of points defining the speed estimation region. + track_line (List[Tuple[float, float]]): List of points representing the object's track. + r_s (LineString): LineString object representing the speed estimation region. + + Methods: + initialize_region: Initializes the speed estimation region. + estimate_speed: Estimates the speed of objects based on tracking data. + store_tracking_history: Stores the tracking history for an object. + extract_tracks: Extracts tracks from the current frame. + display_output: Displays the output with annotations. + + Examples: + >>> estimator = SpeedEstimator() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = estimator.estimate_speed(frame) + >>> cv2.imshow("Speed Estimation", processed_frame) + """ + + def __init__(self, **kwargs): + """Initializes the SpeedEstimator object with speed estimation parameters and data structures.""" + super().__init__(**kwargs) + + self.initialize_region() # Initialize speed region + + self.spd = {} # set for speed data + self.trkd_ids = [] # list for already speed_estimated and tracked ID's + self.trk_pt = {} # set for tracks previous time + self.trk_pp = {} # set for tracks previous point + + def estimate_speed(self, im0): + """ + Estimates the speed of objects based on tracking data. + + Args: + im0 (np.ndarray): Input image for processing. Shape is typically (H, W, C) for RGB images. + + Returns: + (np.ndarray): Processed image with speed estimations and annotations. + + Examples: + >>> estimator = SpeedEstimator() + >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> processed_image = estimator.estimate_speed(image) + """ + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + self.extract_tracks(im0) # Extract tracks + + self.annotator.draw_region( + reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2 + ) # Draw region + + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + self.store_tracking_history(track_id, box) # Store track history + + # Check if track_id is already in self.trk_pp or trk_pt initialize if not + if track_id not in self.trk_pt: + self.trk_pt[track_id] = 0 + if track_id not in self.trk_pp: + self.trk_pp[track_id] = self.track_line[-1] + + speed_label = f"{int(self.spd[track_id])} km/h" if track_id in self.spd else self.names[int(cls)] + self.annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box + + # Draw tracks of objects + self.annotator.draw_centroid_and_tracks( + self.track_line, color=colors(int(track_id), True), track_thickness=self.line_width + ) + + # Calculate object speed and direction based on region intersection + if self.LineString([self.trk_pp[track_id], self.track_line[-1]]).intersects(self.r_s): + direction = "known" + else: + direction = "unknown" + + # Perform speed calculation and tracking updates if direction is valid + if direction == "known" and track_id not in self.trkd_ids: + self.trkd_ids.append(track_id) + time_difference = time() - self.trk_pt[track_id] + if time_difference > 0: + self.spd[track_id] = np.abs(self.track_line[-1][1] - self.trk_pp[track_id][1]) / time_difference + + self.trk_pt[track_id] = time() + self.trk_pp[track_id] = self.track_line[-1] + + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/ultralytics/solutions/streamlit_inference.py b/ultralytics/solutions/streamlit_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..50cc2584095d25c799d48314011741f1bef432ad --- /dev/null +++ b/ultralytics/solutions/streamlit_inference.py @@ -0,0 +1,190 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import io +from typing import Any + +import cv2 + +from ultralytics import YOLO +from ultralytics.utils import LOGGER +from ultralytics.utils.checks import check_requirements +from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS + + +class Inference: + """ + A class to perform object detection, image classification, image segmentation and pose estimation inference using + Streamlit and Ultralytics YOLO models. It provides the functionalities such as loading models, configuring settings, + uploading video files, and performing real-time inference. + + Attributes: + st (module): Streamlit module for UI creation. + temp_dict (dict): Temporary dictionary to store the model path. + model_path (str): Path to the loaded model. + model (YOLO): The YOLO model instance. + source (str): Selected video source. + enable_trk (str): Enable tracking option. + conf (float): Confidence threshold. + iou (float): IoU threshold for non-max suppression. + vid_file_name (str): Name of the uploaded video file. + selected_ind (list): List of selected class indices. + + Methods: + web_ui: Sets up the Streamlit web interface with custom HTML elements. + sidebar: Configures the Streamlit sidebar for model and inference settings. + source_upload: Handles video file uploads through the Streamlit interface. + configure: Configures the model and loads selected classes for inference. + inference: Performs real-time object detection inference. + + Examples: + >>> inf = solutions.Inference(model="path/to/model.pt") # Model is not necessary argument. + >>> inf.inference() + """ + + def __init__(self, **kwargs: Any): + """ + Initializes the Inference class, checking Streamlit requirements and setting up the model path. + + Args: + **kwargs (Any): Additional keyword arguments for model configuration. + """ + check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds + import streamlit as st + + self.st = st # Reference to the Streamlit class instance + self.source = None # Placeholder for video or webcam source details + self.enable_trk = False # Flag to toggle object tracking + self.conf = 0.25 # Confidence threshold for detection + self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression + self.org_frame = None # Container for the original frame to be displayed + self.ann_frame = None # Container for the annotated frame to be displayed + self.vid_file_name = None # Holds the name of the video file + self.selected_ind = [] # List of selected classes for detection or tracking + self.model = None # Container for the loaded model instance + + self.temp_dict = {"model": None, **kwargs} + self.model_path = None # Store model file name with path + if self.temp_dict["model"] is not None: + self.model_path = self.temp_dict["model"] + + LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}") + + def web_ui(self): + """Sets up the Streamlit web interface with custom HTML elements.""" + menu_style_cfg = """""" # Hide main menu style + + # Main title of streamlit application + main_title_cfg = """

Ultralytics YOLO Streamlit Application

""" + + # Subtitle of streamlit application + sub_title_cfg = """

Experience real-time object detection on your webcam with the power + of Ultralytics YOLO! 🚀

""" + + # Set html page configuration and append custom HTML + self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide") + self.st.markdown(menu_style_cfg, unsafe_allow_html=True) + self.st.markdown(main_title_cfg, unsafe_allow_html=True) + self.st.markdown(sub_title_cfg, unsafe_allow_html=True) + + def sidebar(self): + """Configures the Streamlit sidebar for model and inference settings.""" + with self.st.sidebar: # Add Ultralytics LOGO + logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg" + self.st.image(logo, width=250) + + self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu + self.source = self.st.sidebar.selectbox( + "Video", + ("webcam", "video"), + ) # Add source selection dropdown + self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) # Enable object tracking + self.conf = float( + self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01) + ) # Slider for confidence + self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold + + col1, col2 = self.st.columns(2) + self.org_frame = col1.empty() + self.ann_frame = col2.empty() + + def source_upload(self): + """Handles video file uploads through the Streamlit interface.""" + self.vid_file_name = "" + if self.source == "video": + vid_file = self.st.sidebar.file_uploader("Upload Video File", type=["mp4", "mov", "avi", "mkv"]) + if vid_file is not None: + g = io.BytesIO(vid_file.read()) # BytesIO Object + with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes + out.write(g.read()) # Read bytes into file + self.vid_file_name = "ultralytics.mp4" + elif self.source == "webcam": + self.vid_file_name = 0 + + def configure(self): + """Configures the model and loads selected classes for inference.""" + # Add dropdown menu for model selection + available_models = [x.replace("yolo", "YOLO") for x in GITHUB_ASSETS_STEMS if x.startswith("yolo11")] + if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later + available_models.insert(0, self.model_path.split(".pt")[0]) + selected_model = self.st.sidebar.selectbox("Model", available_models) + + with self.st.spinner("Model is downloading..."): + self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model + class_names = list(self.model.names.values()) # Convert dictionary to list of class names + self.st.success("Model loaded successfully!") + + # Multiselect box with class names and get indices of selected classes + selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) + self.selected_ind = [class_names.index(option) for option in selected_classes] + + if not isinstance(self.selected_ind, list): # Ensure selected_options is a list + self.selected_ind = list(self.selected_ind) + + def inference(self): + """Performs real-time object detection inference.""" + self.web_ui() # Initialize the web interface + self.sidebar() # Create the sidebar + self.source_upload() # Upload the video source + self.configure() # Configure the app + + if self.st.sidebar.button("Start"): + stop_button = self.st.button("Stop") # Button to stop the inference + cap = cv2.VideoCapture(self.vid_file_name) # Capture the video + if not cap.isOpened(): + self.st.error("Could not open webcam.") + while cap.isOpened(): + success, frame = cap.read() + if not success: + self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.") + break + + # Store model predictions + if self.enable_trk == "Yes": + results = self.model.track( + frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True + ) + else: + results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind) + annotated_frame = results[0].plot() # Add annotations on frame + + if stop_button: + cap.release() # Release the capture + self.st.stop() # Stop streamlit app + + self.org_frame.image(frame, channels="BGR") # Display original frame + self.ann_frame.image(annotated_frame, channels="BGR") # Display processed frame + + cap.release() # Release the capture + cv2.destroyAllWindows() # Destroy window + + +if __name__ == "__main__": + import sys # Import the sys module for accessing command-line arguments + + # Check if a model name is provided as a command-line argument + args = len(sys.argv) + model = sys.argv[1] if args > 1 else None # assign first argument as the model name + # Create an instance of the Inference class and run inference + Inference(model=model).inference() diff --git a/ultralytics/solutions/trackzone.py b/ultralytics/solutions/trackzone.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d32f2d8e58f27ae661c67eb86ee787f778554d --- /dev/null +++ b/ultralytics/solutions/trackzone.py @@ -0,0 +1,68 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import cv2 +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class TrackZone(BaseSolution): + """ + A class to manage region-based object tracking in a video stream. + + This class extends the BaseSolution class and provides functionality for tracking objects within a specific region + defined by a polygonal area. Objects outside the region are excluded from tracking. It supports dynamic initialization + of the region, allowing either a default region or a user-specified polygon. + + Attributes: + region (ndarray): The polygonal region for tracking, represented as a convex hull. + + Methods: + trackzone: Processes each frame of the video, applying region-based tracking. + + Examples: + >>> tracker = TrackZone() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = tracker.trackzone(frame) + >>> cv2.imshow("Tracked Frame", processed_frame) + """ + + def __init__(self, **kwargs): + """Initializes the TrackZone class for tracking objects within a defined region in video streams.""" + super().__init__(**kwargs) + default_region = [(150, 150), (1130, 150), (1130, 570), (150, 570)] + self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32)) + + def trackzone(self, im0): + """ + Processes the input frame to track objects within a defined region. + + This method initializes the annotator, creates a mask for the specified region, extracts tracks + only from the masked area, and updates tracking information. Objects outside the region are ignored. + + Args: + im0 (numpy.ndarray): The input image or frame to be processed. + + Returns: + (numpy.ndarray): The processed image with tracking id and bounding boxes annotations. + + Examples: + >>> tracker = TrackZone() + >>> frame = cv2.imread("path/to/image.jpg") + >>> tracker.trackzone(frame) + """ + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + # Create a mask for the region and extract tracks from the masked image + masked_frame = cv2.bitwise_and(im0, im0, mask=cv2.fillPoly(np.zeros_like(im0[:, :, 0]), [self.region], 255)) + self.extract_tracks(masked_frame) + + cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2) + + # Iterate over boxes, track ids, classes indexes list and draw bounding boxes + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + self.annotator.box_label(box, label=f"{self.names[cls]}:{track_id}", color=colors(track_id, True)) + + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage diff --git a/ultralytics/trackers/README.md b/ultralytics/trackers/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3743d5374c54f75ec33699d45a9b70fde9ec1e54 --- /dev/null +++ b/ultralytics/trackers/README.md @@ -0,0 +1,314 @@ +# Multi-Object Tracking with Ultralytics YOLO + +YOLOv8 trackers visualization + +Object tracking in the realm of video analytics is a critical task that not only identifies the location and class of objects within the frame but also maintains a unique ID for each detected object as the video progresses. The applications are limitless—ranging from surveillance and security to real-time sports analytics. + +## Why Choose Ultralytics YOLO for Object Tracking? + +The output from Ultralytics trackers is consistent with standard object detection but has the added value of object IDs. This makes it easy to track objects in video streams and perform subsequent analytics. Here's why you should consider using Ultralytics YOLO for your object tracking needs: + +- **Efficiency:** Process video streams in real-time without compromising accuracy. +- **Flexibility:** Supports multiple tracking algorithms and configurations. +- **Ease of Use:** Simple Python API and CLI options for quick integration and deployment. +- **Customizability:** Easy to use with custom trained YOLO models, allowing integration into domain-specific applications. + +**Video Tutorial:** [Object Detection and Tracking with Ultralytics YOLO](https://www.youtube.com/embed/hHyHmOtmEgs?si=VNZtXmm45Nb9s-N-). + +## Features at a Glance + +Ultralytics YOLO extends its object detection features to provide robust and versatile object tracking: + +- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos. +- **Multiple Tracker Support:** Choose from a variety of established tracking algorithms. +- **Customizable Tracker Configurations:** Tailor the tracking algorithm to meet specific requirements by adjusting various parameters. + +## Available Trackers + +Ultralytics YOLO supports the following tracking algorithms. They can be enabled by passing the relevant YAML configuration file such as `tracker=tracker_type.yaml`: + +- [BoT-SORT](https://github.com/NirAharon/BoT-SORT) - Use `botsort.yaml` to enable this tracker. +- [ByteTrack](https://github.com/ifzhang/ByteTrack) - Use `bytetrack.yaml` to enable this tracker. + +The default tracker is BoT-SORT. + +## Tracking + +To run the tracker on video streams, use a trained Detect, Segment or Pose model such as YOLO11n, YOLO11n-seg and YOLO11n-pose. + +#### Python + +```python +from ultralytics import YOLO + +# Load an official or custom model +model = YOLO("yolo11n.pt") # Load an official Detect model +model = YOLO("yolo11n-seg.pt") # Load an official Segment model +model = YOLO("yolo11n-pose.pt") # Load an official Pose model +model = YOLO("path/to/best.pt") # Load a custom trained model + +# Perform tracking with the model +results = model.track(source="https://youtu.be/LNwODJXcvt4", show=True) # Tracking with default tracker +results = model.track( + source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml" +) # Tracking with ByteTrack tracker +``` + +#### CLI + +```bash +# Perform tracking with various models using the command line interface +yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model +yolo track model=yolo11n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model +yolo track model=yolo11n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model +yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model + +# Track using ByteTrack tracker +yolo track model=path/to/best.pt tracker="bytetrack.yaml" +``` + +As can be seen in the above usage, tracking is available for all Detect, Segment and Pose models run on videos or streaming sources. + +## Configuration + +### Tracking Arguments + +Tracking configuration shares properties with Predict mode, such as `conf`, `iou`, and `show`. For further configurations, refer to the [Predict](https://docs.ultralytics.com/modes/predict/) model page. + +#### Python + +```python +from ultralytics import YOLO + +# Configure the tracking parameters and run the tracker +model = YOLO("yolo11n.pt") +results = model.track(source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True) +``` + +#### CLI + +```bash +# Configure tracking parameters and run the tracker using the command line interface +yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3, iou=0.5 show +``` + +### Tracker Selection + +Ultralytics also allows you to use a modified tracker configuration file. To do this, simply make a copy of a tracker config file (for example, `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and modify any configurations (except the `tracker_type`) as per your needs. + +#### Python + +```python +from ultralytics import YOLO + +# Load the model and run the tracker with a custom configuration file +model = YOLO("yolo11n.pt") +results = model.track(source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml") +``` + +#### CLI + +```bash +# Load the model and run the tracker with a custom configuration file using the command line interface +yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml' +``` + +For a comprehensive list of tracking arguments, refer to the [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) page. + +## Python Examples + +### Persisting Tracks Loop + +Here is a Python script using OpenCV (`cv2`) and YOLO11 to run object tracking on video frames. This script still assumes you have already installed the necessary packages (`opencv-python` and `ultralytics`). The `persist=True` argument tells the tracker than the current image or frame is the next in a sequence and to expect tracks from the previous image in the current image. + +#### Python + +```python +import cv2 + +from ultralytics import YOLO + +# Load the YOLO11 model +model = YOLO("yolo11n.pt") + +# Open the video file +video_path = "path/to/video.mp4" +cap = cv2.VideoCapture(video_path) + +# Loop through the video frames +while cap.isOpened(): + # Read a frame from the video + success, frame = cap.read() + + if success: + # Run YOLO11 tracking on the frame, persisting tracks between frames + results = model.track(frame, persist=True) + + # Visualize the results on the frame + annotated_frame = results[0].plot() + + # Display the annotated frame + cv2.imshow("YOLO11 Tracking", annotated_frame) + + # Break the loop if 'q' is pressed + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + # Break the loop if the end of the video is reached + break + +# Release the video capture object and close the display window +cap.release() +cv2.destroyAllWindows() +``` + +Please note the change from `model(frame)` to `model.track(frame)`, which enables object tracking instead of simple detection. This modified script will run the tracker on each frame of the video, visualize the results, and display them in a window. The loop can be exited by pressing 'q'. + +### Plotting Tracks Over Time + +Visualizing object tracks over consecutive frames can provide valuable insights into the movement patterns and behavior of detected objects within a video. With Ultralytics YOLO11, plotting these tracks is a seamless and efficient process. + +In the following example, we demonstrate how to utilize YOLO11's tracking capabilities to plot the movement of detected objects across multiple video frames. This script involves opening a video file, reading it frame by frame, and utilizing the YOLO model to identify and track various objects. By retaining the center points of the detected bounding boxes and connecting them, we can draw lines that represent the paths followed by the tracked objects. + +#### Python + +```python +from collections import defaultdict + +import cv2 +import numpy as np + +from ultralytics import YOLO + +# Load the YOLO11 model +model = YOLO("yolo11n.pt") + +# Open the video file +video_path = "path/to/video.mp4" +cap = cv2.VideoCapture(video_path) + +# Store the track history +track_history = defaultdict(lambda: []) + +# Loop through the video frames +while cap.isOpened(): + # Read a frame from the video + success, frame = cap.read() + + if success: + # Run YOLO11 tracking on the frame, persisting tracks between frames + results = model.track(frame, persist=True) + + # Get the boxes and track IDs + boxes = results[0].boxes.xywh.cpu() + track_ids = results[0].boxes.id.int().cpu().tolist() + + # Visualize the results on the frame + annotated_frame = results[0].plot() + + # Plot the tracks + for box, track_id in zip(boxes, track_ids): + x, y, w, h = box + track = track_history[track_id] + track.append((float(x), float(y))) # x, y center point + if len(track) > 30: # retain 90 tracks for 90 frames + track.pop(0) + + # Draw the tracking lines + points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines( + annotated_frame, + [points], + isClosed=False, + color=(230, 230, 230), + thickness=10, + ) + + # Display the annotated frame + cv2.imshow("YOLO11 Tracking", annotated_frame) + + # Break the loop if 'q' is pressed + if cv2.waitKey(1) & 0xFF == ord("q"): + break + else: + # Break the loop if the end of the video is reached + break + +# Release the video capture object and close the display window +cap.release() +cv2.destroyAllWindows() +``` + +### Multithreaded Tracking + +Multithreaded tracking provides the capability to run object tracking on multiple video streams simultaneously. This is particularly useful when handling multiple video inputs, such as from multiple surveillance cameras, where concurrent processing can greatly enhance efficiency and performance. + +In the provided Python script, we make use of Python's `threading` module to run multiple instances of the tracker concurrently. Each thread is responsible for running the tracker on one video file, and all the threads run simultaneously in the background. + +To ensure that each thread receives the correct parameters (the video file and the model to use), we define a function `run_tracker_in_thread` that accepts these parameters and contains the main tracking loop. This function reads the video frame by frame, runs the tracker, and displays the results. + +Two different models are used in this example: `yolo11n.pt` and `yolo11n-seg.pt`, each tracking objects in a different video file. The video files are specified in `video_file1` and `video_file2`. + +The `daemon=True` parameter in `threading.Thread` means that these threads will be closed as soon as the main program finishes. We then start the threads with `start()` and use `join()` to make the main thread wait until both tracker threads have finished. + +Finally, after all threads have completed their task, the windows displaying the results are closed using `cv2.destroyAllWindows()`. + +#### Python + +```python +import threading + +import cv2 + +from ultralytics import YOLO + + +def run_tracker_in_thread(filename, model): + """Starts multi-thread tracking on video from `filename` using `model` and displays results frame by frame.""" + video = cv2.VideoCapture(filename) + frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + for _ in range(frames): + ret, frame = video.read() + if ret: + results = model.track(source=frame, persist=True) + res_plotted = results[0].plot() + cv2.imshow("p", res_plotted) + if cv2.waitKey(1) == ord("q"): + break + + +# Load the models +model1 = YOLO("yolo11n.pt") +model2 = YOLO("yolo11n-seg.pt") + +# Define the video files for the trackers +video_file1 = "path/to/video1.mp4" +video_file2 = "path/to/video2.mp4" + +# Create the tracker threads +tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1), daemon=True) +tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2), daemon=True) + +# Start the tracker threads +tracker_thread1.start() +tracker_thread2.start() + +# Wait for the tracker threads to finish +tracker_thread1.join() +tracker_thread2.join() + +# Clean up and close windows +cv2.destroyAllWindows() +``` + +This example can easily be extended to handle more video files and models by creating more threads and applying the same methodology. + +## Contribute New Trackers + +Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your real-world applications and solutions could be invaluable for users working on tracking tasks. + +By contributing to this section, you help expand the scope of tracking solutions available within the Ultralytics YOLO framework, adding another layer of functionality and utility for the community. + +To initiate your contribution, please refer to our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) for comprehensive instructions on submitting a Pull Request (PR) 🛠️. We are excited to see what you bring to the table! + +Together, let's enhance the tracking capabilities of the Ultralytics YOLO ecosystem 🙏! diff --git a/ultralytics/trackers/__init__.py b/ultralytics/trackers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2919511ba506cf9887d4fcd1014f4a57263f36ba --- /dev/null +++ b/ultralytics/trackers/__init__.py @@ -0,0 +1,7 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .bot_sort import BOTSORT +from .byte_tracker import BYTETracker +from .track import register_tracker + +__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import diff --git a/ultralytics/trackers/basetrack.py b/ultralytics/trackers/basetrack.py new file mode 100644 index 0000000000000000000000000000000000000000..47b27269e2a92c4925cdc034be3d40efc1db4270 --- /dev/null +++ b/ultralytics/trackers/basetrack.py @@ -0,0 +1,124 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Module defines the base classes and structures for object tracking in YOLO.""" + +from collections import OrderedDict + +import numpy as np + + +class TrackState: + """ + Enumeration class representing the possible states of an object being tracked. + + Attributes: + New (int): State when the object is newly detected. + Tracked (int): State when the object is successfully tracked in subsequent frames. + Lost (int): State when the object is no longer tracked. + Removed (int): State when the object is removed from tracking. + + Examples: + >>> state = TrackState.New + >>> if state == TrackState.New: + >>> print("Object is newly detected.") + """ + + New = 0 + Tracked = 1 + Lost = 2 + Removed = 3 + + +class BaseTrack: + """ + Base class for object tracking, providing foundational attributes and methods. + + Attributes: + _count (int): Class-level counter for unique track IDs. + track_id (int): Unique identifier for the track. + is_activated (bool): Flag indicating whether the track is currently active. + state (TrackState): Current state of the track. + history (OrderedDict): Ordered history of the track's states. + features (List): List of features extracted from the object for tracking. + curr_feature (Any): The current feature of the object being tracked. + score (float): The confidence score of the tracking. + start_frame (int): The frame number where tracking started. + frame_id (int): The most recent frame ID processed by the track. + time_since_update (int): Frames passed since the last update. + location (tuple): The location of the object in the context of multi-camera tracking. + + Methods: + end_frame: Returns the ID of the last frame where the object was tracked. + next_id: Increments and returns the next global track ID. + activate: Abstract method to activate the track. + predict: Abstract method to predict the next state of the track. + update: Abstract method to update the track with new data. + mark_lost: Marks the track as lost. + mark_removed: Marks the track as removed. + reset_id: Resets the global track ID counter. + + Examples: + Initialize a new track and mark it as lost: + >>> track = BaseTrack() + >>> track.mark_lost() + >>> print(track.state) # Output: 2 (TrackState.Lost) + """ + + _count = 0 + + def __init__(self): + """ + Initializes a new track with a unique ID and foundational tracking attributes. + + Examples: + Initialize a new track + >>> track = BaseTrack() + >>> print(track.track_id) + 0 + """ + self.track_id = 0 + self.is_activated = False + self.state = TrackState.New + self.history = OrderedDict() + self.features = [] + self.curr_feature = None + self.score = 0 + self.start_frame = 0 + self.frame_id = 0 + self.time_since_update = 0 + self.location = (np.inf, np.inf) + + @property + def end_frame(self): + """Returns the ID of the most recent frame where the object was tracked.""" + return self.frame_id + + @staticmethod + def next_id(): + """Increment and return the next unique global track ID for object tracking.""" + BaseTrack._count += 1 + return BaseTrack._count + + def activate(self, *args): + """Activates the track with provided arguments, initializing necessary attributes for tracking.""" + raise NotImplementedError + + def predict(self): + """Predicts the next state of the track based on the current state and tracking model.""" + raise NotImplementedError + + def update(self, *args, **kwargs): + """Updates the track with new observations and data, modifying its state and attributes accordingly.""" + raise NotImplementedError + + def mark_lost(self): + """Marks the track as lost by updating its state to TrackState.Lost.""" + self.state = TrackState.Lost + + def mark_removed(self): + """Marks the track as removed by setting its state to TrackState.Removed.""" + self.state = TrackState.Removed + + @staticmethod + def reset_id(): + """Reset the global track ID counter to its initial value.""" + BaseTrack._count = 0 diff --git a/ultralytics/trackers/bot_sort.py b/ultralytics/trackers/bot_sort.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2d02e0e14c48d65d9ab20493740974a0c0f483 --- /dev/null +++ b/ultralytics/trackers/bot_sort.py @@ -0,0 +1,233 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from collections import deque + +import numpy as np + +from .basetrack import TrackState +from .byte_tracker import BYTETracker, STrack +from .utils import matching +from .utils.gmc import GMC +from .utils.kalman_filter import KalmanFilterXYWH + + +class BOTrack(STrack): + """ + An extended version of the STrack class for YOLOv8, adding object tracking features. + + This class extends the STrack class to include additional functionalities for object tracking, such as feature + smoothing, Kalman filter prediction, and reactivation of tracks. + + Attributes: + shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack. + smooth_feat (np.ndarray): Smoothed feature vector. + curr_feat (np.ndarray): Current feature vector. + features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`. + alpha (float): Smoothing factor for the exponential moving average of features. + mean (np.ndarray): The mean state of the Kalman filter. + covariance (np.ndarray): The covariance matrix of the Kalman filter. + + Methods: + update_features(feat): Update features vector and smooth it using exponential moving average. + predict(): Predicts the mean and covariance using Kalman filter. + re_activate(new_track, frame_id, new_id): Reactivates a track with updated features and optionally new ID. + update(new_track, frame_id): Update the YOLOv8 instance with new track and frame ID. + tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`. + multi_predict(stracks): Predicts the mean and covariance of multiple object tracks using shared Kalman filter. + convert_coords(tlwh): Converts tlwh bounding box coordinates to xywh format. + tlwh_to_xywh(tlwh): Convert bounding box to xywh format `(center x, center y, width, height)`. + + Examples: + Create a BOTrack instance and update its features + >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128)) + >>> bo_track.predict() + >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128)) + >>> bo_track.update(new_track, frame_id=2) + """ + + shared_kalman = KalmanFilterXYWH() + + def __init__(self, tlwh, score, cls, feat=None, feat_history=50): + """ + Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features. + + Args: + tlwh (np.ndarray): Bounding box coordinates in tlwh format (top left x, top left y, width, height). + score (float): Confidence score of the detection. + cls (int): Class ID of the detected object. + feat (np.ndarray | None): Feature vector associated with the detection. + feat_history (int): Maximum length of the feature history deque. + + Examples: + Initialize a BOTrack object with bounding box, score, class ID, and feature vector + >>> tlwh = np.array([100, 50, 80, 120]) + >>> score = 0.9 + >>> cls = 1 + >>> feat = np.random.rand(128) + >>> bo_track = BOTrack(tlwh, score, cls, feat) + """ + super().__init__(tlwh, score, cls) + + self.smooth_feat = None + self.curr_feat = None + if feat is not None: + self.update_features(feat) + self.features = deque([], maxlen=feat_history) + self.alpha = 0.9 + + def update_features(self, feat): + """Update the feature vector and apply exponential moving average smoothing.""" + feat /= np.linalg.norm(feat) + self.curr_feat = feat + if self.smooth_feat is None: + self.smooth_feat = feat + else: + self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat + self.features.append(feat) + self.smooth_feat /= np.linalg.norm(self.smooth_feat) + + def predict(self): + """Predicts the object's future state using the Kalman filter to update its mean and covariance.""" + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[6] = 0 + mean_state[7] = 0 + + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + def re_activate(self, new_track, frame_id, new_id=False): + """Reactivates a track with updated features and optionally assigns a new ID.""" + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + super().re_activate(new_track, frame_id, new_id) + + def update(self, new_track, frame_id): + """Updates the YOLOv8 instance with new track information and the current frame ID.""" + if new_track.curr_feat is not None: + self.update_features(new_track.curr_feat) + super().update(new_track, frame_id) + + @property + def tlwh(self): + """Returns the current bounding box position in `(top left x, top left y, width, height)` format.""" + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[:2] -= ret[2:] / 2 + return ret + + @staticmethod + def multi_predict(stracks): + """Predicts the mean and covariance for multiple object tracks using a shared Kalman filter.""" + if len(stracks) <= 0: + return + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][6] = 0 + multi_mean[i][7] = 0 + multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + def convert_coords(self, tlwh): + """Converts tlwh bounding box coordinates to xywh format.""" + return self.tlwh_to_xywh(tlwh) + + @staticmethod + def tlwh_to_xywh(tlwh): + """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format.""" + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + return ret + + +class BOTSORT(BYTETracker): + """ + An extended version of the BYTETracker class for YOLOv8, designed for object tracking with ReID and GMC algorithm. + + Attributes: + proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections. + appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections. + encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled. + gmc (GMC): An instance of the GMC algorithm for data association. + args (Any): Parsed command-line arguments containing tracking parameters. + + Methods: + get_kalmanfilter(): Returns an instance of KalmanFilterXYWH for object tracking. + init_track(dets, scores, cls, img): Initialize track with detections, scores, and classes. + get_dists(tracks, detections): Get distances between tracks and detections using IoU and (optionally) ReID. + multi_predict(tracks): Predict and track multiple objects with YOLOv8 model. + + Examples: + Initialize BOTSORT and process detections + >>> bot_sort = BOTSORT(args, frame_rate=30) + >>> bot_sort.init_track(dets, scores, cls, img) + >>> bot_sort.multi_predict(tracks) + + Note: + The class is designed to work with the YOLOv8 object detection model and supports ReID only if enabled via args. + """ + + def __init__(self, args, frame_rate=30): + """ + Initialize YOLOv8 object with ReID module and GMC algorithm. + + Args: + args (object): Parsed command-line arguments containing tracking parameters. + frame_rate (int): Frame rate of the video being processed. + + Examples: + Initialize BOTSORT with command-line arguments and a specified frame rate: + >>> args = parse_args() + >>> bot_sort = BOTSORT(args, frame_rate=30) + """ + super().__init__(args, frame_rate) + # ReID module + self.proximity_thresh = args.proximity_thresh + self.appearance_thresh = args.appearance_thresh + + if args.with_reid: + # Haven't supported BoT-SORT(reid) yet + self.encoder = None + self.gmc = GMC(method=args.gmc_method) + + def get_kalmanfilter(self): + """Returns an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process.""" + return KalmanFilterXYWH() + + def init_track(self, dets, scores, cls, img=None): + """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features.""" + if len(dets) == 0: + return [] + if self.args.with_reid and self.encoder is not None: + features_keep = self.encoder.inference(img, dets) + return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections + else: + return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections + + def get_dists(self, tracks, detections): + """Calculates distances between tracks and detections using IoU and optionally ReID embeddings.""" + dists = matching.iou_distance(tracks, detections) + dists_mask = dists > self.proximity_thresh + + if self.args.fuse_score: + dists = matching.fuse_score(dists, detections) + + if self.args.with_reid and self.encoder is not None: + emb_dists = matching.embedding_distance(tracks, detections) / 2.0 + emb_dists[emb_dists > self.appearance_thresh] = 1.0 + emb_dists[dists_mask] = 1.0 + dists = np.minimum(dists, emb_dists) + return dists + + def multi_predict(self, tracks): + """Predicts the mean and covariance of multiple object tracks using a shared Kalman filter.""" + BOTrack.multi_predict(tracks) + + def reset(self): + """Resets the BOTSORT tracker to its initial state, clearing all tracked objects and internal states.""" + super().reset() + self.gmc.reset_params() diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..807f4ad667867da5469c4a5abe832009915e7703 --- /dev/null +++ b/ultralytics/trackers/byte_tracker.py @@ -0,0 +1,476 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import numpy as np + +from ..utils import LOGGER +from ..utils.ops import xywh2ltwh +from .basetrack import BaseTrack, TrackState +from .utils import matching +from .utils.kalman_filter import KalmanFilterXYAH + + +class STrack(BaseTrack): + """ + Single object tracking representation that uses Kalman filtering for state estimation. + + This class is responsible for storing all the information regarding individual tracklets and performs state updates + and predictions based on Kalman filter. + + Attributes: + shared_kalman (KalmanFilterXYAH): Shared Kalman filter that is used across all STrack instances for prediction. + _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box. + kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track. + mean (np.ndarray): Mean state estimate vector. + covariance (np.ndarray): Covariance of state estimate. + is_activated (bool): Boolean flag indicating if the track has been activated. + score (float): Confidence score of the track. + tracklet_len (int): Length of the tracklet. + cls (Any): Class label for the object. + idx (int): Index or identifier for the object. + frame_id (int): Current frame ID. + start_frame (int): Frame where the object was first detected. + + Methods: + predict(): Predict the next state of the object using Kalman filter. + multi_predict(stracks): Predict the next states for multiple tracks. + multi_gmc(stracks, H): Update multiple track states using a homography matrix. + activate(kalman_filter, frame_id): Activate a new tracklet. + re_activate(new_track, frame_id, new_id): Reactivate a previously lost tracklet. + update(new_track, frame_id): Update the state of a matched track. + convert_coords(tlwh): Convert bounding box to x-y-aspect-height format. + tlwh_to_xyah(tlwh): Convert tlwh bounding box to xyah format. + + Examples: + Initialize and activate a new track + >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person") + >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1) + """ + + shared_kalman = KalmanFilterXYAH() + + def __init__(self, xywh, score, cls): + """ + Initialize a new STrack instance. + + Args: + xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where + (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id. + score (float): Confidence score of the detection. + cls (Any): Class label for the detected object. + + Examples: + >>> xywh = [100.0, 150.0, 50.0, 75.0, 1] + >>> score = 0.9 + >>> cls = "person" + >>> track = STrack(xywh, score, cls) + """ + super().__init__() + # xywh+idx or xywha+idx + assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}" + self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32) + self.kalman_filter = None + self.mean, self.covariance = None, None + self.is_activated = False + + self.score = score + self.tracklet_len = 0 + self.cls = cls + self.idx = xywh[-1] + self.angle = xywh[4] if len(xywh) == 6 else None + + def predict(self): + """Predicts the next state (mean and covariance) of the object using the Kalman filter.""" + mean_state = self.mean.copy() + if self.state != TrackState.Tracked: + mean_state[7] = 0 + self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) + + @staticmethod + def multi_predict(stracks): + """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.""" + if len(stracks) <= 0: + return + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + for i, st in enumerate(stracks): + if st.state != TrackState.Tracked: + multi_mean[i][7] = 0 + multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + stracks[i].mean = mean + stracks[i].covariance = cov + + @staticmethod + def multi_gmc(stracks, H=np.eye(2, 3)): + """Update state tracks positions and covariances using a homography matrix for multiple tracks.""" + if len(stracks) > 0: + multi_mean = np.asarray([st.mean.copy() for st in stracks]) + multi_covariance = np.asarray([st.covariance for st in stracks]) + + R = H[:2, :2] + R8x8 = np.kron(np.eye(4, dtype=float), R) + t = H[:2, 2] + + for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): + mean = R8x8.dot(mean) + mean[:2] += t + cov = R8x8.dot(cov).dot(R8x8.transpose()) + + stracks[i].mean = mean + stracks[i].covariance = cov + + def activate(self, kalman_filter, frame_id): + """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance.""" + self.kalman_filter = kalman_filter + self.track_id = self.next_id() + self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh)) + + self.tracklet_len = 0 + self.state = TrackState.Tracked + if frame_id == 1: + self.is_activated = True + self.frame_id = frame_id + self.start_frame = frame_id + + def re_activate(self, new_track, frame_id, new_id=False): + """Reactivates a previously lost track using new detection data and updates its state and attributes.""" + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_track.tlwh) + ) + self.tracklet_len = 0 + self.state = TrackState.Tracked + self.is_activated = True + self.frame_id = frame_id + if new_id: + self.track_id = self.next_id() + self.score = new_track.score + self.cls = new_track.cls + self.angle = new_track.angle + self.idx = new_track.idx + + def update(self, new_track, frame_id): + """ + Update the state of a matched track. + + Args: + new_track (STrack): The new track containing updated information. + frame_id (int): The ID of the current frame. + + Examples: + Update the state of a track with new detection information + >>> track = STrack([100, 200, 50, 80, 0.9, 1]) + >>> new_track = STrack([105, 205, 55, 85, 0.95, 1]) + >>> track.update(new_track, 2) + """ + self.frame_id = frame_id + self.tracklet_len += 1 + + new_tlwh = new_track.tlwh + self.mean, self.covariance = self.kalman_filter.update( + self.mean, self.covariance, self.convert_coords(new_tlwh) + ) + self.state = TrackState.Tracked + self.is_activated = True + + self.score = new_track.score + self.cls = new_track.cls + self.angle = new_track.angle + self.idx = new_track.idx + + def convert_coords(self, tlwh): + """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.""" + return self.tlwh_to_xyah(tlwh) + + @property + def tlwh(self): + """Returns the bounding box in top-left-width-height format from the current state estimate.""" + if self.mean is None: + return self._tlwh.copy() + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + @property + def xyxy(self): + """Converts bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.""" + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + @staticmethod + def tlwh_to_xyah(tlwh): + """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.""" + ret = np.asarray(tlwh).copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret + + @property + def xywh(self): + """Returns the current position of the bounding box in (center x, center y, width, height) format.""" + ret = np.asarray(self.tlwh).copy() + ret[:2] += ret[2:] / 2 + return ret + + @property + def xywha(self): + """Returns position in (center x, center y, width, height, angle) format, warning if angle is missing.""" + if self.angle is None: + LOGGER.warning("WARNING ⚠️ `angle` attr not found, returning `xywh` instead.") + return self.xywh + return np.concatenate([self.xywh, self.angle[None]]) + + @property + def result(self): + """Returns the current tracking results in the appropriate bounding box format.""" + coords = self.xyxy if self.angle is None else self.xywha + return coords.tolist() + [self.track_id, self.score, self.cls, self.idx] + + def __repr__(self): + """Returns a string representation of the STrack object including start frame, end frame, and track ID.""" + return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" + + +class BYTETracker: + """ + BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking. + + Responsible for initializing, updating, and managing the tracks for detected objects in a video sequence. + It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for predicting + the new object locations, and performs data association. + + Attributes: + tracked_stracks (List[STrack]): List of successfully activated tracks. + lost_stracks (List[STrack]): List of lost tracks. + removed_stracks (List[STrack]): List of removed tracks. + frame_id (int): The current frame ID. + args (Namespace): Command-line arguments. + max_time_lost (int): The maximum frames for a track to be considered as 'lost'. + kalman_filter (KalmanFilterXYAH): Kalman Filter object. + + Methods: + update(results, img=None): Updates object tracker with new detections. + get_kalmanfilter(): Returns a Kalman filter object for tracking bounding boxes. + init_track(dets, scores, cls, img=None): Initialize object tracking with detections. + get_dists(tracks, detections): Calculates the distance between tracks and detections. + multi_predict(tracks): Predicts the location of tracks. + reset_id(): Resets the ID counter of STrack. + joint_stracks(tlista, tlistb): Combines two lists of stracks. + sub_stracks(tlista, tlistb): Filters out the stracks present in the second list from the first list. + remove_duplicate_stracks(stracksa, stracksb): Removes duplicate stracks based on IoU. + + Examples: + Initialize BYTETracker and update with detection results + >>> tracker = BYTETracker(args, frame_rate=30) + >>> results = yolo_model.detect(image) + >>> tracked_objects = tracker.update(results) + """ + + def __init__(self, args, frame_rate=30): + """ + Initialize a BYTETracker instance for object tracking. + + Args: + args (Namespace): Command-line arguments containing tracking parameters. + frame_rate (int): Frame rate of the video sequence. + + Examples: + Initialize BYTETracker with command-line arguments and a frame rate of 30 + >>> args = Namespace(track_buffer=30) + >>> tracker = BYTETracker(args, frame_rate=30) + """ + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + + self.frame_id = 0 + self.args = args + self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer) + self.kalman_filter = self.get_kalmanfilter() + self.reset_id() + + def update(self, results, img=None): + """Updates the tracker with new detections and returns the current list of tracked objects.""" + self.frame_id += 1 + activated_stracks = [] + refind_stracks = [] + lost_stracks = [] + removed_stracks = [] + + scores = results.conf + bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh + # Add index + bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1) + cls = results.cls + + remain_inds = scores >= self.args.track_high_thresh + inds_low = scores > self.args.track_low_thresh + inds_high = scores < self.args.track_high_thresh + + inds_second = inds_low & inds_high + dets_second = bboxes[inds_second] + dets = bboxes[remain_inds] + scores_keep = scores[remain_inds] + scores_second = scores[inds_second] + cls_keep = cls[remain_inds] + cls_second = cls[inds_second] + + detections = self.init_track(dets, scores_keep, cls_keep, img) + # Add newly detected tracklets to tracked_stracks + unconfirmed = [] + tracked_stracks = [] # type: list[STrack] + for track in self.tracked_stracks: + if not track.is_activated: + unconfirmed.append(track) + else: + tracked_stracks.append(track) + # Step 2: First association, with high score detection boxes + strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks) + # Predict the current location with KF + self.multi_predict(strack_pool) + if hasattr(self, "gmc") and img is not None: + warp = self.gmc.apply(img, dets) + STrack.multi_gmc(strack_pool, warp) + STrack.multi_gmc(unconfirmed, warp) + + dists = self.get_dists(strack_pool, detections) + matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh) + + for itracked, idet in matches: + track = strack_pool[itracked] + det = detections[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_stracks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + # Step 3: Second association, with low score detection boxes association the untrack to the low score detections + detections_second = self.init_track(dets_second, scores_second, cls_second, img) + r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] + # TODO + dists = matching.iou_distance(r_tracked_stracks, detections_second) + matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) + for itracked, idet in matches: + track = r_tracked_stracks[itracked] + det = detections_second[idet] + if track.state == TrackState.Tracked: + track.update(det, self.frame_id) + activated_stracks.append(track) + else: + track.re_activate(det, self.frame_id, new_id=False) + refind_stracks.append(track) + + for it in u_track: + track = r_tracked_stracks[it] + if track.state != TrackState.Lost: + track.mark_lost() + lost_stracks.append(track) + # Deal with unconfirmed tracks, usually tracks with only one beginning frame + detections = [detections[i] for i in u_detection] + dists = self.get_dists(unconfirmed, detections) + matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) + for itracked, idet in matches: + unconfirmed[itracked].update(detections[idet], self.frame_id) + activated_stracks.append(unconfirmed[itracked]) + for it in u_unconfirmed: + track = unconfirmed[it] + track.mark_removed() + removed_stracks.append(track) + # Step 4: Init new stracks + for inew in u_detection: + track = detections[inew] + if track.score < self.args.new_track_thresh: + continue + track.activate(self.kalman_filter, self.frame_id) + activated_stracks.append(track) + # Step 5: Update state + for track in self.lost_stracks: + if self.frame_id - track.end_frame > self.max_time_lost: + track.mark_removed() + removed_stracks.append(track) + + self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] + self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks) + self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks) + self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks) + self.lost_stracks.extend(lost_stracks) + self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks) + self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) + self.removed_stracks.extend(removed_stracks) + if len(self.removed_stracks) > 1000: + self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum + + return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32) + + def get_kalmanfilter(self): + """Returns a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.""" + return KalmanFilterXYAH() + + def init_track(self, dets, scores, cls, img=None): + """Initializes object tracking with given detections, scores, and class labels using the STrack algorithm.""" + return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections + + def get_dists(self, tracks, detections): + """Calculates the distance between tracks and detections using IoU and optionally fuses scores.""" + dists = matching.iou_distance(tracks, detections) + if self.args.fuse_score: + dists = matching.fuse_score(dists, detections) + return dists + + def multi_predict(self, tracks): + """Predict the next states for multiple tracks using Kalman filter.""" + STrack.multi_predict(tracks) + + @staticmethod + def reset_id(): + """Resets the ID counter for STrack instances to ensure unique track IDs across tracking sessions.""" + STrack.reset_id() + + def reset(self): + """Resets the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.""" + self.tracked_stracks = [] # type: list[STrack] + self.lost_stracks = [] # type: list[STrack] + self.removed_stracks = [] # type: list[STrack] + self.frame_id = 0 + self.kalman_filter = self.get_kalmanfilter() + self.reset_id() + + @staticmethod + def joint_stracks(tlista, tlistb): + """Combines two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.""" + exists = {} + res = [] + for t in tlista: + exists[t.track_id] = 1 + res.append(t) + for t in tlistb: + tid = t.track_id + if not exists.get(tid, 0): + exists[tid] = 1 + res.append(t) + return res + + @staticmethod + def sub_stracks(tlista, tlistb): + """Filters out the stracks present in the second list from the first list.""" + track_ids_b = {t.track_id for t in tlistb} + return [t for t in tlista if t.track_id not in track_ids_b] + + @staticmethod + def remove_duplicate_stracks(stracksa, stracksb): + """Removes duplicate stracks from two lists based on Intersection over Union (IoU) distance.""" + pdist = matching.iou_distance(stracksa, stracksb) + pairs = np.where(pdist < 0.15) + dupa, dupb = [], [] + for p, q in zip(*pairs): + timep = stracksa[p].frame_id - stracksa[p].start_frame + timeq = stracksb[q].frame_id - stracksb[q].start_frame + if timep > timeq: + dupb.append(q) + else: + dupa.append(p) + resa = [t for i, t in enumerate(stracksa) if i not in dupa] + resb = [t for i, t in enumerate(stracksb) if i not in dupb] + return resa, resb diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py new file mode 100644 index 0000000000000000000000000000000000000000..e55db6d43d66d520ced5d7295c5d4bb359f96843 --- /dev/null +++ b/ultralytics/trackers/track.py @@ -0,0 +1,104 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from functools import partial +from pathlib import Path + +import torch + +from ultralytics.utils import IterableSimpleNamespace, yaml_load +from ultralytics.utils.checks import check_yaml + +from .bot_sort import BOTSORT +from .byte_tracker import BYTETracker + +# A mapping of tracker types to corresponding tracker classes +TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} + + +def on_predict_start(predictor: object, persist: bool = False) -> None: + """ + Initialize trackers for object tracking during prediction. + + Args: + predictor (object): The predictor object to initialize trackers for. + persist (bool): Whether to persist the trackers if they already exist. + + Raises: + AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. + + Examples: + Initialize trackers for a predictor object: + >>> predictor = SomePredictorClass() + >>> on_predict_start(predictor, persist=True) + """ + if hasattr(predictor, "trackers") and persist: + return + + tracker = check_yaml(predictor.args.tracker) + cfg = IterableSimpleNamespace(**yaml_load(tracker)) + + if cfg.tracker_type not in {"bytetrack", "botsort"}: + raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'") + + trackers = [] + for _ in range(predictor.dataset.bs): + tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) + trackers.append(tracker) + if predictor.dataset.mode != "stream": # only need one tracker for other modes. + break + predictor.trackers = trackers + predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video + + +def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None: + """ + Postprocess detected boxes and update with object tracking. + + Args: + predictor (object): The predictor object containing the predictions. + persist (bool): Whether to persist the trackers if they already exist. + + Examples: + Postprocess predictions and update with tracking + >>> predictor = YourPredictorClass() + >>> on_predict_postprocess_end(predictor, persist=True) + """ + path, im0s = predictor.batch[:2] + + is_obb = predictor.args.task == "obb" + is_stream = predictor.dataset.mode == "stream" + for i in range(len(im0s)): + tracker = predictor.trackers[i if is_stream else 0] + vid_path = predictor.save_dir / Path(path[i]).name + if not persist and predictor.vid_path[i if is_stream else 0] != vid_path: + tracker.reset() + predictor.vid_path[i if is_stream else 0] = vid_path + + det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() + if len(det) == 0: + continue + tracks = tracker.update(det, im0s[i]) + if len(tracks) == 0: + continue + idx = tracks[:, -1].astype(int) + predictor.results[i] = predictor.results[i][idx] + + update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])} + predictor.results[i].update(**update_args) + + +def register_tracker(model: object, persist: bool) -> None: + """ + Register tracking callbacks to the model for object tracking during prediction. + + Args: + model (object): The model object to register tracking callbacks for. + persist (bool): Whether to persist the trackers if they already exist. + + Examples: + Register tracking callbacks to a YOLO model + >>> model = YOLOModel() + >>> register_tracker(model, persist=True) + """ + model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) + model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) diff --git a/ultralytics/trackers/utils/__init__.py b/ultralytics/trackers/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..77a19dcf0f8093de419453747db2e7e719f96349 --- /dev/null +++ b/ultralytics/trackers/utils/__init__.py @@ -0,0 +1 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py new file mode 100644 index 0000000000000000000000000000000000000000..e3cd2dc88ca8caf4f452326b14b42e684a380f37 --- /dev/null +++ b/ultralytics/trackers/utils/gmc.py @@ -0,0 +1,377 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import copy + +import cv2 +import numpy as np + +from ultralytics.utils import LOGGER + + +class GMC: + """ + Generalized Motion Compensation (GMC) class for tracking and object detection in video frames. + + This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB, + SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency. + + Attributes: + method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. + downscale (int): Factor by which to downscale the frames for processing. + prevFrame (np.ndarray): Stores the previous frame for tracking. + prevKeyPoints (List): Stores the keypoints from the previous frame. + prevDescriptors (np.ndarray): Stores the descriptors from the previous frame. + initializedFirstFrame (bool): Flag to indicate if the first frame has been processed. + + Methods: + __init__: Initializes a GMC object with the specified method and downscale factor. + apply: Applies the chosen method to a raw frame and optionally uses provided detections. + apply_ecc: Applies the ECC algorithm to a raw frame. + apply_features: Applies feature-based methods like ORB or SIFT to a raw frame. + apply_sparseoptflow: Applies the Sparse Optical Flow method to a raw frame. + reset_params: Resets the internal parameters of the GMC object. + + Examples: + Create a GMC object and apply it to a frame + >>> gmc = GMC(method="sparseOptFlow", downscale=2) + >>> frame = np.array([[1, 2, 3], [4, 5, 6]]) + >>> processed_frame = gmc.apply(frame) + >>> print(processed_frame) + array([[1, 2, 3], + [4, 5, 6]]) + """ + + def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None: + """ + Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor. + + Args: + method (str): The method used for tracking. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. + downscale (int): Downscale factor for processing frames. + + Examples: + Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2 + >>> gmc = GMC(method="sparseOptFlow", downscale=2) + """ + super().__init__() + + self.method = method + self.downscale = max(1, downscale) + + if self.method == "orb": + self.detector = cv2.FastFeatureDetector_create(20) + self.extractor = cv2.ORB_create() + self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING) + + elif self.method == "sift": + self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) + self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) + self.matcher = cv2.BFMatcher(cv2.NORM_L2) + + elif self.method == "ecc": + number_of_iterations = 5000 + termination_eps = 1e-6 + self.warp_mode = cv2.MOTION_EUCLIDEAN + self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) + + elif self.method == "sparseOptFlow": + self.feature_params = dict( + maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04 + ) + + elif self.method in {"none", "None", None}: + self.method = None + else: + raise ValueError(f"Error: Unknown GMC method:{method}") + + self.prevFrame = None + self.prevKeyPoints = None + self.prevDescriptors = None + self.initializedFirstFrame = False + + def apply(self, raw_frame: np.array, detections: list = None) -> np.array: + """ + Apply object detection on a raw frame using the specified method. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + detections (List | None): List of detections to be used in the processing. + + Returns: + (np.ndarray): Processed frame with applied object detection. + + Examples: + >>> gmc = GMC(method="sparseOptFlow") + >>> raw_frame = np.random.rand(480, 640, 3) + >>> processed_frame = gmc.apply(raw_frame) + >>> print(processed_frame.shape) + (480, 640, 3) + """ + if self.method in {"orb", "sift"}: + return self.apply_features(raw_frame, detections) + elif self.method == "ecc": + return self.apply_ecc(raw_frame) + elif self.method == "sparseOptFlow": + return self.apply_sparseoptflow(raw_frame) + else: + return np.eye(2, 3) + + def apply_ecc(self, raw_frame: np.array) -> np.array: + """ + Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + + Returns: + (np.ndarray): The processed frame with the applied ECC transformation. + + Examples: + >>> gmc = GMC(method="ecc") + >>> processed_frame = gmc.apply_ecc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) + >>> print(processed_frame) + [[1. 0. 0.] + [0. 1. 0.]] + """ + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3, dtype=np.float32) + + # Downscale image + if self.downscale > 1.0: + frame = cv2.GaussianBlur(frame, (3, 3), 1.5) + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + + # Handle first frame + if not self.initializedFirstFrame: + # Initialize data + self.prevFrame = frame.copy() + + # Initialization done + self.initializedFirstFrame = True + + return H + + # Run the ECC algorithm. The results are stored in warp_matrix. + # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria) + try: + (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) + except Exception as e: + LOGGER.warning(f"WARNING: find transform failed. Set warp as identity {e}") + + return H + + def apply_features(self, raw_frame: np.array, detections: list = None) -> np.array: + """ + Apply feature-based methods like ORB or SIFT to a raw frame. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + detections (List | None): List of detections to be used in the processing. + + Returns: + (np.ndarray): Processed frame. + + Examples: + >>> gmc = GMC(method="orb") + >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) + >>> processed_frame = gmc.apply_features(raw_frame) + >>> print(processed_frame.shape) + (2, 3) + """ + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3) + + # Downscale image + if self.downscale > 1.0: + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + width = width // self.downscale + height = height // self.downscale + + # Find the keypoints + mask = np.zeros_like(frame) + mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255 + if detections is not None: + for det in detections: + tlbr = (det[:4] / self.downscale).astype(np.int_) + mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0 + + keypoints = self.detector.detect(frame, mask) + + # Compute the descriptors + keypoints, descriptors = self.extractor.compute(frame, keypoints) + + # Handle first frame + if not self.initializedFirstFrame: + # Initialize data + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + # Initialization done + self.initializedFirstFrame = True + + return H + + # Match descriptors + knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2) + + # Filter matches based on smallest spatial distance + matches = [] + spatialDistances = [] + + maxSpatialDistance = 0.25 * np.array([width, height]) + + # Handle empty matches case + if len(knnMatches) == 0: + # Store to next iteration + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + return H + + for m, n in knnMatches: + if m.distance < 0.9 * n.distance: + prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt + currKeyPointLocation = keypoints[m.trainIdx].pt + + spatialDistance = ( + prevKeyPointLocation[0] - currKeyPointLocation[0], + prevKeyPointLocation[1] - currKeyPointLocation[1], + ) + + if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and ( + np.abs(spatialDistance[1]) < maxSpatialDistance[1] + ): + spatialDistances.append(spatialDistance) + matches.append(m) + + meanSpatialDistances = np.mean(spatialDistances, 0) + stdSpatialDistances = np.std(spatialDistances, 0) + + inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances + + goodMatches = [] + prevPoints = [] + currPoints = [] + for i in range(len(matches)): + if inliers[i, 0] and inliers[i, 1]: + goodMatches.append(matches[i]) + prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt) + currPoints.append(keypoints[matches[i].trainIdx].pt) + + prevPoints = np.array(prevPoints) + currPoints = np.array(currPoints) + + # Draw the keypoint matches on the output image + # if False: + # import matplotlib.pyplot as plt + # matches_img = np.hstack((self.prevFrame, frame)) + # matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR) + # W = self.prevFrame.shape[1] + # for m in goodMatches: + # prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_) + # curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_) + # curr_pt[0] += W + # color = np.random.randint(0, 255, 3) + # color = (int(color[0]), int(color[1]), int(color[2])) + # + # matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA) + # matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1) + # matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1) + # + # plt.figure() + # plt.imshow(matches_img) + # plt.show() + + # Find rigid matrix + if prevPoints.shape[0] > 4: + H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + # Handle downscale + if self.downscale > 1.0: + H[0, 2] *= self.downscale + H[1, 2] *= self.downscale + else: + LOGGER.warning("WARNING: not enough matching points") + + # Store to next iteration + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.prevDescriptors = copy.copy(descriptors) + + return H + + def apply_sparseoptflow(self, raw_frame: np.array) -> np.array: + """ + Apply Sparse Optical Flow method to a raw frame. + + Args: + raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). + + Returns: + (np.ndarray): Processed frame with shape (2, 3). + + Examples: + >>> gmc = GMC() + >>> result = gmc.apply_sparseoptflow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) + >>> print(result) + [[1. 0. 0.] + [0. 1. 0.]] + """ + height, width, _ = raw_frame.shape + frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) + H = np.eye(2, 3) + + # Downscale image + if self.downscale > 1.0: + frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) + + # Find the keypoints + keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params) + + # Handle first frame + if not self.initializedFirstFrame or self.prevKeyPoints is None: + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + self.initializedFirstFrame = True + return H + + # Find correspondences + matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None) + + # Leave good correspondences only + prevPoints = [] + currPoints = [] + + for i in range(len(status)): + if status[i]: + prevPoints.append(self.prevKeyPoints[i]) + currPoints.append(matchedKeypoints[i]) + + prevPoints = np.array(prevPoints) + currPoints = np.array(currPoints) + + # Find rigid matrix + if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == prevPoints.shape[0]): + H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) + + if self.downscale > 1.0: + H[0, 2] *= self.downscale + H[1, 2] *= self.downscale + else: + LOGGER.warning("WARNING: not enough matching points") + + self.prevFrame = frame.copy() + self.prevKeyPoints = copy.copy(keypoints) + + return H + + def reset_params(self) -> None: + """Reset the internal parameters including previous frame, keypoints, and descriptors.""" + self.prevFrame = None + self.prevKeyPoints = None + self.prevDescriptors = None + self.initializedFirstFrame = False diff --git a/ultralytics/trackers/utils/kalman_filter.py b/ultralytics/trackers/utils/kalman_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..8a212ba63a9ad027612929fbe441bc9da4d76dcd --- /dev/null +++ b/ultralytics/trackers/utils/kalman_filter.py @@ -0,0 +1,491 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import numpy as np +import scipy.linalg + + +class KalmanFilterXYAH: + """ + A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter. + + Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space + (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their + respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is + taken as a direct observation of the state space (linear observation model). + + Attributes: + _motion_mat (np.ndarray): The motion matrix for the Kalman filter. + _update_mat (np.ndarray): The update matrix for the Kalman filter. + _std_weight_position (float): Standard deviation weight for position. + _std_weight_velocity (float): Standard deviation weight for velocity. + + Methods: + initiate: Creates a track from an unassociated measurement. + predict: Runs the Kalman filter prediction step. + project: Projects the state distribution to measurement space. + multi_predict: Runs the Kalman filter prediction step (vectorized version). + update: Runs the Kalman filter correction step. + gating_distance: Computes the gating distance between state distribution and measurements. + + Examples: + Initialize the Kalman filter and create a track from a measurement + >>> kf = KalmanFilterXYAH() + >>> measurement = np.array([100, 200, 1.5, 50]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + >>> print(covariance) + """ + + def __init__(self): + """ + Initialize Kalman filter model matrices with motion and observation uncertainty weights. + + The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y) + represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective + velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear + observation model for bounding box location. + + Examples: + Initialize a Kalman filter for tracking: + >>> kf = KalmanFilterXYAH() + """ + ndim, dt = 4, 1.0 + + # Create Kalman filter model matrices + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current state estimate. These weights control + # the amount of uncertainty in the model. + self._std_weight_position = 1.0 / 20 + self._std_weight_velocity = 1.0 / 160 + + def initiate(self, measurement: np.ndarray) -> tuple: + """ + Create a track from an unassociated measurement. + + Args: + measurement (ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, + and height h. + + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector (8-dimensional) and covariance matrix (8x8 dimensional) + of the new track. Unobserved velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> measurement = np.array([100, 50, 1.5, 200]) + >>> mean, covariance = kf.initiate(measurement) + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3], + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: + """ + Run Kalman filter prediction step. + + Args: + mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step. + covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. + + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved + velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3], + ] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3], + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: + """ + Project state distribution to measurement space. + + Args: + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + + Returns: + (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> projected_mean, projected_covariance = kf.project(mean, covariance) + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3], + ] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean: np.ndarray, covariance: np.ndarray) -> tuple: + """ + Run Kalman filter prediction step for multiple object states (Vectorized version). + + Args: + mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. + covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. + + Returns: + (tuple[ndarray, ndarray]): Returns the mean matrix and covariance matrix of the predicted states. + The mean matrix has shape (N, 8) and the covariance matrix has shape (N, 8, 8). Unobserved velocities + are initialized to 0 mean. + + Examples: + >>> mean = np.random.rand(10, 8) # 10 object states + >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states + >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 3], + 1e-2 * np.ones_like(mean[:, 3]), + self._std_weight_position * mean[:, 3], + ] + std_vel = [ + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 3], + 1e-5 * np.ones_like(mean[:, 3]), + self._std_weight_velocity * mean[:, 3], + ] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray) -> tuple: + """ + Run Kalman filter correction step. + + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + measurement (ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center + position, a the aspect ratio, and h the height of the bounding box. + + Returns: + (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. + + Examples: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurement = np.array([1, 1, 1, 1]) + >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False + ).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance( + self, + mean: np.ndarray, + covariance: np.ndarray, + measurements: np.ndarray, + only_position: bool = False, + metric: str = "maha", + ) -> np.ndarray: + """ + Compute gating distance between state distribution and measurements. + + A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square + distribution has 4 degrees of freedom, otherwise 2. + + Args: + mean (ndarray): Mean vector over the state distribution (8 dimensional). + covariance (ndarray): Covariance of the state distribution (8x8 dimensional). + measurements (ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the + bounding box center position, a the aspect ratio, and h the height. + only_position (bool): If True, distance computation is done with respect to box center position only. + metric (str): The metric to use for calculating the distance. Options are 'gaussian' for the squared + Euclidean distance and 'maha' for the squared Mahalanobis distance. + + Returns: + (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between + (mean, covariance) and `measurements[i]`. + + Examples: + Compute gating distance using Mahalanobis metric: + >>> kf = KalmanFilterXYAH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]]) + >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha") + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + d = measurements - mean + if metric == "gaussian": + return np.sum(d * d, axis=1) + elif metric == "maha": + cholesky_factor = np.linalg.cholesky(covariance) + z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) + return np.sum(z * z, axis=0) # square maha + else: + raise ValueError("Invalid distance metric") + + +class KalmanFilterXYWH(KalmanFilterXYAH): + """ + A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter. + + Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where + (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities. + The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct + observation of the state space (linear observation model). + + Attributes: + _motion_mat (np.ndarray): The motion matrix for the Kalman filter. + _update_mat (np.ndarray): The update matrix for the Kalman filter. + _std_weight_position (float): Standard deviation weight for position. + _std_weight_velocity (float): Standard deviation weight for velocity. + + Methods: + initiate: Creates a track from an unassociated measurement. + predict: Runs the Kalman filter prediction step. + project: Projects the state distribution to measurement space. + multi_predict: Runs the Kalman filter prediction step in a vectorized manner. + update: Runs the Kalman filter correction step. + + Examples: + Create a Kalman filter and initialize a track + >>> kf = KalmanFilterXYWH() + >>> measurement = np.array([100, 50, 20, 40]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + >>> print(covariance) + """ + + def initiate(self, measurement: np.ndarray) -> tuple: + """ + Create track from unassociated measurement. + + Args: + measurement (ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height. + + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector (8 dimensional) and covariance matrix (8x8 dimensional) + of the new track. Unobserved velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> measurement = np.array([100, 50, 20, 40]) + >>> mean, covariance = kf.initiate(measurement) + >>> print(mean) + [100. 50. 20. 40. 0. 0. 0. 0.] + >>> print(covariance) + [[ 4. 0. 0. 0. 0. 0. 0. 0.] + [ 0. 4. 0. 0. 0. 0. 0. 0.] + [ 0. 0. 4. 0. 0. 0. 0. 0.] + [ 0. 0. 0. 4. 0. 0. 0. 0.] + [ 0. 0. 0. 0. 0.25 0. 0. 0.] + [ 0. 0. 0. 0. 0. 0.25 0. 0.] + [ 0. 0. 0. 0. 0. 0. 0.25 0.] + [ 0. 0. 0. 0. 0. 0. 0. 0.25]] + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[2], + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[2], + 10 * self._std_weight_velocity * measurement[3], + ] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance) -> tuple: + """ + Run Kalman filter prediction step. + + Args: + mean (ndarray): The 8-dimensional mean vector of the object state at the previous time step. + covariance (ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. + + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved + velocities are initialized to 0 mean. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] + std_vel = [ + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[2], + self._std_weight_velocity * mean[3], + ] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(mean, self._motion_mat.T) + covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance) -> tuple: + """ + Project state distribution to measurement space. + + Args: + mean (ndarray): The state's mean vector (8 dimensional array). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + + Returns: + (tuple[ndarray, ndarray]): Returns the projected mean and covariance matrix of the given state estimate. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> projected_mean, projected_cov = kf.project(mean, covariance) + """ + std = [ + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + self._std_weight_position * mean[2], + self._std_weight_position * mean[3], + ] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def multi_predict(self, mean, covariance) -> tuple: + """ + Run Kalman filter prediction step (Vectorized version). + + Args: + mean (ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. + covariance (ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. + + Returns: + (tuple[ndarray, ndarray]): Returns the mean vector and covariance matrix of the predicted state. Unobserved + velocities are initialized to 0 mean. + + Examples: + >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors + >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices + >>> kf = KalmanFilterXYWH() + >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance) + """ + std_pos = [ + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + self._std_weight_position * mean[:, 2], + self._std_weight_position * mean[:, 3], + ] + std_vel = [ + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + self._std_weight_velocity * mean[:, 2], + self._std_weight_velocity * mean[:, 3], + ] + sqr = np.square(np.r_[std_pos, std_vel]).T + + motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] + motion_cov = np.asarray(motion_cov) + + mean = np.dot(mean, self._motion_mat.T) + left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) + covariance = np.dot(left, self._motion_mat.T) + motion_cov + + return mean, covariance + + def update(self, mean, covariance, measurement) -> tuple: + """ + Run Kalman filter correction step. + + Args: + mean (ndarray): The predicted state's mean vector (8 dimensional). + covariance (ndarray): The state's covariance matrix (8x8 dimensional). + measurement (ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center + position, w the width, and h the height of the bounding box. + + Returns: + (tuple[ndarray, ndarray]): Returns the measurement-corrected state distribution. + + Examples: + >>> kf = KalmanFilterXYWH() + >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) + >>> covariance = np.eye(8) + >>> measurement = np.array([0.5, 0.5, 1.2, 1.2]) + >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) + """ + return super().update(mean, covariance, measurement) diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b354f1290eb38168cf4b516b7d06fff9dc5fe4 --- /dev/null +++ b/ultralytics/trackers/utils/matching.py @@ -0,0 +1,157 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import numpy as np +import scipy +from scipy.spatial.distance import cdist + +from ultralytics.utils.metrics import batch_probiou, bbox_ioa + +try: + import lap # for linear_assignment + + assert lap.__version__ # verify package is not directory +except (ImportError, AssertionError, AttributeError): + from ultralytics.utils.checks import check_requirements + + check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap + import lap + + +def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True) -> tuple: + """ + Perform linear assignment using either the scipy or lap.lapjv method. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + thresh (float): Threshold for considering an assignment valid. + use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. + + Returns: + matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. + unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). + unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). + + Examples: + >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> thresh = 5.0 + >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) + """ + if cost_matrix.size == 0: + return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) + + if use_lap: + # Use lap.lapjv + # https://github.com/gatagat/lap + _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) + matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] + unmatched_a = np.where(x < 0)[0] + unmatched_b = np.where(y < 0)[0] + else: + # Use scipy.optimize.linear_sum_assignment + # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html + x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y + matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) + if len(matches) == 0: + unmatched_a = list(np.arange(cost_matrix.shape[0])) + unmatched_b = list(np.arange(cost_matrix.shape[1])) + else: + unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0])) + unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1])) + + return matches, unmatched_a, unmatched_b + + +def iou_distance(atracks: list, btracks: list) -> np.ndarray: + """ + Compute cost based on Intersection over Union (IoU) between tracks. + + Args: + atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes. + btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes. + + Returns: + (np.ndarray): Cost matrix computed based on IoU. + + Examples: + Compute IoU distance between two sets of tracks + >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] + >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] + >>> cost_matrix = iou_distance(atracks, btracks) + """ + if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): + atlbrs = atracks + btlbrs = btracks + else: + atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] + btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] + + ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) + if len(atlbrs) and len(btlbrs): + if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: + ious = batch_probiou( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + ).numpy() + else: + ious = bbox_ioa( + np.ascontiguousarray(atlbrs, dtype=np.float32), + np.ascontiguousarray(btlbrs, dtype=np.float32), + iou=True, + ) + return 1 - ious # cost matrix + + +def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: + """ + Compute distance between tracks and detections based on embeddings. + + Args: + tracks (list[STrack]): List of tracks, where each track contains embedding features. + detections (list[BaseTrack]): List of detections, where each detection contains embedding features. + metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. + + Returns: + (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks + and M is the number of detections. + + Examples: + Compute the embedding distance between tracks and detections using cosine metric + >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features + >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features + >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine") + """ + cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) + if cost_matrix.size == 0: + return cost_matrix + det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) + # for i, track in enumerate(tracks): + # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) + track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) + cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features + return cost_matrix + + +def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: + """ + Fuses cost matrix with detection scores to produce a single similarity matrix. + + Args: + cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). + detections (list[BaseTrack]): List of detections, each containing a score attribute. + + Returns: + (np.ndarray): Fused similarity matrix with shape (N, M). + + Examples: + Fuse a cost matrix with detection scores + >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections + >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] + >>> fused_matrix = fuse_score(cost_matrix, detections) + """ + if cost_matrix.size == 0: + return cost_matrix + iou_sim = 1 - cost_matrix + det_scores = np.array([det.score for det in detections]) + det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) + fuse_sim = iou_sim * det_scores + return 1 - fuse_sim # fuse_cost diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa3039236f211abffdb320262fe44528deab98fd --- /dev/null +++ b/ultralytics/utils/__init__.py @@ -0,0 +1,1331 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import importlib.metadata +import inspect +import json +import logging.config +import os +import platform +import re +import subprocess +import sys +import threading +import time +import uuid +from pathlib import Path +from threading import Lock +from types import SimpleNamespace +from typing import Union +from urllib.parse import unquote + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +import yaml +from tqdm import tqdm as tqdm_original + +from ultralytics import __version__ + +# PyTorch Multi-GPU DDP Constants +RANK = int(os.getenv("RANK", -1)) +LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html + +# Other Constants +ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] +FILE = Path(__file__).resolve() +ROOT = FILE.parents[1] # YOLO +ASSETS = ROOT / "assets" # default images +ASSETS_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0" # assets GitHub URL +DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml" +DEFAULT_SOL_CFG_PATH = ROOT / "cfg/solutions/default.yaml" # Ultralytics solutions yaml path +NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLO multiprocessing threads +AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode +VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode +TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format +LOGGING_NAME = "ultralytics" +MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans +ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans +PYTHON_VERSION = platform.python_version() +TORCH_VERSION = torch.__version__ +TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision +IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode" +HELP_MSG = """ + Examples for running Ultralytics: + + 1. Install the ultralytics package: + + pip install ultralytics + + 2. Use the Python SDK: + + from ultralytics import YOLO + + # Load a model + model = YOLO("yolo11n.yaml") # build a new model from scratch + model = YOLO("yolo11n.pt") # load a pretrained model (recommended for training) + + # Use the model + results = model.train(data="coco8.yaml", epochs=3) # train the model + results = model.val() # evaluate model performance on the validation set + results = model("https://ultralytics.com/images/bus.jpg") # predict on an image + success = model.export(format="onnx") # export the model to ONNX format + + 3. Use the command line interface (CLI): + + Ultralytics 'yolo' CLI commands use the following syntax: + + yolo TASK MODE ARGS + + Where TASK (optional) is one of [detect, segment, classify, pose, obb] + MODE (required) is one of [train, val, predict, export, track, benchmark] + ARGS (optional) are any number of custom "arg=value" pairs like "imgsz=320" that override defaults. + See all ARGS at https://docs.ultralytics.com/usage/cfg or with "yolo cfg" + + - Train a detection model for 10 epochs with an initial learning_rate of 0.01 + yolo detect train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 + + - Predict a YouTube video using a pretrained segmentation model at image size 320: + yolo segment predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 + + - Val a pretrained detection model at batch-size 1 and image size 640: + yolo detect val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 + + - Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) + yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 + + - Run special commands: + yolo help + yolo checks + yolo version + yolo settings + yolo copy-cfg + yolo cfg + + Docs: https://docs.ultralytics.com + Community: https://community.ultralytics.com + GitHub: https://github.com/ultralytics/ultralytics + """ + +# Settings and Environment Variables +torch.set_printoptions(linewidth=320, precision=4, profile="default") +np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5 +cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) +os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training to avoid CUDA warning +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warnings in Colab +os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings +os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs + + +class TQDM(tqdm_original): + """ + A custom TQDM progress bar class that extends the original tqdm functionality. + + This class modifies the behavior of the original tqdm progress bar based on global settings and provides + additional customization options. + + Attributes: + disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and + any passed 'disable' argument. + bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not + explicitly set. + + Methods: + __init__: Initializes the TQDM object with custom settings. + + Examples: + >>> from ultralytics.utils import TQDM + >>> for i in TQDM(range(100)): + ... # Your processing code here + ... pass + """ + + def __init__(self, *args, **kwargs): + """ + Initializes a custom TQDM progress bar. + + This class extends the original tqdm class to provide customized behavior for Ultralytics projects. + + Args: + *args (Any): Variable length argument list to be passed to the original tqdm constructor. + **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor. + + Notes: + - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs. + - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs. + + Examples: + >>> from ultralytics.utils import TQDM + >>> for i in TQDM(range(100)): + ... # Your code here + ... pass + """ + kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) # logical 'and' with default value if passed + kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed + super().__init__(*args, **kwargs) + + +class SimpleClass: + """ + A simple base class for creating objects with string representations of their attributes. + + This class provides a foundation for creating objects that can be easily printed or represented as strings, + showing all their non-callable attributes. It's useful for debugging and introspection of object states. + + Methods: + __str__: Returns a human-readable string representation of the object. + __repr__: Returns a machine-readable string representation of the object. + __getattr__: Provides a custom attribute access error message with helpful information. + + Examples: + >>> class MyClass(SimpleClass): + ... def __init__(self): + ... self.x = 10 + ... self.y = "hello" + >>> obj = MyClass() + >>> print(obj) + __main__.MyClass object with attributes: + + x: 10 + y: 'hello' + + Notes: + - This class is designed to be subclassed. It provides a convenient way to inspect object attributes. + - The string representation includes the module and class name of the object. + - Callable attributes and attributes starting with an underscore are excluded from the string representation. + """ + + def __str__(self): + """Return a human-readable string representation of the object.""" + attr = [] + for a in dir(self): + v = getattr(self, a) + if not callable(v) and not a.startswith("_"): + if isinstance(v, SimpleClass): + # Display only the module and class name for subclasses + s = f"{a}: {v.__module__}.{v.__class__.__name__} object" + else: + s = f"{a}: {repr(v)}" + attr.append(s) + return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr) + + def __repr__(self): + """Return a machine-readable string representation of the object.""" + return self.__str__() + + def __getattr__(self, attr): + """Custom attribute access error message with helpful information.""" + name = self.__class__.__name__ + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + +class IterableSimpleNamespace(SimpleNamespace): + """ + An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration. + + This class extends the SimpleNamespace class with additional methods for iteration, string representation, + and attribute access. It is designed to be used as a convenient container for storing and accessing + configuration parameters. + + Methods: + __iter__: Returns an iterator of key-value pairs from the namespace's attributes. + __str__: Returns a human-readable string representation of the object. + __getattr__: Provides a custom attribute access error message with helpful information. + get: Retrieves the value of a specified key, or a default value if the key doesn't exist. + + Examples: + >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3) + >>> for k, v in cfg: + ... print(f"{k}: {v}") + a: 1 + b: 2 + c: 3 + >>> print(cfg) + a=1 + b=2 + c=3 + >>> cfg.get("b") + 2 + >>> cfg.get("d", "default") + 'default' + + Notes: + This class is particularly useful for storing configuration parameters in a more accessible + and iterable format compared to a standard dictionary. + """ + + def __iter__(self): + """Return an iterator of key-value pairs from the namespace's attributes.""" + return iter(vars(self).items()) + + def __str__(self): + """Return a human-readable string representation of the object.""" + return "\n".join(f"{k}={v}" for k, v in vars(self).items()) + + def __getattr__(self, attr): + """Custom attribute access error message with helpful information.""" + name = self.__class__.__name__ + raise AttributeError( + f""" + '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics + 'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace + {DEFAULT_CFG_PATH} with the latest version from + https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml + """ + ) + + def get(self, key, default=None): + """Return the value of the specified key if it exists; otherwise, return the default value.""" + return getattr(self, key, default) + + +def plt_settings(rcparams=None, backend="Agg"): + """ + Decorator to temporarily set rc parameters and the backend for a plotting function. + + Example: + decorator: @plt_settings({"font.size": 12}) + context manager: with plt_settings({"font.size": 12}): + + Args: + rcparams (dict): Dictionary of rc parameters to set. + backend (str, optional): Name of the backend to use. Defaults to 'Agg'. + + Returns: + (Callable): Decorated function with temporarily set rc parameters and backend. This decorator can be + applied to any function that needs to have specific matplotlib rc parameters and backend for its execution. + """ + if rcparams is None: + rcparams = {"font.size": 11} + + def decorator(func): + """Decorator to apply temporary rc parameters and backend to a function.""" + + def wrapper(*args, **kwargs): + """Sets rc parameters and backend, calls the original function, and restores the settings.""" + original_backend = plt.get_backend() + switch = backend.lower() != original_backend.lower() + if switch: + plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8 + plt.switch_backend(backend) + + # Plot with backend and always revert to original backend + try: + with plt.rc_context(rcparams): + result = func(*args, **kwargs) + finally: + if switch: + plt.close("all") + plt.switch_backend(original_backend) + return result + + return wrapper + + return decorator + + +def set_logging(name="LOGGING_NAME", verbose=True): + """ + Sets up logging with UTF-8 encoding and configurable verbosity. + + This function configures logging for the Ultralytics library, setting the appropriate logging level and + formatter based on the verbosity flag and the current process rank. It handles special cases for Windows + environments where UTF-8 encoding might not be the default. + + Args: + name (str): Name of the logger. Defaults to "LOGGING_NAME". + verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. Defaults to True. + + Examples: + >>> set_logging(name="ultralytics", verbose=True) + >>> logger = logging.getLogger("ultralytics") + >>> logger.info("This is an info message") + + Notes: + - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible. + - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments. + - The function sets up a StreamHandler with the appropriate formatter and level. + - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers. + """ + level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings + + # Configure the console (stdout) encoding to UTF-8, with checks for compatibility + formatter = logging.Formatter("%(message)s") # Default formatter + if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8": + + class CustomFormatter(logging.Formatter): + def format(self, record): + """Sets up logging with UTF-8 encoding and configurable verbosity.""" + return emojis(super().format(record)) + + try: + # Attempt to reconfigure stdout to use UTF-8 encoding if possible + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8") + # For environments where reconfigure is not available, wrap stdout in a TextIOWrapper + elif hasattr(sys.stdout, "buffer"): + import io + + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + else: + formatter = CustomFormatter("%(message)s") + except Exception as e: + print(f"Creating custom formatter for non UTF-8 environments due to {e}") + formatter = CustomFormatter("%(message)s") + + # Create and configure the StreamHandler with the appropriate formatter and level + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(formatter) + stream_handler.setLevel(level) + + # Set up the logger + logger = logging.getLogger(name) + logger.setLevel(level) + logger.addHandler(stream_handler) + logger.propagate = False + return logger + + +# Set logger +LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.) +for logger in "sentry_sdk", "urllib3.connectionpool": + logging.getLogger(logger).setLevel(logging.CRITICAL + 1) + + +def emojis(string=""): + """Return platform-dependent emoji-safe version of string.""" + return string.encode().decode("ascii", "ignore") if WINDOWS else string + + +class ThreadingLocked: + """ + A decorator class for ensuring thread-safe execution of a function or method. This class can be used as a decorator + to make sure that if the decorated function is called from multiple threads, only one thread at a time will be able + to execute the function. + + Attributes: + lock (threading.Lock): A lock object used to manage access to the decorated function. + + Example: + ```python + from ultralytics.utils import ThreadingLocked + + @ThreadingLocked() + def my_function(): + # Your code here + ``` + """ + + def __init__(self): + """Initializes the decorator class for thread-safe execution of a function or method.""" + self.lock = threading.Lock() + + def __call__(self, f): + """Run thread-safe execution of function or method.""" + from functools import wraps + + @wraps(f) + def decorated(*args, **kwargs): + """Applies thread-safety to the decorated function or method.""" + with self.lock: + return f(*args, **kwargs) + + return decorated + + +def yaml_save(file="data.yaml", data=None, header=""): + """ + Save YAML data to a file. + + Args: + file (str, optional): File name. Default is 'data.yaml'. + data (dict): Data to save in YAML format. + header (str, optional): YAML header to add. + + Returns: + (None): Data is saved to the specified file. + """ + if data is None: + data = {} + file = Path(file) + if not file.parent.exists(): + # Create parent directories if they don't exist + file.parent.mkdir(parents=True, exist_ok=True) + + # Convert Path objects to strings + valid_types = int, float, str, bool, list, tuple, dict, type(None) + for k, v in data.items(): + if not isinstance(v, valid_types): + data[k] = str(v) + + # Dump data to file in YAML format + with open(file, "w", errors="ignore", encoding="utf-8") as f: + if header: + f.write(header) + yaml.safe_dump(data, f, sort_keys=False, allow_unicode=True) + + +def yaml_load(file="data.yaml", append_filename=False): + """ + Load YAML data from a file. + + Args: + file (str, optional): File name. Default is 'data.yaml'. + append_filename (bool): Add the YAML filename to the YAML dictionary. Default is False. + + Returns: + (dict): YAML data and file name. + """ + assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()" + with open(file, errors="ignore", encoding="utf-8") as f: + s = f.read() # string + + # Remove special characters + if not s.isprintable(): + s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s) + + # Add YAML filename to dict and return + data = yaml.safe_load(s) or {} # always return a dict (yaml.safe_load() may return None for empty files) + if append_filename: + data["yaml_file"] = str(file) + return data + + +def yaml_print(yaml_file: Union[str, Path, dict]) -> None: + """ + Pretty prints a YAML file or a YAML-formatted dictionary. + + Args: + yaml_file: The file path of the YAML file or a YAML-formatted dictionary. + + Returns: + (None) + """ + yaml_dict = yaml_load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file + dump = yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=float("inf")) + LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}") + + +# Default configuration +DEFAULT_CFG_DICT = yaml_load(DEFAULT_CFG_PATH) +DEFAULT_SOL_DICT = yaml_load(DEFAULT_SOL_CFG_PATH) # Ultralytics solutions configuration +for k, v in DEFAULT_CFG_DICT.items(): + if isinstance(v, str) and v.lower() == "none": + DEFAULT_CFG_DICT[k] = None +DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() +DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) + + +def read_device_model() -> str: + """ + Reads the device model information from the system and caches it for quick access. Used by is_jetson() and + is_raspberrypi(). + + Returns: + (str): Kernel release information. + """ + return platform.release().lower() + + +def is_ubuntu() -> bool: + """ + Check if the OS is Ubuntu. + + Returns: + (bool): True if OS is Ubuntu, False otherwise. + """ + try: + with open("/etc/os-release") as f: + return "ID=ubuntu" in f.read() + except FileNotFoundError: + return False + + +def is_colab(): + """ + Check if the current script is running inside a Google Colab notebook. + + Returns: + (bool): True if running inside a Colab notebook, False otherwise. + """ + return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ + + +def is_kaggle(): + """ + Check if the current script is running inside a Kaggle kernel. + + Returns: + (bool): True if running inside a Kaggle kernel, False otherwise. + """ + return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com" + + +def is_jupyter(): + """ + Check if the current script is running inside a Jupyter Notebook. + + Returns: + (bool): True if running inside a Jupyter Notebook, False otherwise. + + Note: + - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable. + - "get_ipython" in globals() method suffers false positives when IPython package installed manually. + """ + return IS_COLAB or IS_KAGGLE + + +def is_runpod(): + """ + Check if the current script is running inside a RunPod container. + + Returns: + (bool): True if running in RunPod, False otherwise. + """ + return "RUNPOD_POD_ID" in os.environ + + +def is_docker() -> bool: + """ + Determine if the script is running inside a Docker container. + + Returns: + (bool): True if the script is running inside a Docker container, False otherwise. + """ + try: + with open("/proc/self/cgroup") as f: + return "docker" in f.read() + except Exception: + return False + + +def is_raspberrypi() -> bool: + """ + Determines if the Python environment is running on a Raspberry Pi by checking the device model information. + + Returns: + (bool): True if running on a Raspberry Pi, False otherwise. + """ + return "rpi" in DEVICE_MODEL + + +def is_jetson() -> bool: + """ + Determines if the Python environment is running on an NVIDIA Jetson device by checking the device model information. + + Returns: + (bool): True if running on an NVIDIA Jetson device, False otherwise. + """ + return "tegra" in DEVICE_MODEL + + +def is_online() -> bool: + """ + Check internet connectivity by attempting to connect to a known online host. + + Returns: + (bool): True if connection is successful, False otherwise. + """ + try: + assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True" + import socket + + for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS + socket.create_connection(address=(dns, 80), timeout=2.0).close() + return True + except Exception: + return False + + +def is_pip_package(filepath: str = __name__) -> bool: + """ + Determines if the file at the given filepath is part of a pip package. + + Args: + filepath (str): The filepath to check. + + Returns: + (bool): True if the file is part of a pip package, False otherwise. + """ + import importlib.util + + # Get the spec for the module + spec = importlib.util.find_spec(filepath) + + # Return whether the spec is not None and the origin is not None (indicating it is a package) + return spec is not None and spec.origin is not None + + +def is_dir_writeable(dir_path: Union[str, Path]) -> bool: + """ + Check if a directory is writeable. + + Args: + dir_path (str | Path): The path to the directory. + + Returns: + (bool): True if the directory is writeable, False otherwise. + """ + return os.access(str(dir_path), os.W_OK) + + +def is_pytest_running(): + """ + Determines whether pytest is currently running or not. + + Returns: + (bool): True if pytest is running, False otherwise. + """ + return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem) + + +def is_github_action_running() -> bool: + """ + Determine if the current environment is a GitHub Actions runner. + + Returns: + (bool): True if the current environment is a GitHub Actions runner, False otherwise. + """ + return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ + + +def get_git_dir(): + """ + Determines whether the current file is part of a git repository and if so, returns the repository root directory. If + the current file is not part of a git repository, returns None. + + Returns: + (Path | None): Git root directory if found or None if not found. + """ + for d in Path(__file__).parents: + if (d / ".git").is_dir(): + return d + + +def is_git_dir(): + """ + Determines whether the current file is part of a git repository. If the current file is not part of a git + repository, returns None. + + Returns: + (bool): True if current file is part of a git repository. + """ + return GIT_DIR is not None + + +def get_git_origin_url(): + """ + Retrieves the origin URL of a git repository. + + Returns: + (str | None): The origin URL of the git repository or None if not git directory. + """ + if IS_GIT_DIR: + try: + origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"]) + return origin.decode().strip() + except subprocess.CalledProcessError: + return None + + +def get_git_branch(): + """ + Returns the current git branch name. If not in a git repository, returns None. + + Returns: + (str | None): The current git branch name or None if not a git directory. + """ + if IS_GIT_DIR: + try: + origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + return origin.decode().strip() + except subprocess.CalledProcessError: + return None + + +def get_default_args(func): + """ + Returns a dictionary of default arguments for a function. + + Args: + func (callable): The function to inspect. + + Returns: + (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter. + """ + signature = inspect.signature(func) + return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} + + +def get_ubuntu_version(): + """ + Retrieve the Ubuntu version if the OS is Ubuntu. + + Returns: + (str): Ubuntu version or None if not an Ubuntu OS. + """ + if is_ubuntu(): + try: + with open("/etc/os-release") as f: + return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1] + except (FileNotFoundError, AttributeError): + return None + + +def get_user_config_dir(sub_dir="Ultralytics"): + """ + Return the appropriate config directory based on the environment operating system. + + Args: + sub_dir (str): The name of the subdirectory to create. + + Returns: + (Path): The path to the user config directory. + """ + if WINDOWS: + path = Path.home() / "AppData" / "Roaming" / sub_dir + elif MACOS: # macOS + path = Path.home() / "Library" / "Application Support" / sub_dir + elif LINUX: + path = Path.home() / ".config" / sub_dir + else: + raise ValueError(f"Unsupported operating system: {platform.system()}") + + # GCP and AWS lambda fix, only /tmp is writeable + if not is_dir_writeable(path.parent): + LOGGER.warning( + f"WARNING ⚠️ user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD." + "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path." + ) + path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir + + # Create the subdirectory if it does not exist + path.mkdir(parents=True, exist_ok=True) + + return path + + +# Define constants (required below) +DEVICE_MODEL = read_device_model() # is_jetson() and is_raspberrypi() depend on this constant +ONLINE = is_online() +IS_COLAB = is_colab() +IS_KAGGLE = is_kaggle() +IS_DOCKER = is_docker() +IS_JETSON = is_jetson() +IS_JUPYTER = is_jupyter() +IS_PIP_PACKAGE = is_pip_package() +IS_RASPBERRYPI = is_raspberrypi() +GIT_DIR = get_git_dir() +IS_GIT_DIR = is_git_dir() +USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir +SETTINGS_FILE = USER_CONFIG_DIR / "settings.json" + + +def colorstr(*input): + r""" + Colors a string based on the provided color and style arguments. Utilizes ANSI escape codes. + See https://en.wikipedia.org/wiki/ANSI_escape_code for more details. + + This function can be called in two ways: + - colorstr('color', 'style', 'your string') + - colorstr('your string') + + In the second form, 'blue' and 'bold' will be applied by default. + + Args: + *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments, + and the last string is the one to be colored. + + Supported Colors and Styles: + Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white' + Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow', + 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white' + Misc: 'end', 'bold', 'underline' + + Returns: + (str): The input string wrapped with ANSI escape codes for the specified color and style. + + Examples: + >>> colorstr("blue", "bold", "hello world") + >>> "\033[34m\033[1mhello world\033[0m" + """ + *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string + colors = { + "black": "\033[30m", # basic colors + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", + "bright_black": "\033[90m", # bright colors + "bright_red": "\033[91m", + "bright_green": "\033[92m", + "bright_yellow": "\033[93m", + "bright_blue": "\033[94m", + "bright_magenta": "\033[95m", + "bright_cyan": "\033[96m", + "bright_white": "\033[97m", + "end": "\033[0m", # misc + "bold": "\033[1m", + "underline": "\033[4m", + } + return "".join(colors[x] for x in args) + f"{string}" + colors["end"] + + +def remove_colorstr(input_string): + """ + Removes ANSI escape codes from a string, effectively un-coloring it. + + Args: + input_string (str): The string to remove color and style from. + + Returns: + (str): A new string with all ANSI escape codes removed. + + Examples: + >>> remove_colorstr(colorstr("blue", "bold", "hello world")) + >>> "hello world" + """ + ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") + return ansi_escape.sub("", input_string) + + +class TryExcept(contextlib.ContextDecorator): + """ + Ultralytics TryExcept class. Use as @TryExcept() decorator or 'with TryExcept():' context manager. + + Examples: + As a decorator: + >>> @TryExcept(msg="Error occurred in func", verbose=True) + >>> def func(): + >>> # Function logic here + >>> pass + + As a context manager: + >>> with TryExcept(msg="Error occurred in block", verbose=True): + >>> # Code block here + >>> pass + """ + + def __init__(self, msg="", verbose=True): + """Initialize TryExcept class with optional message and verbosity settings.""" + self.msg = msg + self.verbose = verbose + + def __enter__(self): + """Executes when entering TryExcept context, initializes instance.""" + pass + + def __exit__(self, exc_type, value, traceback): + """Defines behavior when exiting a 'with' block, prints error message if necessary.""" + if self.verbose and value: + print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) + return True + + +class Retry(contextlib.ContextDecorator): + """ + Retry class for function execution with exponential backoff. + + Can be used as a decorator to retry a function on exceptions, up to a specified number of times with an + exponentially increasing delay between retries. + + Examples: + Example usage as a decorator: + >>> @Retry(times=3, delay=2) + >>> def test_func(): + >>> # Replace with function logic that may raise exceptions + >>> return True + """ + + def __init__(self, times=3, delay=2): + """Initialize Retry class with specified number of retries and delay.""" + self.times = times + self.delay = delay + self._attempts = 0 + + def __call__(self, func): + """Decorator implementation for Retry with exponential backoff.""" + + def wrapped_func(*args, **kwargs): + """Applies retries to the decorated function or method.""" + self._attempts = 0 + while self._attempts < self.times: + try: + return func(*args, **kwargs) + except Exception as e: + self._attempts += 1 + print(f"Retry {self._attempts}/{self.times} failed: {e}") + if self._attempts >= self.times: + raise e + time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay + + return wrapped_func + + +def threaded(func): + """ + Multi-threads a target function by default and returns the thread or function result. + + Use as @threaded decorator. The function runs in a separate thread unless 'threaded=False' is passed. + """ + + def wrapper(*args, **kwargs): + """Multi-threads a given function based on 'threaded' kwarg and returns the thread or function result.""" + if kwargs.pop("threaded", True): # run in thread + thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) + thread.start() + return thread + else: + return func(*args, **kwargs) + + return wrapper + + +def set_sentry(): + """ + Initialize the Sentry SDK for error tracking and reporting. Only used if sentry_sdk package is installed and + sync=True in settings. Run 'yolo settings' to see and update settings. + + Conditions required to send errors (ALL conditions must be met or no errors will be reported): + - sentry_sdk package is installed + - sync=True in YOLO settings + - pytest is not running + - running in a pip package installation + - running in a non-git directory + - running with rank -1 or 0 + - online environment + - CLI used to run package (checked with 'yolo' as the name of the main CLI command) + + The function also configures Sentry SDK to ignore KeyboardInterrupt and FileNotFoundError exceptions and to exclude + events with 'out of memory' in their exception message. + + Additionally, the function sets custom tags and user information for Sentry events. + """ + if ( + not SETTINGS["sync"] + or RANK not in {-1, 0} + or Path(ARGV[0]).name != "yolo" + or TESTS_RUNNING + or not ONLINE + or not IS_PIP_PACKAGE + or IS_GIT_DIR + ): + return + # If sentry_sdk package is not installed then return and do not use Sentry + try: + import sentry_sdk # noqa + except ImportError: + return + + def before_send(event, hint): + """ + Modify the event before sending it to Sentry based on specific exception types and messages. + + Args: + event (dict): The event dictionary containing information about the error. + hint (dict): A dictionary containing additional information about the error. + + Returns: + dict: The modified event or None if the event should not be sent to Sentry. + """ + if "exc_info" in hint: + exc_type, exc_value, _ = hint["exc_info"] + if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value): + return None # do not send event + + event["tags"] = { + "sys_argv": ARGV[0], + "sys_argv_name": Path(ARGV[0]).name, + "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "os": ENVIRONMENT, + } + return event + + sentry_sdk.init( + dsn="https://888e5a0778212e1d0314c37d4b9aae5d@o4504521589325824.ingest.us.sentry.io/4504521592406016", + debug=False, + auto_enabling_integrations=False, + traces_sample_rate=1.0, + release=__version__, + environment="runpod" if is_runpod() else "production", + before_send=before_send, + ignore_errors=[KeyboardInterrupt, FileNotFoundError], + ) + sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash + + +class JSONDict(dict): + """ + A dictionary-like class that provides JSON persistence for its contents. + + This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are + modified. It ensures thread-safe operations using a lock. + + Attributes: + file_path (Path): The path to the JSON file used for persistence. + lock (threading.Lock): A lock object to ensure thread-safe operations. + + Methods: + _load: Loads the data from the JSON file into the dictionary. + _save: Saves the current state of the dictionary to the JSON file. + __setitem__: Stores a key-value pair and persists it to disk. + __delitem__: Removes an item and updates the persistent storage. + update: Updates the dictionary and persists changes. + clear: Clears all entries and updates the persistent storage. + + Examples: + >>> json_dict = JSONDict("data.json") + >>> json_dict["key"] = "value" + >>> print(json_dict["key"]) + value + >>> del json_dict["key"] + >>> json_dict.update({"new_key": "new_value"}) + >>> json_dict.clear() + """ + + def __init__(self, file_path: Union[str, Path] = "data.json"): + """Initialize a JSONDict object with a specified file path for JSON persistence.""" + super().__init__() + self.file_path = Path(file_path) + self.lock = Lock() + self._load() + + def _load(self): + """Load the data from the JSON file into the dictionary.""" + try: + if self.file_path.exists(): + with open(self.file_path) as f: + self.update(json.load(f)) + except json.JSONDecodeError: + print(f"Error decoding JSON from {self.file_path}. Starting with an empty dictionary.") + except Exception as e: + print(f"Error reading from {self.file_path}: {e}") + + def _save(self): + """Save the current state of the dictionary to the JSON file.""" + try: + self.file_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.file_path, "w") as f: + json.dump(dict(self), f, indent=2, default=self._json_default) + except Exception as e: + print(f"Error writing to {self.file_path}: {e}") + + @staticmethod + def _json_default(obj): + """Handle JSON serialization of Path objects.""" + if isinstance(obj, Path): + return str(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + def __setitem__(self, key, value): + """Store a key-value pair and persist to disk.""" + with self.lock: + super().__setitem__(key, value) + self._save() + + def __delitem__(self, key): + """Remove an item and update the persistent storage.""" + with self.lock: + super().__delitem__(key) + self._save() + + def __str__(self): + """Return a pretty-printed JSON string representation of the dictionary.""" + contents = json.dumps(dict(self), indent=2, ensure_ascii=False, default=self._json_default) + return f'JSONDict("{self.file_path}"):\n{contents}' + + def update(self, *args, **kwargs): + """Update the dictionary and persist changes.""" + with self.lock: + super().update(*args, **kwargs) + self._save() + + def clear(self): + """Clear all entries and update the persistent storage.""" + with self.lock: + super().clear() + self._save() + + +class SettingsManager(JSONDict): + """ + SettingsManager class for managing and persisting Ultralytics settings. + + This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default + values. It validates settings on initialization and provides methods to update or reset settings. + + Attributes: + file (Path): The path to the JSON file used for persistence. + version (str): The version of the settings schema. + defaults (Dict): A dictionary containing default settings. + help_msg (str): A help message for users on how to view and update settings. + + Methods: + _validate_settings: Validates the current settings and resets if necessary. + update: Updates settings, validating keys and types. + reset: Resets the settings to default and saves them. + + Examples: + Initialize and update settings: + >>> settings = SettingsManager() + >>> settings.update(runs_dir="/new/runs/dir") + >>> print(settings["runs_dir"]) + /new/runs/dir + """ + + def __init__(self, file=SETTINGS_FILE, version="0.0.6"): + """Initializes the SettingsManager with default settings and loads user settings.""" + import hashlib + + from ultralytics.utils.torch_utils import torch_distributed_zero_first + + root = GIT_DIR or Path() + datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve() + + self.file = Path(file) + self.version = version + self.defaults = { + "settings_version": version, # Settings schema version + "datasets_dir": str(datasets_root / "datasets"), # Datasets directory + "weights_dir": str(root / "weights"), # Model weights directory + "runs_dir": str(root / "runs"), # Experiment runs directory + "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # SHA-256 anonymized UUID hash + "sync": True, # Enable synchronization + "api_key": "", # Ultralytics API Key + "openai_api_key": "", # OpenAI API Key + "clearml": True, # ClearML integration + "comet": True, # Comet integration + "dvc": True, # DVC integration + "hub": True, # Ultralytics HUB integration + "mlflow": True, # MLflow integration + "neptune": True, # Neptune integration + "raytune": True, # Ray Tune integration + "tensorboard": True, # TensorBoard logging + "wandb": False, # Weights & Biases logging + "vscode_msg": True, # VSCode messaging + } + + self.help_msg = ( + f"\nView Ultralytics Settings with 'yolo settings' or at '{self.file}'" + "\nUpdate Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. " + "For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings." + ) + + with torch_distributed_zero_first(RANK): + super().__init__(self.file) + + if not self.file.exists() or not self: # Check if file doesn't exist or is empty + LOGGER.info(f"Creating new Ultralytics Settings v{version} file ✅ {self.help_msg}") + self.reset() + + self._validate_settings() + + def _validate_settings(self): + """Validate the current settings and reset if necessary.""" + correct_keys = set(self.keys()) == set(self.defaults.keys()) + correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items()) + correct_version = self.get("settings_version", "") == self.version + + if not (correct_keys and correct_types and correct_version): + LOGGER.warning( + "WARNING ⚠️ Ultralytics settings reset to default values. This may be due to a possible problem " + f"with your settings or a recent ultralytics package update. {self.help_msg}" + ) + self.reset() + + if self.get("datasets_dir") == self.get("runs_dir"): + LOGGER.warning( + f"WARNING ⚠️ Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' " + f"must be different than 'runs_dir: {self.get('runs_dir')}'. " + f"Please change one to avoid possible issues during training. {self.help_msg}" + ) + + def __setitem__(self, key, value): + """Updates one key: value pair.""" + self.update({key: value}) + + def update(self, *args, **kwargs): + """Updates settings, validating keys and types.""" + for arg in args: + if isinstance(arg, dict): + kwargs.update(arg) + for k, v in kwargs.items(): + if k not in self.defaults: + raise KeyError(f"No Ultralytics setting '{k}'. {self.help_msg}") + t = type(self.defaults[k]) + if not isinstance(v, t): + raise TypeError( + f"Ultralytics setting '{k}' must be '{t.__name__}' type, not '{type(v).__name__}'. {self.help_msg}" + ) + super().update(*args, **kwargs) + + def reset(self): + """Resets the settings to default and saves them.""" + self.clear() + self.update(self.defaults) + + +def deprecation_warn(arg, new_arg=None): + """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.""" + msg = f"WARNING ⚠️ '{arg}' is deprecated and will be removed in in the future." + if new_arg is not None: + msg += f" Use '{new_arg}' instead." + LOGGER.warning(msg) + + +def clean_url(url): + """Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.""" + url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows + return unquote(url).split("?")[0] # '%2F' to '/', split https://url.com/file.txt?auth + + +def url2file(url): + """Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.""" + return Path(clean_url(url)).name + + +def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str: + """Display a message to install Ultralytics-Snippets for VS Code if not already installed.""" + path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / ".vscode/extensions" + obs_file = path / ".obsolete" # file tracks uninstalled extensions, while source directory remains + installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "") + url = "https://docs.ultralytics.com/integrations/vscode" + return "" if installed else f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at {url}" + + +# Run below code on utils init ------------------------------------------------------------------------------------ + +# Check first-install steps +PREFIX = colorstr("Ultralytics: ") +SETTINGS = SettingsManager() # initialize settings +PERSISTENT_CACHE = JSONDict(USER_CONFIG_DIR / "persistent_cache.json") # initialize persistent cache +DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory +WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory +RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory +ENVIRONMENT = ( + "Colab" + if IS_COLAB + else "Kaggle" + if IS_KAGGLE + else "Jupyter" + if IS_JUPYTER + else "Docker" + if IS_DOCKER + else platform.system() +) +TESTS_RUNNING = is_pytest_running() or is_github_action_running() +set_sentry() + +# Apply monkey patches +from ultralytics.utils.patches import imread, imshow, imwrite, torch_load, torch_save + +torch.load = torch_load +torch.save = torch_save +if WINDOWS: + # Apply cv2 patches for non-ASCII and non-UTF characters in image paths + cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py new file mode 100644 index 0000000000000000000000000000000000000000..085001a153cb14f89c07fb6b7bcc8d6b6c3a4564 --- /dev/null +++ b/ultralytics/utils/autobatch.py @@ -0,0 +1,106 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.""" + +import os +from copy import deepcopy + +import numpy as np +import torch + +from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr +from ultralytics.utils.torch_utils import autocast, profile + + +def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1): + """ + Compute optimal YOLO training batch size using the autobatch() function. + + Args: + model (torch.nn.Module): YOLO model to check batch size for. + imgsz (int, optional): Image size used for training. + amp (bool, optional): Use automatic mixed precision if True. + batch (float, optional): Fraction of GPU memory to use. If -1, use default. + max_num_obj (int, optional): The maximum number of objects from dataset. + + Returns: + (int): Optimal batch size computed using the autobatch() function. + + Note: + If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use. + Otherwise, a default fraction of 0.6 is used. + """ + with autocast(enabled=amp): + return autobatch( + deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj + ) + + +def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1): + """ + Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. + + Args: + model (torch.nn.module): YOLO model to compute batch size for. + imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640. + fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60. + batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16. + max_num_obj (int, optional): The maximum number of objects from dataset. + + Returns: + (int): The optimal batch size. + """ + # Check device + prefix = colorstr("AutoBatch: ") + LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.") + device = next(model.parameters()).device # get model device + if device.type in {"cpu", "mps"}: + LOGGER.info(f"{prefix} ⚠️ intended for CUDA devices, using default batch-size {batch_size}") + return batch_size + if torch.backends.cudnn.benchmark: + LOGGER.info(f"{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}") + return batch_size + + # Inspect CUDA memory + gb = 1 << 30 # bytes to GiB (1024 ** 3) + d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0' + properties = torch.cuda.get_device_properties(device) # device properties + t = properties.total_memory / gb # GiB total + r = torch.cuda.memory_reserved(device) / gb # GiB reserved + a = torch.cuda.memory_allocated(device) / gb # GiB allocated + f = t - (r + a) # GiB free + LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free") + + # Profile batch sizes + batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64] + try: + img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] + results = profile(img, model, n=1, device=device, max_num_obj=max_num_obj) + + # Fit a solution + xy = [ + [x, y[2]] + for i, (x, y) in enumerate(zip(batch_sizes, results)) + if y # valid result + and isinstance(y[2], (int, float)) # is numeric + and 0 < y[2] < t # between 0 and GPU limit + and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory + ] + fit_x, fit_y = zip(*xy) if xy else ([], []) + p = np.polyfit(np.log(fit_x), np.log(fit_y), deg=1) # first-degree polynomial fit in log space + b = int(round(np.exp((np.log(f * fraction) - p[1]) / p[0]))) # y intercept (optimal batch size) + if None in results: # some sizes failed + i = results.index(None) # first fail index + if b >= batch_sizes[i]: # y intercept above failure point + b = batch_sizes[max(i - 1, 0)] # select prior safe point + if b < 1 or b > 1024: # b outside of safe range + LOGGER.info(f"{prefix}WARNING ⚠️ batch={b} outside safe range, using default batch-size {batch_size}.") + b = batch_size + + fraction = (np.exp(np.polyval(p, np.log(b))) + r + a) / t # predicted fraction + LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅") + return b + except Exception as e: + LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.") + return batch_size + finally: + torch.cuda.empty_cache() diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py new file mode 100644 index 0000000000000000000000000000000000000000..a161f3022e38386364d1220fd3a14e0438ab9582 --- /dev/null +++ b/ultralytics/utils/benchmarks.py @@ -0,0 +1,583 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +Benchmark a YOLO model formats for speed and accuracy. + +Usage: + from ultralytics.utils.benchmarks import ProfileModels, benchmark + ProfileModels(['yolov8n.yaml', 'yolov8s.yaml']).profile() + benchmark(model='yolov8n.pt', imgsz=160) + +Format | `format=argument` | Model +--- | --- | --- +PyTorch | - | yolov8n.pt +TorchScript | `torchscript` | yolov8n.torchscript +ONNX | `onnx` | yolov8n.onnx +OpenVINO | `openvino` | yolov8n_openvino_model/ +TensorRT | `engine` | yolov8n.engine +CoreML | `coreml` | yolov8n.mlpackage +TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/ +TensorFlow GraphDef | `pb` | yolov8n.pb +TensorFlow Lite | `tflite` | yolov8n.tflite +TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite +TensorFlow.js | `tfjs` | yolov8n_web_model/ +PaddlePaddle | `paddle` | yolov8n_paddle_model/ +MNN | `mnn` | yolov8n.mnn +NCNN | `ncnn` | yolov8n_ncnn_model/ +""" + +import glob +import os +import platform +import re +import shutil +import time +from pathlib import Path + +import numpy as np +import torch.cuda +import yaml + +from ultralytics import YOLO, YOLOWorld +from ultralytics.cfg import TASK2DATA, TASK2METRIC +from ultralytics.engine.exporter import export_formats +from ultralytics.utils import ARM64, ASSETS, IS_JETSON, IS_RASPBERRYPI, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR +from ultralytics.utils.checks import IS_PYTHON_3_12, check_requirements, check_yolo +from ultralytics.utils.downloads import safe_download +from ultralytics.utils.files import file_size +from ultralytics.utils.torch_utils import get_cpu_info, select_device + + +def benchmark( + model=WEIGHTS_DIR / "yolo11n.pt", + data=None, + imgsz=160, + half=False, + int8=False, + device="cpu", + verbose=False, + eps=1e-3, +): + """ + Benchmark a YOLO model across different formats for speed and accuracy. + + Args: + model (str | Path): Path to the model file or directory. + data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed. + imgsz (int): Image size for the benchmark. + half (bool): Use half-precision for the model if True. + int8 (bool): Use int8-precision for the model if True. + device (str): Device to run the benchmark on, either 'cpu' or 'cuda'. + verbose (bool | float): If True or a float, assert benchmarks pass with given metric. + eps (float): Epsilon value for divide by zero prevention. + + Returns: + (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric, + and inference time. + + Examples: + Benchmark a YOLO model with default settings: + >>> from ultralytics.utils.benchmarks import benchmark + >>> benchmark(model="yolo11n.pt", imgsz=640) + """ + import pandas as pd # scope for faster 'import ultralytics' + + pd.options.display.max_columns = 10 + pd.options.display.width = 120 + device = select_device(device, verbose=False) + if isinstance(model, (str, Path)): + model = YOLO(model) + is_end2end = getattr(model.model.model[-1], "end2end", False) + + y = [] + t0 = time.time() + for i, (name, format, suffix, cpu, gpu, _) in enumerate(zip(*export_formats().values())): + emoji, filename = "❌", None # export defaults + try: + # Checks + if i == 7: # TF GraphDef + assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task" + elif i == 9: # Edge TPU + assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux" + elif i in {5, 10}: # CoreML and TF.js + assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux" + assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi" + assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson" + if i in {5}: # CoreML + assert not IS_PYTHON_3_12, "CoreML not supported on Python 3.12" + if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" + if i in {9, 10}: # TF EdgeTPU and TF.js + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" + if i == 11: # Paddle + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" + assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet" + assert LINUX or MACOS, "Windows Paddle exports not supported yet" + if i == 12: # MNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet" + if i == 13: # NCNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" + if i == 14: # IMX + assert not is_end2end + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported" + assert model.task == "detect", "IMX only supported for detection task" + assert "C2f" in model.__str__(), "IMX only supported for YOLOv8" + if "cpu" in device.type: + assert cpu, "inference not supported on CPU" + if "cuda" in device.type: + assert gpu, "inference not supported on GPU" + + # Export + if format == "-": + filename = model.ckpt_path or model.cfg + exported_model = model # PyTorch format + else: + filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device, verbose=False) + exported_model = YOLO(filename, task=model.task) + assert suffix in str(filename), "export failed" + emoji = "❎" # indicates export succeeded + + # Predict + assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported" + assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported + assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML + if i in {13}: + assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet" + exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half) + + # Validate + data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect + key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect + results = exported_model.val( + data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, int8=int8, verbose=False + ) + metric, speed = results.results_dict[key], results.speed["inference"] + fps = round(1000 / (speed + eps), 2) # frames per second + y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps]) + except Exception as e: + if verbose: + assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}" + LOGGER.warning(f"ERROR ❌️ Benchmark failure for {name}: {e}") + y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference + + # Print results + check_yolo(device=device) # print system info + df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"]) + + name = Path(model.ckpt_path).name + s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n" + LOGGER.info(s) + with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f: + f.write(s) + + if verbose and isinstance(verbose, float): + metrics = df[key].array # values to compare to floor + floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n + assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}" + + return df + + +class RF100Benchmark: + """Benchmark YOLO model performance across various formats for speed and accuracy.""" + + def __init__(self): + """Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats.""" + self.ds_names = [] + self.ds_cfg_list = [] + self.rf = None + self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"] + + def set_key(self, api_key): + """ + Set Roboflow API key for processing. + + Args: + api_key (str): The API key. + + Examples: + Set the Roboflow API key for accessing datasets: + >>> benchmark = RF100Benchmark() + >>> benchmark.set_key("your_roboflow_api_key") + """ + check_requirements("roboflow") + from roboflow import Roboflow + + self.rf = Roboflow(api_key=api_key) + + def parse_dataset(self, ds_link_txt="datasets_links.txt"): + """ + Parse dataset links and download datasets. + + Args: + ds_link_txt (str): Path to the file containing dataset links. + + Examples: + >>> benchmark = RF100Benchmark() + >>> benchmark.set_key("api_key") + >>> benchmark.parse_dataset("datasets_links.txt") + """ + (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100") + os.chdir("rf-100") + os.mkdir("ultralytics-benchmarks") + safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt") + + with open(ds_link_txt) as file: + for line in file: + try: + _, url, workspace, project, version = re.split("/+", line.strip()) + self.ds_names.append(project) + proj_version = f"{project}-{version}" + if not Path(proj_version).exists(): + self.rf.workspace(workspace).project(project).version(version).download("yolov8") + else: + print("Dataset already downloaded.") + self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml") + except Exception: + continue + + return self.ds_names, self.ds_cfg_list + + @staticmethod + def fix_yaml(path): + """ + Fixes the train and validation paths in a given YAML file. + + Args: + path (str): Path to the YAML file to be fixed. + + Examples: + >>> RF100Benchmark.fix_yaml("path/to/data.yaml") + """ + with open(path) as file: + yaml_data = yaml.safe_load(file) + yaml_data["train"] = "train/images" + yaml_data["val"] = "valid/images" + with open(path, "w") as file: + yaml.safe_dump(yaml_data, file) + + def evaluate(self, yaml_path, val_log_file, eval_log_file, list_ind): + """ + Evaluate model performance on validation results. + + Args: + yaml_path (str): Path to the YAML configuration file. + val_log_file (str): Path to the validation log file. + eval_log_file (str): Path to the evaluation log file. + list_ind (int): Index of the current dataset in the list. + + Returns: + (float): The mean average precision (mAP) value for the evaluated model. + + Examples: + Evaluate a model on a specific dataset + >>> benchmark = RF100Benchmark() + >>> benchmark.evaluate("path/to/data.yaml", "path/to/val_log.txt", "path/to/eval_log.txt", 0) + """ + skip_symbols = ["🚀", "⚠️", "💡", "❌"] + with open(yaml_path) as stream: + class_names = yaml.safe_load(stream)["names"] + with open(val_log_file, encoding="utf-8") as f: + lines = f.readlines() + eval_lines = [] + for line in lines: + if any(symbol in line for symbol in skip_symbols): + continue + entries = line.split(" ") + entries = list(filter(lambda val: val != "", entries)) + entries = [e.strip("\n") for e in entries] + eval_lines.extend( + { + "class": entries[0], + "images": entries[1], + "targets": entries[2], + "precision": entries[3], + "recall": entries[4], + "map50": entries[5], + "map95": entries[6], + } + for e in entries + if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries) + ) + map_val = 0.0 + if len(eval_lines) > 1: + print("There's more dicts") + for lst in eval_lines: + if lst["class"] == "all": + map_val = lst["map50"] + else: + print("There's only one dict res") + map_val = [res["map50"] for res in eval_lines][0] + + with open(eval_log_file, "a") as f: + f.write(f"{self.ds_names[list_ind]}: {map_val}\n") + + +class ProfileModels: + """ + ProfileModels class for profiling different models on ONNX and TensorRT. + + This class profiles the performance of different models, returning results such as model speed and FLOPs. + + Attributes: + paths (List[str]): Paths of the models to profile. + num_timed_runs (int): Number of timed runs for the profiling. + num_warmup_runs (int): Number of warmup runs before profiling. + min_time (float): Minimum number of seconds to profile for. + imgsz (int): Image size used in the models. + half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling. + trt (bool): Flag to indicate whether to profile using TensorRT. + device (torch.device): Device used for profiling. + + Methods: + profile: Profiles the models and prints the result. + + Examples: + Profile models and print results + >>> from ultralytics.utils.benchmarks import ProfileModels + >>> profiler = ProfileModels(["yolov8n.yaml", "yolov8s.yaml"], imgsz=640) + >>> profiler.profile() + """ + + def __init__( + self, + paths: list, + num_timed_runs=100, + num_warmup_runs=10, + min_time=60, + imgsz=640, + half=True, + trt=True, + device=None, + ): + """ + Initialize the ProfileModels class for profiling models. + + Args: + paths (List[str]): List of paths of the models to be profiled. + num_timed_runs (int): Number of timed runs for the profiling. + num_warmup_runs (int): Number of warmup runs before the actual profiling starts. + min_time (float): Minimum time in seconds for profiling a model. + imgsz (int): Size of the image used during profiling. + half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling. + trt (bool): Flag to indicate whether to profile using TensorRT. + device (torch.device | None): Device used for profiling. If None, it is determined automatically. + + Notes: + FP16 'half' argument option removed for ONNX as slower on CPU than FP32. + + Examples: + Initialize and profile models + >>> from ultralytics.utils.benchmarks import ProfileModels + >>> profiler = ProfileModels(["yolov8n.yaml", "yolov8s.yaml"], imgsz=640) + >>> profiler.profile() + """ + self.paths = paths + self.num_timed_runs = num_timed_runs + self.num_warmup_runs = num_warmup_runs + self.min_time = min_time + self.imgsz = imgsz + self.half = half + self.trt = trt # run TensorRT profiling + self.device = device or torch.device(0 if torch.cuda.is_available() else "cpu") + + def profile(self): + """Profiles YOLO models for speed and accuracy across various formats including ONNX and TensorRT.""" + files = self.get_files() + + if not files: + print("No matching *.pt or *.onnx files found.") + return + + table_rows = [] + output = [] + for file in files: + engine_file = file.with_suffix(".engine") + if file.suffix in {".pt", ".yaml", ".yml"}: + model = YOLO(str(file)) + model.fuse() # to report correct params and GFLOPs in model.info() + model_info = model.info() + if self.trt and self.device.type != "cpu" and not engine_file.is_file(): + engine_file = model.export( + format="engine", + half=self.half, + imgsz=self.imgsz, + device=self.device, + verbose=False, + ) + onnx_file = model.export( + format="onnx", + imgsz=self.imgsz, + device=self.device, + verbose=False, + ) + elif file.suffix == ".onnx": + model_info = self.get_onnx_model_info(file) + onnx_file = file + else: + continue + + t_engine = self.profile_tensorrt_model(str(engine_file)) + t_onnx = self.profile_onnx_model(str(onnx_file)) + table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info)) + output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info)) + + self.print_table(table_rows) + return output + + def get_files(self): + """Returns a list of paths for all relevant model files given by the user.""" + files = [] + for path in self.paths: + path = Path(path) + if path.is_dir(): + extensions = ["*.pt", "*.onnx", "*.yaml"] + files.extend([file for ext in extensions for file in glob.glob(str(path / ext))]) + elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing + files.append(str(path)) + else: + files.extend(glob.glob(str(path))) + + print(f"Profiling: {sorted(files)}") + return [Path(file) for file in sorted(files)] + + @staticmethod + def get_onnx_model_info(onnx_file: str): + """Extracts metadata from an ONNX model file including parameters, GFLOPs, and input shape.""" + return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops) + + @staticmethod + def iterative_sigma_clipping(data, sigma=2, max_iters=3): + """Applies iterative sigma clipping to data to remove outliers based on specified sigma and iteration count.""" + data = np.array(data) + for _ in range(max_iters): + mean, std = np.mean(data), np.std(data) + clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)] + if len(clipped_data) == len(data): + break + data = clipped_data + return data + + def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3): + """Profiles YOLO model performance with TensorRT, measuring average run time and standard deviation.""" + if not self.trt or not Path(engine_file).is_file(): + return 0.0, 0.0 + + # Model and input + model = YOLO(engine_file) + input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify + + # Warmup runs + elapsed = 0.0 + for _ in range(3): + start_time = time.time() + for _ in range(self.num_warmup_runs): + model(input_data, imgsz=self.imgsz, verbose=False) + elapsed = time.time() - start_time + + # Compute number of runs as higher of min_time or num_timed_runs + num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50) + + # Timed runs + run_times = [] + for _ in TQDM(range(num_runs), desc=engine_file): + results = model(input_data, imgsz=self.imgsz, verbose=False) + run_times.append(results[0].speed["inference"]) # Convert to milliseconds + + run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping + return np.mean(run_times), np.std(run_times) + + def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3): + """Profiles an ONNX model, measuring average inference time and standard deviation across multiple runs.""" + check_requirements("onnxruntime") + import onnxruntime as ort + + # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL + sess_options.intra_op_num_threads = 8 # Limit the number of threads + sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"]) + + input_tensor = sess.get_inputs()[0] + input_type = input_tensor.type + dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape + input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape + + # Mapping ONNX datatype to numpy datatype + if "float16" in input_type: + input_dtype = np.float16 + elif "float" in input_type: + input_dtype = np.float32 + elif "double" in input_type: + input_dtype = np.float64 + elif "int64" in input_type: + input_dtype = np.int64 + elif "int32" in input_type: + input_dtype = np.int32 + else: + raise ValueError(f"Unsupported ONNX datatype {input_type}") + + input_data = np.random.rand(*input_shape).astype(input_dtype) + input_name = input_tensor.name + output_name = sess.get_outputs()[0].name + + # Warmup runs + elapsed = 0.0 + for _ in range(3): + start_time = time.time() + for _ in range(self.num_warmup_runs): + sess.run([output_name], {input_name: input_data}) + elapsed = time.time() - start_time + + # Compute number of runs as higher of min_time or num_timed_runs + num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs) + + # Timed runs + run_times = [] + for _ in TQDM(range(num_runs), desc=onnx_file): + start_time = time.time() + sess.run([output_name], {input_name: input_data}) + run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds + + run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping + return np.mean(run_times), np.std(run_times) + + def generate_table_row(self, model_name, t_onnx, t_engine, model_info): + """Generates a table row string with model performance metrics including inference times and model details.""" + layers, params, gradients, flops = model_info + return ( + f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±" + f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |" + ) + + @staticmethod + def generate_results_dict(model_name, t_onnx, t_engine, model_info): + """Generates a dictionary of profiling results including model name, parameters, GFLOPs, and speed metrics.""" + layers, params, gradients, flops = model_info + return { + "model/name": model_name, + "model/parameters": params, + "model/GFLOPs": round(flops, 3), + "model/speed_ONNX(ms)": round(t_onnx[0], 3), + "model/speed_TensorRT(ms)": round(t_engine[0], 3), + } + + @staticmethod + def print_table(table_rows): + """Prints a formatted table of model profiling results, including speed and accuracy metrics.""" + gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU" + headers = [ + "Model", + "size
(pixels)", + "mAPval
50-95", + f"Speed
CPU ({get_cpu_info()}) ONNX
(ms)", + f"Speed
{gpu} TensorRT
(ms)", + "params
(M)", + "FLOPs
(B)", + ] + header = "|" + "|".join(f" {h} " for h in headers) + "|" + separator = "|" + "|".join("-" * (len(h) + 2) for h in headers) + "|" + + print(f"\n\n{header}") + print(separator) + for row in table_rows: + print(row) diff --git a/ultralytics/utils/callbacks/__init__.py b/ultralytics/utils/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..920cc4fad9d6fa2b7b99707d2fe33941e4612e5b --- /dev/null +++ b/ultralytics/utils/callbacks/__init__.py @@ -0,0 +1,5 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from .base import add_integration_callbacks, default_callbacks, get_default_callbacks + +__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks" diff --git a/ultralytics/utils/callbacks/base.py b/ultralytics/utils/callbacks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..11e0a8979e0c253110fb132b54cd09ee8fa524f5 --- /dev/null +++ b/ultralytics/utils/callbacks/base.py @@ -0,0 +1,217 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Base callbacks.""" + +from collections import defaultdict +from copy import deepcopy + +# Trainer callbacks ---------------------------------------------------------------------------------------------------- + + +def on_pretrain_routine_start(trainer): + """Called before the pretraining routine starts.""" + pass + + +def on_pretrain_routine_end(trainer): + """Called after the pretraining routine ends.""" + pass + + +def on_train_start(trainer): + """Called when the training starts.""" + pass + + +def on_train_epoch_start(trainer): + """Called at the start of each training epoch.""" + pass + + +def on_train_batch_start(trainer): + """Called at the start of each training batch.""" + pass + + +def optimizer_step(trainer): + """Called when the optimizer takes a step.""" + pass + + +def on_before_zero_grad(trainer): + """Called before the gradients are set to zero.""" + pass + + +def on_train_batch_end(trainer): + """Called at the end of each training batch.""" + pass + + +def on_train_epoch_end(trainer): + """Called at the end of each training epoch.""" + pass + + +def on_fit_epoch_end(trainer): + """Called at the end of each fit epoch (train + val).""" + pass + + +def on_model_save(trainer): + """Called when the model is saved.""" + pass + + +def on_train_end(trainer): + """Called when the training ends.""" + pass + + +def on_params_update(trainer): + """Called when the model parameters are updated.""" + pass + + +def teardown(trainer): + """Called during the teardown of the training process.""" + pass + + +# Validator callbacks -------------------------------------------------------------------------------------------------- + + +def on_val_start(validator): + """Called when the validation starts.""" + pass + + +def on_val_batch_start(validator): + """Called at the start of each validation batch.""" + pass + + +def on_val_batch_end(validator): + """Called at the end of each validation batch.""" + pass + + +def on_val_end(validator): + """Called when the validation ends.""" + pass + + +# Predictor callbacks -------------------------------------------------------------------------------------------------- + + +def on_predict_start(predictor): + """Called when the prediction starts.""" + pass + + +def on_predict_batch_start(predictor): + """Called at the start of each prediction batch.""" + pass + + +def on_predict_batch_end(predictor): + """Called at the end of each prediction batch.""" + pass + + +def on_predict_postprocess_end(predictor): + """Called after the post-processing of the prediction ends.""" + pass + + +def on_predict_end(predictor): + """Called when the prediction ends.""" + pass + + +# Exporter callbacks --------------------------------------------------------------------------------------------------- + + +def on_export_start(exporter): + """Called when the model export starts.""" + pass + + +def on_export_end(exporter): + """Called when the model export ends.""" + pass + + +default_callbacks = { + # Run in trainer + "on_pretrain_routine_start": [on_pretrain_routine_start], + "on_pretrain_routine_end": [on_pretrain_routine_end], + "on_train_start": [on_train_start], + "on_train_epoch_start": [on_train_epoch_start], + "on_train_batch_start": [on_train_batch_start], + "optimizer_step": [optimizer_step], + "on_before_zero_grad": [on_before_zero_grad], + "on_train_batch_end": [on_train_batch_end], + "on_train_epoch_end": [on_train_epoch_end], + "on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val + "on_model_save": [on_model_save], + "on_train_end": [on_train_end], + "on_params_update": [on_params_update], + "teardown": [teardown], + # Run in validator + "on_val_start": [on_val_start], + "on_val_batch_start": [on_val_batch_start], + "on_val_batch_end": [on_val_batch_end], + "on_val_end": [on_val_end], + # Run in predictor + "on_predict_start": [on_predict_start], + "on_predict_batch_start": [on_predict_batch_start], + "on_predict_postprocess_end": [on_predict_postprocess_end], + "on_predict_batch_end": [on_predict_batch_end], + "on_predict_end": [on_predict_end], + # Run in exporter + "on_export_start": [on_export_start], + "on_export_end": [on_export_end], +} + + +def get_default_callbacks(): + """ + Return a copy of the default_callbacks dictionary with lists as default values. + + Returns: + (defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values. + """ + return defaultdict(list, deepcopy(default_callbacks)) + + +def add_integration_callbacks(instance): + """ + Add integration callbacks from various sources to the instance's callbacks. + + Args: + instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary + of callback lists. + """ + # Load HUB callbacks + from .hub import callbacks as hub_cb + + callbacks_list = [hub_cb] + + # Load training callbacks + if "Trainer" in instance.__class__.__name__: + from .clearml import callbacks as clear_cb + from .comet import callbacks as comet_cb + from .dvc import callbacks as dvc_cb + from .mlflow import callbacks as mlflow_cb + from .neptune import callbacks as neptune_cb + from .raytune import callbacks as tune_cb + from .tensorboard import callbacks as tb_cb + from .wb import callbacks as wb_cb + + callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb]) + + # Add the callbacks to the callbacks dictionary + for callbacks in callbacks_list: + for k, v in callbacks.items(): + if v not in instance.callbacks[k]: + instance.callbacks[k].append(v) diff --git a/ultralytics/utils/callbacks/clearml.py b/ultralytics/utils/callbacks/clearml.py new file mode 100644 index 0000000000000000000000000000000000000000..5afc7a3659f84f0c56ae79eabd20d3537932ef78 --- /dev/null +++ b/ultralytics/utils/callbacks/clearml.py @@ -0,0 +1,153 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["clearml"] is True # verify integration is enabled + import clearml + from clearml import Task + + assert hasattr(clearml, "__version__") # verify package is not directory + +except (ImportError, AssertionError): + clearml = None + + +def _log_debug_samples(files, title="Debug Samples") -> None: + """ + Log files (images) as debug samples in the ClearML task. + + Args: + files (list): A list of file paths in PosixPath format. + title (str): A title that groups together images with the same values. + """ + import re + + if task := Task.current_task(): + for f in files: + if f.exists(): + it = re.search(r"_batch(\d+)", f.name) + iteration = int(it.groups()[0]) if it else 0 + task.get_logger().report_image( + title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration + ) + + +def _log_plot(title, plot_path) -> None: + """ + Log an image as a plot in the plot section of ClearML. + + Args: + title (str): The title of the plot. + plot_path (str): The path to the saved image file. + """ + import matplotlib.image as mpimg + import matplotlib.pyplot as plt + + img = mpimg.imread(plot_path) + fig = plt.figure() + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks + ax.imshow(img) + + Task.current_task().get_logger().report_matplotlib_figure( + title=title, series="", figure=fig, report_interactive=False + ) + + +def on_pretrain_routine_start(trainer): + """Runs at start of pretraining routine; initializes and connects/ logs task to ClearML.""" + try: + if task := Task.current_task(): + # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled! + # We are logging these plots and model files manually in the integration + from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO + from clearml.binding.matplotlib_bind import PatchedMatplotlib + + PatchPyTorchModelIO.update_current_task(None) + PatchedMatplotlib.update_current_task(None) + else: + task = Task.init( + project_name=trainer.args.project or "Ultralytics", + task_name=trainer.args.name, + tags=["Ultralytics"], + output_uri=True, + reuse_last_task_id=False, + auto_connect_frameworks={"pytorch": False, "matplotlib": False}, + ) + LOGGER.warning( + "ClearML Initialized a new task. If you want to run remotely, " + "please add clearml-init and connect your arguments before initializing YOLO." + ) + task.connect(vars(trainer.args), name="General") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ ClearML installed but not initialized correctly, not logging this run. {e}") + + +def on_train_epoch_end(trainer): + """Logs debug samples for the first epoch of YOLO training and report current training progress.""" + if task := Task.current_task(): + # Log debug samples + if trainer.epoch == 1: + _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic") + # Report the current training progress + for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items(): + task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch) + for k, v in trainer.lr.items(): + task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch) + + +def on_fit_epoch_end(trainer): + """Reports model information to logger at the end of an epoch.""" + if task := Task.current_task(): + # You should have access to the validation bboxes under jdict + task.get_logger().report_scalar( + title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch + ) + for k, v in trainer.metrics.items(): + task.get_logger().report_scalar("val", k, v, iteration=trainer.epoch) + if trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + for k, v in model_info_for_loggers(trainer).items(): + task.get_logger().report_single_value(k, v) + + +def on_val_end(validator): + """Logs validation results including labels and predictions.""" + if Task.current_task(): + # Log val_labels and val_pred + _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation") + + +def on_train_end(trainer): + """Logs final model and its name on training completion.""" + if task := Task.current_task(): + # Log final results, CM matrix + PR plots + files = [ + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] + files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter + for f in files: + _log_plot(title=f.stem, plot_path=f) + # Report final metrics + for k, v in trainer.validator.metrics.results_dict.items(): + task.get_logger().report_single_value(k, v) + # Log the final model + task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if clearml + else {} +) diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py new file mode 100644 index 0000000000000000000000000000000000000000..910e3c424d28e04c2c69107c4bab82a3420d76f2 --- /dev/null +++ b/ultralytics/utils/callbacks/comet.py @@ -0,0 +1,397 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops +from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["comet"] is True # verify integration is enabled + import comet_ml + + assert hasattr(comet_ml, "__version__") # verify package is not directory + + import os + from pathlib import Path + + # Ensures certain logging functions only run for supported tasks + COMET_SUPPORTED_TASKS = ["detect"] + + # Names of plots created by Ultralytics that are logged to Comet + CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized" + EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve" + LABEL_PLOT_NAMES = "labels", "labels_correlogram" + SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask" + POSE_METRICS_PLOT_PREFIX = "Box", "Pose" + + _comet_image_prediction_count = 0 + +except (ImportError, AssertionError): + comet_ml = None + + +def _get_comet_mode(): + """Returns the mode of comet set in the environment variables, defaults to 'online' if not set.""" + return os.getenv("COMET_MODE", "online") + + +def _get_comet_model_name(): + """Returns the model name for Comet from the environment variable COMET_MODEL_NAME or defaults to 'Ultralytics'.""" + return os.getenv("COMET_MODEL_NAME", "Ultralytics") + + +def _get_eval_batch_logging_interval(): + """Get the evaluation batch logging interval from environment variable or use default value 1.""" + return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1)) + + +def _get_max_image_predictions_to_log(): + """Get the maximum number of image predictions to log from the environment variables.""" + return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100)) + + +def _scale_confidence_score(score): + """Scales the given confidence score by a factor specified in an environment variable.""" + scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0)) + return score * scale + + +def _should_log_confusion_matrix(): + """Determines if the confusion matrix should be logged based on the environment variable settings.""" + return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true" + + +def _should_log_image_predictions(): + """Determines whether to log image predictions based on a specified environment variable.""" + return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true" + + +def _get_experiment_type(mode, project_name): + """Return an experiment based on mode and project name.""" + if mode == "offline": + return comet_ml.OfflineExperiment(project_name=project_name) + + return comet_ml.Experiment(project_name=project_name) + + +def _create_experiment(args): + """Ensures that the experiment object is only created in a single process during distributed training.""" + if RANK not in {-1, 0}: + return + try: + comet_mode = _get_comet_mode() + _project_name = os.getenv("COMET_PROJECT_NAME", args.project) + experiment = _get_experiment_type(comet_mode, _project_name) + experiment.log_parameters(vars(args)) + experiment.log_others( + { + "eval_batch_logging_interval": _get_eval_batch_logging_interval(), + "log_confusion_matrix_on_eval": _should_log_confusion_matrix(), + "log_image_predictions": _should_log_image_predictions(), + "max_image_predictions": _get_max_image_predictions_to_log(), + } + ) + experiment.log_other("Created from", "ultralytics") + + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ Comet installed but not initialized correctly, not logging this run. {e}") + + +def _fetch_trainer_metadata(trainer): + """Returns metadata for YOLO training including epoch and asset saving status.""" + curr_epoch = trainer.epoch + 1 + + train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size + curr_step = curr_epoch * train_num_steps_per_epoch + final_epoch = curr_epoch == trainer.epochs + + save = trainer.args.save + save_period = trainer.args.save_period + save_interval = curr_epoch % save_period == 0 + save_assets = save and save_period > 0 and save_interval and not final_epoch + + return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch) + + +def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad): + """ + YOLO resizes images during training and the label values are normalized based on this resized shape. + + This function rescales the bounding box labels to the original image shape. + """ + resized_image_height, resized_image_width = resized_image_shape + + # Convert normalized xywh format predictions to xyxy in resized scale format + box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width) + # Scale box predictions from resized image scale back to original image scale + box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad) + # Convert bounding box format from xyxy to xywh for Comet logging + box = ops.xyxy2xywh(box) + # Adjust xy center to correspond top-left corner + box[:2] -= box[2:] / 2 + box = box.tolist() + + return box + + +def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None): + """Format ground truth annotations for detection.""" + indices = batch["batch_idx"] == img_idx + bboxes = batch["bboxes"][indices] + if len(bboxes) == 0: + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes labels") + return None + + cls_labels = batch["cls"][indices].squeeze(1).tolist() + if class_name_map: + cls_labels = [str(class_name_map[label]) for label in cls_labels] + + original_image_shape = batch["ori_shape"][img_idx] + resized_image_shape = batch["resized_shape"][img_idx] + ratio_pad = batch["ratio_pad"][img_idx] + + data = [] + for box, label in zip(bboxes, cls_labels): + box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad) + data.append( + { + "boxes": [box], + "label": f"gt_{label}", + "score": _scale_confidence_score(1.0), + } + ) + + return {"name": "ground_truth", "data": data} + + +def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None): + """Format YOLO predictions for object detection visualization.""" + stem = image_path.stem + image_id = int(stem) if stem.isnumeric() else stem + + predictions = metadata.get(image_id) + if not predictions: + LOGGER.debug(f"COMET WARNING: Image: {image_path} has no bounding boxes predictions") + return None + + data = [] + for prediction in predictions: + boxes = prediction["bbox"] + score = _scale_confidence_score(prediction["score"]) + cls_label = prediction["category_id"] + if class_label_map: + cls_label = str(class_label_map[cls_label]) + + data.append({"boxes": [boxes], "label": cls_label, "score": score}) + + return {"name": "prediction", "data": data} + + +def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map): + """Join the ground truth and prediction annotations if they exist.""" + ground_truth_annotations = _format_ground_truth_annotations_for_detection( + img_idx, image_path, batch, class_label_map + ) + prediction_annotations = _format_prediction_annotations_for_detection( + image_path, prediction_metadata_map, class_label_map + ) + + annotations = [ + annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None + ] + return [annotations] if annotations else None + + +def _create_prediction_metadata_map(model_predictions): + """Create metadata map for model predictions by groupings them based on image ID.""" + pred_metadata_map = {} + for prediction in model_predictions: + pred_metadata_map.setdefault(prediction["image_id"], []) + pred_metadata_map[prediction["image_id"]].append(prediction) + + return pred_metadata_map + + +def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch): + """Log the confusion matrix to Comet experiment.""" + conf_mat = trainer.validator.confusion_matrix.matrix + names = list(trainer.data["names"].values()) + ["background"] + experiment.log_confusion_matrix( + matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step + ) + + +def _log_images(experiment, image_paths, curr_step, annotations=None): + """Logs images to the experiment with optional annotations.""" + if annotations: + for image_path, annotation in zip(image_paths, annotations): + experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation) + + else: + for image_path in image_paths: + experiment.log_image(image_path, name=image_path.stem, step=curr_step) + + +def _log_image_predictions(experiment, validator, curr_step): + """Logs predicted boxes for a single image during training.""" + global _comet_image_prediction_count + + task = validator.args.task + if task not in COMET_SUPPORTED_TASKS: + return + + jdict = validator.jdict + if not jdict: + return + + predictions_metadata_map = _create_prediction_metadata_map(jdict) + dataloader = validator.dataloader + class_label_map = validator.names + + batch_logging_interval = _get_eval_batch_logging_interval() + max_image_predictions = _get_max_image_predictions_to_log() + + for batch_idx, batch in enumerate(dataloader): + if (batch_idx + 1) % batch_logging_interval != 0: + continue + + image_paths = batch["im_file"] + for img_idx, image_path in enumerate(image_paths): + if _comet_image_prediction_count >= max_image_predictions: + return + + image_path = Path(image_path) + annotations = _fetch_annotations( + img_idx, + image_path, + batch, + predictions_metadata_map, + class_label_map, + ) + _log_images( + experiment, + [image_path], + curr_step, + annotations=annotations, + ) + _comet_image_prediction_count += 1 + + +def _log_plots(experiment, trainer): + """Logs evaluation plots and label plots for the experiment.""" + plot_filenames = None + if isinstance(trainer.validator.metrics, SegmentMetrics) and trainer.validator.metrics.task == "segment": + plot_filenames = [ + trainer.save_dir / f"{prefix}{plots}.png" + for plots in EVALUATION_PLOT_NAMES + for prefix in SEGMENT_METRICS_PLOT_PREFIX + ] + elif isinstance(trainer.validator.metrics, PoseMetrics): + plot_filenames = [ + trainer.save_dir / f"{prefix}{plots}.png" + for plots in EVALUATION_PLOT_NAMES + for prefix in POSE_METRICS_PLOT_PREFIX + ] + elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)): + plot_filenames = [trainer.save_dir / f"{plots}.png" for plots in EVALUATION_PLOT_NAMES] + + if plot_filenames is not None: + _log_images(experiment, plot_filenames, None) + + confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES] + _log_images(experiment, confusion_matrix_filenames, None) + + if not isinstance(trainer.validator.metrics, ClassifyMetrics): + label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES] + _log_images(experiment, label_plot_filenames, None) + + +def _log_model(experiment, trainer): + """Log the best-trained model to Comet.ml.""" + model_name = _get_comet_model_name() + experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True) + + +def on_pretrain_routine_start(trainer): + """Creates or resumes a CometML experiment at the start of a YOLO pre-training routine.""" + experiment = comet_ml.get_global_experiment() + is_alive = getattr(experiment, "alive", False) + if not experiment or not is_alive: + _create_experiment(trainer.args) + + +def on_train_epoch_end(trainer): + """Log metrics and save batch images at the end of training epochs.""" + experiment = comet_ml.get_global_experiment() + if not experiment: + return + + metadata = _fetch_trainer_metadata(trainer) + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + + experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch) + + +def on_fit_epoch_end(trainer): + """Logs model assets at the end of each epoch.""" + experiment = comet_ml.get_global_experiment() + if not experiment: + return + + metadata = _fetch_trainer_metadata(trainer) + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + save_assets = metadata["save_assets"] + + experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) + experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) + if curr_epoch == 1: + from ultralytics.utils.torch_utils import model_info_for_loggers + + experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) + + if not save_assets: + return + + _log_model(experiment, trainer) + if _should_log_confusion_matrix(): + _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) + if _should_log_image_predictions(): + _log_image_predictions(experiment, trainer.validator, curr_step) + + +def on_train_end(trainer): + """Perform operations at the end of training.""" + experiment = comet_ml.get_global_experiment() + if not experiment: + return + + metadata = _fetch_trainer_metadata(trainer) + curr_epoch = metadata["curr_epoch"] + curr_step = metadata["curr_step"] + plots = trainer.args.plots + + _log_model(experiment, trainer) + if plots: + _log_plots(experiment, trainer) + + _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) + _log_image_predictions(experiment, trainer.validator, curr_step) + _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step) + _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step) + experiment.end() + + global _comet_image_prediction_count + _comet_image_prediction_count = 0 + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if comet_ml + else {} +) diff --git a/ultralytics/utils/callbacks/dvc.py b/ultralytics/utils/callbacks/dvc.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc0c632ecb3384ca62a872a1e285be4a05512ba --- /dev/null +++ b/ultralytics/utils/callbacks/dvc.py @@ -0,0 +1,145 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["dvc"] is True # verify integration is enabled + import dvclive + + assert checks.check_version("dvclive", "2.11.0", verbose=True) + + import os + import re + from pathlib import Path + + # DVCLive logger instance + live = None + _processed_plots = {} + + # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we + # distinguish final evaluation of the best model vs last epoch validation + _training_epoch = False + +except (ImportError, AssertionError, TypeError): + dvclive = None + + +def _log_images(path, prefix=""): + """Logs images at specified path with an optional prefix using DVCLive.""" + if live: + name = path.name + + # Group images by batch to enable sliders in UI + if m := re.search(r"_batch(\d+)", name): + ni = m[1] + new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem) + name = (Path(new_stem) / ni).with_suffix(path.suffix) + + live.log_image(os.path.join(prefix, name), path) + + +def _log_plots(plots, prefix=""): + """Logs plot images for training progress if they have not been previously processed.""" + for name, params in plots.items(): + timestamp = params["timestamp"] + if _processed_plots.get(name) != timestamp: + _log_images(name, prefix) + _processed_plots[name] = timestamp + + +def _log_confusion_matrix(validator): + """Logs the confusion matrix for the given validator using DVCLive.""" + targets = [] + preds = [] + matrix = validator.confusion_matrix.matrix + names = list(validator.names.values()) + if validator.confusion_matrix.task == "detect": + names += ["background"] + + for ti, pred in enumerate(matrix.T.astype(int)): + for pi, num in enumerate(pred): + targets.extend([names[ti]] * num) + preds.extend([names[pi]] * num) + + live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True) + + +def on_pretrain_routine_start(trainer): + """Initializes DVCLive logger for training metadata during pre-training routine.""" + try: + global live + live = dvclive.Live(save_dvc_exp=True, cache_images=True) + LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).") + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ DVCLive installed but not initialized correctly, not logging this run. {e}") + + +def on_pretrain_routine_end(trainer): + """Logs plots related to the training process at the end of the pretraining routine.""" + _log_plots(trainer.plots, "train") + + +def on_train_start(trainer): + """Logs the training parameters if DVCLive logging is active.""" + if live: + live.log_params(trainer.args) + + +def on_train_epoch_start(trainer): + """Sets the global variable _training_epoch value to True at the start of training each epoch.""" + global _training_epoch + _training_epoch = True + + +def on_fit_epoch_end(trainer): + """Logs training metrics and model info, and advances to next step on the end of each fit epoch.""" + global _training_epoch + if live and _training_epoch: + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} + for metric, value in all_metrics.items(): + live.log_metric(metric, value) + + if trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + for metric, value in model_info_for_loggers(trainer).items(): + live.log_metric(metric, value, plot=False) + + _log_plots(trainer.plots, "train") + _log_plots(trainer.validator.plots, "val") + + live.next_step() + _training_epoch = False + + +def on_train_end(trainer): + """Logs the best metrics, plots, and confusion matrix at the end of training if DVCLive is active.""" + if live: + # At the end log the best metrics. It runs validator on the best model internally. + all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} + for metric, value in all_metrics.items(): + live.log_metric(metric, value, plot=False) + + _log_plots(trainer.plots, "val") + _log_plots(trainer.validator.plots, "val") + _log_confusion_matrix(trainer.validator) + + if trainer.best.exists(): + live.log_artifact(trainer.best, copy=True, type="model") + + live.end() + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_train_start": on_train_start, + "on_train_epoch_start": on_train_epoch_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if dvclive + else {} +) diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py new file mode 100644 index 0000000000000000000000000000000000000000..4709fbea8ba2d7d3b0e6501f9123658656660c05 --- /dev/null +++ b/ultralytics/utils/callbacks/hub.py @@ -0,0 +1,108 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import json +from time import time + +from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events +from ultralytics.utils import LOGGER, RANK, SETTINGS + + +def on_pretrain_routine_start(trainer): + """Create a remote Ultralytics HUB session to log local model training.""" + if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None: + trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args) + + +def on_pretrain_routine_end(trainer): + """Logs info before starting timer for upload rate limit.""" + if session := getattr(trainer, "hub_session", None): + # Start timer for upload rate limit + session.timers = {"metrics": time(), "ckpt": time()} # start timer on session.rate_limit + + +def on_fit_epoch_end(trainer): + """Uploads training progress metrics at the end of each epoch.""" + if session := getattr(trainer, "hub_session", None): + # Upload metrics after val end + all_plots = { + **trainer.label_loss_items(trainer.tloss, prefix="train"), + **trainer.metrics, + } + if trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + all_plots = {**all_plots, **model_info_for_loggers(trainer)} + + session.metrics_queue[trainer.epoch] = json.dumps(all_plots) + + # If any metrics fail to upload, add them to the queue to attempt uploading again. + if session.metrics_upload_failed_queue: + session.metrics_queue.update(session.metrics_upload_failed_queue) + + if time() - session.timers["metrics"] > session.rate_limits["metrics"]: + session.upload_metrics() + session.timers["metrics"] = time() # reset timer + session.metrics_queue = {} # reset queue + + +def on_model_save(trainer): + """Saves checkpoints to Ultralytics HUB with rate limiting.""" + if session := getattr(trainer, "hub_session", None): + # Upload checkpoints with rate limiting + is_best = trainer.best_fitness == trainer.fitness + if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]: + LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}") + session.upload_model(trainer.epoch, trainer.last, is_best) + session.timers["ckpt"] = time() # reset timer + + +def on_train_end(trainer): + """Upload final model and metrics to Ultralytics HUB at the end of training.""" + if session := getattr(trainer, "hub_session", None): + # Upload final model and metrics with exponential standoff + LOGGER.info(f"{PREFIX}Syncing final model...") + session.upload_model( + trainer.epoch, + trainer.best, + map=trainer.metrics.get("metrics/mAP50-95(B)", 0), + final=True, + ) + session.alive = False # stop heartbeats + LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀") + + +def on_train_start(trainer): + """Run events on train start.""" + events(trainer.args) + + +def on_val_start(validator): + """Runs events on validation start.""" + events(validator.args) + + +def on_predict_start(predictor): + """Run events on predict start.""" + events(predictor.args) + + +def on_export_start(exporter): + """Run events on export start.""" + events(exporter.args) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_model_save": on_model_save, + "on_train_end": on_train_end, + "on_train_start": on_train_start, + "on_val_start": on_val_start, + "on_predict_start": on_predict_start, + "on_export_start": on_export_start, + } + if SETTINGS["hub"] is True + else {} +) # verify enabled diff --git a/ultralytics/utils/callbacks/mlflow.py b/ultralytics/utils/callbacks/mlflow.py new file mode 100644 index 0000000000000000000000000000000000000000..49c8a200ee62bb97ec1429510e0dd943e0b54d0e --- /dev/null +++ b/ultralytics/utils/callbacks/mlflow.py @@ -0,0 +1,137 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +""" +MLflow Logging for Ultralytics YOLO. + +This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts. +For setting up, a tracking URI should be specified. The logging can be customized using environment variables. + +Commands: + 1. To set a project name: + `export MLFLOW_EXPERIMENT_NAME=` or use the project= argument + + 2. To set a run name: + `export MLFLOW_RUN=` or use the name= argument + + 3. To start a local MLflow server: + mlflow server --backend-store-uri runs/mlflow + It will by default start a local server at http://127.0.0.1:5000. + To specify a different URI, set the MLFLOW_TRACKING_URI environment variable. + + 4. To kill all running MLflow server instances: + ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9 +""" + +from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr + +try: + import os + + assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest + assert SETTINGS["mlflow"] is True # verify integration is enabled + import mlflow + + assert hasattr(mlflow, "__version__") # verify package is not directory + from pathlib import Path + + PREFIX = colorstr("MLflow: ") + +except (ImportError, AssertionError): + mlflow = None + + +def sanitize_dict(x): + """Sanitize dictionary keys by removing parentheses and converting values to floats.""" + return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()} + + +def on_pretrain_routine_end(trainer): + """ + Log training parameters to MLflow at the end of the pretraining routine. + + This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI, + experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters + from the trainer. + + Args: + trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log. + + Global: + mlflow: The imported mlflow module to use for logging. + + Environment Variables: + MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'. + MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project. + MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name. + MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after the end of training. + """ + global mlflow + + uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow") + LOGGER.debug(f"{PREFIX} tracking uri: {uri}") + mlflow.set_tracking_uri(uri) + + # Set experiment and run names + experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics" + run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name + mlflow.set_experiment(experiment_name) + + mlflow.autolog() + try: + active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name) + LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}") + if Path(uri).is_dir(): + LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'") + LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'") + mlflow.log_params(dict(trainer.args)) + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ Failed to initialize: {e}\n{PREFIX}WARNING ⚠️ Not tracking this run") + + +def on_train_epoch_end(trainer): + """Log training metrics at the end of each train epoch to MLflow.""" + if mlflow: + mlflow.log_metrics( + metrics={ + **sanitize_dict(trainer.lr), + **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")), + }, + step=trainer.epoch, + ) + + +def on_fit_epoch_end(trainer): + """Log training metrics at the end of each fit epoch to MLflow.""" + if mlflow: + mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch) + + +def on_train_end(trainer): + """Log model artifacts at the end of the training.""" + if not mlflow: + return + mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt + for f in trainer.save_dir.glob("*"): # log all other files in save_dir + if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: + mlflow.log_artifact(str(f)) + keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true" + if keep_run_active: + LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()") + else: + mlflow.end_run() + LOGGER.debug(f"{PREFIX}mlflow run ended") + + LOGGER.info( + f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'" + ) + + +callbacks = ( + { + "on_pretrain_routine_end": on_pretrain_routine_end, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if mlflow + else {} +) diff --git a/ultralytics/utils/callbacks/neptune.py b/ultralytics/utils/callbacks/neptune.py new file mode 100644 index 0000000000000000000000000000000000000000..7adfdad1fdb9f45af9755ec780677ff995d66a20 --- /dev/null +++ b/ultralytics/utils/callbacks/neptune.py @@ -0,0 +1,116 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["neptune"] is True # verify integration is enabled + import neptune + from neptune.types import File + + assert hasattr(neptune, "__version__") + + run = None # NeptuneAI experiment logger instance + +except (ImportError, AssertionError): + neptune = None + + +def _log_scalars(scalars, step=0): + """Log scalars to the NeptuneAI experiment logger.""" + if run: + for k, v in scalars.items(): + run[k].append(value=v, step=step) + + +def _log_images(imgs_dict, group=""): + """Log scalars to the NeptuneAI experiment logger.""" + if run: + for k, v in imgs_dict.items(): + run[f"{group}/{k}"].upload(File(v)) + + +def _log_plot(title, plot_path): + """ + Log plots to the NeptuneAI experiment logger. + + Args: + title (str): Title of the plot. + plot_path (PosixPath | str): Path to the saved image file. + """ + import matplotlib.image as mpimg + import matplotlib.pyplot as plt + + img = mpimg.imread(plot_path) + fig = plt.figure() + ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks + ax.imshow(img) + run[f"Plots/{title}"].upload(fig) + + +def on_pretrain_routine_start(trainer): + """Callback function called before the training routine starts.""" + try: + global run + run = neptune.init_run( + project=trainer.args.project or "Ultralytics", + name=trainer.args.name, + tags=["Ultralytics"], + ) + run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()} + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ NeptuneAI installed but not initialized correctly, not logging this run. {e}") + + +def on_train_epoch_end(trainer): + """Callback function called at end of each training epoch.""" + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) + _log_scalars(trainer.lr, trainer.epoch + 1) + if trainer.epoch == 1: + _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic") + + +def on_fit_epoch_end(trainer): + """Callback function called at end of each fit (train+val) epoch.""" + if run and trainer.epoch == 0: + from ultralytics.utils.torch_utils import model_info_for_loggers + + run["Configuration/Model"] = model_info_for_loggers(trainer) + _log_scalars(trainer.metrics, trainer.epoch + 1) + + +def on_val_end(validator): + """Callback function called at end of each validation.""" + if run: + # Log val_labels and val_pred + _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation") + + +def on_train_end(trainer): + """Callback function called at end of training.""" + if run: + # Log final results, CM matrix + PR plots + files = [ + "results.png", + "confusion_matrix.png", + "confusion_matrix_normalized.png", + *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), + ] + files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter + for f in files: + _log_plot(title=f.stem, plot_path=f) + # Log the final model + run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best))) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_val_end": on_val_end, + "on_train_end": on_train_end, + } + if neptune + else {} +) diff --git a/ultralytics/utils/callbacks/raytune.py b/ultralytics/utils/callbacks/raytune.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e01d0985f14e6b05f0604a01aa0a1def167535 --- /dev/null +++ b/ultralytics/utils/callbacks/raytune.py @@ -0,0 +1,28 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import SETTINGS + +try: + assert SETTINGS["raytune"] is True # verify integration is enabled + import ray + from ray import tune + from ray.air import session + +except (ImportError, AssertionError): + tune = None + + +def on_fit_epoch_end(trainer): + """Sends training metrics to Ray Tune at end of each epoch.""" + if ray.train._internal.session._get_session(): # replacement for deprecated ray.tune.is_session_enabled() + metrics = trainer.metrics + session.report({**metrics, **{"epoch": trainer.epoch + 1}}) + + +callbacks = ( + { + "on_fit_epoch_end": on_fit_epoch_end, + } + if tune + else {} +) diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py new file mode 100644 index 0000000000000000000000000000000000000000..2920fa23bf1361ed74e8113b8273bbdd8131e0b0 --- /dev/null +++ b/ultralytics/utils/callbacks/tensorboard.py @@ -0,0 +1,106 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr + +try: + # WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674 + from torch.utils.tensorboard import SummaryWriter + + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["tensorboard"] is True # verify integration is enabled + WRITER = None # TensorBoard SummaryWriter instance + PREFIX = colorstr("TensorBoard: ") + + # Imports below only required if TensorBoard enabled + import warnings + from copy import deepcopy + + from ultralytics.utils.torch_utils import de_parallel, torch + +except (ImportError, AssertionError, TypeError, AttributeError): + # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows + # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed + SummaryWriter = None + + +def _log_scalars(scalars, step=0): + """Logs scalar values to TensorBoard.""" + if WRITER: + for k, v in scalars.items(): + WRITER.add_scalar(k, v, step) + + +def _log_tensorboard_graph(trainer): + """Log model graph to TensorBoard.""" + # Input image + imgsz = trainer.args.imgsz + imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz + p = next(trainer.model.parameters()) # for device, type + im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning + warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning + + # Try simple method first (YOLO) + try: + trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes + WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), []) + LOGGER.info(f"{PREFIX}model graph visualization added ✅") + return + + except Exception: + # Fallback to TorchScript export steps (RTDETR) + try: + model = deepcopy(de_parallel(trainer.model)) + model.eval() + model = model.fuse(verbose=False) + for m in model.modules(): + if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class) + m.export = True + m.format = "torchscript" + model(im) # dry run + WRITER.add_graph(torch.jit.trace(model, im, strict=False), []) + LOGGER.info(f"{PREFIX}model graph visualization added ✅") + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}") + + +def on_pretrain_routine_start(trainer): + """Initialize TensorBoard logging with SummaryWriter.""" + if SummaryWriter: + try: + global WRITER + WRITER = SummaryWriter(str(trainer.save_dir)) + LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") + except Exception as e: + LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}") + + +def on_train_start(trainer): + """Log TensorBoard graph.""" + if WRITER: + _log_tensorboard_graph(trainer) + + +def on_train_epoch_end(trainer): + """Logs scalar statistics at the end of a training epoch.""" + _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) + _log_scalars(trainer.lr, trainer.epoch + 1) + + +def on_fit_epoch_end(trainer): + """Logs epoch metrics at end of training epoch.""" + _log_scalars(trainer.metrics, trainer.epoch + 1) + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_start": on_train_start, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_epoch_end": on_train_epoch_end, + } + if SummaryWriter + else {} +) diff --git a/ultralytics/utils/callbacks/wb.py b/ultralytics/utils/callbacks/wb.py new file mode 100644 index 0000000000000000000000000000000000000000..7242d51e3d8ee705ebc1edec00ab02fe391b64a2 --- /dev/null +++ b/ultralytics/utils/callbacks/wb.py @@ -0,0 +1,170 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import SETTINGS, TESTS_RUNNING +from ultralytics.utils.torch_utils import model_info_for_loggers + +try: + assert not TESTS_RUNNING # do not log pytest + assert SETTINGS["wandb"] is True # verify integration is enabled + import wandb as wb + + assert hasattr(wb, "__version__") # verify package is not directory + _processed_plots = {} + +except (ImportError, AssertionError): + wb = None + + +def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"): + """ + Create and log a custom metric visualization to wandb.plot.pr_curve. + + This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall + curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across + different classes. + + Args: + x (List): Values for the x-axis; expected to have length N. + y (List): Corresponding values for the y-axis; also expected to have length N. + classes (List): Labels identifying the class of each point; length N. + title (str, optional): Title for the plot; defaults to 'Precision Recall Curve'. + x_title (str, optional): Label for the x-axis; defaults to 'Recall'. + y_title (str, optional): Label for the y-axis; defaults to 'Precision'. + + Returns: + (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization. + """ + import pandas # scope for faster 'import ultralytics' + + df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3) + fields = {"x": "x", "y": "y", "class": "class"} + string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title} + return wb.plot_table( + "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields + ) + + +def _plot_curve( + x, + y, + names=None, + id="precision-recall", + title="Precision Recall Curve", + x_title="Recall", + y_title="Precision", + num_x=100, + only_mean=False, +): + """ + Log a metric curve visualization. + + This function generates a metric curve based on input data and logs the visualization to wandb. + The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag. + + Args: + x (np.ndarray): Data points for the x-axis with length N. + y (np.ndarray): Corresponding data points for the y-axis with shape CxN, where C is the number of classes. + names (list, optional): Names of the classes corresponding to the y-axis data; length C. Defaults to []. + id (str, optional): Unique identifier for the logged data in wandb. Defaults to 'precision-recall'. + title (str, optional): Title for the visualization plot. Defaults to 'Precision Recall Curve'. + x_title (str, optional): Label for the x-axis. Defaults to 'Recall'. + y_title (str, optional): Label for the y-axis. Defaults to 'Precision'. + num_x (int, optional): Number of interpolated data points for visualization. Defaults to 100. + only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. Defaults to True. + + Note: + The function leverages the '_custom_table' function to generate the actual visualization. + """ + import numpy as np + + # Create new x + if names is None: + names = [] + x_new = np.linspace(x[0], x[-1], num_x).round(5) + + # Create arrays for logging + x_log = x_new.tolist() + y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist() + + if only_mean: + table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title]) + wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)}) + else: + classes = ["mean"] * len(x_log) + for i, yi in enumerate(y): + x_log.extend(x_new) # add new x + y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x + classes.extend([names[i]] * len(x_new)) # add class names + wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False) + + +def _log_plots(plots, step): + """Logs plots from the input dictionary if they haven't been logged already at the specified step.""" + for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration + timestamp = params["timestamp"] + if _processed_plots.get(name) != timestamp: + wb.run.log({name.stem: wb.Image(str(name))}, step=step) + _processed_plots[name] = timestamp + + +def on_pretrain_routine_start(trainer): + """Initiate and start project if module is present.""" + if not wb.run: + wb.init( + project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics", + name=str(trainer.args.name).replace("/", "-"), + config=vars(trainer.args), + ) + + +def on_fit_epoch_end(trainer): + """Logs training metrics and model information at the end of an epoch.""" + wb.run.log(trainer.metrics, step=trainer.epoch + 1) + _log_plots(trainer.plots, step=trainer.epoch + 1) + _log_plots(trainer.validator.plots, step=trainer.epoch + 1) + if trainer.epoch == 0: + wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) + + +def on_train_epoch_end(trainer): + """Log metrics and save images at the end of each training epoch.""" + wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) + wb.run.log(trainer.lr, step=trainer.epoch + 1) + if trainer.epoch == 1: + _log_plots(trainer.plots, step=trainer.epoch + 1) + + +def on_train_end(trainer): + """Save the best model as an artifact at end of training.""" + _log_plots(trainer.validator.plots, step=trainer.epoch + 1) + _log_plots(trainer.plots, step=trainer.epoch + 1) + art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model") + if trainer.best.exists(): + art.add_file(trainer.best) + wb.run.log_artifact(art, aliases=["best"]) + # Check if we actually have plots to save + if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"): + for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results): + x, y, x_title, y_title = curve_values + _plot_curve( + x, + y, + names=list(trainer.validator.metrics.names.values()), + id=f"curves/{curve_name}", + title=curve_name, + x_title=x_title, + y_title=y_title, + ) + wb.run.finish() # required or run continues on dashboard + + +callbacks = ( + { + "on_pretrain_routine_start": on_pretrain_routine_start, + "on_train_epoch_end": on_train_epoch_end, + "on_fit_epoch_end": on_fit_epoch_end, + "on_train_end": on_train_end, + } + if wb + else {} +) diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py new file mode 100644 index 0000000000000000000000000000000000000000..b6de75e92ef9f47c8d6a2aac008bb1c9b945148b --- /dev/null +++ b/ultralytics/utils/checks.py @@ -0,0 +1,803 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import glob +import inspect +import math +import os +import platform +import re +import shutil +import subprocess +import time +from importlib import metadata +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import requests +import torch + +from ultralytics.utils import ( + ASSETS, + AUTOINSTALL, + IS_COLAB, + IS_GIT_DIR, + IS_KAGGLE, + IS_PIP_PACKAGE, + LINUX, + LOGGER, + MACOS, + ONLINE, + PYTHON_VERSION, + ROOT, + TORCHVISION_VERSION, + USER_CONFIG_DIR, + WINDOWS, + Retry, + SimpleNamespace, + ThreadingLocked, + TryExcept, + clean_url, + colorstr, + downloads, + emojis, + is_github_action_running, + url2file, +) + + +def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): + """ + Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. + + Args: + file_path (Path): Path to the requirements.txt file. + package (str, optional): Python package to use instead of requirements.txt file, i.e. package='ultralytics'. + + Returns: + (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. + + Example: + ```python + from ultralytics.utils.checks import parse_requirements + + parse_requirements(package="ultralytics") + ``` + """ + if package: + requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] + else: + requires = Path(file_path).read_text().splitlines() + + requirements = [] + for line in requires: + line = line.strip() + if line and not line.startswith("#"): + line = line.split("#")[0].strip() # ignore inline comments + if match := re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line): + requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) + + return requirements + + +def parse_version(version="0.0.0") -> tuple: + """ + Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This + function replaces deprecated 'pkg_resources.parse_version(v)'. + + Args: + version (str): Version string, i.e. '2.0.1+cpu' + + Returns: + (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) + """ + try: + return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") + return 0, 0, 0 + + +def is_ascii(s) -> bool: + """ + Check if a string is composed of only ASCII characters. + + Args: + s (str): String to be checked. + + Returns: + (bool): True if the string is composed only of ASCII characters, False otherwise. + """ + # Convert list, tuple, None, etc. to string + s = str(s) + + # Check if the string is composed of only ASCII characters + return all(ord(c) < 128 for c in s) + + +def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): + """ + Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the + stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. + + Args: + imgsz (int | cList[int]): Image size. + stride (int): Stride value. + min_dim (int): Minimum number of dimensions. + max_dim (int): Maximum number of dimensions. + floor (int): Minimum allowed value for image size. + + Returns: + (List[int]): Updated image size. + """ + # Convert stride to integer if it is a tensor + stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) + + # Convert image size to list if it is an integer + if isinstance(imgsz, int): + imgsz = [imgsz] + elif isinstance(imgsz, (list, tuple)): + imgsz = list(imgsz) + elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' + imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) + else: + raise TypeError( + f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " + f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" + ) + + # Apply max_dim + if len(imgsz) > max_dim: + msg = ( + "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " + "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" + ) + if max_dim != 1: + raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") + LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") + imgsz = [max(imgsz)] + # Make image size a multiple of the stride + sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] + + # Print warning message if image size was updated + if sz != imgsz: + LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") + + # Add missing dimensions if necessary + sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz + + return sz + + +def check_version( + current: str = "0.0.0", + required: str = "0.0.0", + name: str = "version", + hard: bool = False, + verbose: bool = False, + msg: str = "", +) -> bool: + """ + Check current version against the required version or range. + + Args: + current (str): Current version or package name to get version from. + required (str): Required version or range (in pip-style format). + name (str, optional): Name to be used in warning message. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + msg (str, optional): Extra message to display if verbose. + + Returns: + (bool): True if requirement is met, False otherwise. + + Example: + ```python + # Check if current version is exactly 22.04 + check_version(current="22.04", required="==22.04") + + # Check if current version is greater than or equal to 22.04 + check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed + + # Check if current version is less than or equal to 22.04 + check_version(current="22.04", required="<=22.04") + + # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) + check_version(current="21.10", required=">20.04,<22.04") + ``` + """ + if not current: # if current is '' or None + LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") + return True + elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' + try: + name = current # assigned package name to 'name' arg + current = metadata.version(current) # get version string from package name + except metadata.PackageNotFoundError as e: + if hard: + raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e + else: + return False + + if not required: # if required is '' or None + return True + + if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' + (WINDOWS and "win32" not in required) + or (LINUX and "linux" not in required) + or (MACOS and "macos" not in required and "darwin" not in required) + ): + return True + + op = "" + version = "" + result = True + c = parse_version(current) # '1.2.3' -> (1, 2, 3) + for r in required.strip(",").split(","): + op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') + if not op: + op = ">=" # assume >= if no op passed + v = parse_version(version) # '1.2.3' -> (1, 2, 3) + if op == "==" and c != v: + result = False + elif op == "!=" and c == v: + result = False + elif op == ">=" and not (c >= v): + result = False + elif op == "<=" and not (c <= v): + result = False + elif op == ">" and not (c > v): + result = False + elif op == "<" and not (c < v): + result = False + if not result: + warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" + if hard: + raise ModuleNotFoundError(emojis(warning)) # assert version requirements met + if verbose: + LOGGER.warning(warning) + return result + + +def check_latest_pypi_version(package_name="ultralytics"): + """ + Returns the latest version of a PyPI package without downloading or installing it. + + Args: + package_name (str): The name of the package to find the latest version for. + + Returns: + (str): The latest version of the package. + """ + try: + requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning + response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) + if response.status_code == 200: + return response.json()["info"]["version"] + except Exception: + return None + + +def check_pip_update_available(): + """ + Checks if a new version of the ultralytics package is available on PyPI. + + Returns: + (bool): True if an update is available, False otherwise. + """ + if ONLINE and IS_PIP_PACKAGE: + try: + from ultralytics import __version__ + + latest = check_latest_pypi_version() + if check_version(__version__, f"<{latest}"): # check if current version is < latest version + LOGGER.info( + f"New https://pypi.org/project/ultralytics/{latest} available 😃 " + f"Update with 'pip install -U ultralytics'" + ) + return True + except Exception: + pass + return False + + +@ThreadingLocked() +def check_font(font="Arial.ttf"): + """ + Find font locally or download to user's configuration directory if it does not already exist. + + Args: + font (str): Path or name of font. + + Returns: + file (Path): Resolved font file path. + """ + from matplotlib import font_manager + + # Check USER_CONFIG_DIR + name = Path(font).name + file = USER_CONFIG_DIR / name + if file.exists(): + return file + + # Check system fonts + matches = [s for s in font_manager.findSystemFonts() if font in s] + if any(matches): + return matches[0] + + # Download to USER_CONFIG_DIR if missing + url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}" + if downloads.is_url(url, check=True): + downloads.safe_download(url=url, file=file) + return file + + +def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: + """ + Check current python version against the required minimum version. + + Args: + minimum (str): Required minimum version of python. + hard (bool, optional): If True, raise an AssertionError if the requirement is not met. + verbose (bool, optional): If True, print warning message if requirement is not met. + + Returns: + (bool): Whether the installed Python version meets the minimum constraints. + """ + return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose) + + +@TryExcept() +def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): + """ + Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. + + Args: + requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a + string, or a list of package requirements as strings. + exclude (Tuple[str]): Tuple of package names to exclude from checking. + install (bool): If True, attempt to auto-update packages that don't meet requirements. + cmds (str): Additional commands to pass to the pip install command when auto-updating. + + Example: + ```python + from ultralytics.utils.checks import check_requirements + + # Check a requirements.txt file + check_requirements("path/to/requirements.txt") + + # Check a single package + check_requirements("ultralytics>=8.0.0") + + # Check multiple packages + check_requirements(["numpy", "ultralytics>=8.0.0"]) + ``` + """ + prefix = colorstr("red", "bold", "requirements:") + if isinstance(requirements, Path): # requirements.txt file + file = requirements.resolve() + assert file.exists(), f"{prefix} {file} not found, check failed." + requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] + elif isinstance(requirements, str): + requirements = [requirements] + + pkgs = [] + for r in requirements: + r_stripped = r.split("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' + match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) + name, required = match[1], match[2].strip() if match[2] else "" + try: + assert check_version(metadata.version(name), required) # exception if requirements not met + except (AssertionError, metadata.PackageNotFoundError): + pkgs.append(r) + + @Retry(times=2, delay=1) + def attempt_install(packages, commands): + """Attempt pip install command with retries on failure.""" + return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() + + s = " ".join(f'"{x}"' for x in pkgs) # console string + if s: + if install and AUTOINSTALL: # check environment variable + n = len(pkgs) # number of packages updates + LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") + try: + t = time.time() + assert ONLINE, "AutoUpdate skipped (offline)" + LOGGER.info(attempt_install(s, cmds)) + dt = time.time() - t + LOGGER.info( + f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" + f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" + ) + except Exception as e: + LOGGER.warning(f"{prefix} ❌ {e}") + return False + else: + return False + + return True + + +def check_torchvision(): + """ + Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. + + This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according + to the provided compatibility table based on: + https://github.com/pytorch/vision#installation. + + The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible + Torchvision versions. + """ + # Compatibility table + compatibility_table = { + "2.5": ["0.20"], + "2.4": ["0.19"], + "2.3": ["0.18"], + "2.2": ["0.17"], + "2.1": ["0.16"], + "2.0": ["0.15"], + "1.13": ["0.14"], + "1.12": ["0.13"], + } + + # Extract only the major and minor versions + v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) + if v_torch in compatibility_table: + compatible_versions = compatibility_table[v_torch] + v_torchvision = ".".join(TORCHVISION_VERSION.split("+")[0].split(".")[:2]) + if all(v_torchvision != v for v in compatible_versions): + print( + f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" + f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " + "'pip install -U torch torchvision' to update both.\n" + "For a full compatibility table see https://github.com/pytorch/vision#installation" + ) + + +def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""): + """Check file(s) for acceptable suffix.""" + if file and suffix: + if isinstance(suffix, str): + suffix = (suffix,) + for f in file if isinstance(file, (list, tuple)) else [file]: + s = Path(f).suffix.lower().strip() # file suffix + if len(s): + assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" + + +def check_yolov5u_filename(file: str, verbose: bool = True): + """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" + if "yolov3" in file or "yolov5" in file: + if "u.yaml" in file: + file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml + elif ".pt" in file and "u" not in file: + original_file = file + file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt + file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt + if file != original_file and verbose: + LOGGER.info( + f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " + f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " + f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" + ) + return file + + +def check_model_file_from_stem(model="yolov8n"): + """Return a model filename from a valid model stem.""" + if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: + return Path(model).with_suffix(".pt") # add suffix, i.e. yolov8n -> yolov8n.pt + else: + return model + + +def check_file(file, suffix="", download=True, download_dir=".", hard=True): + """Search/download file (if necessary) and return path.""" + check_suffix(file, suffix) # optional + file = str(file).strip() # convert to string and strip spaces + file = check_yolov5u_filename(file) # yolov5n -> yolov5nu + if ( + not file + or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 + or file.lower().startswith("grpc://") + ): # file exists or gRPC Triton images + return file + elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download + url = file # warning: Pathlib turns :// -> :/ + file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth + if file.exists(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + downloads.safe_download(url=url, file=file, unzip=False) + return str(file) + else: # search + files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file + if not files and hard: + raise FileNotFoundError(f"'{file}' does not exist") + elif len(files) > 1 and hard: + raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") + return files[0] if len(files) else [] # return file + + +def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): + """Search/download YAML file (if necessary) and return path, checking suffix.""" + return check_file(file, suffix, hard=hard) + + +def check_is_path_safe(basedir, path): + """ + Check if the resolved path is under the intended directory to prevent path traversal. + + Args: + basedir (Path | str): The intended directory. + path (Path | str): The path to check. + + Returns: + (bool): True if the path is safe, False otherwise. + """ + base_dir_resolved = Path(basedir).resolve() + path_resolved = Path(path).resolve() + + return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts + + +def check_imshow(warn=False): + """Check if environment supports image displays.""" + try: + if LINUX: + assert not IS_COLAB and not IS_KAGGLE + assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set." + cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image + cv2.waitKey(1) + cv2.destroyAllWindows() + cv2.waitKey(1) + return True + except Exception as e: + if warn: + LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") + return False + + +def check_yolo(verbose=True, device=""): + """Return a human-readable YOLO software and hardware summary.""" + import psutil + + from ultralytics.utils.torch_utils import select_device + + if IS_COLAB: + shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory + + if verbose: + # System info + gib = 1 << 30 # bytes per GiB + ram = psutil.virtual_memory().total + total, used, free = shutil.disk_usage("/") + s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" + try: + from IPython import display + + display.clear_output() # clear display if notebook + except ImportError: + pass + else: + s = "" + + select_device(device=device, newline=False) + LOGGER.info(f"Setup complete ✅ {s}") + + +def collect_system_info(): + """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" + import psutil + + from ultralytics.utils import ENVIRONMENT # scope to avoid circular import + from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info + + gib = 1 << 30 # bytes per GiB + cuda = torch and torch.cuda.is_available() + check_yolo() + total, used, free = shutil.disk_usage("/") + + info_dict = { + "OS": platform.platform(), + "Environment": ENVIRONMENT, + "Python": PYTHON_VERSION, + "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", + "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB", + "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB", + "CPU": get_cpu_info(), + "CPU count": os.cpu_count(), + "GPU": get_gpu_info(index=0) if cuda else None, + "GPU count": torch.cuda.device_count() if cuda else None, + "CUDA": torch.version.cuda if cuda else None, + } + LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n") + + package_info = {} + for r in parse_requirements(package="ultralytics"): + try: + current = metadata.version(r.name) + is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ " + except metadata.PackageNotFoundError: + current = "(not installed)" + is_met = "❌ " + package_info[r.name] = f"{is_met}{current}{r.specifier}" + LOGGER.info(f"{r.name:<20}{package_info[r.name]}") + + info_dict["Package Info"] = package_info + + if is_github_action_running(): + github_info = { + "RUNNER_OS": os.getenv("RUNNER_OS"), + "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"), + "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"), + "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"), + "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"), + "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"), + } + LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items())) + info_dict["GitHub Info"] = github_info + + return info_dict + + +def check_amp(model): + """ + Checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO11 model. If the checks fail, it means + there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will be disabled + during training. + + Args: + model (nn.Module): A YOLO11 model instance. + + Example: + ```python + from ultralytics import YOLO + from ultralytics.utils.checks import check_amp + + model = YOLO("yolo11n.pt").model.cuda() + check_amp(model) + ``` + + Returns: + (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False. + """ + from ultralytics.utils.torch_utils import autocast + + device = next(model.parameters()).device # get model device + prefix = colorstr("AMP: ") + if device.type in {"cpu", "mps"}: + return False # AMP only used on CUDA devices + else: + # GPUs that have issues with AMP + pattern = re.compile( + r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE + ) + + gpu = torch.cuda.get_device_name(device) + if bool(pattern.search(gpu)): + LOGGER.warning( + f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + + def amp_allclose(m, im): + """All close FP32 vs AMP results.""" + batch = [im] * 8 + imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 + a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference + with autocast(enabled=True): + b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference + del m + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance + + im = ASSETS / "bus.jpg" # image to check + LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") + warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." + try: + from ultralytics import YOLO + + # assert amp_allclose(YOLO("yolo11n.pt"), im) + assert amp_allclose(YOLO("yolov12n.pt"), im) + LOGGER.info(f"{prefix}checks passed ✅") + except ConnectionError: + LOGGER.warning( + f"{prefix}checks skipped ⚠️. Offline and unable to download YOLO11n for AMP checks. {warning_msg}" + ) + except (AttributeError, ModuleNotFoundError): + LOGGER.warning( + f"{prefix}checks skipped ⚠️. " + f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" + ) + except AssertionError: + LOGGER.warning( + f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " + f"NaN losses or zero-mAP results, so AMP will be disabled during training." + ) + return False + return True + + +def git_describe(path=ROOT): # path must be a directory + """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" + try: + return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] + except Exception: + return "" + + +def print_args(args: Optional[dict] = None, show_file=True, show_func=False): + """Print function arguments (optional args dict).""" + + def strip_auth(v): + """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" + return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v + + x = inspect.currentframe().f_back # previous frame + file, _, func, _, _ = inspect.getframeinfo(x) + if args is None: # get args automatically + args, _, _, frm = inspect.getargvalues(x) + args = {k: v for k, v in frm.items() if k in args} + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix("") + except ValueError: + file = Path(file).stem + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) + + +def cuda_device_count() -> int: + """ + Get the number of NVIDIA GPUs available in the environment. + + Returns: + (int): The number of NVIDIA GPUs available. + """ + try: + # Run the nvidia-smi command and capture its output + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" + ) + + # Take the first line and strip any leading/trailing white space + first_line = output.strip().split("\n")[0] + + return int(first_line) + except (subprocess.CalledProcessError, FileNotFoundError, ValueError): + # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available + return 0 + + +def cuda_is_available() -> bool: + """ + Check if CUDA is available in the environment. + + Returns: + (bool): True if one or more NVIDIA GPUs are available, False otherwise. + """ + return cuda_device_count() > 0 + + +def is_sudo_available() -> bool: + """ + Check if the sudo command is available in the environment. + + Returns: + (bool): True if the sudo command is available, False otherwise. + """ + if WINDOWS: + return False + cmd = "sudo --version" + return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0 + + +# Run checks and define constants +check_python("3.8", hard=False, verbose=True) # check python version +check_torchvision() # check torch-torchvision compatibility +IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False) +IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") diff --git a/ultralytics/utils/dist.py b/ultralytics/utils/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7e5bbe4cebad54733f050937446a60f27d9577 --- /dev/null +++ b/ultralytics/utils/dist.py @@ -0,0 +1,72 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import os +import shutil +import socket +import sys +import tempfile + +from . import USER_CONFIG_DIR +from .torch_utils import TORCH_1_9 + + +def find_free_network_port() -> int: + """ + Finds a free port on localhost. + + It is useful in single-node training when we don't want to connect to a real main node but have to set the + `MASTER_PORT` environment variable. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] # port + + +def generate_ddp_file(trainer): + """Generates a DDP file and returns its file name.""" + module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1) + + content = f""" +# Ultralytics Multi-GPU training temp file (should be automatically deleted after use) +overrides = {vars(trainer.args)} + +if __name__ == "__main__": + from {module} import {name} + from ultralytics.utils import DEFAULT_CFG_DICT + + cfg = DEFAULT_CFG_DICT.copy() + cfg.update(save_dir='') # handle the extra key 'save_dir' + trainer = {name}(cfg=cfg, overrides=overrides) + trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}" + results = trainer.train() +""" + (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True) + with tempfile.NamedTemporaryFile( + prefix="_temp_", + suffix=f"{id(trainer)}.py", + mode="w+", + encoding="utf-8", + dir=USER_CONFIG_DIR / "DDP", + delete=False, + ) as file: + file.write(content) + return file.name + + +def generate_ddp_command(world_size, trainer): + """Generates and returns command for distributed training.""" + import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + + if not trainer.resume: + shutil.rmtree(trainer.save_dir) # remove the save_dir + file = generate_ddp_file(trainer) + dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch" + port = find_free_network_port() + cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file] + return cmd, file + + +def ddp_cleanup(trainer, file): + """Delete temp file if created.""" + if f"{id(trainer)}.py" in file: # if temp_file suffix in file + os.remove(file) diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py new file mode 100644 index 0000000000000000000000000000000000000000..8a029c091b7e575e0f14a02780b1cdf897b5716c --- /dev/null +++ b/ultralytics/utils/downloads.py @@ -0,0 +1,510 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import re +import shutil +import subprocess +from itertools import repeat +from multiprocessing.pool import ThreadPool +from pathlib import Path +from urllib import parse, request + +import requests +import torch + +from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file + +# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets +GITHUB_ASSETS_REPO = "ultralytics/assets" +GITHUB_ASSETS_NAMES = ( + [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")] + + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] + + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] + + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] + + [f"yolov8{k}-world.pt" for k in "smlx"] + + [f"yolov8{k}-worldv2.pt" for k in "smlx"] + + [f"yolov9{k}.pt" for k in "tsmce"] + + [f"yolov10{k}.pt" for k in "nsmblx"] + + [f"yolo_nas_{k}.pt" for k in "sml"] + + [f"sam_{k}.pt" for k in "bl"] + + [f"FastSAM-{k}.pt" for k in "sx"] + + [f"rtdetr-{k}.pt" for k in "lx"] + + ["mobile_sam.pt"] + + ["calibration_image_sample_data_20x128x128x3_float32.npy.zip"] +) +GITHUB_ASSETS_STEMS = [Path(k).stem for k in GITHUB_ASSETS_NAMES] + + +def is_url(url, check=False): + """ + Validates if the given string is a URL and optionally checks if the URL exists online. + + Args: + url (str): The string to be validated as a URL. + check (bool, optional): If True, performs an additional check to see if the URL exists online. + Defaults to False. + + Returns: + (bool): Returns True for a valid URL. If 'check' is True, also returns True if the URL exists online. + Returns False otherwise. + + Example: + ```python + valid = is_url("https://www.example.com") + ``` + """ + try: + url = str(url) + result = parse.urlparse(url) + assert all([result.scheme, result.netloc]) # check if is url + if check: + with request.urlopen(url) as response: + return response.getcode() == 200 # check if exists online + return True + except Exception: + return False + + +def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): + """ + Deletes all ".DS_store" files under a specified directory. + + Args: + path (str, optional): The directory path where the ".DS_store" files should be deleted. + files_to_delete (tuple): The files to be deleted. + + Example: + ```python + from ultralytics.utils.downloads import delete_dsstore + + delete_dsstore("path/to/dir") + ``` + + Note: + ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They + are hidden system files and can cause issues when transferring files between different operating systems. + """ + for file in files_to_delete: + matches = list(Path(path).rglob(file)) + LOGGER.info(f"Deleting {file} files: {matches}") + for f in matches: + f.unlink() + + +def zip_directory(directory, compress=True, exclude=(".DS_Store", "__MACOSX"), progress=True): + """ + Zips the contents of a directory, excluding files containing strings in the exclude list. The resulting zip file is + named after the directory and placed alongside it. + + Args: + directory (str | Path): The path to the directory to be zipped. + compress (bool): Whether to compress the files while zipping. Default is True. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Returns: + (Path): The path to the resulting zip file. + + Example: + ```python + from ultralytics.utils.downloads import zip_directory + + file = zip_directory("path/to/dir") + ``` + """ + from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile + + delete_dsstore(directory) + directory = Path(directory) + if not directory.is_dir(): + raise FileNotFoundError(f"Directory '{directory}' does not exist.") + + # Unzip with progress bar + files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] + zip_file = directory.with_suffix(".zip") + compression = ZIP_DEFLATED if compress else ZIP_STORED + with ZipFile(zip_file, "w", compression) as f: + for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): + f.write(file, file.relative_to(directory)) + + return zip_file # return path to zip file + + +def unzip_file(file, path=None, exclude=(".DS_Store", "__MACOSX"), exist_ok=False, progress=True): + """ + Unzips a *.zip file to the specified path, excluding files containing strings in the exclude list. + + If the zipfile does not contain a single top-level directory, the function will create a new + directory with the same name as the zipfile (without the extension) to extract its contents. + If a path is not provided, the function will use the parent directory of the zipfile as the default path. + + Args: + file (str | Path): The path to the zipfile to be extracted. + path (str, optional): The path to extract the zipfile to. Defaults to None. + exclude (tuple, optional): A tuple of filename strings to be excluded. Defaults to ('.DS_Store', '__MACOSX'). + exist_ok (bool, optional): Whether to overwrite existing contents if they exist. Defaults to False. + progress (bool, optional): Whether to display a progress bar. Defaults to True. + + Raises: + BadZipFile: If the provided file does not exist or is not a valid zipfile. + + Returns: + (Path): The path to the directory where the zipfile was extracted. + + Example: + ```python + from ultralytics.utils.downloads import unzip_file + + dir = unzip_file("path/to/file.zip") + ``` + """ + from zipfile import BadZipFile, ZipFile, is_zipfile + + if not (Path(file).exists() and is_zipfile(file)): + raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") + if path is None: + path = Path(file).parent # default path + + # Unzip the file contents + with ZipFile(file) as zipObj: + files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] + top_level_dirs = {Path(f).parts[0] for f in files} + + # Decide to unzip directly or unzip into a directory + unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/")) + if unzip_as_dir: + # Zip has 1 top-level directory + extract_path = path # i.e. ../datasets + path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/ + else: + # Zip has multiple files at top level + path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/ + + # Check if destination directory already exists and contains files + if path.exists() and any(path.iterdir()) and not exist_ok: + # If it exists and is not empty, return the path without unzipping + LOGGER.warning(f"WARNING ⚠️ Skipping {file} unzip as destination directory {path} is not empty.") + return path + + for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): + # Ensure the file is within the extract_path to avoid path traversal security vulnerability + if ".." in Path(f).parts: + LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") + continue + zipObj.extract(f, extract_path) + + return path # return unzip dir + + +def check_disk_space(url="https://ultralytics.com/assets/coco8.zip", path=Path.cwd(), sf=1.5, hard=True): + """ + Check if there is sufficient disk space to download and store a file. + + Args: + url (str, optional): The URL to the file. Defaults to 'https://ultralytics.com/assets/coco8.zip'. + path (str | Path, optional): The path or drive to check the available free space on. + sf (float, optional): Safety factor, the multiplier for the required free space. Defaults to 1.5. + hard (bool, optional): Whether to throw an error or not on insufficient disk space. Defaults to True. + + Returns: + (bool): True if there is sufficient disk space, False otherwise. + """ + try: + r = requests.head(url) # response + assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response + except Exception: + return True # requests issue, default to True + + # Check file size + gib = 1 << 30 # bytes per GiB + data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) + total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes + + if data * sf < free: + return True # sufficient space + + # Insufficient space + text = ( + f"WARNING ⚠️ Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " + f"Please free {data * sf - free:.1f} GB additional disk space and try again." + ) + if hard: + raise MemoryError(text) + LOGGER.warning(text) + return False + + +def get_google_drive_file_info(link): + """ + Retrieves the direct download link and filename for a shareable Google Drive file link. + + Args: + link (str): The shareable link of the Google Drive file. + + Returns: + (str): Direct download URL for the Google Drive file. + (str): Original filename of the Google Drive file. If filename extraction fails, returns None. + + Example: + ```python + from ultralytics.utils.downloads import get_google_drive_file_info + + link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link" + url, filename = get_google_drive_file_info(link) + ``` + """ + file_id = link.split("/d/")[1].split("/view")[0] + drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" + filename = None + + # Start session + with requests.Session() as session: + response = session.get(drive_url, stream=True) + if "quota exceeded" in str(response.content.lower()): + raise ConnectionError( + emojis( + f"❌ Google Drive file download quota exceeded. " + f"Please try again later or download this file manually at {link}." + ) + ) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + drive_url += f"&confirm={v}" # v is token + if cd := response.headers.get("content-disposition"): + filename = re.findall('filename="(.+)"', cd)[0] + return drive_url, filename + + +def safe_download( + url, + file=None, + dir=None, + unzip=True, + delete=False, + curl=False, + retry=3, + min_bytes=1e0, + exist_ok=False, + progress=True, +): + """ + Downloads files from a URL, with options for retrying, unzipping, and deleting the downloaded file. + + Args: + url (str): The URL of the file to be downloaded. + file (str, optional): The filename of the downloaded file. + If not provided, the file will be saved with the same name as the URL. + dir (str, optional): The directory to save the downloaded file. + If not provided, the file will be saved in the current working directory. + unzip (bool, optional): Whether to unzip the downloaded file. Default: True. + delete (bool, optional): Whether to delete the downloaded file after unzipping. Default: False. + curl (bool, optional): Whether to use curl command line tool for downloading. Default: False. + retry (int, optional): The number of times to retry the download in case of failure. Default: 3. + min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered + a successful download. Default: 1E0. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + progress (bool, optional): Whether to display a progress bar during the download. Default: True. + + Example: + ```python + from ultralytics.utils.downloads import safe_download + + link = "https://ultralytics.com/assets/bus.jpg" + path = safe_download(link) + ``` + """ + gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link + if gdrive: + url, file = get_google_drive_file_info(url) + + f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename + if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) + f = Path(url) # filename + elif not f.is_file(): # URL and file do not exist + uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url + "https://github.com/ultralytics/assets/releases/download/v0.0.0/", + "https://ultralytics.com/assets/", # assets alias + ) + desc = f"Downloading {uri} to '{f}'" + LOGGER.info(f"{desc}...") + f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing + check_disk_space(url, path=f.parent) + for i in range(retry + 1): + try: + if curl or i > 0: # curl download with retry, continue + s = "sS" * (not progress) # silent + r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode + assert r == 0, f"Curl return value {r}" + else: # urllib download + method = "torch" + if method == "torch": + torch.hub.download_url_to_file(url, f, progress=progress) + else: + with request.urlopen(url) as response, TQDM( + total=int(response.getheader("Content-Length", 0)), + desc=desc, + disable=not progress, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + with open(f, "wb") as f_opened: + for data in response: + f_opened.write(data) + pbar.update(len(data)) + + if f.exists(): + if f.stat().st_size > min_bytes: + break # success + f.unlink() # remove partial downloads + except Exception as e: + if i == 0 and not is_online(): + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e + elif i >= retry: + raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e + LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {uri}...") + + if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: + from zipfile import is_zipfile + + unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place + if is_zipfile(f): + unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip + elif f.suffix in {".tar", ".gz"}: + LOGGER.info(f"Unzipping {f} to {unzip_dir}...") + subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) + if delete: + f.unlink() # remove zip + return unzip_dir + + +def get_github_assets(repo="ultralytics/assets", version="latest", retry=False): + """ + Retrieve the specified version's tag and assets from a GitHub repository. If the version is not specified, the + function fetches the latest release assets. + + Args: + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + version (str, optional): The release version to fetch assets from. Defaults to 'latest'. + retry (bool, optional): Flag to retry the request in case of a failure. Defaults to False. + + Returns: + (tuple): A tuple containing the release tag and a list of asset names. + + Example: + ```python + tag, assets = get_github_assets(repo="ultralytics/assets", version="latest") + ``` + """ + if version != "latest": + version = f"tags/{version}" # i.e. tags/v6.2 + url = f"https://api.github.com/repos/{repo}/releases/{version}" + r = requests.get(url) # github api + if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded + r = requests.get(url) # try again + if r.status_code != 200: + LOGGER.warning(f"⚠️ GitHub assets check failure for {url}: {r.status_code} {r.reason}") + return "", [] + data = r.json() + return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolov8n.pt', 'yolov8s.pt', ...] + + +def attempt_download_asset(file, repo="ultralytics/assets", release="v8.3.0", **kwargs): + """ + Attempt to download a file from GitHub release assets if it is not found locally. The function checks for the file + locally first, then tries to download it from the specified GitHub repository release. + + Args: + file (str | Path): The filename or file path to be downloaded. + repo (str, optional): The GitHub repository in the format 'owner/repo'. Defaults to 'ultralytics/assets'. + release (str, optional): The specific release version to be downloaded. Defaults to 'v8.3.0'. + **kwargs (any): Additional keyword arguments for the download process. + + Returns: + (str): The path to the downloaded file. + + Example: + ```python + file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest") + ``` + """ + from ultralytics.utils import SETTINGS # scoped for circular import + + if 'v12' in str(file): + repo = "sunsmarterjie/yolov12" + release = "v1.0" + + # YOLOv3/5u updates + file = str(file) + file = checks.check_yolov5u_filename(file) + file = Path(file.strip().replace("'", "")) + if file.exists(): + return str(file) + elif (SETTINGS["weights_dir"] / file).exists(): + return str(SETTINGS["weights_dir"] / file) + else: + # URL specified + name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. + download_url = f"https://github.com/{repo}/releases/download" + if str(file).startswith(("http:/", "https:/")): # download + url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ + file = url2file(name) # parse authentication https://url.com/file.txt?auth... + if Path(file).is_file(): + LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists + else: + safe_download(url=url, file=file, min_bytes=1e5, **kwargs) + + elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: + safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) + + else: + tag, assets = get_github_assets(repo, release) + if not assets: + tag, assets = get_github_assets(repo) # latest release + if name in assets: + safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) + + return str(file) + + +def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3, exist_ok=False): + """ + Downloads files from specified URLs to a given directory. Supports concurrent downloads if multiple threads are + specified. + + Args: + url (str | list): The URL or list of URLs of the files to be downloaded. + dir (Path, optional): The directory where the files will be saved. Defaults to the current working directory. + unzip (bool, optional): Flag to unzip the files after downloading. Defaults to True. + delete (bool, optional): Flag to delete the zip files after extraction. Defaults to False. + curl (bool, optional): Flag to use curl for downloading. Defaults to False. + threads (int, optional): Number of threads to use for concurrent downloads. Defaults to 1. + retry (int, optional): Number of retries in case of download failure. Defaults to 3. + exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. Defaults to False. + + Example: + ```python + download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True) + ``` + """ + dir = Path(dir) + dir.mkdir(parents=True, exist_ok=True) # make directory + if threads > 1: + with ThreadPool(threads) as pool: + pool.map( + lambda x: safe_download( + url=x[0], + dir=x[1], + unzip=unzip, + delete=delete, + curl=curl, + retry=retry, + exist_ok=exist_ok, + progress=threads <= 1, + ), + zip(url, repeat(dir)), + ) + pool.close() + pool.join() + else: + for u in [url] if isinstance(url, (str, Path)) else url: + safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/ultralytics/utils/errors.py b/ultralytics/utils/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb7aae13f1acc7cfae790ac8aa2246f50485e5e --- /dev/null +++ b/ultralytics/utils/errors.py @@ -0,0 +1,22 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.utils import emojis + + +class HUBModelError(Exception): + """ + Custom exception class for handling errors related to model fetching in Ultralytics YOLO. + + This exception is raised when a requested model is not found or cannot be retrieved. + The message is also processed to include emojis for better user experience. + + Attributes: + message (str): The error message displayed when the exception is raised. + + Note: + The message is automatically processed through the 'emojis' function from the 'ultralytics.utils' package. + """ + + def __init__(self, message="Model not found. Please check model URL and try again."): + """Create an exception for when a model is not found.""" + super().__init__(emojis(message)) diff --git a/ultralytics/utils/files.py b/ultralytics/utils/files.py new file mode 100644 index 0000000000000000000000000000000000000000..0af6b0c23326a6232160e7227bb532013ffa9231 --- /dev/null +++ b/ultralytics/utils/files.py @@ -0,0 +1,222 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import glob +import os +import shutil +import tempfile +from contextlib import contextmanager +from datetime import datetime +from pathlib import Path + + +class WorkingDirectory(contextlib.ContextDecorator): + """ + A context manager and decorator for temporarily changing the working directory. + + This class allows for the temporary change of the working directory using a context manager or decorator. + It ensures that the original working directory is restored after the context or decorated function completes. + + Attributes: + dir (Path): The new directory to switch to. + cwd (Path): The original current working directory before the switch. + + Methods: + __enter__: Changes the current directory to the specified directory. + __exit__: Restores the original working directory on context exit. + + Examples: + Using as a context manager: + >>> with WorkingDirectory('/path/to/new/dir'): + >>> # Perform operations in the new directory + >>> pass + + Using as a decorator: + >>> @WorkingDirectory('/path/to/new/dir') + >>> def some_function(): + >>> # Perform operations in the new directory + >>> pass + """ + + def __init__(self, new_dir): + """Sets the working directory to 'new_dir' upon instantiation for use with context managers or decorators.""" + self.dir = new_dir # new dir + self.cwd = Path.cwd().resolve() # current dir + + def __enter__(self): + """Changes the current working directory to the specified directory upon entering the context.""" + os.chdir(self.dir) + + def __exit__(self, exc_type, exc_val, exc_tb): # noqa + """Restores the original working directory when exiting the context.""" + os.chdir(self.cwd) + + +@contextmanager +def spaces_in_path(path): + """ + Context manager to handle paths with spaces in their names. If a path contains spaces, it replaces them with + underscores, copies the file/directory to the new path, executes the context code block, then copies the + file/directory back to its original location. + + Args: + path (str | Path): The original path that may contain spaces. + + Yields: + (Path): Temporary path with spaces replaced by underscores if spaces were present, otherwise the original path. + + Examples: + Use the context manager to handle paths with spaces: + >>> from ultralytics.utils.files import spaces_in_path + >>> with spaces_in_path('/path/with spaces') as new_path: + >>> # Your code here + """ + # If path has spaces, replace them with underscores + if " " in str(path): + string = isinstance(path, str) # input type + path = Path(path) + + # Create a temporary directory and construct the new path + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) / path.name.replace(" ", "_") + + # Copy file/directory + if path.is_dir(): + # tmp_path.mkdir(parents=True, exist_ok=True) + shutil.copytree(path, tmp_path) + elif path.is_file(): + tmp_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(path, tmp_path) + + try: + # Yield the temporary path + yield str(tmp_path) if string else tmp_path + + finally: + # Copy file/directory back + if tmp_path.is_dir(): + shutil.copytree(tmp_path, path, dirs_exist_ok=True) + elif tmp_path.is_file(): + shutil.copy2(tmp_path, path) # Copy back the file + + else: + # If there are no spaces, just yield the original path + yield path + + +def increment_path(path, exist_ok=False, sep="", mkdir=False): + """ + Increments a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + + If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to + the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the + number will be appended directly to the end of the path. If `mkdir` is set to True, the path will be created as a + directory if it does not already exist. + + Args: + path (str | pathlib.Path): Path to increment. + exist_ok (bool): If True, the path will not be incremented and returned as-is. + sep (str): Separator to use between the path and the incrementation number. + mkdir (bool): Create a directory if it does not exist. + + Returns: + (pathlib.Path): Incremented path. + + Examples: + Increment a directory path: + >>> from pathlib import Path + >>> path = Path("runs/exp") + >>> new_path = increment_path(path) + >>> print(new_path) + runs/exp2 + + Increment a file path: + >>> path = Path("runs/exp/results.txt") + >>> new_path = increment_path(path) + >>> print(new_path) + runs/exp/results2.txt + """ + path = Path(path) # os-agnostic + if path.exists() and not exist_ok: + path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") + + # Method 1 + for n in range(2, 9999): + p = f"{path}{sep}{n}{suffix}" # increment path + if not os.path.exists(p): + break + path = Path(p) + + if mkdir: + path.mkdir(parents=True, exist_ok=True) # make directory + + return path + + +def file_age(path=__file__): + """Return days since the last modification of the specified file.""" + dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta + return dt.days # + dt.seconds / 86400 # fractional days + + +def file_date(path=__file__): + """Returns the file modification date in 'YYYY-M-D' format.""" + t = datetime.fromtimestamp(Path(path).stat().st_mtime) + return f"{t.year}-{t.month}-{t.day}" + + +def file_size(path): + """Returns the size of a file or directory in megabytes (MB).""" + if isinstance(path, (str, Path)): + mb = 1 << 20 # bytes to MiB (1024 ** 2) + path = Path(path) + if path.is_file(): + return path.stat().st_size / mb + elif path.is_dir(): + return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb + return 0.0 + + +def get_latest_run(search_dir="."): + """Returns the path to the most recent 'last.pt' file in the specified directory for resuming training.""" + last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) + return max(last_list, key=os.path.getctime) if last_list else "" + + +def update_models(model_names=("yolo11n.pt",), source_dir=Path("."), update_names=False): + """ + Updates and re-saves specified YOLO models in an 'updated_models' subdirectory. + + Args: + model_names (Tuple[str, ...]): Model filenames to update. + source_dir (Path): Directory containing models and target subdirectory. + update_names (bool): Update model names from a data YAML. + + Examples: + Update specified YOLO models and save them in 'updated_models' subdirectory: + >>> from ultralytics.utils.files import update_models + >>> model_names = ("yolo11n.pt", "yolov8s.pt") + >>> update_models(model_names, source_dir=Path("/models"), update_names=True) + """ + from ultralytics import YOLO + from ultralytics.nn.autobackend import default_class_names + + target_dir = source_dir / "updated_models" + target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists + + for model_name in model_names: + model_path = source_dir / model_name + print(f"Loading model from {model_path}") + + # Load model + model = YOLO(model_path) + model.half() + if update_names: # update model names from a dataset YAML + model.model.names = default_class_names("coco8.yaml") + + # Define new save path + save_path = target_dir / model_name + + # Save model using model.save() + print(f"Re-saving {model_name} model to {save_path}") + model.save(save_path) diff --git a/ultralytics/utils/instance.py b/ultralytics/utils/instance.py new file mode 100644 index 0000000000000000000000000000000000000000..e92a9614129755189709e3ce47376827dbeffa21 --- /dev/null +++ b/ultralytics/utils/instance.py @@ -0,0 +1,429 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from collections import abc +from itertools import repeat +from numbers import Number +from typing import List + +import numpy as np + +from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh + + +def _ntuple(n): + """From PyTorch internals.""" + + def parse(x): + """Parse bounding boxes format between XYWH and LTWH.""" + return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) +to_4tuple = _ntuple(4) + +# `xyxy` means left top and right bottom +# `xywh` means center x, center y and width, height(YOLO format) +# `ltwh` means left top and width, height(COCO format) +_formats = ["xyxy", "xywh", "ltwh"] + +__all__ = ("Bboxes", "Instances") # tuple or list + + +class Bboxes: + """ + A class for handling bounding boxes. + + The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh'. + Bounding box data should be provided in numpy arrays. + + Attributes: + bboxes (numpy.ndarray): The bounding boxes stored in a 2D numpy array. + format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh'). + + Note: + This class does not handle normalization or denormalization of bounding boxes. + """ + + def __init__(self, bboxes, format="xyxy") -> None: + """Initializes the Bboxes class with bounding box data in a specified format.""" + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" + bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes + assert bboxes.ndim == 2 + assert bboxes.shape[1] == 4 + self.bboxes = bboxes + self.format = format + # self.normalized = normalized + + def convert(self, format): + """Converts bounding box format from one type to another.""" + assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" + if self.format == format: + return + elif self.format == "xyxy": + func = xyxy2xywh if format == "xywh" else xyxy2ltwh + elif self.format == "xywh": + func = xywh2xyxy if format == "xyxy" else xywh2ltwh + else: + func = ltwh2xyxy if format == "xyxy" else ltwh2xywh + self.bboxes = func(self.bboxes) + self.format = format + + def areas(self): + """Return box areas.""" + return ( + (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy + if self.format == "xyxy" + else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh + ) + + # def denormalize(self, w, h): + # if not self.normalized: + # return + # assert (self.bboxes <= 1.0).all() + # self.bboxes[:, 0::2] *= w + # self.bboxes[:, 1::2] *= h + # self.normalized = False + # + # def normalize(self, w, h): + # if self.normalized: + # return + # assert (self.bboxes > 1.0).any() + # self.bboxes[:, 0::2] /= w + # self.bboxes[:, 1::2] /= h + # self.normalized = True + + def mul(self, scale): + """ + Multiply bounding box coordinates by scale factor(s). + + Args: + scale (int | tuple | list): Scale factor(s) for four coordinates. + If int, the same scale is applied to all coordinates. + """ + if isinstance(scale, Number): + scale = to_4tuple(scale) + assert isinstance(scale, (tuple, list)) + assert len(scale) == 4 + self.bboxes[:, 0] *= scale[0] + self.bboxes[:, 1] *= scale[1] + self.bboxes[:, 2] *= scale[2] + self.bboxes[:, 3] *= scale[3] + + def add(self, offset): + """ + Add offset to bounding box coordinates. + + Args: + offset (int | tuple | list): Offset(s) for four coordinates. + If int, the same offset is applied to all coordinates. + """ + if isinstance(offset, Number): + offset = to_4tuple(offset) + assert isinstance(offset, (tuple, list)) + assert len(offset) == 4 + self.bboxes[:, 0] += offset[0] + self.bboxes[:, 1] += offset[1] + self.bboxes[:, 2] += offset[2] + self.bboxes[:, 3] += offset[3] + + def __len__(self): + """Return the number of boxes.""" + return len(self.bboxes) + + @classmethod + def concatenate(cls, boxes_list: List["Bboxes"], axis=0) -> "Bboxes": + """ + Concatenate a list of Bboxes objects into a single Bboxes object. + + Args: + boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate. + axis (int, optional): The axis along which to concatenate the bounding boxes. + Defaults to 0. + + Returns: + Bboxes: A new Bboxes object containing the concatenated bounding boxes. + + Note: + The input should be a list or tuple of Bboxes objects. + """ + assert isinstance(boxes_list, (list, tuple)) + if not boxes_list: + return cls(np.empty(0)) + assert all(isinstance(box, Bboxes) for box in boxes_list) + + if len(boxes_list) == 1: + return boxes_list[0] + return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis)) + + def __getitem__(self, index) -> "Bboxes": + """ + Retrieve a specific bounding box or a set of bounding boxes using indexing. + + Args: + index (int, slice, or np.ndarray): The index, slice, or boolean array to select + the desired bounding boxes. + + Returns: + Bboxes: A new Bboxes object containing the selected bounding boxes. + + Raises: + AssertionError: If the indexed bounding boxes do not form a 2-dimensional matrix. + + Note: + When using boolean indexing, make sure to provide a boolean array with the same + length as the number of bounding boxes. + """ + if isinstance(index, int): + return Bboxes(self.bboxes[index].reshape(1, -1)) + b = self.bboxes[index] + assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!" + return Bboxes(b) + + +class Instances: + """ + Container for bounding boxes, segments, and keypoints of detected objects in an image. + + Attributes: + _bboxes (Bboxes): Internal object for handling bounding box operations. + keypoints (ndarray): keypoints(x, y, visible) with shape [N, 17, 3]. Default is None. + normalized (bool): Flag indicating whether the bounding box coordinates are normalized. + segments (ndarray): Segments array with shape [N, 1000, 2] after resampling. + + Args: + bboxes (ndarray): An array of bounding boxes with shape [N, 4]. + segments (list | ndarray, optional): A list or array of object segments. Default is None. + keypoints (ndarray, optional): An array of keypoints with shape [N, 17, 3]. Default is None. + bbox_format (str, optional): The format of bounding boxes ('xywh' or 'xyxy'). Default is 'xywh'. + normalized (bool, optional): Whether the bounding box coordinates are normalized. Default is True. + + Examples: + ```python + # Create an Instances object + instances = Instances( + bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]), + segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])], + keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]), + ) + ``` + + Note: + The bounding box format is either 'xywh' or 'xyxy', and is determined by the `bbox_format` argument. + This class does not perform input validation, and it assumes the inputs are well-formed. + """ + + def __init__(self, bboxes, segments=None, keypoints=None, bbox_format="xywh", normalized=True) -> None: + """ + Initialize the object with bounding boxes, segments, and keypoints. + + Args: + bboxes (np.ndarray): Bounding boxes, shape [N, 4]. + segments (list | np.ndarray, optional): Segmentation masks. Defaults to None. + keypoints (np.ndarray, optional): Keypoints, shape [N, 17, 3] and format (x, y, visible). Defaults to None. + bbox_format (str, optional): Format of bboxes. Defaults to "xywh". + normalized (bool, optional): Whether the coordinates are normalized. Defaults to True. + """ + self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format) + self.keypoints = keypoints + self.normalized = normalized + self.segments = segments + + def convert_bbox(self, format): + """Convert bounding box format.""" + self._bboxes.convert(format=format) + + @property + def bbox_areas(self): + """Calculate the area of bounding boxes.""" + return self._bboxes.areas() + + def scale(self, scale_w, scale_h, bbox_only=False): + """Similar to denormalize func but without normalized sign.""" + self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h)) + if bbox_only: + return + self.segments[..., 0] *= scale_w + self.segments[..., 1] *= scale_h + if self.keypoints is not None: + self.keypoints[..., 0] *= scale_w + self.keypoints[..., 1] *= scale_h + + def denormalize(self, w, h): + """Denormalizes boxes, segments, and keypoints from normalized coordinates.""" + if not self.normalized: + return + self._bboxes.mul(scale=(w, h, w, h)) + self.segments[..., 0] *= w + self.segments[..., 1] *= h + if self.keypoints is not None: + self.keypoints[..., 0] *= w + self.keypoints[..., 1] *= h + self.normalized = False + + def normalize(self, w, h): + """Normalize bounding boxes, segments, and keypoints to image dimensions.""" + if self.normalized: + return + self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h)) + self.segments[..., 0] /= w + self.segments[..., 1] /= h + if self.keypoints is not None: + self.keypoints[..., 0] /= w + self.keypoints[..., 1] /= h + self.normalized = True + + def add_padding(self, padw, padh): + """Handle rect and mosaic situation.""" + assert not self.normalized, "you should add padding with absolute coordinates." + self._bboxes.add(offset=(padw, padh, padw, padh)) + self.segments[..., 0] += padw + self.segments[..., 1] += padh + if self.keypoints is not None: + self.keypoints[..., 0] += padw + self.keypoints[..., 1] += padh + + def __getitem__(self, index) -> "Instances": + """ + Retrieve a specific instance or a set of instances using indexing. + + Args: + index (int, slice, or np.ndarray): The index, slice, or boolean array to select + the desired instances. + + Returns: + Instances: A new Instances object containing the selected bounding boxes, + segments, and keypoints if present. + + Note: + When using boolean indexing, make sure to provide a boolean array with the same + length as the number of instances. + """ + segments = self.segments[index] if len(self.segments) else self.segments + keypoints = self.keypoints[index] if self.keypoints is not None else None + bboxes = self.bboxes[index] + bbox_format = self._bboxes.format + return Instances( + bboxes=bboxes, + segments=segments, + keypoints=keypoints, + bbox_format=bbox_format, + normalized=self.normalized, + ) + + def flipud(self, h): + """Flips the coordinates of bounding boxes, segments, and keypoints vertically.""" + if self._bboxes.format == "xyxy": + y1 = self.bboxes[:, 1].copy() + y2 = self.bboxes[:, 3].copy() + self.bboxes[:, 1] = h - y2 + self.bboxes[:, 3] = h - y1 + else: + self.bboxes[:, 1] = h - self.bboxes[:, 1] + self.segments[..., 1] = h - self.segments[..., 1] + if self.keypoints is not None: + self.keypoints[..., 1] = h - self.keypoints[..., 1] + + def fliplr(self, w): + """Reverses the order of the bounding boxes and segments horizontally.""" + if self._bboxes.format == "xyxy": + x1 = self.bboxes[:, 0].copy() + x2 = self.bboxes[:, 2].copy() + self.bboxes[:, 0] = w - x2 + self.bboxes[:, 2] = w - x1 + else: + self.bboxes[:, 0] = w - self.bboxes[:, 0] + self.segments[..., 0] = w - self.segments[..., 0] + if self.keypoints is not None: + self.keypoints[..., 0] = w - self.keypoints[..., 0] + + def clip(self, w, h): + """Clips bounding boxes, segments, and keypoints values to stay within image boundaries.""" + ori_format = self._bboxes.format + self.convert_bbox(format="xyxy") + self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) + self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) + if ori_format != "xyxy": + self.convert_bbox(format=ori_format) + self.segments[..., 0] = self.segments[..., 0].clip(0, w) + self.segments[..., 1] = self.segments[..., 1].clip(0, h) + if self.keypoints is not None: + self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w) + self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) + + def remove_zero_area_boxes(self): + """Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.""" + good = self.bbox_areas > 0 + if not all(good): + self._bboxes = self._bboxes[good] + if len(self.segments): + self.segments = self.segments[good] + if self.keypoints is not None: + self.keypoints = self.keypoints[good] + return good + + def update(self, bboxes, segments=None, keypoints=None): + """Updates instance variables.""" + self._bboxes = Bboxes(bboxes, format=self._bboxes.format) + if segments is not None: + self.segments = segments + if keypoints is not None: + self.keypoints = keypoints + + def __len__(self): + """Return the length of the instance list.""" + return len(self.bboxes) + + @classmethod + def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances": + """ + Concatenates a list of Instances objects into a single Instances object. + + Args: + instances_list (List[Instances]): A list of Instances objects to concatenate. + axis (int, optional): The axis along which the arrays will be concatenated. Defaults to 0. + + Returns: + Instances: A new Instances object containing the concatenated bounding boxes, + segments, and keypoints if present. + + Note: + The `Instances` objects in the list should have the same properties, such as + the format of the bounding boxes, whether keypoints are present, and if the + coordinates are normalized. + """ + assert isinstance(instances_list, (list, tuple)) + if not instances_list: + return cls(np.empty(0)) + assert all(isinstance(instance, Instances) for instance in instances_list) + + if len(instances_list) == 1: + return instances_list[0] + + use_keypoint = instances_list[0].keypoints is not None + bbox_format = instances_list[0]._bboxes.format + normalized = instances_list[0].normalized + + cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) + seg_len = [b.segments.shape[1] for b in instances_list] + if len(set(seg_len)) > 1: # resample segments if there's different length + max_len = max(seg_len) + cat_segments = np.concatenate( + [ + resample_segments(list(b.segments), max_len) + if len(b.segments) + else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments + for b in instances_list + ], + axis=axis, + ) + else: + cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) + cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None + return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized) + + @property + def bboxes(self): + """Return bounding boxes.""" + return self._bboxes.bboxes diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f42a57787e978fb0ba2439d4707ab977b8506b3a --- /dev/null +++ b/ultralytics/utils/loss.py @@ -0,0 +1,743 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils.metrics import OKS_SIGMA +from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh +from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors +from ultralytics.utils.torch_utils import autocast + +from .metrics import bbox_iou, probiou +from .tal import bbox2dist + + +class VarifocalLoss(nn.Module): + """ + Varifocal loss by Zhang et al. + + https://arxiv.org/abs/2008.13367. + """ + + def __init__(self): + """Initialize the VarifocalLoss class.""" + super().__init__() + + @staticmethod + def forward(pred_score, gt_score, label, alpha=0.75, gamma=2.0): + """Computes varfocal loss.""" + weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label + with autocast(enabled=False): + loss = ( + (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) + .mean(1) + .sum() + ) + return loss + + +class FocalLoss(nn.Module): + """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" + + def __init__(self): + """Initializer for FocalLoss class with no parameters.""" + super().__init__() + + @staticmethod + def forward(pred, label, gamma=1.5, alpha=0.25): + """Calculates and updates confusion matrix for object detection/classification tasks.""" + loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none") + # p_t = torch.exp(-loss) + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability + + # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = pred.sigmoid() # prob from logits + p_t = label * pred_prob + (1 - label) * (1 - pred_prob) + modulating_factor = (1.0 - p_t) ** gamma + loss *= modulating_factor + if alpha > 0: + alpha_factor = label * alpha + (1 - label) * (1 - alpha) + loss *= alpha_factor + return loss.mean(1).sum() + + +class DFLoss(nn.Module): + """Criterion class for computing DFL losses during training.""" + + def __init__(self, reg_max=16) -> None: + """Initialize the DFL module.""" + super().__init__() + self.reg_max = reg_max + + def __call__(self, pred_dist, target): + """ + Return sum of left and right DFL losses. + + Distribution Focal Loss (DFL) proposed in Generalized Focal Loss + https://ieeexplore.ieee.org/document/9792391 + """ + target = target.clamp_(0, self.reg_max - 1 - 0.01) + tl = target.long() # target left + tr = tl + 1 # target right + wl = tr - target # weight left + wr = 1 - wl # weight right + return ( + F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl + + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr + ).mean(-1, keepdim=True) + + +class BboxLoss(nn.Module): + """Criterion class for computing training losses during training.""" + + def __init__(self, reg_max=16): + """Initialize the BboxLoss module with regularization maximum and DFL settings.""" + super().__init__() + self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None + + def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): + """IoU loss.""" + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) + iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) + loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum + + # DFL loss + if self.dfl_loss: + target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1) + loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight + loss_dfl = loss_dfl.sum() / target_scores_sum + else: + loss_dfl = torch.tensor(0.0).to(pred_dist.device) + + return loss_iou, loss_dfl + + +class RotatedBboxLoss(BboxLoss): + """Criterion class for computing training losses during training.""" + + def __init__(self, reg_max): + """Initialize the BboxLoss module with regularization maximum and DFL settings.""" + super().__init__(reg_max) + + def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask): + """IoU loss.""" + weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) + iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask]) + loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum + + # DFL loss + if self.dfl_loss: + target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1) + loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight + loss_dfl = loss_dfl.sum() / target_scores_sum + else: + loss_dfl = torch.tensor(0.0).to(pred_dist.device) + + return loss_iou, loss_dfl + + +class KeypointLoss(nn.Module): + """Criterion class for computing training losses.""" + + def __init__(self, sigmas) -> None: + """Initialize the KeypointLoss class.""" + super().__init__() + self.sigmas = sigmas + + def forward(self, pred_kpts, gt_kpts, kpt_mask, area): + """Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints.""" + d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2) + kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9) + # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula + e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval + return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean() + + +class v8DetectionLoss: + """Criterion class for computing training losses.""" + + def __init__(self, model, tal_topk=10): # model must be de-paralleled + """Initializes v8DetectionLoss with the model, defining model-related properties and BCE loss function.""" + device = next(model.parameters()).device # get model device + h = model.args # hyperparameters + + m = model.model[-1] # Detect() module + self.bce = nn.BCEWithLogitsLoss(reduction="none") + self.hyp = h + self.stride = m.stride # model strides + self.nc = m.nc # number of classes + self.no = m.nc + m.reg_max * 4 + self.reg_max = m.reg_max + self.device = device + + self.use_dfl = m.reg_max > 1 + + self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = BboxLoss(m.reg_max).to(device) + self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) + + def preprocess(self, targets, batch_size, scale_tensor): + """Preprocesses the target counts and matches with the input batch size to output a tensor.""" + nl, ne = targets.shape + if nl == 0: + out = torch.zeros(batch_size, 0, ne - 1, device=self.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device) + for j in range(batch_size): + matches = i == j + if n := matches.sum(): + out[j, :n] = targets[matches, 1:] + out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) + return out + + def bbox_decode(self, anchor_points, pred_dist): + """Decode predicted object bounding box coordinates from anchor points and distribution.""" + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) + return dist2bbox(pred_dist, anchor_points, xywh=False) + + def __call__(self, preds, batch): + """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" + loss = torch.zeros(3, device=self.device) # box, cls, dfl + feats = preds[1] if isinstance(preds, tuple) else preds + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + batch_size = pred_scores.shape[0] + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # Targets + targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1) + # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2 + + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2, + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # Bbox loss + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.cls # cls gain + loss[2] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + +class v8SegmentationLoss(v8DetectionLoss): + """Criterion class for computing training losses.""" + + def __init__(self, model): # model must be de-paralleled + """Initializes the v8SegmentationLoss class, taking a de-paralleled model as argument.""" + super().__init__(model) + self.overlap = model.args.overlap_mask + + def __call__(self, preds, batch): + """Calculate and return the loss for the YOLO model.""" + loss = torch.zeros(4, device=self.device) # box, cls, dfl + feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] + batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + # B, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_masks = pred_masks.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # Targets + try: + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + except RuntimeError as e: + raise TypeError( + "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n" + "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolov8n-seg.pt data=coco8.yaml'.\nVerify your dataset is a " + "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help." + ) from e + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + + _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + if fg_mask.sum(): + # Bbox loss + loss[0], loss[3] = self.bbox_loss( + pred_distri, + pred_bboxes, + anchor_points, + target_bboxes / stride_tensor, + target_scores, + target_scores_sum, + fg_mask, + ) + # Masks loss + masks = batch["masks"].to(self.device).float() + if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample + masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] + + loss[1] = self.calculate_segmentation_loss( + fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap + ) + + # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.box # seg gain + loss[2] *= self.hyp.cls # cls gain + loss[3] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + @staticmethod + def single_mask_loss( + gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor + ) -> torch.Tensor: + """ + Compute the instance segmentation loss for a single image. + + Args: + gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects. + pred (torch.Tensor): Predicted mask coefficients of shape (n, 32). + proto (torch.Tensor): Prototype masks of shape (32, H, W). + xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4). + area (torch.Tensor): Area of each ground truth bounding box of shape (n,). + + Returns: + (torch.Tensor): The calculated mask loss for a single image. + + Notes: + The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the + predicted masks from the prototype masks and predicted mask coefficients. + """ + pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) + loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") + return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum() + + def calculate_segmentation_loss( + self, + fg_mask: torch.Tensor, + masks: torch.Tensor, + target_gt_idx: torch.Tensor, + target_bboxes: torch.Tensor, + batch_idx: torch.Tensor, + proto: torch.Tensor, + pred_masks: torch.Tensor, + imgsz: torch.Tensor, + overlap: bool, + ) -> torch.Tensor: + """ + Calculate the loss for instance segmentation. + + Args: + fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. + masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). + target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). + target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). + batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). + proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). + pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). + imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). + overlap (bool): Whether the masks in `masks` tensor overlap. + + Returns: + (torch.Tensor): The calculated loss for instance segmentation. + + Notes: + The batch loss can be computed for improved speed at higher memory usage. + For example, pred_mask can be computed as follows: + pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) + """ + _, _, mask_h, mask_w = proto.shape + loss = 0 + + # Normalize to 0-1 + target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]] + + # Areas of target bboxes + marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2) + + # Normalize to mask size + mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device) + + for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)): + fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i + if fg_mask_i.any(): + mask_idx = target_gt_idx_i[fg_mask_i] + if overlap: + gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) + gt_mask = gt_mask.float() + else: + gt_mask = masks[batch_idx.view(-1) == i][mask_idx] + + loss += self.single_mask_loss( + gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i] + ) + + # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove + else: + loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss + + return loss / fg_mask.sum() + + +class v8PoseLoss(v8DetectionLoss): + """Criterion class for computing training losses.""" + + def __init__(self, model): # model must be de-paralleled + """Initializes v8PoseLoss with model, sets keypoint variables and declares a keypoint loss instance.""" + super().__init__(model) + self.kpt_shape = model.model[-1].kpt_shape + self.bce_pose = nn.BCEWithLogitsLoss() + is_pose = self.kpt_shape == [17, 3] + nkpt = self.kpt_shape[0] # number of keypoints + sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt + self.keypoint_loss = KeypointLoss(sigmas=sigmas) + + def __call__(self, preds, batch): + """Calculate the total loss and detach it.""" + loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility + feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + # B, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_kpts = pred_kpts.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # Targets + batch_size = pred_scores.shape[0] + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) + pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3) + + _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # Bbox loss + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss[0], loss[4] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + keypoints = batch["keypoints"].to(self.device).float().clone() + keypoints[..., 0] *= imgsz[1] + keypoints[..., 1] *= imgsz[0] + + loss[1], loss[2] = self.calculate_keypoints_loss( + fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ) + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.pose # pose gain + loss[2] *= self.hyp.kobj # kobj gain + loss[3] *= self.hyp.cls # cls gain + loss[4] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + @staticmethod + def kpts_decode(anchor_points, pred_kpts): + """Decodes predicted keypoints to image coordinates.""" + y = pred_kpts.clone() + y[..., :2] *= 2.0 + y[..., 0] += anchor_points[:, [0]] - 0.5 + y[..., 1] += anchor_points[:, [1]] - 0.5 + return y + + def calculate_keypoints_loss( + self, masks, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts + ): + """ + Calculate the keypoints loss for the model. + + This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is + based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is + a binary classification loss that classifies whether a keypoint is present or not. + + Args: + masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors). + target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors). + keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim). + batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1). + stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1). + target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4). + pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim). + + Returns: + kpts_loss (torch.Tensor): The keypoints loss. + kpts_obj_loss (torch.Tensor): The keypoints object loss. + """ + batch_idx = batch_idx.flatten() + batch_size = len(masks) + + # Find the maximum number of keypoints in a single image + max_kpts = torch.unique(batch_idx, return_counts=True)[1].max() + + # Create a tensor to hold batched keypoints + batched_keypoints = torch.zeros( + (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device + ) + + # TODO: any idea how to vectorize this? + # Fill batched_keypoints with keypoints based on batch_idx + for i in range(batch_size): + keypoints_i = keypoints[batch_idx == i] + batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i + + # Expand dimensions of target_gt_idx to match the shape of batched_keypoints + target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1) + + # Use target_gt_idx_expanded to select keypoints from batched_keypoints + selected_keypoints = batched_keypoints.gather( + 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]) + ) + + # Divide coordinates by stride + selected_keypoints /= stride_tensor.view(1, -1, 1, 1) + + kpts_loss = 0 + kpts_obj_loss = 0 + + if masks.any(): + gt_kpt = selected_keypoints[masks] + area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True) + pred_kpt = pred_kpts[masks] + kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True) + kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss + + if pred_kpt.shape[-1] == 3: + kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss + + return kpts_loss, kpts_obj_loss + + +class v8ClassificationLoss: + """Criterion class for computing training losses.""" + + def __call__(self, preds, batch): + """Compute the classification loss between predictions and true labels.""" + preds = preds[1] if isinstance(preds, (list, tuple)) else preds + loss = F.cross_entropy(preds, batch["cls"], reduction="mean") + loss_items = loss.detach() + return loss, loss_items + + +class v8OBBLoss(v8DetectionLoss): + """Calculates losses for object detection, classification, and box distribution in rotated YOLO models.""" + + def __init__(self, model): + """Initializes v8OBBLoss with model, assigner, and rotated bbox loss; note model must be de-paralleled.""" + super().__init__(model) + self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) + self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device) + + def preprocess(self, targets, batch_size, scale_tensor): + """Preprocesses the target counts and matches with the input batch size to output a tensor.""" + if targets.shape[0] == 0: + out = torch.zeros(batch_size, 0, 6, device=self.device) + else: + i = targets[:, 0] # image index + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + out = torch.zeros(batch_size, counts.max(), 6, device=self.device) + for j in range(batch_size): + matches = i == j + if n := matches.sum(): + bboxes = targets[matches, 2:] + bboxes[..., :4].mul_(scale_tensor) + out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1) + return out + + def __call__(self, preds, batch): + """Calculate and return the loss for the YOLO model.""" + loss = torch.zeros(3, device=self.device) # box, cls, dfl + feats, pred_angle = preds if isinstance(preds[0], list) else preds[1] + batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width + pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( + (self.reg_max * 4, self.nc), 1 + ) + + # b, grids, .. + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + pred_angle = pred_angle.permute(0, 2, 1).contiguous() + + dtype = pred_scores.dtype + imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) + anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) + + # targets + try: + batch_idx = batch["batch_idx"].view(-1, 1) + targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1) + rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item() + targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training + targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) + gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) + except RuntimeError as e: + raise TypeError( + "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n" + "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, " + "i.e. 'yolo train model=yolov8n-obb.pt data=dota8.yaml'.\nVerify your dataset is a " + "correctly formatted 'OBB' dataset using 'data=dota8.yaml' " + "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help." + ) from e + + # Pboxes + pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4) + + bboxes_for_assigner = pred_bboxes.clone().detach() + # Only the first four elements need to be scaled + bboxes_for_assigner[..., :4] *= stride_tensor + _, target_bboxes, target_scores, fg_mask, _ = self.assigner( + pred_scores.detach().sigmoid(), + bboxes_for_assigner.type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_scores_sum = max(target_scores.sum(), 1) + + # Cls loss + # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way + loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE + + # Bbox loss + if fg_mask.sum(): + target_bboxes[..., :4] /= stride_tensor + loss[0], loss[2] = self.bbox_loss( + pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask + ) + else: + loss[0] += (pred_angle * 0).sum() + + loss[0] *= self.hyp.box # box gain + loss[1] *= self.hyp.cls # cls gain + loss[2] *= self.hyp.dfl # dfl gain + + return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) + + def bbox_decode(self, anchor_points, pred_dist, pred_angle): + """ + Decode predicted object bounding box coordinates from anchor points and distribution. + + Args: + anchor_points (torch.Tensor): Anchor points, (h*w, 2). + pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4). + pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1). + + Returns: + (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5). + """ + if self.use_dfl: + b, a, c = pred_dist.shape # batch, anchors, channels + pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) + return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1) + + +class E2EDetectLoss: + """Criterion class for computing training losses.""" + + def __init__(self, model): + """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model.""" + self.one2many = v8DetectionLoss(model, tal_topk=10) + self.one2one = v8DetectionLoss(model, tal_topk=1) + + def __call__(self, preds, batch): + """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" + preds = preds[1] if isinstance(preds, tuple) else preds + one2many = preds["one2many"] + loss_one2many = self.one2many(one2many, batch) + one2one = preds["one2one"] + loss_one2one = self.one2one(one2one, batch) + return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1] diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5821cbbb158abd08c6b3d2a038c04c79a87cb3 --- /dev/null +++ b/ultralytics/utils/metrics.py @@ -0,0 +1,1294 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Model validation metrics.""" + +import math +import warnings +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +from ultralytics.utils import LOGGER, SimpleClass, TryExcept, plt_settings + +OKS_SIGMA = ( + np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) + / 10.0 +) + + +def bbox_ioa(box1, box2, iou=False, eps=1e-7): + """ + Calculate the intersection over box2 area given box1 and box2. Boxes are in x1y1x2y2 format. + + Args: + box1 (np.ndarray): A numpy array of shape (n, 4) representing n bounding boxes. + box2 (np.ndarray): A numpy array of shape (m, 4) representing m bounding boxes. + iou (bool): Calculate the standard IoU if True else return inter_area/box2_area. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (np.ndarray): A numpy array of shape (n, m) representing the intersection over box2 area. + """ + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1.T + b2_x1, b2_y1, b2_x2, b2_y2 = box2.T + + # Intersection area + inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * ( + np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1) + ).clip(0) + + # Box2 area + area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + if iou: + box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) + area = area + box1_area[:, None] - inter_area + + # Intersection over box2 area + return inter_area / (area + eps) + + +def box_iou(box1, box2, eps=1e-7): + """ + Calculate intersection-over-union (IoU) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py. + + Args: + box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes. + box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2. + """ + # NOTE: Need .float() to get accurate iou values + # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2) + inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2) + + # IoU = inter / (area1 + area2 - inter) + return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps) + + +def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7): + """ + Calculates the Intersection over Union (IoU) between bounding boxes. + + This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. + For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). + Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`, + or (x1, y1, x2, y2) if `xywh=False`. + + Args: + box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4. + box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4. + xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in + (x1, y1, x2, y2) format. Defaults to True. + GIoU (bool, optional): If True, calculate Generalized IoU. Defaults to False. + DIoU (bool, optional): If True, calculate Distance IoU. Defaults to False. + CIoU (bool, optional): If True, calculate Complete IoU. Defaults to False. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags. + """ + # Get the coordinates of bounding boxes + if xywh: # transform from xywh to xyxy + (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) + w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 + b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ + b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ + else: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) + b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + # Intersection area + inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * ( + b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) + ).clamp_(0) + + # Union Area + union = w1 * h1 + w2 * h2 - inter + eps + + # IoU + iou = inter / union + if CIoU or DIoU or GIoU: + cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width + ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared + rho2 = ( + (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2) + ) / 4 # center dist**2 + if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - (rho2 / c2 + v * alpha) # CIoU + return iou - rho2 / c2 # DIoU + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf + return iou # IoU + + +def mask_iou(mask1, mask2, eps=1e-7): + """ + Calculate masks IoU. + + Args: + mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the + product of image width and height. + mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the + product of image width and height. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): A tensor of shape (N, M) representing masks IoU. + """ + intersection = torch.matmul(mask1, mask2.T).clamp_(0) + union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection + return intersection / (union + eps) + + +def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7): + """ + Calculate Object Keypoint Similarity (OKS). + + Args: + kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints. + kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints. + area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth. + sigma (list): A list containing 17 values representing keypoint scales. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities. + """ + d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17) + sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, ) + kpt_mask = kpt1[..., 2] != 0 # (N, 17) + e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval + # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula + return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps) + + +def _get_covariance_matrix(boxes): + """ + Generating covariance matrix from obbs. + + Args: + boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format. + + Returns: + (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes. + """ + # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here. + gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1) + a, b, c = gbbs.split(1, dim=-1) + cos = c.cos() + sin = c.sin() + cos2 = cos.pow(2) + sin2 = sin.pow(2) + return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin + + +def probiou(obb1, obb2, CIoU=False, eps=1e-7): + """ + Calculate probabilistic IoU between oriented bounding boxes. + + Implements the algorithm from https://arxiv.org/pdf/2106.06072v1.pdf. + + Args: + obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr. + obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr. + CIoU (bool, optional): If True, calculate CIoU. Defaults to False. + eps (float, optional): Small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): OBB similarities, shape (N,). + + Note: + OBB format: [center_x, center_y, width, height, rotation_angle]. + If CIoU is True, returns CIoU instead of IoU. + """ + x1, y1 = obb1[..., :2].split(1, dim=-1) + x2, y2 = obb2[..., :2].split(1, dim=-1) + a1, b1, c1 = _get_covariance_matrix(obb1) + a2, b2, c2 = _get_covariance_matrix(obb2) + + t1 = ( + ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) + ) * 0.25 + t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) + / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) + + eps + ).log() * 0.5 + bd = (t1 + t2 + t3).clamp(eps, 100.0) + hd = (1.0 - (-bd).exp() + eps).sqrt() + iou = 1 - hd + if CIoU: # only include the wh aspect ratio part + w1, h1 = obb1[..., 2:4].split(1, dim=-1) + w2, h2 = obb2[..., 2:4].split(1, dim=-1) + v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - v * alpha # CIoU + return iou + + +def batch_probiou(obb1, obb2, eps=1e-7): + """ + Calculate the prob IoU between oriented bounding boxes, https://arxiv.org/pdf/2106.06072v1.pdf. + + Args: + obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. + obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7. + + Returns: + (torch.Tensor): A tensor of shape (N, M) representing obb similarities. + """ + obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1 + obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2 + + x1, y1 = obb1[..., :2].split(1, dim=-1) + x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1)) + a1, b1, c1 = _get_covariance_matrix(obb1) + a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2)) + + t1 = ( + ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) + ) * 0.25 + t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 + t3 = ( + ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) + / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) + + eps + ).log() * 0.5 + bd = (t1 + t2 + t3).clamp(eps, 100.0) + hd = (1.0 - (-bd).exp() + eps).sqrt() + return 1 - hd + + +def smooth_bce(eps=0.1): + """ + Computes smoothed positive and negative Binary Cross-Entropy targets. + + This function calculates positive and negative label smoothing BCE targets based on a given epsilon value. + For implementation details, refer to https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441. + + Args: + eps (float, optional): The epsilon value for label smoothing. Defaults to 0.1. + + Returns: + (tuple): A tuple containing the positive and negative label smoothing BCE targets. + """ + return 1.0 - 0.5 * eps, 0.5 * eps + + +class ConfusionMatrix: + """ + A class for calculating and updating a confusion matrix for object detection and classification tasks. + + Attributes: + task (str): The type of task, either 'detect' or 'classify'. + matrix (np.ndarray): The confusion matrix, with dimensions depending on the task. + nc (int): The number of classes. + conf (float): The confidence threshold for detections. + iou_thres (float): The Intersection over Union threshold. + """ + + def __init__(self, nc, conf=0.25, iou_thres=0.45, task="detect"): + """Initialize attributes for the YOLO model.""" + self.task = task + self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc)) + self.nc = nc # number of classes + self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed + self.iou_thres = iou_thres + + def process_cls_preds(self, preds, targets): + """ + Update confusion matrix for classification task. + + Args: + preds (Array[N, min(nc,5)]): Predicted class labels. + targets (Array[N, 1]): Ground truth class labels. + """ + preds, targets = torch.cat(preds)[:, 0], torch.cat(targets) + for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()): + self.matrix[p][t] += 1 + + def process_batch(self, detections, gt_bboxes, gt_cls): + """ + Update confusion matrix for object detection task. + + Args: + detections (Array[N, 6] | Array[N, 7]): Detected bounding boxes and their associated information. + Each row should contain (x1, y1, x2, y2, conf, class) + or with an additional element `angle` when it's obb. + gt_bboxes (Array[M, 4]| Array[N, 5]): Ground truth bounding boxes with xyxy/xyxyr format. + gt_cls (Array[M]): The class labels. + """ + if gt_cls.shape[0] == 0: # Check if labels is empty + if detections is not None: + detections = detections[detections[:, 4] > self.conf] + detection_classes = detections[:, 5].int() + for dc in detection_classes: + self.matrix[dc, self.nc] += 1 # false positives + return + if detections is None: + gt_classes = gt_cls.int() + for gc in gt_classes: + self.matrix[self.nc, gc] += 1 # background FN + return + + detections = detections[detections[:, 4] > self.conf] + gt_classes = gt_cls.int() + detection_classes = detections[:, 5].int() + is_obb = detections.shape[1] == 7 and gt_bboxes.shape[1] == 5 # with additional `angle` dimension + iou = ( + batch_probiou(gt_bboxes, torch.cat([detections[:, :4], detections[:, -1:]], dim=-1)) + if is_obb + else box_iou(gt_bboxes, detections[:, :4]) + ) + + x = torch.where(iou > self.iou_thres) + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + else: + matches = np.zeros((0, 3)) + + n = matches.shape[0] > 0 + m0, m1, _ = matches.transpose().astype(int) + for i, gc in enumerate(gt_classes): + j = m0 == i + if n and sum(j) == 1: + self.matrix[detection_classes[m1[j]], gc] += 1 # correct + else: + self.matrix[self.nc, gc] += 1 # true background + + for i, dc in enumerate(detection_classes): + if not any(m1 == i): + self.matrix[dc, self.nc] += 1 # predicted background + + def matrix(self): + """Returns the confusion matrix.""" + return self.matrix + + def tp_fp(self): + """Returns true positives and false positives.""" + tp = self.matrix.diagonal() # true positives + fp = self.matrix.sum(1) - tp # false positives + # fn = self.matrix.sum(0) - tp # false negatives (missed detections) + return (tp[:-1], fp[:-1]) if self.task == "detect" else (tp, fp) # remove background class if task=detect + + @TryExcept("WARNING ⚠️ ConfusionMatrix plot failure") + @plt_settings() + def plot(self, normalize=True, save_dir="", names=(), on_plot=None): + """ + Plot the confusion matrix using seaborn and save it to a file. + + Args: + normalize (bool): Whether to normalize the confusion matrix. + save_dir (str): Directory where the plot will be saved. + names (tuple): Names of classes, used as labels on the plot. + on_plot (func): An optional callback to pass plots path and data when they are rendered. + """ + import seaborn # scope for faster 'import ultralytics' + + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns + array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) + + fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) + nc, nn = self.nc, len(names) # number of classes, names + seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size + labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels + ticklabels = (list(names) + ["background"]) if labels else "auto" + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered + seaborn.heatmap( + array, + ax=ax, + annot=nc < 30, + annot_kws={"size": 8}, + cmap="Blues", + fmt=".2f" if normalize else ".0f", + square=True, + vmin=0.0, + xticklabels=ticklabels, + yticklabels=ticklabels, + ).set_facecolor((1, 1, 1)) + title = "Confusion Matrix" + " Normalized" * normalize + ax.set_xlabel("True") + ax.set_ylabel("Predicted") + ax.set_title(title) + plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png" + fig.savefig(plot_fname, dpi=250) + plt.close(fig) + if on_plot: + on_plot(plot_fname) + + def print(self): + """Print the confusion matrix to the console.""" + for i in range(self.nc + 1): + LOGGER.info(" ".join(map(str, self.matrix[i]))) + + +def smooth(y, f=0.05): + """Box filter of fraction f.""" + nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) + p = np.ones(nf // 2) # ones padding + yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded + return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed + + +@plt_settings() +def plot_pr_curve(px, py, ap, save_dir=Path("pr_curve.png"), names={}, on_plot=None): + """Plots a precision-recall curve.""" + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + py = np.stack(py, axis=1) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py.T): + ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) + else: + ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) + + ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5") + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title("Precision-Recall Curve") + fig.savefig(save_dir, dpi=250) + plt.close(fig) + if on_plot: + on_plot(save_dir) + + +@plt_settings() +def plot_mc_curve(px, py, save_dir=Path("mc_curve.png"), names={}, xlabel="Confidence", ylabel="Metric", on_plot=None): + """Plots a metric-confidence curve.""" + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py): + ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric) + else: + ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) + + y = smooth(py.mean(0), 0.05) + ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}") + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title(f"{ylabel}-Confidence Curve") + fig.savefig(save_dir, dpi=250) + plt.close(fig) + if on_plot: + on_plot(save_dir) + + +def compute_ap(recall, precision): + """ + Compute the average precision (AP) given the recall and precision curves. + + Args: + recall (list): The recall curve. + precision (list): The precision curve. + + Returns: + (float): Average precision. + (np.ndarray): Precision envelope curve. + (np.ndarray): Modified recall curve with sentinel values added at the beginning and end. + """ + # Append sentinel values to beginning and end + mrec = np.concatenate(([0.0], recall, [1.0])) + mpre = np.concatenate(([1.0], precision, [0.0])) + + # Compute the precision envelope + mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) + + # Integrate area under curve + method = "interp" # methods: 'continuous', 'interp' + if method == "interp": + x = np.linspace(0, 1, 101) # 101-point interp (COCO) + ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate + else: # 'continuous' + i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve + + return ap, mpre, mrec + + +def ap_per_class( + tp, conf, pred_cls, target_cls, plot=False, on_plot=None, save_dir=Path(), names={}, eps=1e-16, prefix="" +): + """ + Computes the average precision per class for object detection evaluation. + + Args: + tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False). + conf (np.ndarray): Array of confidence scores of the detections. + pred_cls (np.ndarray): Array of predicted classes of the detections. + target_cls (np.ndarray): Array of true classes of the detections. + plot (bool, optional): Whether to plot PR curves or not. Defaults to False. + on_plot (func, optional): A callback to pass plots path and data when they are rendered. Defaults to None. + save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path. + names (dict, optional): Dict of class names to plot PR curves. Defaults to an empty tuple. + eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16. + prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string. + + Returns: + tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.Shape: (nc,). + fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. Shape: (nc,). + p (np.ndarray): Precision values at threshold given by max F1 metric for each class. Shape: (nc,). + r (np.ndarray): Recall values at threshold given by max F1 metric for each class. Shape: (nc,). + f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. Shape: (nc,). + ap (np.ndarray): Average precision for each class at different IoU thresholds. Shape: (nc, 10). + unique_classes (np.ndarray): An array of unique classes that have data. Shape: (nc,). + p_curve (np.ndarray): Precision curves for each class. Shape: (nc, 1000). + r_curve (np.ndarray): Recall curves for each class. Shape: (nc, 1000). + f1_curve (np.ndarray): F1-score curves for each class. Shape: (nc, 1000). + x (np.ndarray): X-axis values for the curves. Shape: (1000,). + prec_values (np.ndarray): Precision values at mAP@0.5 for each class. Shape: (nc, 1000). + """ + # Sort by objectness + i = np.argsort(-conf) + tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] + + # Find unique classes + unique_classes, nt = np.unique(target_cls, return_counts=True) + nc = unique_classes.shape[0] # number of classes, number of detections + + # Create Precision-Recall curve and compute AP for each class + x, prec_values = np.linspace(0, 1, 1000), [] + + # Average precision, precision and recall curves + ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) + for ci, c in enumerate(unique_classes): + i = pred_cls == c + n_l = nt[ci] # number of labels + n_p = i.sum() # number of predictions + if n_p == 0 or n_l == 0: + continue + + # Accumulate FPs and TPs + fpc = (1 - tp[i]).cumsum(0) + tpc = tp[i].cumsum(0) + + # Recall + recall = tpc / (n_l + eps) # recall curve + r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases + + # Precision + precision = tpc / (tpc + fpc) # precision curve + p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score + + # AP from recall-precision curve + for j in range(tp.shape[1]): + ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) + if j == 0: + prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5 + + prec_values = np.array(prec_values) # (nc, 1000) + + # Compute F1 (harmonic mean of precision and recall) + f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps) + names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data + names = dict(enumerate(names)) # to dict + if plot: + plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot) + plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot) + plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot) + plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot) + + i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index + p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values + tp = (r * nt).round() # true positives + fp = (tp / (p + eps) - tp).round() # false positives + return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values + + +class Metric(SimpleClass): + """ + Class for computing evaluation metrics for YOLOv8 model. + + Attributes: + p (list): Precision for each class. Shape: (nc,). + r (list): Recall for each class. Shape: (nc,). + f1 (list): F1 score for each class. Shape: (nc,). + all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). + ap_class_index (list): Index of class for each AP score. Shape: (nc,). + nc (int): Number of classes. + + Methods: + ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or []. + ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or []. + mp(): Mean precision of all classes. Returns: Float. + mr(): Mean recall of all classes. Returns: Float. + map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float. + map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float. + map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float. + mean_results(): Mean of results, returns mp, mr, map50, map. + class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i]. + maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,). + fitness(): Model fitness as a weighted combination of metrics. Returns: Float. + update(results): Update metric attributes with new evaluation results. + """ + + def __init__(self) -> None: + """Initializes a Metric instance for computing evaluation metrics for the YOLOv8 model.""" + self.p = [] # (nc, ) + self.r = [] # (nc, ) + self.f1 = [] # (nc, ) + self.all_ap = [] # (nc, 10) + self.ap_class_index = [] # (nc, ) + self.nc = 0 + + @property + def ap50(self): + """ + Returns the Average Precision (AP) at an IoU threshold of 0.5 for all classes. + + Returns: + (np.ndarray, list): Array of shape (nc,) with AP50 values per class, or an empty list if not available. + """ + return self.all_ap[:, 0] if len(self.all_ap) else [] + + @property + def ap(self): + """ + Returns the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes. + + Returns: + (np.ndarray, list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available. + """ + return self.all_ap.mean(1) if len(self.all_ap) else [] + + @property + def mp(self): + """ + Returns the Mean Precision of all classes. + + Returns: + (float): The mean precision of all classes. + """ + return self.p.mean() if len(self.p) else 0.0 + + @property + def mr(self): + """ + Returns the Mean Recall of all classes. + + Returns: + (float): The mean recall of all classes. + """ + return self.r.mean() if len(self.r) else 0.0 + + @property + def map50(self): + """ + Returns the mean Average Precision (mAP) at an IoU threshold of 0.5. + + Returns: + (float): The mAP at an IoU threshold of 0.5. + """ + return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 + + @property + def map75(self): + """ + Returns the mean Average Precision (mAP) at an IoU threshold of 0.75. + + Returns: + (float): The mAP at an IoU threshold of 0.75. + """ + return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0 + + @property + def map(self): + """ + Returns the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05. + + Returns: + (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05. + """ + return self.all_ap.mean() if len(self.all_ap) else 0.0 + + def mean_results(self): + """Mean of results, return mp, mr, map50, map.""" + return [self.mp, self.mr, self.map50, self.map] + + def class_result(self, i): + """Class-aware result, return p[i], r[i], ap50[i], ap[i].""" + return self.p[i], self.r[i], self.ap50[i], self.ap[i] + + @property + def maps(self): + """MAP of each class.""" + maps = np.zeros(self.nc) + self.map + for i, c in enumerate(self.ap_class_index): + maps[c] = self.ap[i] + return maps + + def fitness(self): + """Model fitness as a weighted combination of metrics.""" + w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] + return (np.array(self.mean_results()) * w).sum() + + def update(self, results): + """ + Updates the evaluation metrics of the model with a new set of results. + + Args: + results (tuple): A tuple containing the following evaluation metrics: + - p (list): Precision for each class. Shape: (nc,). + - r (list): Recall for each class. Shape: (nc,). + - f1 (list): F1 score for each class. Shape: (nc,). + - all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). + - ap_class_index (list): Index of class for each AP score. Shape: (nc,). + + Side Effects: + Updates the class attributes `self.p`, `self.r`, `self.f1`, `self.all_ap`, and `self.ap_class_index` based + on the values provided in the `results` tuple. + """ + ( + self.p, + self.r, + self.f1, + self.all_ap, + self.ap_class_index, + self.p_curve, + self.r_curve, + self.f1_curve, + self.px, + self.prec_values, + ) = results + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [ + [self.px, self.prec_values, "Recall", "Precision"], + [self.px, self.f1_curve, "Confidence", "F1"], + [self.px, self.p_curve, "Confidence", "Precision"], + [self.px, self.r_curve, "Confidence", "Recall"], + ] + + +class DetMetrics(SimpleClass): + """ + Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an + object detection model. + + Args: + save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory. + plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False. + on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None. + names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple. + + Attributes: + save_dir (Path): A path to the directory where the output plots will be saved. + plot (bool): A flag that indicates whether to plot the precision-recall curves for each class. + on_plot (func): An optional callback to pass plots path and data when they are rendered. + names (dict of str): A dict of strings that represents the names of the classes. + box (Metric): An instance of the Metric class for storing the results of the detection metrics. + speed (dict): A dictionary for storing the execution time of different parts of the detection process. + + Methods: + process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions. + keys: Returns a list of keys for accessing the computed detection metrics. + mean_results: Returns a list of mean values for the computed detection metrics. + class_result(i): Returns a list of values for the computed detection metrics for a specific class. + maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds. + fitness: Computes the fitness score based on the computed detection metrics. + ap_class_index: Returns a list of class indices sorted by their average precision (AP) values. + results_dict: Returns a dictionary that maps detection metric keys to their computed values. + curves: TODO + curves_results: TODO + """ + + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names={}) -> None: + """Initialize a DetMetrics instance with a save directory, plot flag, callback function, and class names.""" + self.save_dir = save_dir + self.plot = plot + self.on_plot = on_plot + self.names = names + self.box = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "detect" + + def process(self, tp, conf, pred_cls, target_cls): + """Process predicted results for object detection and update metrics.""" + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=self.on_plot, + )[2:] + self.box.nc = len(self.names) + self.box.update(results) + + @property + def keys(self): + """Returns a list of keys for accessing specific metrics.""" + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] + + def mean_results(self): + """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" + return self.box.mean_results() + + def class_result(self, i): + """Return the result of evaluating the performance of an object detection model on a specific class.""" + return self.box.class_result(i) + + @property + def maps(self): + """Returns mean Average Precision (mAP) scores per class.""" + return self.box.maps + + @property + def fitness(self): + """Returns the fitness of box object.""" + return self.box.fitness() + + @property + def ap_class_index(self): + """Returns the average precision index per class.""" + return self.box.ap_class_index + + @property + def results_dict(self): + """Returns dictionary of computed performance metrics and statistics.""" + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"] + + @property + def curves_results(self): + """Returns dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + + +class SegmentMetrics(SimpleClass): + """ + Calculates and aggregates detection and segmentation metrics over a given set of classes. + + Args: + save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory. + plot (bool): Whether to save the detection and segmentation plots. Default is False. + on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None. + names (list): List of class names. Default is an empty list. + + Attributes: + save_dir (Path): Path to the directory where the output plots should be saved. + plot (bool): Whether to save the detection and segmentation plots. + on_plot (func): An optional callback to pass plots path and data when they are rendered. + names (list): List of class names. + box (Metric): An instance of the Metric class to calculate box detection metrics. + seg (Metric): An instance of the Metric class to calculate mask segmentation metrics. + speed (dict): Dictionary to store the time taken in different phases of inference. + + Methods: + process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions. + mean_results(): Returns the mean of the detection and segmentation metrics over all the classes. + class_result(i): Returns the detection and segmentation metrics of class `i`. + maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95. + fitness: Returns the fitness scores, which are a single weighted combination of metrics. + ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP). + results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. + """ + + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + """Initialize a SegmentMetrics instance with a save directory, plot flag, callback function, and class names.""" + self.save_dir = save_dir + self.plot = plot + self.on_plot = on_plot + self.names = names + self.box = Metric() + self.seg = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "segment" + + def process(self, tp, tp_m, conf, pred_cls, target_cls): + """ + Processes the detection and segmentation metrics over the given set of predictions. + + Args: + tp (list): List of True Positive boxes. + tp_m (list): List of True Positive masks. + conf (list): List of confidence scores. + pred_cls (list): List of predicted classes. + target_cls (list): List of target classes. + """ + results_mask = ap_per_class( + tp_m, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Mask", + )[2:] + self.seg.nc = len(self.names) + self.seg.update(results_mask) + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] + self.box.nc = len(self.names) + self.box.update(results_box) + + @property + def keys(self): + """Returns a list of keys for accessing metrics.""" + return [ + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(M)", + "metrics/recall(M)", + "metrics/mAP50(M)", + "metrics/mAP50-95(M)", + ] + + def mean_results(self): + """Return the mean metrics for bounding box and segmentation results.""" + return self.box.mean_results() + self.seg.mean_results() + + def class_result(self, i): + """Returns classification results for a specified class index.""" + return self.box.class_result(i) + self.seg.class_result(i) + + @property + def maps(self): + """Returns mAP scores for object detection and semantic segmentation models.""" + return self.box.maps + self.seg.maps + + @property + def fitness(self): + """Get the fitness score for both segmentation and bounding box models.""" + return self.seg.fitness() + self.box.fitness() + + @property + def ap_class_index(self): + """Boxes and masks have the same ap_class_index.""" + return self.box.ap_class_index + + @property + def results_dict(self): + """Returns results of object detection model for evaluation.""" + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [ + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(M)", + "F1-Confidence(M)", + "Precision-Confidence(M)", + "Recall-Confidence(M)", + ] + + @property + def curves_results(self): + """Returns dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + self.seg.curves_results + + +class PoseMetrics(SegmentMetrics): + """ + Calculates and aggregates detection and pose metrics over a given set of classes. + + Args: + save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory. + plot (bool): Whether to save the detection and segmentation plots. Default is False. + on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None. + names (list): List of class names. Default is an empty list. + + Attributes: + save_dir (Path): Path to the directory where the output plots should be saved. + plot (bool): Whether to save the detection and segmentation plots. + on_plot (func): An optional callback to pass plots path and data when they are rendered. + names (list): List of class names. + box (Metric): An instance of the Metric class to calculate box detection metrics. + pose (Metric): An instance of the Metric class to calculate mask segmentation metrics. + speed (dict): Dictionary to store the time taken in different phases of inference. + + Methods: + process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions. + mean_results(): Returns the mean of the detection and segmentation metrics over all the classes. + class_result(i): Returns the detection and segmentation metrics of class `i`. + maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95. + fitness: Returns the fitness scores, which are a single weighted combination of metrics. + ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP). + results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score. + """ + + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + """Initialize the PoseMetrics class with directory path, class names, and plotting options.""" + super().__init__(save_dir, plot, names) + self.save_dir = save_dir + self.plot = plot + self.on_plot = on_plot + self.names = names + self.box = Metric() + self.pose = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "pose" + + def process(self, tp, tp_p, conf, pred_cls, target_cls): + """ + Processes the detection and pose metrics over the given set of predictions. + + Args: + tp (list): List of True Positive boxes. + tp_p (list): List of True Positive keypoints. + conf (list): List of confidence scores. + pred_cls (list): List of predicted classes. + target_cls (list): List of target classes. + """ + results_pose = ap_per_class( + tp_p, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Pose", + )[2:] + self.pose.nc = len(self.names) + self.pose.update(results_pose) + results_box = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + on_plot=self.on_plot, + save_dir=self.save_dir, + names=self.names, + prefix="Box", + )[2:] + self.box.nc = len(self.names) + self.box.update(results_box) + + @property + def keys(self): + """Returns list of evaluation metric keys.""" + return [ + "metrics/precision(B)", + "metrics/recall(B)", + "metrics/mAP50(B)", + "metrics/mAP50-95(B)", + "metrics/precision(P)", + "metrics/recall(P)", + "metrics/mAP50(P)", + "metrics/mAP50-95(P)", + ] + + def mean_results(self): + """Return the mean results of box and pose.""" + return self.box.mean_results() + self.pose.mean_results() + + def class_result(self, i): + """Return the class-wise detection results for a specific class i.""" + return self.box.class_result(i) + self.pose.class_result(i) + + @property + def maps(self): + """Returns the mean average precision (mAP) per class for both box and pose detections.""" + return self.box.maps + self.pose.maps + + @property + def fitness(self): + """Computes classification metrics and speed using the `targets` and `pred` inputs.""" + return self.pose.fitness() + self.box.fitness() + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [ + "Precision-Recall(B)", + "F1-Confidence(B)", + "Precision-Confidence(B)", + "Recall-Confidence(B)", + "Precision-Recall(P)", + "F1-Confidence(P)", + "Precision-Confidence(P)", + "Recall-Confidence(P)", + ] + + @property + def curves_results(self): + """Returns dictionary of computed performance metrics and statistics.""" + return self.box.curves_results + self.pose.curves_results + + +class ClassifyMetrics(SimpleClass): + """ + Class for computing classification metrics including top-1 and top-5 accuracy. + + Attributes: + top1 (float): The top-1 accuracy. + top5 (float): The top-5 accuracy. + speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline. + fitness (float): The fitness of the model, which is equal to top-5 accuracy. + results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness. + keys (List[str]): A list of keys for the results_dict. + + Methods: + process(targets, pred): Processes the targets and predictions to compute classification metrics. + """ + + def __init__(self) -> None: + """Initialize a ClassifyMetrics instance.""" + self.top1 = 0 + self.top5 = 0 + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + self.task = "classify" + + def process(self, targets, pred): + """Target classes and predicted classes.""" + pred, targets = torch.cat(pred), torch.cat(targets) + correct = (targets[:, None] == pred).float() + acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy + self.top1, self.top5 = acc.mean(0).tolist() + + @property + def fitness(self): + """Returns mean of top-1 and top-5 accuracies as fitness score.""" + return (self.top1 + self.top5) / 2 + + @property + def results_dict(self): + """Returns a dictionary with model's performance metrics and fitness score.""" + return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness])) + + @property + def keys(self): + """Returns a list of keys for the results_dict property.""" + return ["metrics/accuracy_top1", "metrics/accuracy_top5"] + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + +class OBBMetrics(SimpleClass): + """Metrics for evaluating oriented bounding box (OBB) detection, see https://arxiv.org/pdf/2106.06072.pdf.""" + + def __init__(self, save_dir=Path("."), plot=False, on_plot=None, names=()) -> None: + """Initialize an OBBMetrics instance with directory, plotting, callback, and class names.""" + self.save_dir = save_dir + self.plot = plot + self.on_plot = on_plot + self.names = names + self.box = Metric() + self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} + + def process(self, tp, conf, pred_cls, target_cls): + """Process predicted results for object detection and update metrics.""" + results = ap_per_class( + tp, + conf, + pred_cls, + target_cls, + plot=self.plot, + save_dir=self.save_dir, + names=self.names, + on_plot=self.on_plot, + )[2:] + self.box.nc = len(self.names) + self.box.update(results) + + @property + def keys(self): + """Returns a list of keys for accessing specific metrics.""" + return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] + + def mean_results(self): + """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" + return self.box.mean_results() + + def class_result(self, i): + """Return the result of evaluating the performance of an object detection model on a specific class.""" + return self.box.class_result(i) + + @property + def maps(self): + """Returns mean Average Precision (mAP) scores per class.""" + return self.box.maps + + @property + def fitness(self): + """Returns the fitness of box object.""" + return self.box.fitness() + + @property + def ap_class_index(self): + """Returns the average precision index per class.""" + return self.box.ap_class_index + + @property + def results_dict(self): + """Returns dictionary of computed performance metrics and statistics.""" + return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) + + @property + def curves(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] + + @property + def curves_results(self): + """Returns a list of curves for accessing specific metrics curves.""" + return [] diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..52b51552175af9db18f07997fd2c426b229b490f --- /dev/null +++ b/ultralytics/utils/ops.py @@ -0,0 +1,854 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import contextlib +import math +import re +import time + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + +from ultralytics.utils import LOGGER +from ultralytics.utils.metrics import batch_probiou + + +class Profile(contextlib.ContextDecorator): + """ + YOLOv8 Profile class. Use as a decorator with @Profile() or as a context manager with 'with Profile():'. + + Example: + ```python + from ultralytics.utils.ops import Profile + + with Profile(device=device) as dt: + pass # slow operation here + + print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" + ``` + """ + + def __init__(self, t=0.0, device: torch.device = None): + """ + Initialize the Profile class. + + Args: + t (float): Initial time. Defaults to 0.0. + device (torch.device): Devices used for model inference. Defaults to None (cpu). + """ + self.t = t + self.device = device + self.cuda = bool(device and str(device).startswith("cuda")) + + def __enter__(self): + """Start timing.""" + self.start = self.time() + return self + + def __exit__(self, type, value, traceback): # noqa + """Stop timing.""" + self.dt = self.time() - self.start # delta-time + self.t += self.dt # accumulate dt + + def __str__(self): + """Returns a human-readable string representing the accumulated elapsed time in the profiler.""" + return f"Elapsed time is {self.t} s" + + def time(self): + """Get current time.""" + if self.cuda: + torch.cuda.synchronize(self.device) + return time.time() + + +def segment2box(segment, width=640, height=640): + """ + Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy). + + Args: + segment (torch.Tensor): the segment label + width (int): the width of the image. Defaults to 640 + height (int): The height of the image. Defaults to 640 + + Returns: + (np.ndarray): the minimum and maximum x and y values of the segment. + """ + x, y = segment.T # segment xy + # any 3 out of 4 sides are outside the image, clip coordinates first, https://github.com/ultralytics/ultralytics/pull/18294 + if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3: + x = x.clip(0, width) + y = y.clip(0, height) + inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) + x = x[inside] + y = y[inside] + return ( + np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) + if any(x) + else np.zeros(4, dtype=segment.dtype) + ) # xyxy + + +def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding=True, xywh=False): + """ + Rescales bounding boxes (in the format of xyxy by default) from the shape of the image they were originally + specified in (img1_shape) to the shape of a different image (img0_shape). + + Args: + img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width). + boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2) + img0_shape (tuple): the shape of the target image, in the format of (height, width). + ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be + calculated based on the size difference between the two images. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + xywh (bool): The box format is xywh or not, default=False. + + Returns: + boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2) + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = ( + round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), + round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), + ) # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + boxes[..., 0] -= pad[0] # x padding + boxes[..., 1] -= pad[1] # y padding + if not xywh: + boxes[..., 2] -= pad[0] # x padding + boxes[..., 3] -= pad[1] # y padding + boxes[..., :4] /= gain + return clip_boxes(boxes, img0_shape) + + +def make_divisible(x, divisor): + """ + Returns the nearest number that is divisible by the given divisor. + + Args: + x (int): The number to make divisible. + divisor (int | torch.Tensor): The divisor. + + Returns: + (int): The nearest number divisible by the divisor. + """ + if isinstance(divisor, torch.Tensor): + divisor = int(divisor.max()) # to int + return math.ceil(x / divisor) * divisor + + +def nms_rotated(boxes, scores, threshold=0.45): + """ + NMS for oriented bounding boxes using probiou and fast-nms. + + Args: + boxes (torch.Tensor): Rotated bounding boxes, shape (N, 5), format xywhr. + scores (torch.Tensor): Confidence scores, shape (N,). + threshold (float, optional): IoU threshold. Defaults to 0.45. + + Returns: + (torch.Tensor): Indices of boxes to keep after NMS. + """ + if len(boxes) == 0: + return np.empty((0,), dtype=np.int8) + sorted_idx = torch.argsort(scores, descending=True) + boxes = boxes[sorted_idx] + ious = batch_probiou(boxes, boxes).triu_(diagonal=1) + pick = torch.nonzero(ious.max(dim=0)[0] < threshold).squeeze_(-1) + return sorted_idx[pick] + + +def non_max_suppression( + prediction, + conf_thres=0.25, + iou_thres=0.45, + classes=None, + agnostic=False, + multi_label=False, + labels=(), + max_det=300, + nc=0, # number of classes (optional) + max_time_img=0.05, + max_nms=30000, + max_wh=7680, + in_place=True, + rotated=False, +): + """ + Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. + + Args: + prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes) + containing the predicted boxes, classes, and masks. The tensor should be in the format + output by a model, such as YOLO. + conf_thres (float): The confidence threshold below which boxes will be filtered out. + Valid values are between 0.0 and 1.0. + iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS. + Valid values are between 0.0 and 1.0. + classes (List[int]): A list of class indices to consider. If None, all classes will be considered. + agnostic (bool): If True, the model is agnostic to the number of classes, and all + classes will be considered as one. + multi_label (bool): If True, each box may have multiple labels. + labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner + list contains the apriori labels for a given image. The list should be in the format + output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2). + max_det (int): The maximum number of boxes to keep after NMS. + nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks. + max_time_img (float): The maximum time (seconds) for processing one image. + max_nms (int): The maximum number of boxes into torchvision.ops.nms(). + max_wh (int): The maximum box width and height in pixels. + in_place (bool): If True, the input prediction tensor will be modified in place. + rotated (bool): If Oriented Bounding Boxes (OBB) are being passed for NMS. + + Returns: + (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of + shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns + (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). + """ + import torchvision # scope for faster 'import ultralytics' + + # Checks + assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" + assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) + prediction = prediction[0] # select only inference output + if classes is not None: + classes = torch.tensor(classes, device=prediction.device) + + if prediction.shape[-1] == 6: # end-to-end model (BNC, i.e. 1,300,6) + output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] + if classes is not None: + output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] + return output + + bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) + nc = nc or (prediction.shape[1] - 4) # number of classes + nm = prediction.shape[1] - nc - 4 # number of masks + mi = 4 + nc # mask start index + xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates + + # Settings + # min_wh = 2 # (pixels) minimum box width and height + time_limit = 2.0 + max_time_img * bs # seconds to quit after + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) + + prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) + if not rotated: + if in_place: + prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy + else: + prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy + + t = time.time() + output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]) and not rotated: + lb = labels[xi] + v = torch.zeros((len(lb), nc + nm + 4), device=x.device) + v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box + v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Detections matrix nx6 (xyxy, conf, cls) + box, cls, mask = x.split((4, nc, nm), 1) + + if multi_label: + i, j = torch.where(cls > conf_thres) + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) + else: # best class only + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == classes).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + if n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + scores = x[:, 4] # scores + if rotated: + boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr + i = nms_rotated(boxes, scores, iou_thres) + else: + boxes = x[:, :4] + c # boxes (offset by class) + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + i = i[:max_det] # limit detections + + # # Experimental + # merge = False # use merge-NMS + # if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) + # # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + # from .metrics import box_iou + # iou = box_iou(boxes[i], boxes) > iou_thres # IoU matrix + # weights = iou * scores[None] # box weights + # x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + # redundant = True # require redundant detections + # if redundant: + # i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded") + break # time limit exceeded + + return output + + +def clip_boxes(boxes, shape): + """ + Takes a list of bounding boxes and a shape (height, width) and clips the bounding boxes to the shape. + + Args: + boxes (torch.Tensor): The bounding boxes to clip. + shape (tuple): The shape of the image. + + Returns: + (torch.Tensor | numpy.ndarray): The clipped boxes. + """ + if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 + boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 + boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 + boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 + else: # np.array (faster grouped) + boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 + boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 + return boxes + + +def clip_coords(coords, shape): + """ + Clip line coordinates to the image boundaries. + + Args: + coords (torch.Tensor | numpy.ndarray): A list of line coordinates. + shape (tuple): A tuple of integers representing the size of the image in the format (height, width). + + Returns: + (torch.Tensor | numpy.ndarray): Clipped coordinates + """ + if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) + coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y + else: # np.array (faster grouped) + coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x + coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y + return coords + + +def scale_image(masks, im0_shape, ratio_pad=None): + """ + Takes a mask, and resizes it to the original image size. + + Args: + masks (np.ndarray): Resized and padded masks/images, [h, w, num]/[h, w, 3]. + im0_shape (tuple): The original image shape. + ratio_pad (tuple): The ratio of the padding to the original image. + + Returns: + masks (np.ndarray): The masks that are being returned with shape [h, w, num]. + """ + # Rescale coordinates (xyxy) from im1_shape to im0_shape + im1_shape = masks.shape + if im1_shape[:2] == im0_shape[:2]: + return masks + if ratio_pad is None: # calculate from im0_shape + gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new + pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding + else: + # gain = ratio_pad[0][0] + pad = ratio_pad[1] + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(im1_shape[0] - pad[1]), int(im1_shape[1] - pad[0]) + + if len(masks.shape) < 2: + raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') + masks = masks[top:bottom, left:right] + masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) + if len(masks.shape) == 2: + masks = masks[:, :, None] + + return masks + + +def xyxy2xywh(x): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center + y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def xywh2xyxy(x): + """ + Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the + top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x, y, width, height) format. + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + xy = x[..., :2] # centers + wh = x[..., 2:] / 2 # half width-height + y[..., :2] = xy - wh # top left xy + y[..., 2:] = xy + wh # bottom right xy + return y + + +def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): + """ + Convert normalized bounding box coordinates to pixel coordinates. + + Args: + x (np.ndarray | torch.Tensor): The bounding box coordinates. + w (int): Width of the image. Defaults to 640 + h (int): Height of the image. Defaults to 640 + padw (int): Padding width. Defaults to 0 + padh (int): Padding height. Defaults to 0 + Returns: + y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where + x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. + """ + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x + y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y + y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x + y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y + return y + + +def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): + """ + Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, + width and height are normalized to image dimensions. + + Args: + x (np.ndarray | torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format. + w (int): The width of the image. Defaults to 640 + h (int): The height of the image. Defaults to 640 + clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False + eps (float): The minimum value of the box's width and height. Defaults to 0.0 + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format + """ + if clip: + x = clip_boxes(x, (h - eps, w - eps)) + assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" + y = empty_like(x) # faster than clone/copy + y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center + y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center + y[..., 2] = (x[..., 2] - x[..., 0]) / w # width + y[..., 3] = (x[..., 3] - x[..., 1]) / h # height + return y + + +def xywh2ltwh(x): + """ + Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding box coordinates in the xywh format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x + y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y + return y + + +def xyxy2ltwh(x): + """ + Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xyltwh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] - x[..., 0] # width + y[..., 3] = x[..., 3] - x[..., 1] # height + return y + + +def ltwh2xywh(x): + """ + Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. + + Args: + x (torch.Tensor): the input tensor + + Returns: + y (np.ndarray | torch.Tensor): The bounding box coordinates in the xywh format. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x + y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y + return y + + +def xyxyxyxy2xywhr(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation]. Rotation values are + returned in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Input box corners [xy1, xy2, xy3, xy4] of shape (n, 8). + + Returns: + (numpy.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format of shape (n, 5). + """ + is_torch = isinstance(x, torch.Tensor) + points = x.cpu().numpy() if is_torch else x + points = points.reshape(len(x), -1, 2) + rboxes = [] + for pts in points: + # NOTE: Use cv2.minAreaRect to get accurate xywhr, + # especially some objects are cut off by augmentations in dataloader. + (cx, cy), (w, h), angle = cv2.minAreaRect(pts) + rboxes.append([cx, cy, w, h, angle / 180 * np.pi]) + return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes) + + +def xywhr2xyxyxyxy(x): + """ + Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should + be in radians from 0 to pi/2. + + Args: + x (numpy.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format of shape (n, 5) or (b, n, 5). + + Returns: + (numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2). + """ + cos, sin, cat, stack = ( + (torch.cos, torch.sin, torch.cat, torch.stack) + if isinstance(x, torch.Tensor) + else (np.cos, np.sin, np.concatenate, np.stack) + ) + + ctr = x[..., :2] + w, h, angle = (x[..., i : i + 1] for i in range(2, 5)) + cos_value, sin_value = cos(angle), sin(angle) + vec1 = [w / 2 * cos_value, w / 2 * sin_value] + vec2 = [-h / 2 * sin_value, h / 2 * cos_value] + vec1 = cat(vec1, -1) + vec2 = cat(vec2, -1) + pt1 = ctr + vec1 + vec2 + pt2 = ctr + vec1 - vec2 + pt3 = ctr - vec1 - vec2 + pt4 = ctr - vec1 + vec2 + return stack([pt1, pt2, pt3, pt4], -2) + + +def ltwh2xyxy(x): + """ + It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. + + Args: + x (np.ndarray | torch.Tensor): the input image + + Returns: + y (np.ndarray | torch.Tensor): the xyxy coordinates of the bounding boxes. + """ + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[..., 2] = x[..., 2] + x[..., 0] # width + y[..., 3] = x[..., 3] + x[..., 1] # height + return y + + +def segments2boxes(segments): + """ + It converts segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh). + + Args: + segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates + + Returns: + (np.ndarray): the xywh coordinates of the bounding boxes. + """ + boxes = [] + for s in segments: + x, y = s.T # segment xy + boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy + return xyxy2xywh(np.array(boxes)) # cls, xywh + + +def resample_segments(segments, n=1000): + """ + Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each. + + Args: + segments (list): a list of (n,2) arrays, where n is the number of points in the segment. + n (int): number of points to resample the segment to. Defaults to 1000 + + Returns: + segments (list): the resampled segments. + """ + for i, s in enumerate(segments): + if len(s) == n: + continue + s = np.concatenate((s, s[0:1, :]), axis=0) + x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n) + xp = np.arange(len(s)) + x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x + segments[i] = ( + np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T + ) # segment xy + return segments + + +def crop_mask(masks, boxes): + """ + It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box. + + Args: + masks (torch.Tensor): [n, h, w] tensor of masks + boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form + + Returns: + (torch.Tensor): The masks are being cropped to the bounding box. + """ + _, h, w = masks.shape + x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) + r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) + c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) + + return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + + +def process_mask(protos, masks_in, bboxes, shape, upsample=False): + """ + Apply masks to bounding boxes using the output of the mask head. + + Args: + protos (torch.Tensor): A tensor of shape [mask_dim, mask_h, mask_w]. + masks_in (torch.Tensor): A tensor of shape [n, mask_dim], where n is the number of masks after NMS. + bboxes (torch.Tensor): A tensor of shape [n, 4], where n is the number of masks after NMS. + shape (tuple): A tuple of integers representing the size of the input image in the format (h, w). + upsample (bool): A flag to indicate whether to upsample the mask to the original image size. Default is False. + + Returns: + (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w + are the height and width of the input image. The mask is applied to the bounding boxes. + """ + c, mh, mw = protos.shape # CHW + ih, iw = shape + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW + width_ratio = mw / iw + height_ratio = mh / ih + + downsampled_bboxes = bboxes.clone() + downsampled_bboxes[:, 0] *= width_ratio + downsampled_bboxes[:, 2] *= width_ratio + downsampled_bboxes[:, 3] *= height_ratio + downsampled_bboxes[:, 1] *= height_ratio + + masks = crop_mask(masks, downsampled_bboxes) # CHW + if upsample: + masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW + return masks.gt_(0.0) + + +def process_mask_native(protos, masks_in, bboxes, shape): + """ + It takes the output of the mask head, and crops it after upsampling to the bounding boxes. + + Args: + protos (torch.Tensor): [mask_dim, mask_h, mask_w] + masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms. + bboxes (torch.Tensor): [n, 4], n is number of masks after nms. + shape (tuple): The size of the input image (h,w). + + Returns: + masks (torch.Tensor): The returned masks with dimensions [h, w, n]. + """ + c, mh, mw = protos.shape # CHW + masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) + masks = scale_masks(masks[None], shape)[0] # CHW + masks = crop_mask(masks, bboxes) # CHW + return masks.gt_(0.0) + + +def scale_masks(masks, shape, padding=True): + """ + Rescale segment masks to shape. + + Args: + masks (torch.Tensor): (N, C, H, W). + shape (tuple): Height and width. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + """ + mh, mw = masks.shape[2:] + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding + if padding: + pad[0] /= 2 + pad[1] /= 2 + top, left = (int(pad[1]), int(pad[0])) if padding else (0, 0) # y, x + bottom, right = (int(mh - pad[1]), int(mw - pad[0])) + masks = masks[..., top:bottom, left:right] + + masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW + return masks + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize=False, padding=True): + """ + Rescale segment coordinates (xy) from img1_shape to img0_shape. + + Args: + img1_shape (tuple): The shape of the image that the coords are from. + coords (torch.Tensor): the coords to be scaled of shape n,2. + img0_shape (tuple): the shape of the image that the segmentation is being applied to. + ratio_pad (tuple): the ratio of the image size to the padded image size. + normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False. + padding (bool): If True, assuming the boxes is based on image augmented by yolo style. If False then do regular + rescaling. + + Returns: + coords (torch.Tensor): The scaled coordinates. + """ + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + if padding: + coords[..., 0] -= pad[0] # x padding + coords[..., 1] -= pad[1] # y padding + coords[..., 0] /= gain + coords[..., 1] /= gain + coords = clip_coords(coords, img0_shape) + if normalize: + coords[..., 0] /= img0_shape[1] # width + coords[..., 1] /= img0_shape[0] # height + return coords + + +def regularize_rboxes(rboxes): + """ + Regularize rotated boxes in range [0, pi/2]. + + Args: + rboxes (torch.Tensor): Input boxes of shape(N, 5) in xywhr format. + + Returns: + (torch.Tensor): The regularized boxes. + """ + x, y, w, h, t = rboxes.unbind(dim=-1) + # Swap edge and angle if h >= w + w_ = torch.where(w > h, w, h) + h_ = torch.where(w > h, h, w) + t = torch.where(w > h, t, t + math.pi / 2) % math.pi + return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes + + +def masks2segments(masks, strategy="all"): + """ + It takes a list of masks(n,h,w) and returns a list of segments(n,xy). + + Args: + masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160) + strategy (str): 'all' or 'largest'. Defaults to all + + Returns: + segments (List): list of segment masks + """ + from ultralytics.data.converter import merge_multi_segment + + segments = [] + for x in masks.int().cpu().numpy().astype("uint8"): + c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if c: + if strategy == "all": # merge and concatenate all segments + c = ( + np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) + if len(c) > 1 + else c[0].reshape(-1, 2) + ) + elif strategy == "largest": # select largest segment + c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) + else: + c = np.zeros((0, 2)) # no segments found + segments.append(c.astype("float32")) + return segments + + +def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray: + """ + Convert a batch of FP32 torch tensors (0.0-1.0) to a NumPy uint8 array (0-255), changing from BCHW to BHWC layout. + + Args: + batch (torch.Tensor): Input tensor batch of shape (Batch, Channels, Height, Width) and dtype torch.float32. + + Returns: + (np.ndarray): Output NumPy array batch of shape (Batch, Height, Width, Channels) and dtype uint8. + """ + return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() + + +def clean_str(s): + """ + Cleans a string by replacing special characters with '_' character. + + Args: + s (str): a string needing special characters replaced + + Returns: + (str): a string with special characters replaced by an underscore _ + """ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + + +def empty_like(x): + """Creates empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.""" + return ( + torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32) + ) diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py new file mode 100644 index 0000000000000000000000000000000000000000..1531cd7f8f62417f604ff126a0e8f11319fc6d39 --- /dev/null +++ b/ultralytics/utils/patches.py @@ -0,0 +1,104 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license +"""Monkey patches to update/extend functionality of existing functions.""" + +import time +from pathlib import Path + +import cv2 +import numpy as np +import torch + +# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------ +_imshow = cv2.imshow # copy to avoid recursion errors + + +def imread(filename: str, flags: int = cv2.IMREAD_COLOR): + """ + Read an image from a file. + + Args: + filename (str): Path to the file to read. + flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR. + + Returns: + (np.ndarray): The read image. + """ + return cv2.imdecode(np.fromfile(filename, np.uint8), flags) + + +def imwrite(filename: str, img: np.ndarray, params=None): + """ + Write an image to a file. + + Args: + filename (str): Path to the file to write. + img (np.ndarray): Image to write. + params (list of ints, optional): Additional parameters. See OpenCV documentation. + + Returns: + (bool): True if the file was written, False otherwise. + """ + try: + cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename) + return True + except Exception: + return False + + +def imshow(winname: str, mat: np.ndarray): + """ + Displays an image in the specified window. + + Args: + winname (str): Name of the window. + mat (np.ndarray): Image to be shown. + """ + _imshow(winname.encode("unicode_escape").decode(), mat) + + +# PyTorch functions ---------------------------------------------------------------------------------------------------- +_torch_load = torch.load # copy to avoid recursion errors +_torch_save = torch.save + + +def torch_load(*args, **kwargs): + """ + Load a PyTorch model with updated arguments to avoid warnings. + + This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings. + + Args: + *args (Any): Variable length argument list to pass to torch.load. + **kwargs (Any): Arbitrary keyword arguments to pass to torch.load. + + Returns: + (Any): The loaded PyTorch object. + + Note: + For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False' + if the argument is not provided, to avoid deprecation warnings. + """ + from ultralytics.utils.torch_utils import TORCH_1_13 + + if TORCH_1_13 and "weights_only" not in kwargs: + kwargs["weights_only"] = False + + return _torch_load(*args, **kwargs) + + +def torch_save(*args, **kwargs): + """ + Optionally use dill to serialize lambda functions where pickle does not, adding robustness with 3 retries and + exponential standoff in case of save failure. + + Args: + *args (tuple): Positional arguments to pass to torch.save. + **kwargs (Any): Keyword arguments to pass to torch.save. + """ + for i in range(4): # 3 retries + try: + return _torch_save(*args, **kwargs) + except RuntimeError as e: # unable to save, possibly waiting for device to flush or antivirus scan + if i == 3: + raise e + time.sleep((2**i) / 2) # exponential standoff: 0.5s, 1.0s, 2.0s diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..c308e8d2fa0b5ee268d43eb0cafaaa6fdcb7c255 --- /dev/null +++ b/ultralytics/utils/plotting.py @@ -0,0 +1,1378 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import math +import warnings +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from PIL import __version__ as pil_version + +from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded +from ultralytics.utils.checks import check_font, check_version, is_ascii +from ultralytics.utils.files import increment_path + + +class Colors: + """ + Ultralytics color palette https://docs.ultralytics.com/reference/utils/plotting/#ultralytics.utils.plotting.Colors. + + This class provides methods to work with the Ultralytics color palette, including converting hex color codes to + RGB values. + + Attributes: + palette (list of tuple): List of RGB color values. + n (int): The number of colors in the palette. + pose_palette (np.ndarray): A specific color palette array with dtype np.uint8. + + ## Ultralytics Color Palette + + | Index | Color | HEX | RGB | + |-------|-------------------------------------------------------------------|-----------|-------------------| + | 0 | | `#042aff` | (4, 42, 255) | + | 1 | | `#0bdbeb` | (11, 219, 235) | + | 2 | | `#f3f3f3` | (243, 243, 243) | + | 3 | | `#00dfb7` | (0, 223, 183) | + | 4 | | `#111f68` | (17, 31, 104) | + | 5 | | `#ff6fdd` | (255, 111, 221) | + | 6 | | `#ff444f` | (255, 68, 79) | + | 7 | | `#cced00` | (204, 237, 0) | + | 8 | | `#00f344` | (0, 243, 68) | + | 9 | | `#bd00ff` | (189, 0, 255) | + | 10 | | `#00b4ff` | (0, 180, 255) | + | 11 | | `#dd00ba` | (221, 0, 186) | + | 12 | | `#00ffff` | (0, 255, 255) | + | 13 | | `#26c000` | (38, 192, 0) | + | 14 | | `#01ffb3` | (1, 255, 179) | + | 15 | | `#7d24ff` | (125, 36, 255) | + | 16 | | `#7b0068` | (123, 0, 104) | + | 17 | | `#ff1b6c` | (255, 27, 108) | + | 18 | | `#fc6d2f` | (252, 109, 47) | + | 19 | | `#a2ff0b` | (162, 255, 11) | + + ## Pose Color Palette + + | Index | Color | HEX | RGB | + |-------|-------------------------------------------------------------------|-----------|-------------------| + | 0 | | `#ff8000` | (255, 128, 0) | + | 1 | | `#ff9933` | (255, 153, 51) | + | 2 | | `#ffb266` | (255, 178, 102) | + | 3 | | `#e6e600` | (230, 230, 0) | + | 4 | | `#ff99ff` | (255, 153, 255) | + | 5 | | `#99ccff` | (153, 204, 255) | + | 6 | | `#ff66ff` | (255, 102, 255) | + | 7 | | `#ff33ff` | (255, 51, 255) | + | 8 | | `#66b2ff` | (102, 178, 255) | + | 9 | | `#3399ff` | (51, 153, 255) | + | 10 | | `#ff9999` | (255, 153, 153) | + | 11 | | `#ff6666` | (255, 102, 102) | + | 12 | | `#ff3333` | (255, 51, 51) | + | 13 | | `#99ff99` | (153, 255, 153) | + | 14 | | `#66ff66` | (102, 255, 102) | + | 15 | | `#33ff33` | (51, 255, 51) | + | 16 | | `#00ff00` | (0, 255, 0) | + | 17 | | `#0000ff` | (0, 0, 255) | + | 18 | | `#ff0000` | (255, 0, 0) | + | 19 | | `#ffffff` | (255, 255, 255) | + + !!! note "Ultralytics Brand Colors" + + For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). Please use the official Ultralytics colors for all marketing materials. + """ + + def __init__(self): + """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" + hexs = ( + "042AFF", + "0BDBEB", + "F3F3F3", + "00DFB7", + "111F68", + "FF6FDD", + "FF444F", + "CCED00", + "00F344", + "BD00FF", + "00B4FF", + "DD00BA", + "00FFFF", + "26C000", + "01FFB3", + "7D24FF", + "7B0068", + "FF1B6C", + "FC6D2F", + "A2FF0B", + ) + self.palette = [self.hex2rgb(f"#{c}") for c in hexs] + self.n = len(self.palette) + self.pose_palette = np.array( + [ + [255, 128, 0], + [255, 153, 51], + [255, 178, 102], + [230, 230, 0], + [255, 153, 255], + [153, 204, 255], + [255, 102, 255], + [255, 51, 255], + [102, 178, 255], + [51, 153, 255], + [255, 153, 153], + [255, 102, 102], + [255, 51, 51], + [153, 255, 153], + [102, 255, 102], + [51, 255, 51], + [0, 255, 0], + [0, 0, 255], + [255, 0, 0], + [255, 255, 255], + ], + dtype=np.uint8, + ) + + def __call__(self, i, bgr=False): + """Converts hex color codes to RGB values.""" + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): + """Converts hex color codes to RGB values (i.e. default PIL order).""" + return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) + + +colors = Colors() # create instance for 'from utils.plots import colors' + + +class Annotator: + """ + Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations. + + Attributes: + im (Image.Image or numpy array): The image to annotate. + pil (bool): Whether to use PIL or cv2 for drawing annotations. + font (ImageFont.truetype or ImageFont.load_default): Font used for text annotations. + lw (float): Line width for drawing. + skeleton (List[List[int]]): Skeleton structure for keypoints. + limb_color (List[int]): Color palette for limbs. + kpt_color (List[int]): Color palette for keypoints. + """ + + def __init__(self, im, line_width=None, font_size=None, font="Arial.ttf", pil=False, example="abc"): + """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.""" + non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic + input_is_pil = isinstance(im, Image.Image) + self.pil = pil or non_ascii or input_is_pil + self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2) + if self.pil: # use PIL + self.im = im if input_is_pil else Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + try: + font = check_font("Arial.Unicode.ttf" if non_ascii else font) + size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) + self.font = ImageFont.truetype(str(font), size) + except Exception: + self.font = ImageFont.load_default() + # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string) + if check_version(pil_version, "9.2.0"): + self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height + else: # use cv2 + assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." + self.im = im if im.flags.writeable else im.copy() + self.tf = max(self.lw - 1, 1) # font thickness + self.sf = self.lw / 3 # font scale + # Pose + self.skeleton = [ + [16, 14], + [14, 12], + [17, 15], + [15, 13], + [12, 13], + [6, 12], + [7, 13], + [6, 7], + [6, 8], + [7, 9], + [8, 10], + [9, 11], + [2, 3], + [1, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + [5, 7], + ] + + self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]] + self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] + self.dark_colors = { + (235, 219, 11), + (243, 243, 243), + (183, 223, 0), + (221, 111, 255), + (0, 237, 204), + (68, 243, 0), + (255, 255, 0), + (179, 255, 1), + (11, 255, 162), + } + self.light_colors = { + (255, 42, 4), + (79, 68, 255), + (255, 0, 189), + (255, 180, 0), + (186, 0, 221), + (0, 192, 38), + (255, 36, 125), + (104, 0, 123), + (108, 27, 255), + (47, 109, 252), + (104, 31, 17), + } + + def get_txt_color(self, color=(128, 128, 128), txt_color=(255, 255, 255)): + """ + Assign text color based on background color. + + Args: + color (tuple, optional): The background color of the rectangle for text (B, G, R). + txt_color (tuple, optional): The color of the text (R, G, B). + + Returns: + txt_color (tuple): Text color for label + """ + if color in self.dark_colors: + return 104, 31, 17 + elif color in self.light_colors: + return 255, 255, 255 + else: + return txt_color + + def circle_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=2): + """ + Draws a label with a background circle centered within a given bounding box. + + Args: + box (tuple): The bounding box coordinates (x1, y1, x2, y2). + label (str): The text label to be displayed. + color (tuple, optional): The background color of the rectangle (B, G, R). + txt_color (tuple, optional): The color of the text (R, G, B). + margin (int, optional): The margin between the text and the rectangle border. + """ + # If label have more than 3 characters, skip other characters, due to circle size + if len(label) > 3: + print( + f"Length of label is {len(label)}, initial 3 label characters will be considered for circle annotation!" + ) + label = label[:3] + + # Calculate the center of the box + x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + # Get the text size + text_size = cv2.getTextSize(str(label), cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0] + # Calculate the required radius to fit the text with the margin + required_radius = int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin + # Draw the circle with the required radius + cv2.circle(self.im, (x_center, y_center), required_radius, color, -1) + # Calculate the position for the text + text_x = x_center - text_size[0] // 2 + text_y = y_center + text_size[1] // 2 + # Draw the text + cv2.putText( + self.im, + str(label), + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + self.sf - 0.15, + self.get_txt_color(color, txt_color), + self.tf, + lineType=cv2.LINE_AA, + ) + + def text_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), margin=5): + """ + Draws a label with a background rectangle centered within a given bounding box. + + Args: + box (tuple): The bounding box coordinates (x1, y1, x2, y2). + label (str): The text label to be displayed. + color (tuple, optional): The background color of the rectangle (B, G, R). + txt_color (tuple, optional): The color of the text (R, G, B). + margin (int, optional): The margin between the text and the rectangle border. + """ + # Calculate the center of the bounding box + x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + # Get the size of the text + text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.1, self.tf)[0] + # Calculate the top-left corner of the text (to center it) + text_x = x_center - text_size[0] // 2 + text_y = y_center + text_size[1] // 2 + # Calculate the coordinates of the background rectangle + rect_x1 = text_x - margin + rect_y1 = text_y - text_size[1] - margin + rect_x2 = text_x + text_size[0] + margin + rect_y2 = text_y + margin + # Draw the background rectangle + cv2.rectangle(self.im, (rect_x1, rect_y1), (rect_x2, rect_y2), color, -1) + # Draw the text on top of the rectangle + cv2.putText( + self.im, + label, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + self.sf - 0.1, + self.get_txt_color(color, txt_color), + self.tf, + lineType=cv2.LINE_AA, + ) + + def box_label(self, box, label="", color=(128, 128, 128), txt_color=(255, 255, 255), rotated=False): + """ + Draws a bounding box to image with label. + + Args: + box (tuple): The bounding box coordinates (x1, y1, x2, y2). + label (str): The text label to be displayed. + color (tuple, optional): The background color of the rectangle (B, G, R). + txt_color (tuple, optional): The color of the text (R, G, B). + rotated (bool, optional): Variable used to check if task is OBB + """ + txt_color = self.get_txt_color(color, txt_color) + if isinstance(box, torch.Tensor): + box = box.tolist() + if self.pil or not is_ascii(label): + if rotated: + p1 = box[0] + self.draw.polygon([tuple(b) for b in box], width=self.lw, outline=color) # PIL requires tuple box + else: + p1 = (box[0], box[1]) + self.draw.rectangle(box, width=self.lw, outline=color) # box + if label: + w, h = self.font.getsize(label) # text width, height + outside = p1[1] >= h # label fits outside box + if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image + p1 = self.im.size[0] - w, p1[1] + self.draw.rectangle( + (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), + fill=color, + ) + # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 + self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) + else: # cv2 + if rotated: + p1 = [int(b) for b in box[0]] + cv2.polylines(self.im, [np.asarray(box, dtype=int)], True, color, self.lw) # cv2 requires nparray box + else: + p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA) + if label: + w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + h += 3 # add pixels to pad text + outside = p1[1] >= h # label fits outside box + if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image + p1 = self.im.shape[1] - w, p1[1] + p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h + cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.putText( + self.im, + label, + (p1[0], p1[1] - 2 if outside else p1[1] + h - 1), + 0, + self.sf, + txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False): + """ + Plot masks on image. + + Args: + masks (tensor): Predicted masks on cuda, shape: [n, h, w] + colors (List[List[Int]]): Colors for predicted masks, [[r, g, b] * n] + im_gpu (tensor): Image is in cuda, shape: [3, h, w], range: [0, 1] + alpha (float): Mask transparency: 0.0 fully transparent, 1.0 opaque + retina_masks (bool): Whether to use high resolution masks or not. Defaults to False. + """ + if self.pil: + # Convert to numpy first + self.im = np.asarray(self.im).copy() + if len(masks) == 0: + self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 + if im_gpu.device != masks.device: + im_gpu = im_gpu.to(masks.device) + colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3) + colors = colors[:, None, None] # shape(n,1,1,3) + masks = masks.unsqueeze(3) # shape(n,h,w,1) + masks_color = masks * (colors * alpha) # shape(n,h,w,3) + + inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1) + mcs = masks_color.max(dim=0).values # shape(n,h,w,3) + + im_gpu = im_gpu.flip(dims=[0]) # flip channel + im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3) + im_gpu = im_gpu * inv_alpha_masks[-1] + mcs + im_mask = im_gpu * 255 + im_mask_np = im_mask.byte().cpu().numpy() + self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape) + if self.pil: + # Convert im back to PIL and update draw + self.fromarray(self.im) + + def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None): + """ + Plot keypoints on the image. + + Args: + kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence). + shape (tuple, optional): Image shape (h, w). Defaults to (640, 640). + radius (int, optional): Keypoint radius. Defaults to 5. + kpt_line (bool, optional): Draw lines between keypoints. Defaults to True. + conf_thres (float, optional): Confidence threshold. Defaults to 0.25. + kpt_color (tuple, optional): Keypoint color (B, G, R). Defaults to None. + + Note: + - `kpt_line=True` currently only supports human pose plotting. + - Modifies self.im in-place. + - If self.pil is True, converts image to numpy array and back to PIL. + """ + radius = radius if radius is not None else self.lw + if self.pil: + # Convert to numpy first + self.im = np.asarray(self.im).copy() + nkpt, ndim = kpts.shape + is_pose = nkpt == 17 and ndim in {2, 3} + kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting + for i, k in enumerate(kpts): + color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i)) + x_coord, y_coord = k[0], k[1] + if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: + if len(k) == 3: + conf = k[2] + if conf < conf_thres: + continue + cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA) + + if kpt_line: + ndim = kpts.shape[-1] + for i, sk in enumerate(self.skeleton): + pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1])) + pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1])) + if ndim == 3: + conf1 = kpts[(sk[0] - 1), 2] + conf2 = kpts[(sk[1] - 1), 2] + if conf1 < conf_thres or conf2 < conf_thres: + continue + if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0: + continue + if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0: + continue + cv2.line( + self.im, + pos1, + pos2, + kpt_color or self.limb_color[i].tolist(), + thickness=int(np.ceil(self.lw / 2)), + lineType=cv2.LINE_AA, + ) + if self.pil: + # Convert im back to PIL and update draw + self.fromarray(self.im) + + def rectangle(self, xy, fill=None, outline=None, width=1): + """Add rectangle to image (PIL-only).""" + self.draw.rectangle(xy, fill, outline, width) + + def text(self, xy, text, txt_color=(255, 255, 255), anchor="top", box_style=False): + """Adds text to an image using PIL or cv2.""" + if anchor == "bottom": # start y from font bottom + w, h = self.font.getsize(text) # text width, height + xy[1] += 1 - h + if self.pil: + if box_style: + w, h = self.font.getsize(text) + self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=txt_color) + # Using `txt_color` for background and draw fg with white color + txt_color = (255, 255, 255) + if "\n" in text: + lines = text.split("\n") + _, h = self.font.getsize(text) + for line in lines: + self.draw.text(xy, line, fill=txt_color, font=self.font) + xy[1] += h + else: + self.draw.text(xy, text, fill=txt_color, font=self.font) + else: + if box_style: + w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height + h += 3 # add pixels to pad text + outside = xy[1] >= h # label fits outside box + p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h + cv2.rectangle(self.im, xy, p2, txt_color, -1, cv2.LINE_AA) # filled + # Using `txt_color` for background and draw fg with white color + txt_color = (255, 255, 255) + cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA) + + def fromarray(self, im): + """Update self.im from a numpy array.""" + self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) + self.draw = ImageDraw.Draw(self.im) + + def result(self): + """Return annotated image as array.""" + return np.asarray(self.im) + + def show(self, title=None): + """Show the annotated image.""" + im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR + if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments + try: + display(im) # noqa - display() function only available in ipython environments + except ImportError as e: + LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}") + else: + im.show(title=title) + + def save(self, filename="image.jpg"): + """Save the annotated image to 'filename'.""" + cv2.imwrite(filename, np.asarray(self.im)) + + @staticmethod + def get_bbox_dimension(bbox=None): + """ + Calculate the area of a bounding box. + + Args: + bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max). + + Returns: + width (float): Width of the bounding box. + height (float): Height of the bounding box. + area (float): Area enclosed by the bounding box. + """ + x_min, y_min, x_max, y_max = bbox + width = x_max - x_min + height = y_max - y_min + return width, height, width * height + + def draw_region(self, reg_pts=None, color=(0, 255, 0), thickness=5): + """ + Draw region line. + + Args: + reg_pts (list): Region Points (for line 2 points, for region 4 points) + color (tuple): Region Color value + thickness (int): Region area thickness value + """ + cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness) + + # Draw small circles at the corner points + for point in reg_pts: + cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle + + def draw_centroid_and_tracks(self, track, color=(255, 0, 255), track_thickness=2): + """ + Draw centroid point and track trails. + + Args: + track (list): object tracking points for trails display + color (tuple): tracks line color + track_thickness (int): track line thickness value + """ + points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) + cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness) + cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1) + + def queue_counts_display(self, label, points=None, region_color=(255, 255, 255), txt_color=(0, 0, 0)): + """ + Displays queue counts on an image centered at the points with customizable font size and colors. + + Args: + label (str): Queue counts label. + points (tuple): Region points for center point calculation to display text. + region_color (tuple): RGB queue region color. + txt_color (tuple): RGB text display color. + """ + x_values = [point[0] for point in points] + y_values = [point[1] for point in points] + center_x = sum(x_values) // len(points) + center_y = sum(y_values) // len(points) + + text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] + text_width = text_size[0] + text_height = text_size[1] + + rect_width = text_width + 20 + rect_height = text_height + 20 + rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2) + rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2) + cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1) + + text_x = center_x - text_width // 2 + text_y = center_y + text_height // 2 + + # Draw text + cv2.putText( + self.im, + label, + (text_x, text_y), + 0, + fontScale=self.sf, + color=txt_color, + thickness=self.tf, + lineType=cv2.LINE_AA, + ) + + def display_objects_labels(self, im0, text, txt_color, bg_color, x_center, y_center, margin): + """ + Display the bounding boxes labels in parking management app. + + Args: + im0 (ndarray): Inference image. + text (str): Object/class name. + txt_color (tuple): Display color for text foreground. + bg_color (tuple): Display color for text background. + x_center (float): The x position center point for bounding box. + y_center (float): The y position center point for bounding box. + margin (int): The gap between text and rectangle for better display. + """ + text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] + text_x = x_center - text_size[0] // 2 + text_y = y_center + text_size[1] // 2 + + rect_x1 = text_x - margin + rect_y1 = text_y - text_size[1] - margin + rect_x2 = text_x + text_size[0] + margin + rect_y2 = text_y + margin + cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1) + cv2.putText(im0, text, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA) + + def display_analytics(self, im0, text, txt_color, bg_color, margin): + """ + Display the overall statistics for parking lots. + + Args: + im0 (ndarray): Inference image. + text (dict): Labels dictionary. + txt_color (tuple): Display color for text foreground. + bg_color (tuple): Display color for text background. + margin (int): Gap between text and rectangle for better display. + """ + horizontal_gap = int(im0.shape[1] * 0.02) + vertical_gap = int(im0.shape[0] * 0.01) + text_y_offset = 0 + for label, value in text.items(): + txt = f"{label}: {value}" + text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0] + if text_size[0] < 5 or text_size[1] < 5: + text_size = (5, 5) + text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap + text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap + rect_x1 = text_x - margin * 2 + rect_y1 = text_y - text_size[1] - margin * 2 + rect_x2 = text_x + text_size[0] + margin * 2 + rect_y2 = text_y + margin * 2 + cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1) + cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA) + text_y_offset = rect_y2 + + @staticmethod + def estimate_pose_angle(a, b, c): + """ + Calculate the pose angle for object. + + Args: + a (float) : The value of pose point a + b (float): The value of pose point b + c (float): The value o pose point c + + Returns: + angle (degree): Degree value of angle between three points + """ + a, b, c = np.array(a), np.array(b), np.array(c) + radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0]) + angle = np.abs(radians * 180.0 / np.pi) + if angle > 180.0: + angle = 360 - angle + return angle + + def draw_specific_points(self, keypoints, indices=None, radius=2, conf_thres=0.25): + """ + Draw specific keypoints for gym steps counting. + + Args: + keypoints (list): Keypoints data to be plotted. + indices (list, optional): Keypoint indices to be plotted. Defaults to [2, 5, 7]. + radius (int, optional): Keypoint radius. Defaults to 2. + conf_thres (float, optional): Confidence threshold for keypoints. Defaults to 0.25. + + Returns: + (numpy.ndarray): Image with drawn keypoints. + + Note: + Keypoint format: [x, y] or [x, y, confidence]. + Modifies self.im in-place. + """ + indices = indices or [2, 5, 7] + points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thres] + + # Draw lines between consecutive points + for start, end in zip(points[:-1], points[1:]): + cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA) + + # Draw circles for keypoints + for pt in points: + cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA) + + return self.im + + def plot_workout_information(self, display_text, position, color=(104, 31, 17), txt_color=(255, 255, 255)): + """ + Draw text with a background on the image. + + Args: + display_text (str): The text to be displayed. + position (tuple): Coordinates (x, y) on the image where the text will be placed. + color (tuple, optional): Text background color + txt_color (tuple, optional): Text foreground color + """ + (text_width, text_height), _ = cv2.getTextSize(display_text, 0, self.sf, self.tf) + + # Draw background rectangle + cv2.rectangle( + self.im, + (position[0], position[1] - text_height - 5), + (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf), + color, + -1, + ) + # Draw text + cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf) + + return text_height + + def plot_angle_and_count_and_stage( + self, angle_text, count_text, stage_text, center_kpt, color=(104, 31, 17), txt_color=(255, 255, 255) + ): + """ + Plot the pose angle, count value, and step stage. + + Args: + angle_text (str): Angle value for workout monitoring + count_text (str): Counts value for workout monitoring + stage_text (str): Stage decision for workout monitoring + center_kpt (list): Centroid pose index for workout monitoring + color (tuple, optional): Text background color + txt_color (tuple, optional): Text foreground color + """ + # Format text + angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}" + + # Draw angle, count and stage text + angle_height = self.plot_workout_information( + angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color + ) + count_height = self.plot_workout_information( + count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color + ) + self.plot_workout_information( + stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color + ) + + def seg_bbox(self, mask, mask_color=(255, 0, 255), label=None, txt_color=(255, 255, 255)): + """ + Function for drawing segmented object in bounding box shape. + + Args: + mask (np.ndarray): A 2D array of shape (N, 2) containing the contour points of the segmented object. + mask_color (tuple): RGB color for the contour and label background. + label (str, optional): Text label for the object. If None, no label is drawn. + txt_color (tuple): RGB color for the label text. + """ + if mask.size == 0: # no masks to plot + return + + cv2.polylines(self.im, [np.int32([mask])], isClosed=True, color=mask_color, thickness=2) + text_size, _ = cv2.getTextSize(label, 0, self.sf, self.tf) + + if label: + cv2.rectangle( + self.im, + (int(mask[0][0]) - text_size[0] // 2 - 10, int(mask[0][1]) - text_size[1] - 10), + (int(mask[0][0]) + text_size[0] // 2 + 10, int(mask[0][1] + 10)), + mask_color, + -1, + ) + cv2.putText( + self.im, label, (int(mask[0][0]) - text_size[0] // 2, int(mask[0][1])), 0, self.sf, txt_color, self.tf + ) + + def sweep_annotator(self, line_x=0, line_y=0, label=None, color=(221, 0, 186), txt_color=(255, 255, 255)): + """ + Function for drawing a sweep annotation line and an optional label. + + Args: + line_x (int): The x-coordinate of the sweep line. + line_y (int): The y-coordinate limit of the sweep line. + label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn. + color (tuple): RGB color for the line and label background. + txt_color (tuple): RGB color for the label text. + """ + # Draw the sweep line + cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2) + + # Draw label, if provided + if label: + (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf) + cv2.rectangle( + self.im, + (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10), + (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10), + color, + -1, + ) + cv2.putText( + self.im, + label, + (line_x - text_width // 2, line_y // 2 + text_height // 2), + cv2.FONT_HERSHEY_SIMPLEX, + self.sf, + txt_color, + self.tf, + ) + + def plot_distance_and_line( + self, pixels_distance, centroids, line_color=(104, 31, 17), centroid_color=(255, 0, 255) + ): + """ + Plot the distance and line on frame. + + Args: + pixels_distance (float): Pixels distance between two bbox centroids. + centroids (list): Bounding box centroids data. + line_color (tuple, optional): Distance line color. + centroid_color (tuple, optional): Bounding box centroid color. + """ + # Get the text size + text = f"Pixels Distance: {pixels_distance:.2f}" + (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf) + + # Define corners with 10-pixel margin and draw rectangle + cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1) + + # Calculate the position for the text with a 10-pixel margin and draw text + text_position = (25, 25 + text_height_m + 10) + cv2.putText( + self.im, + text, + text_position, + 0, + self.sf, + (255, 255, 255), + self.tf, + cv2.LINE_AA, + ) + + cv2.line(self.im, centroids[0], centroids[1], line_color, 3) + cv2.circle(self.im, centroids[0], 6, centroid_color, -1) + cv2.circle(self.im, centroids[1], 6, centroid_color, -1) + + def visioneye(self, box, center_point, color=(235, 219, 11), pin_color=(255, 0, 255)): + """ + Function for pinpoint human-vision eye mapping and plotting. + + Args: + box (list): Bounding box coordinates + center_point (tuple): center point for vision eye view + color (tuple): object centroid and line color value + pin_color (tuple): visioneye point color value + """ + center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) + cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1) + cv2.circle(self.im, center_bbox, self.tf * 2, color, -1) + cv2.line(self.im, center_point, center_bbox, color, self.tf) + + +@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 +@plt_settings() +def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None): + """Plot training labels including class histograms and box statistics.""" + import pandas # scope for faster 'import ultralytics' + import seaborn # scope for faster 'import ultralytics' + + # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings + warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight") + warnings.filterwarnings("ignore", category=FutureWarning) + + # Plot dataset labels + LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") + nc = int(cls.max() + 1) # number of classes + boxes = boxes[:1000000] # limit to 1M boxes + x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"]) + + # Seaborn correlogram + seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200) + plt.close() + + # Matplotlib labels + ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() + y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) + for i in range(nc): + y[2].patches[i].set_color([x / 255 for x in colors(i)]) + ax[0].set_ylabel("instances") + if 0 < len(names) < 30: + ax[0].set_xticks(range(len(names))) + ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10) + else: + ax[0].set_xlabel("classes") + seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9) + seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9) + + # Rectangles + boxes[:, 0:2] = 0.5 # center + boxes = ops.xywh2xyxy(boxes) * 1000 + img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255) + for cls, box in zip(cls[:500], boxes[:500]): + ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot + ax[1].imshow(img) + ax[1].axis("off") + + for a in [0, 1, 2, 3]: + for s in ["top", "right", "left", "bottom"]: + ax[a].spines[s].set_visible(False) + + fname = save_dir / "labels.jpg" + plt.savefig(fname, dpi=200) + plt.close() + if on_plot: + on_plot(fname) + + +def save_one_box(xyxy, im, file=Path("im.jpg"), gain=1.02, pad=10, square=False, BGR=False, save=True): + """ + Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop. + + This function takes a bounding box and an image, and then saves a cropped portion of the image according + to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding + adjustments to the bounding box. + + Args: + xyxy (torch.Tensor or list): A tensor or list representing the bounding box in xyxy format. + im (numpy.ndarray): The input image. + file (Path, optional): The path where the cropped image will be saved. Defaults to 'im.jpg'. + gain (float, optional): A multiplicative factor to increase the size of the bounding box. Defaults to 1.02. + pad (int, optional): The number of pixels to add to the width and height of the bounding box. Defaults to 10. + square (bool, optional): If True, the bounding box will be transformed into a square. Defaults to False. + BGR (bool, optional): If True, the image will be saved in BGR format, otherwise in RGB. Defaults to False. + save (bool, optional): If True, the cropped image will be saved to disk. Defaults to True. + + Returns: + (numpy.ndarray): The cropped image. + + Example: + ```python + from ultralytics.utils.plotting import save_one_box + + xyxy = [50, 50, 150, 150] + im = cv2.imread("image.jpg") + cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True) + ``` + """ + if not isinstance(xyxy, torch.Tensor): # may be list + xyxy = torch.stack(xyxy) + b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes + if square: + b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square + b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad + xyxy = ops.xywh2xyxy(b).long() + xyxy = ops.clip_boxes(xyxy, im.shape) + crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR else -1)] + if save: + file.parent.mkdir(parents=True, exist_ok=True) # make directory + f = str(increment_path(file).with_suffix(".jpg")) + # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue + Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB + return crop + + +@threaded +def plot_images( + images: Union[torch.Tensor, np.ndarray], + batch_idx: Union[torch.Tensor, np.ndarray], + cls: Union[torch.Tensor, np.ndarray], + bboxes: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.float32), + confs: Optional[Union[torch.Tensor, np.ndarray]] = None, + masks: Union[torch.Tensor, np.ndarray] = np.zeros(0, dtype=np.uint8), + kpts: Union[torch.Tensor, np.ndarray] = np.zeros((0, 51), dtype=np.float32), + paths: Optional[List[str]] = None, + fname: str = "images.jpg", + names: Optional[Dict[int, str]] = None, + on_plot: Optional[Callable] = None, + max_size: int = 1920, + max_subplots: int = 16, + save: bool = True, + conf_thres: float = 0.25, +) -> Optional[np.ndarray]: + """ + Plot image grid with labels, bounding boxes, masks, and keypoints. + + Args: + images: Batch of images to plot. Shape: (batch_size, channels, height, width). + batch_idx: Batch indices for each detection. Shape: (num_detections,). + cls: Class labels for each detection. Shape: (num_detections,). + bboxes: Bounding boxes for each detection. Shape: (num_detections, 4) or (num_detections, 5) for rotated boxes. + confs: Confidence scores for each detection. Shape: (num_detections,). + masks: Instance segmentation masks. Shape: (num_detections, height, width) or (1, height, width). + kpts: Keypoints for each detection. Shape: (num_detections, 51). + paths: List of file paths for each image in the batch. + fname: Output filename for the plotted image grid. + names: Dictionary mapping class indices to class names. + on_plot: Optional callback function to be called after saving the plot. + max_size: Maximum size of the output image grid. + max_subplots: Maximum number of subplots in the image grid. + save: Whether to save the plotted image grid to a file. + conf_thres: Confidence threshold for displaying detections. + + Returns: + np.ndarray: Plotted image grid as a numpy array if save is False, None otherwise. + + Note: + This function supports both tensor and numpy array inputs. It will automatically + convert tensor inputs to numpy arrays for processing. + """ + if isinstance(images, torch.Tensor): + images = images.cpu().float().numpy() + if isinstance(cls, torch.Tensor): + cls = cls.cpu().numpy() + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.cpu().numpy() + if isinstance(masks, torch.Tensor): + masks = masks.cpu().numpy().astype(int) + if isinstance(kpts, torch.Tensor): + kpts = kpts.cpu().numpy() + if isinstance(batch_idx, torch.Tensor): + batch_idx = batch_idx.cpu().numpy() + + bs, _, h, w = images.shape # batch size, _, height, width + bs = min(bs, max_subplots) # limit plot images + ns = np.ceil(bs**0.5) # number of subplots (square) + if np.max(images[0]) <= 1: + images *= 255 # de-normalise (optional) + + # Build Image + mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init + for i in range(bs): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0) + + # Resize (optional) + scale = max_size / ns / max(h, w) + if scale < 1: + h = math.ceil(scale * h) + w = math.ceil(scale * w) + mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) + + # Annotate + fs = int((h + w) * ns * 0.01) # font size + annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names) + for i in range(bs): + x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin + annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders + if paths: + annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames + if len(cls) > 0: + idx = batch_idx == i + classes = cls[idx].astype("int") + labels = confs is None + + if len(bboxes): + boxes = bboxes[idx] + conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred) + if len(boxes): + if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1 + boxes[..., [0, 2]] *= w # scale to pixels + boxes[..., [1, 3]] *= h + elif scale < 1: # absolute coords need scale if image scales + boxes[..., :4] *= scale + boxes[..., 0] += x + boxes[..., 1] += y + is_obb = boxes.shape[-1] == 5 # xywhr + boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes) + for j, box in enumerate(boxes.astype(np.int64).tolist()): + c = classes[j] + color = colors(c) + c = names.get(c, c) if names else c + if labels or conf[j] > conf_thres: + label = f"{c}" if labels else f"{c} {conf[j]:.1f}" + annotator.box_label(box, label, color=color, rotated=is_obb) + + elif len(classes): + for c in classes: + color = colors(c) + c = names.get(c, c) if names else c + annotator.text((x, y), f"{c}", txt_color=color, box_style=True) + + # Plot keypoints + if len(kpts): + kpts_ = kpts[idx].copy() + if len(kpts_): + if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01 + kpts_[..., 0] *= w # scale to pixels + kpts_[..., 1] *= h + elif scale < 1: # absolute coords need scale if image scales + kpts_ *= scale + kpts_[..., 0] += x + kpts_[..., 1] += y + for j in range(len(kpts_)): + if labels or conf[j] > conf_thres: + annotator.kpts(kpts_[j], conf_thres=conf_thres) + + # Plot masks + if len(masks): + if idx.shape[0] == masks.shape[0]: # overlap_masks=False + image_masks = masks[idx] + else: # overlap_masks=True + image_masks = masks[[i]] # (1, 640, 640) + nl = idx.sum() + index = np.arange(nl).reshape((nl, 1, 1)) + 1 + image_masks = np.repeat(image_masks, nl, axis=0) + image_masks = np.where(image_masks == index, 1.0, 0.0) + + im = np.asarray(annotator.im).copy() + for j in range(len(image_masks)): + if labels or conf[j] > conf_thres: + color = colors(classes[j]) + mh, mw = image_masks[j].shape + if mh != h or mw != w: + mask = image_masks[j].astype(np.uint8) + mask = cv2.resize(mask, (w, h)) + mask = mask.astype(bool) + else: + mask = image_masks[j].astype(bool) + try: + im[y : y + h, x : x + w, :][mask] = ( + im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6 + ) + except Exception: + pass + annotator.fromarray(im) + if not save: + return np.asarray(annotator.im) + annotator.im.save(fname) # save + if on_plot: + on_plot(fname) + + +@plt_settings() +def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False, classify=False, on_plot=None): + """ + Plot training results from a results CSV file. The function supports various types of data including segmentation, + pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located. + + Args: + file (str, optional): Path to the CSV file containing the training results. Defaults to 'path/to/results.csv'. + dir (str, optional): Directory where the CSV file is located if 'file' is not provided. Defaults to ''. + segment (bool, optional): Flag to indicate if the data is for segmentation. Defaults to False. + pose (bool, optional): Flag to indicate if the data is for pose estimation. Defaults to False. + classify (bool, optional): Flag to indicate if the data is for classification. Defaults to False. + on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument. + Defaults to None. + + Example: + ```python + from ultralytics.utils.plotting import plot_results + + plot_results("path/to/results.csv", segment=True) + ``` + """ + import pandas as pd # scope for faster 'import ultralytics' + from scipy.ndimage import gaussian_filter1d + + save_dir = Path(file).parent if file else Path(dir) + if classify: + fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True) + index = [2, 5, 3, 4] + elif segment: + fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True) + index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13] + elif pose: + fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True) + index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14] + else: + fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) + index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8] + ax = ax.ravel() + files = list(save_dir.glob("results*.csv")) + assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." + for f in files: + try: + data = pd.read_csv(f) + s = [x.strip() for x in data.columns] + x = data.values[:, 0] + for i, j in enumerate(index): + y = data.values[:, j].astype("float") + # y[y == 0] = np.nan # don't show zero values + ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results + ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line + ax[i].set_title(s[j], fontsize=12) + # if j in {8, 9, 10}: # share train and val loss y axes + # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) + except Exception as e: + LOGGER.warning(f"WARNING: Plotting error for {f}: {e}") + ax[1].legend() + fname = save_dir / "results.png" + fig.savefig(fname, dpi=200) + plt.close() + if on_plot: + on_plot(fname) + + +def plt_color_scatter(v, f, bins=20, cmap="viridis", alpha=0.8, edgecolors="none"): + """ + Plots a scatter plot with points colored based on a 2D histogram. + + Args: + v (array-like): Values for the x-axis. + f (array-like): Values for the y-axis. + bins (int, optional): Number of bins for the histogram. Defaults to 20. + cmap (str, optional): Colormap for the scatter plot. Defaults to 'viridis'. + alpha (float, optional): Alpha for the scatter plot. Defaults to 0.8. + edgecolors (str, optional): Edge colors for the scatter plot. Defaults to 'none'. + + Examples: + >>> v = np.random.rand(100) + >>> f = np.random.rand(100) + >>> plt_color_scatter(v, f) + """ + # Calculate 2D histogram and corresponding colors + hist, xedges, yedges = np.histogram2d(v, f, bins=bins) + colors = [ + hist[ + min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1), + min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1), + ] + for i in range(len(v)) + ] + + # Scatter plot + plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors) + + +def plot_tune_results(csv_file="tune_results.csv"): + """ + Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key + in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots. + + Args: + csv_file (str, optional): Path to the CSV file containing the tuning results. Defaults to 'tune_results.csv'. + + Examples: + >>> plot_tune_results("path/to/tune_results.csv") + """ + import pandas as pd # scope for faster 'import ultralytics' + from scipy.ndimage import gaussian_filter1d + + def _save_one_file(file): + """Save one matplotlib plot to 'file'.""" + plt.savefig(file, dpi=200) + plt.close() + LOGGER.info(f"Saved {file}") + + # Scatter plots for each hyperparameter + csv_file = Path(csv_file) + data = pd.read_csv(csv_file) + num_metrics_columns = 1 + keys = [x.strip() for x in data.columns][num_metrics_columns:] + x = data.values + fitness = x[:, 0] # fitness + j = np.argmax(fitness) # max fitness index + n = math.ceil(len(keys) ** 0.5) # columns and rows in plot + plt.figure(figsize=(10, 10), tight_layout=True) + for i, k in enumerate(keys): + v = x[:, i + num_metrics_columns] + mu = v[j] # best single result + plt.subplot(n, n, i + 1) + plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none") + plt.plot(mu, fitness.max(), "k+", markersize=15) + plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters + plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8 + if i % n != 0: + plt.yticks([]) + _save_one_file(csv_file.with_name("tune_scatter_plots.png")) + + # Fitness vs iteration + x = range(1, len(fitness) + 1) + plt.figure(figsize=(10, 6), tight_layout=True) + plt.plot(x, fitness, marker="o", linestyle="none", label="fitness") + plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line + plt.title("Fitness vs Iteration") + plt.xlabel("Iteration") + plt.ylabel("Fitness") + plt.grid(True) + plt.legend() + _save_one_file(csv_file.with_name("tune_fitness.png")) + + +def output_to_target(output, max_det=300): + """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting.""" + targets = [] + for i, o in enumerate(output): + box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1) + j = torch.full((conf.shape[0], 1), i) + targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1)) + targets = torch.cat(targets, 0).numpy() + return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1] + + +def output_to_rotated_target(output, max_det=300): + """Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting.""" + targets = [] + for i, o in enumerate(output): + box, conf, cls, angle = o[:max_det].cpu().split((4, 1, 1, 1), 1) + j = torch.full((conf.shape[0], 1), i) + targets.append(torch.cat((j, cls, box, angle, conf), 1)) + targets = torch.cat(targets, 0).numpy() + return targets[:, 0], targets[:, 1], targets[:, 2:-1], targets[:, -1] + + +def feature_visualization(x, module_type, stage, n=32, save_dir=Path("runs/detect/exp")): + """ + Visualize feature maps of a given model module during inference. + + Args: + x (torch.Tensor): Features to be visualized. + module_type (str): Module type. + stage (int): Module stage within the model. + n (int, optional): Maximum number of feature maps to plot. Defaults to 32. + save_dir (Path, optional): Directory to save results. Defaults to Path('runs/detect/exp'). + """ + for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads + if m in module_type: + return + if isinstance(x, torch.Tensor): + _, channels, height, width = x.shape # batch, channels, height, width + if height > 1 and width > 1: + f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename + + blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels + n = min(n, channels) # number of plots + _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols + ax = ax.ravel() + plt.subplots_adjust(wspace=0.05, hspace=0.05) + for i in range(n): + ax[i].imshow(blocks[i].squeeze()) # cmap='gray' + ax[i].axis("off") + + LOGGER.info(f"Saving {f}... ({n}/{channels})") + plt.savefig(f, dpi=300, bbox_inches="tight") + plt.close() + np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a40f5e241e5870f27dca75f7bbc3c48a86f7c0 --- /dev/null +++ b/ultralytics/utils/tal.py @@ -0,0 +1,385 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import torch +import torch.nn as nn + +from . import LOGGER +from .checks import check_version +from .metrics import bbox_iou, probiou +from .ops import xywhr2xyxyxyxy + +TORCH_1_10 = check_version(torch.__version__, "1.10.0") + + +class TaskAlignedAssigner(nn.Module): + """ + A task-aligned assigner for object detection. + + This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both + classification and localization information. + + Attributes: + topk (int): The number of top candidates to consider. + num_classes (int): The number of object classes. + alpha (float): The alpha parameter for the classification component of the task-aligned metric. + beta (float): The beta parameter for the localization component of the task-aligned metric. + eps (float): A small value to prevent division by zero. + """ + + def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): + """Initialize a TaskAlignedAssigner object with customizable hyperparameters.""" + super().__init__() + self.topk = topk + self.num_classes = num_classes + self.bg_idx = num_classes + self.alpha = alpha + self.beta = beta + self.eps = eps + + @torch.no_grad() + def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + """ + Compute the task-aligned assignment. Reference code is available at + https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py. + + Args: + pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) + pd_bboxes (Tensor): shape(bs, num_total_anchors, 4) + anc_points (Tensor): shape(num_total_anchors, 2) + gt_labels (Tensor): shape(bs, n_max_boxes, 1) + gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) + mask_gt (Tensor): shape(bs, n_max_boxes, 1) + + Returns: + target_labels (Tensor): shape(bs, num_total_anchors) + target_bboxes (Tensor): shape(bs, num_total_anchors, 4) + target_scores (Tensor): shape(bs, num_total_anchors, num_classes) + fg_mask (Tensor): shape(bs, num_total_anchors) + target_gt_idx (Tensor): shape(bs, num_total_anchors) + """ + self.bs = pd_scores.shape[0] + self.n_max_boxes = gt_bboxes.shape[1] + device = gt_bboxes.device + + if self.n_max_boxes == 0: + return ( + torch.full_like(pd_scores[..., 0], self.bg_idx), + torch.zeros_like(pd_bboxes), + torch.zeros_like(pd_scores), + torch.zeros_like(pd_scores[..., 0]), + torch.zeros_like(pd_scores[..., 0]), + ) + + try: + return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt) + except torch.OutOfMemoryError: + # Move tensors to CPU, compute, then move back to original device + LOGGER.warning("WARNING: CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU") + cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)] + result = self._forward(*cpu_tensors) + return tuple(t.to(device) for t in result) + + def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + """ + Compute the task-aligned assignment. Reference code is available at + https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py. + + Args: + pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) + pd_bboxes (Tensor): shape(bs, num_total_anchors, 4) + anc_points (Tensor): shape(num_total_anchors, 2) + gt_labels (Tensor): shape(bs, n_max_boxes, 1) + gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) + mask_gt (Tensor): shape(bs, n_max_boxes, 1) + + Returns: + target_labels (Tensor): shape(bs, num_total_anchors) + target_bboxes (Tensor): shape(bs, num_total_anchors, 4) + target_scores (Tensor): shape(bs, num_total_anchors, num_classes) + fg_mask (Tensor): shape(bs, num_total_anchors) + target_gt_idx (Tensor): shape(bs, num_total_anchors) + """ + mask_pos, align_metric, overlaps = self.get_pos_mask( + pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt + ) + + target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) + + # Assigned target + target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) + + # Normalize + align_metric *= mask_pos + pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj + pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj + norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) + target_scores = target_scores * norm_align_metric + + return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx + + def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): + """Get in_gts mask, (b, max_num_obj, h*w).""" + mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes) + # Get anchor_align metric, (b, max_num_obj, h*w) + align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt) + # Get topk_metric mask, (b, max_num_obj, h*w) + mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool()) + # Merge all mask to a final mask, (b, max_num_obj, h*w) + mask_pos = mask_topk * mask_in_gts * mask_gt + + return mask_pos, align_metric, overlaps + + def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt): + """Compute alignment metric given predicted and ground truth bounding boxes.""" + na = pd_bboxes.shape[-2] + mask_gt = mask_gt.bool() # b, max_num_obj, h*w + overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device) + bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device) + + ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj + ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj + ind[1] = gt_labels.squeeze(-1) # b, max_num_obj + # Get the scores of each grid for each gt cls + bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w + + # (b, max_num_obj, 1, 4), (b, 1, h*w, 4) + pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt] + gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt] + overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes) + + align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) + return align_metric, overlaps + + def iou_calculation(self, gt_bboxes, pd_bboxes): + """IoU calculation for horizontal bounding boxes.""" + return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0) + + def select_topk_candidates(self, metrics, largest=True, topk_mask=None): + """ + Select the top-k candidates based on the given metrics. + + Args: + metrics (Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, + max_num_obj is the maximum number of objects, and h*w represents the + total number of anchor points. + largest (bool): If True, select the largest values; otherwise, select the smallest values. + topk_mask (Tensor): An optional boolean tensor of shape (b, max_num_obj, topk), where + topk is the number of top candidates to consider. If not provided, + the top-k values are automatically computed based on the given metrics. + + Returns: + (Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates. + """ + # (b, max_num_obj, topk) + topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest) + if topk_mask is None: + topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs) + # (b, max_num_obj, topk) + topk_idxs.masked_fill_(~topk_mask, 0) + + # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w) + count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device) + ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device) + for k in range(self.topk): + # Expand topk_idxs for each value of k and add 1 at the specified positions + count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) + # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device)) + # Filter invalid bboxes + count_tensor.masked_fill_(count_tensor > 1, 0) + + return count_tensor.to(metrics.dtype) + + def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): + """ + Compute target labels, target bounding boxes, and target scores for the positive anchor points. + + Args: + gt_labels (Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the + batch size and max_num_obj is the maximum number of objects. + gt_bboxes (Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4). + target_gt_idx (Tensor): Indices of the assigned ground truth objects for positive + anchor points, with shape (b, h*w), where h*w is the total + number of anchor points. + fg_mask (Tensor): A boolean tensor of shape (b, h*w) indicating the positive + (foreground) anchor points. + + Returns: + (Tuple[Tensor, Tensor, Tensor]): A tuple containing the following tensors: + - target_labels (Tensor): Shape (b, h*w), containing the target labels for + positive anchor points. + - target_bboxes (Tensor): Shape (b, h*w, 4), containing the target bounding boxes + for positive anchor points. + - target_scores (Tensor): Shape (b, h*w, num_classes), containing the target scores + for positive anchor points, where num_classes is the number + of object classes. + """ + # Assigned target labels, (b, 1) + batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None] + target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w) + target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w) + + # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4) + target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx] + + # Assigned target scores + target_labels.clamp_(0) + + # 10x faster than F.one_hot() + target_scores = torch.zeros( + (target_labels.shape[0], target_labels.shape[1], self.num_classes), + dtype=torch.int64, + device=target_labels.device, + ) # (b, h*w, 80) + target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) + + fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) + target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) + + return target_labels, target_bboxes, target_scores + + @staticmethod + def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): + """ + Select positive anchor centers within ground truth bounding boxes. + + Args: + xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2). + gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4). + eps (float, optional): Small value for numerical stability. Defaults to 1e-9. + + Returns: + (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w). + + Note: + b: batch size, n_boxes: number of ground truth boxes, h: height, w: width. + Bounding box format: [x_min, y_min, x_max, y_max]. + """ + n_anchors = xy_centers.shape[0] + bs, n_boxes, _ = gt_bboxes.shape + lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom + bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) + # return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) + return bbox_deltas.amin(3).gt_(eps) + + @staticmethod + def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): + """ + Select anchor boxes with highest IoU when assigned to multiple ground truths. + + Args: + mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w). + overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w). + n_max_boxes (int): Maximum number of ground truth boxes. + + Returns: + target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w). + fg_mask (torch.Tensor): Foreground mask, shape (b, h*w). + mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w). + + Note: + b: batch size, h: height, w: width. + """ + # Convert (b, n_max_boxes, h*w) -> (b, h*w) + fg_mask = mask_pos.sum(-2) + if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes + mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w) + max_overlaps_idx = overlaps.argmax(1) # (b, h*w) + + is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) + is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) + + mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w) + fg_mask = mask_pos.sum(-2) + # Find each grid serve which gt(index) + target_gt_idx = mask_pos.argmax(-2) # (b, h*w) + return target_gt_idx, fg_mask, mask_pos + + +class RotatedTaskAlignedAssigner(TaskAlignedAssigner): + """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.""" + + def iou_calculation(self, gt_bboxes, pd_bboxes): + """IoU calculation for rotated bounding boxes.""" + return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0) + + @staticmethod + def select_candidates_in_gts(xy_centers, gt_bboxes): + """ + Select the positive anchor center in gt for rotated bounding boxes. + + Args: + xy_centers (Tensor): shape(h*w, 2) + gt_bboxes (Tensor): shape(b, n_boxes, 5) + + Returns: + (Tensor): shape(b, n_boxes, h*w) + """ + # (b, n_boxes, 5) --> (b, n_boxes, 4, 2) + corners = xywhr2xyxyxyxy(gt_bboxes) + # (b, n_boxes, 1, 2) + a, b, _, d = corners.split(1, dim=-2) + ab = b - a + ad = d - a + + # (b, n_boxes, h*w, 2) + ap = xy_centers - a + norm_ab = (ab * ab).sum(dim=-1) + norm_ad = (ad * ad).sum(dim=-1) + ap_dot_ab = (ap * ab).sum(dim=-1) + ap_dot_ad = (ap * ad).sum(dim=-1) + return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box + + +def make_anchors(feats, strides, grid_cell_offset=0.5): + """Generate anchors from features.""" + anchor_points, stride_tensor = [], [] + assert feats is not None + dtype, device = feats[0].dtype, feats[0].device + for i, stride in enumerate(strides): + h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1])) + sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x + sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y + sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) + anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + return torch.cat(anchor_points), torch.cat(stride_tensor) + + +def dist2bbox(distance, anchor_points, xywh=True, dim=-1): + """Transform distance(ltrb) to box(xywh or xyxy).""" + lt, rb = distance.chunk(2, dim) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + if xywh: + c_xy = (x1y1 + x2y2) / 2 + wh = x2y2 - x1y1 + return torch.cat((c_xy, wh), dim) # xywh bbox + return torch.cat((x1y1, x2y2), dim) # xyxy bbox + + +def bbox2dist(anchor_points, bbox, reg_max): + """Transform bbox(xyxy) to dist(ltrb).""" + x1y1, x2y2 = bbox.chunk(2, -1) + return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb) + + +def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1): + """ + Decode predicted rotated bounding box coordinates from anchor points and distribution. + + Args: + pred_dist (torch.Tensor): Predicted rotated distance, shape (bs, h*w, 4). + pred_angle (torch.Tensor): Predicted angle, shape (bs, h*w, 1). + anchor_points (torch.Tensor): Anchor points, shape (h*w, 2). + dim (int, optional): Dimension along which to split. Defaults to -1. + + Returns: + (torch.Tensor): Predicted rotated bounding boxes, shape (bs, h*w, 4). + """ + lt, rb = pred_dist.split(2, dim=dim) + cos, sin = torch.cos(pred_angle), torch.sin(pred_angle) + # (bs, h*w, 1) + xf, yf = ((rb - lt) / 2).split(1, dim=dim) + x, y = xf * cos - yf * sin, xf * sin + yf * cos + xy = torch.cat([x, y], dim=dim) + anchor_points + return torch.cat([xy, lt + rb], dim=dim) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cd9de8d9bf5413d359a4bcae1cc7eb76aca3d2 --- /dev/null +++ b/ultralytics/utils/torch_utils.py @@ -0,0 +1,801 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +import gc +import math +import os +import random +import time +from contextlib import contextmanager +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Union + +import numpy as np +import thop +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from ultralytics.utils import ( + DEFAULT_CFG_DICT, + DEFAULT_CFG_KEYS, + LOGGER, + NUM_THREADS, + PYTHON_VERSION, + TORCHVISION_VERSION, + WINDOWS, + __version__, + colorstr, +) +from ultralytics.utils.checks import check_version + +# Version checks (all default to version>=min_version) +TORCH_1_9 = check_version(torch.__version__, "1.9.0") +TORCH_1_13 = check_version(torch.__version__, "1.13.0") +TORCH_2_0 = check_version(torch.__version__, "2.0.0") +TORCH_2_4 = check_version(torch.__version__, "2.4.0") +TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0") +TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0") +TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0") +TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0") +if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows + LOGGER.warning( + "WARNING ⚠️ Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve " + "https://github.com/ultralytics/ultralytics/issues/15049" + ) + + +@contextmanager +def torch_distributed_zero_first(local_rank: int): + """Ensures all processes in distributed training wait for the local master (rank 0) to complete a task first.""" + initialized = dist.is_available() and dist.is_initialized() + + if initialized and local_rank not in {-1, 0}: + dist.barrier(device_ids=[local_rank]) + yield + if initialized and local_rank == 0: + dist.barrier(device_ids=[local_rank]) + + +def smart_inference_mode(): + """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.""" + + def decorate(fn): + """Applies appropriate torch decorator for inference mode based on torch version.""" + if TORCH_1_9 and torch.is_inference_mode_enabled(): + return fn # already in inference_mode, act as a pass-through + else: + return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) + + return decorate + + +def autocast(enabled: bool, device: str = "cuda"): + """ + Get the appropriate autocast context manager based on PyTorch version and AMP setting. + + This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both + older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions. + + Args: + enabled (bool): Whether to enable automatic mixed precision. + device (str, optional): The device to use for autocast. Defaults to 'cuda'. + + Returns: + (torch.amp.autocast): The appropriate autocast context manager. + + Note: + - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`. + - For older versions, it uses `torch.cuda.autocast`. + + Example: + ```python + with autocast(amp=True): + # Your mixed precision operations here + pass + ``` + """ + if TORCH_1_13: + return torch.amp.autocast(device, enabled=enabled) + else: + return torch.cuda.amp.autocast(enabled) + + +def get_cpu_info(): + """Return a string with system CPU information, i.e. 'Apple M2'.""" + from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error + + if "cpu_info" not in PERSISTENT_CACHE: + try: + import cpuinfo # pip install py-cpuinfo + + k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference + info = cpuinfo.get_cpu_info() # info dict + string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown") + PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "") + except Exception: + pass + return PERSISTENT_CACHE.get("cpu_info", "unknown") + + +def get_gpu_info(index): + """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.""" + properties = torch.cuda.get_device_properties(index) + return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB" + + +def select_device(device="", batch=0, newline=False, verbose=True): + """ + Selects the appropriate PyTorch device based on the provided arguments. + + The function takes a string specifying the device or a torch.device object and returns a torch.device object + representing the selected device. The function also validates the number of available devices and raises an + exception if the requested device(s) are not available. + + Args: + device (str | torch.device, optional): Device string or torch.device object. + Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects + the first available GPU, or CPU if no GPU is available. + batch (int, optional): Batch size being used in your model. Defaults to 0. + newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False. + verbose (bool, optional): If True, logs the device information. Defaults to True. + + Returns: + (torch.device): Selected device. + + Raises: + ValueError: If the specified device is not available or if the batch size is not a multiple of the number of + devices when using multiple GPUs. + + Examples: + >>> select_device("cuda:0") + device(type='cuda', index=0) + + >>> select_device("cpu") + device(type='cpu') + + Note: + Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use. + """ + if isinstance(device, torch.device) or str(device).startswith("tpu"): + return device + + s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} " + device = str(device).lower() + for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ": + device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' + cpu = device == "cpu" + mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS) + if cpu or mps: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False + elif device: # non-cpu device requested + if device == "cuda": + device = "0" + if "," in device: + device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1" + visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) + os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() + if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))): + LOGGER.info(s) + install = ( + "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " + "CUDA devices are seen by torch.\n" + if torch.cuda.device_count() == 0 + else "" + ) + raise ValueError( + f"Invalid CUDA 'device={device}' requested." + f" Use 'device=cpu' or pass valid CUDA device(s) if available," + f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" + f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" + f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" + f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" + f"{install}" + ) + + if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available + devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"] + n = len(devices) # device count + if n > 1: # multi-GPU + if batch < 1: + raise ValueError( + "AutoBatch with batch<1 not supported for Multi-GPU training, " + "please specify a valid batch size, i.e. batch=16." + ) + if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count + raise ValueError( + f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or " + f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}." + ) + space = " " * (len(s) + 1) + for i, d in enumerate(devices): + s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB + arg = "cuda:0" + elif mps and TORCH_2_0 and torch.backends.mps.is_available(): + # Prefer MPS if available + s += f"MPS ({get_cpu_info()})\n" + arg = "mps" + else: # revert to CPU + s += f"CPU ({get_cpu_info()})\n" + arg = "cpu" + + if arg in {"cpu", "mps"}: + torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training + if verbose: + LOGGER.info(s if newline else s.rstrip()) + return torch.device(arg) + + +def time_sync(): + """PyTorch-accurate time.""" + if torch.cuda.is_available(): + torch.cuda.synchronize() + return time.time() + + +def fuse_conv_and_bn(conv, bn): + """Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/.""" + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # Prepare filters + w_conv = conv.weight.view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) + + # Prepare spatial bias + b_conv = torch.zeros(conv.weight.shape[0], device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def fuse_deconv_and_bn(deconv, bn): + """Fuse ConvTranspose2d() and BatchNorm2d() layers.""" + fuseddconv = ( + nn.ConvTranspose2d( + deconv.in_channels, + deconv.out_channels, + kernel_size=deconv.kernel_size, + stride=deconv.stride, + padding=deconv.padding, + output_padding=deconv.output_padding, + dilation=deconv.dilation, + groups=deconv.groups, + bias=True, + ) + .requires_grad_(False) + .to(deconv.weight.device) + ) + + # Prepare filters + w_deconv = deconv.weight.view(deconv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape)) + + # Prepare spatial bias + b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fuseddconv + + +def model_info(model, detailed=False, verbose=True, imgsz=640): + """Print and return detailed model information layer by layer.""" + if not verbose: + return + n_p = get_num_params(model) # number of parameters + n_g = get_num_gradients(model) # number of gradients + n_l = len(list(model.modules())) # number of layers + if detailed: + LOGGER.info(f"{'layer':>5}{'name':>40}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}") + for i, (name, p) in enumerate(model.named_parameters()): + name = name.replace("module_list.", "") + LOGGER.info( + f"{i:>5g}{name:>40s}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20s}" + f"{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype):>15s}" + ) + + flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] + fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else "" + fs = f", {flops:.1f} GFLOPs" if flops else "" + yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "") + model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model" + LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}") + return n_l, n_p, n_g, flops + + +def get_num_params(model): + """Return the total number of parameters in a YOLO model.""" + return sum(x.numel() for x in model.parameters()) + + +def get_num_gradients(model): + """Return the total number of parameters with gradients in a YOLO model.""" + return sum(x.numel() for x in model.parameters() if x.requires_grad) + + +def model_info_for_loggers(trainer): + """ + Return model info dict with useful model information. + + Example: + YOLOv8n info for loggers + ```python + results = { + "model/parameters": 3151904, + "model/GFLOPs": 8.746, + "model/speed_ONNX(ms)": 41.244, + "model/speed_TensorRT(ms)": 3.211, + "model/speed_PyTorch(ms)": 18.755, + } + ``` + """ + if trainer.args.profile: # profile ONNX and TensorRT times + from ultralytics.utils.benchmarks import ProfileModels + + results = ProfileModels([trainer.last], device=trainer.device).profile()[0] + results.pop("model/name") + else: # only return PyTorch times from most recent validation + results = { + "model/parameters": get_num_params(trainer.model), + "model/GFLOPs": round(get_flops(trainer.model), 3), + } + results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3) + return results + + +def get_flops(model, imgsz=640): + """Return a YOLO model's FLOPs.""" + try: + model = de_parallel(model) + p = next(model.parameters()) + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] # expand if int/float + try: + # Use stride size for input tensor + stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride + im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs + return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs + except Exception: + # Use actual image size for input tensor (i.e. required for RTDETR models) + im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs + except Exception: + return 0.0 + + +def get_flops_with_torch_profiler(model, imgsz=640): + """Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately).""" + if not TORCH_2_0: # torch profiler implemented in torch>=2.0 + return 0.0 + model = de_parallel(model) + p = next(model.parameters()) + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] # expand if int/float + try: + # Use stride size for input tensor + stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride + im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + with torch.profiler.profile(with_flops=True) as prof: + model(im) + flops = sum(x.flops for x in prof.key_averages()) / 1e9 + flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs + except Exception: + # Use actual image size for input tensor (i.e. required for RTDETR models) + im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + with torch.profiler.profile(with_flops=True) as prof: + model(im) + flops = sum(x.flops for x in prof.key_averages()) / 1e9 + return flops + + +def initialize_weights(model): + """Initialize model weights to random values.""" + for m in model.modules(): + t = type(m) + if t is nn.Conv2d: + pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif t is nn.BatchNorm2d: + m.eps = 1e-3 + m.momentum = 0.03 + elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}: + m.inplace = True + + +def scale_img(img, ratio=1.0, same_shape=False, gs=32): + """Scales and pads an image tensor, optionally maintaining aspect ratio and padding to gs multiple.""" + if ratio == 1.0: + return img + h, w = img.shape[2:] + s = (int(h * ratio), int(w * ratio)) # new size + img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize + if not same_shape: # pad/crop img + h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) + return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean + + +def copy_attr(a, b, include=(), exclude=()): + """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes.""" + for k, v in b.__dict__.items(): + if (len(include) and k not in include) or k.startswith("_") or k in exclude: + continue + else: + setattr(a, k, v) + + +def get_latest_opset(): + """Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity.""" + if TORCH_1_13: + # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset' + return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 + # Otherwise for PyTorch<=1.12 return the corresponding predefined opset + version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3' + return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12) + + +def intersect_dicts(da, db, exclude=()): + """Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.""" + return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} + + +def is_parallel(model): + """Returns True if model is of type DP or DDP.""" + return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) + + +def de_parallel(model): + """De-parallelize a model: returns single-GPU model if model is of type DP or DDP.""" + return model.module if is_parallel(model) else model + + +def one_cycle(y1=0.0, y2=1.0, steps=100): + """Returns a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.""" + return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1 + + +def init_seeds(seed=0, deterministic=False): + """Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe + # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 + if deterministic: + if TORCH_2_0: + torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible + torch.backends.cudnn.deterministic = True + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + os.environ["PYTHONHASHSEED"] = str(seed) + else: + LOGGER.warning("WARNING ⚠️ Upgrade to torch>=2.0.0 for deterministic training.") + else: + torch.use_deterministic_algorithms(False) + torch.backends.cudnn.deterministic = False + + +class ModelEMA: + """ + Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models. Keeps a moving + average of everything in the model state_dict (parameters and buffers). + + For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + To disable EMA set the `enabled` attribute to `False`. + """ + + def __init__(self, model, decay=0.9999, tau=2000, updates=0): + """Initialize EMA for 'model' with given arguments.""" + self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA + self.updates = updates # number of EMA updates + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) + for p in self.ema.parameters(): + p.requires_grad_(False) + self.enabled = True + + def update(self, model): + """Update EMA parameters.""" + if self.enabled: + self.updates += 1 + d = self.decay(self.updates) + + msd = de_parallel(model).state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: # true for FP16 and FP32 + v *= d + v += (1 - d) * msd[k].detach() + # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}' + + def update_attr(self, model, include=(), exclude=("process_group", "reducer")): + """Updates attributes and saves stripped model with optimizer removed.""" + if self.enabled: + copy_attr(self.ema, model, include, exclude) + + +def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict: + """ + Strip optimizer from 'f' to finalize training, optionally save as 's'. + + Args: + f (str): file path to model to strip the optimizer from. Default is 'best.pt'. + s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten. + updates (dict): a dictionary of updates to overlay onto the checkpoint before saving. + + Returns: + (dict): The combined checkpoint dictionary. + + Example: + ```python + from pathlib import Path + from ultralytics.utils.torch_utils import strip_optimizer + + for f in Path("path/to/model/checkpoints").rglob("*.pt"): + strip_optimizer(f) + ``` + + Note: + Use `ultralytics.nn.torch_safe_load` for missing modules with `x = torch_safe_load(f)[0]` + """ + try: + x = torch.load(f, map_location=torch.device("cpu")) + assert isinstance(x, dict), "checkpoint is not a Python dictionary" + assert "model" in x, "'model' missing from checkpoint" + except Exception as e: + LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}") + return {} + + metadata = { + "date": datetime.now().isoformat(), + "version": __version__, + "license": "AGPL-3.0 License (https://ultralytics.com/license)", + "docs": "https://docs.ultralytics.com", + } + + # Update model + if x.get("ema"): + x["model"] = x["ema"] # replace model with EMA + if hasattr(x["model"], "args"): + x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict + if hasattr(x["model"], "criterion"): + x["model"].criterion = None # strip loss criterion + x["model"].half() # to FP16 + for p in x["model"].parameters(): + p.requires_grad = False + + # Update other keys + args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args + for k in "optimizer", "best_fitness", "ema", "updates": # keys + x[k] = None + x["epoch"] = -1 + x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys + # x['model'].args = x['train_args'] + + # Save + combined = {**metadata, **x, **(updates or {})} + torch.save(combined, s or f) # combine dicts (prefer to the right) + mb = os.path.getsize(s or f) / 1e6 # file size + LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") + return combined + + +def convert_optimizer_state_dict_to_fp16(state_dict): + """ + Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions. + + This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data. + """ + for state in state_dict["state"].values(): + for k, v in state.items(): + if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32: + state[k] = v.half() + + return state_dict + + +@contextmanager +def cuda_memory_usage(device=None): + """ + Monitor and manage CUDA memory usage. + + This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory. + It then yields a dictionary containing memory usage information, which can be updated by the caller. + Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device. + + Args: + device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None. + + Yields: + (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory. + """ + cuda_info = dict(memory=0) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + try: + yield cuda_info + finally: + cuda_info["memory"] = torch.cuda.memory_reserved(device) + else: + yield cuda_info + + +def profile(input, ops, n=10, device=None, max_num_obj=0): + """ + Ultralytics speed, memory and FLOPs profiler. + + Example: + ```python + from ultralytics.utils.torch_utils import profile + + input = torch.randn(16, 3, 640, 640) + m1 = lambda x: x * torch.sigmoid(x) + m2 = nn.SiLU() + profile(input, [m1, m2], n=100) # profile over 100 iterations + ``` + """ + results = [] + if not isinstance(device, torch.device): + device = select_device(device) + LOGGER.info( + f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" + f"{'input':>24s}{'output':>24s}" + ) + gc.collect() # attempt to free unused memory + torch.cuda.empty_cache() + for x in input if isinstance(input, list) else [input]: + x = x.to(device) + x.requires_grad = True + for m in ops if isinstance(ops, list) else [ops]: + m = m.to(device) if hasattr(m, "to") else m # device + m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m + tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward + try: + flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs + except Exception: + flops = 0 + + try: + mem = 0 + for _ in range(n): + with cuda_memory_usage(device) as cuda_info: + t[0] = time_sync() + y = m(x) + t[1] = time_sync() + try: + (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() + t[2] = time_sync() + except Exception: # no backward method + # print(e) # for debug + t[2] = float("nan") + mem += cuda_info["memory"] / 1e9 # (GB) + tf += (t[1] - t[0]) * 1000 / n # ms per op forward + tb += (t[2] - t[1]) * 1000 / n # ms per op backward + if max_num_obj: # simulate training with predictions per image grid (for AutoBatch) + with cuda_memory_usage(device) as cuda_info: + torch.randn( + x.shape[0], + max_num_obj, + int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())), + device=device, + dtype=torch.float32, + ) + mem += cuda_info["memory"] / 1e9 # (GB) + s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes + p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters + LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}") + results.append([p, flops, mem, tf, tb, s_in, s_out]) + except Exception as e: + LOGGER.info(e) + results.append(None) + finally: + gc.collect() # attempt to free unused memory + torch.cuda.empty_cache() + return results + + +class EarlyStopping: + """Early stopping class that stops training when a specified number of epochs have passed without improvement.""" + + def __init__(self, patience=50): + """ + Initialize early stopping object. + + Args: + patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. + """ + self.best_fitness = 0.0 # i.e. mAP + self.best_epoch = 0 + self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop + self.possible_stop = False # possible stop may occur next epoch + + def __call__(self, epoch, fitness): + """ + Check whether to stop training. + + Args: + epoch (int): Current epoch of training + fitness (float): Fitness value of current epoch + + Returns: + (bool): True if training should stop, False otherwise + """ + if fitness is None: # check if fitness=None (happens when val=False) + return False + + if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training + self.best_epoch = epoch + self.best_fitness = fitness + delta = epoch - self.best_epoch # epochs without improvement + self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch + stop = delta >= self.patience # stop training if patience exceeded + if stop: + prefix = colorstr("EarlyStopping: ") + LOGGER.info( + f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. " + f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n" + f"To update EarlyStopping(patience={self.patience}) pass a new patience value, " + f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping." + ) + return stop + + +class FXModel(nn.Module): + """ + A custom model class for torch.fx compatibility. + + This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph manipulation. + It copies attributes from an existing model and explicitly sets the model attribute to ensure proper copying. + + Args: + model (torch.nn.Module): The original model to wrap for torch.fx compatibility. + """ + + def __init__(self, model): + """ + Initialize the FXModel. + + Args: + model (torch.nn.Module): The original model to wrap for torch.fx compatibility. + """ + super().__init__() + copy_attr(self, model) + # Explicitly set `model` since `copy_attr` somehow does not copy it. + self.model = model.model + + def forward(self, x): + """ + Forward pass through the model. + + This method performs the forward pass through the model, handling the dependencies between layers and saving intermediate outputs. + + Args: + x (torch.Tensor): The input tensor to the model. + + Returns: + (torch.Tensor): The output tensor from the model. + """ + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + # from earlier layers + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] + x = m(x) # run + y.append(x) # save output + return x diff --git a/ultralytics/utils/triton.py b/ultralytics/utils/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b97d89f07210e1892e6127cbd34d84e61047ad --- /dev/null +++ b/ultralytics/utils/triton.py @@ -0,0 +1,93 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from typing import List +from urllib.parse import urlsplit + +import numpy as np + + +class TritonRemoteModel: + """ + Client for interacting with a remote Triton Inference Server model. + + Attributes: + endpoint (str): The name of the model on the Triton server. + url (str): The URL of the Triton server. + triton_client: The Triton client (either HTTP or gRPC). + InferInput: The input class for the Triton client. + InferRequestedOutput: The output request class for the Triton client. + input_formats (List[str]): The data types of the model inputs. + np_input_formats (List[type]): The numpy data types of the model inputs. + input_names (List[str]): The names of the model inputs. + output_names (List[str]): The names of the model outputs. + """ + + def __init__(self, url: str, endpoint: str = "", scheme: str = ""): + """ + Initialize the TritonRemoteModel. + + Arguments may be provided individually or parsed from a collective 'url' argument of the form + ://// + + Args: + url (str): The URL of the Triton server. + endpoint (str): The name of the model on the Triton server. + scheme (str): The communication scheme ('http' or 'grpc'). + """ + if not endpoint and not scheme: # Parse all args from URL string + splits = urlsplit(url) + endpoint = splits.path.strip("/").split("/")[0] + scheme = splits.scheme + url = splits.netloc + + self.endpoint = endpoint + self.url = url + + # Choose the Triton client based on the communication scheme + if scheme == "http": + import tritonclient.http as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint) + else: + import tritonclient.grpc as client # noqa + + self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) + config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] + + # Sort output names alphabetically, i.e. 'output0', 'output1', etc. + config["output"] = sorted(config["output"], key=lambda x: x.get("name")) + + # Define model attributes + type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} + self.InferRequestedOutput = client.InferRequestedOutput + self.InferInput = client.InferInput + self.input_formats = [x["data_type"] for x in config["input"]] + self.np_input_formats = [type_map[x] for x in self.input_formats] + self.input_names = [x["name"] for x in config["input"]] + self.output_names = [x["name"] for x in config["output"]] + self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None")) + + def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: + """ + Call the model with the given inputs. + + Args: + *inputs (List[np.ndarray]): Input data to the model. + + Returns: + (List[np.ndarray]): Model outputs. + """ + infer_inputs = [] + input_format = inputs[0].dtype + for i, x in enumerate(inputs): + if x.dtype != self.np_input_formats[i]: + x = x.astype(self.np_input_formats[i]) + infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) + infer_input.set_data_from_numpy(x) + infer_inputs.append(infer_input) + + infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] + outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) + + return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names] diff --git a/ultralytics/utils/tuner.py b/ultralytics/utils/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..10a77d56ea8d4b4edbbee7e1362a38d1d9140b2d --- /dev/null +++ b/ultralytics/utils/tuner.py @@ -0,0 +1,157 @@ +# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license + +from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir +from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks + + +def run_ray_tune( + model, + space: dict = None, + grace_period: int = 10, + gpu_per_trial: int = None, + max_samples: int = 10, + **train_args, +): + """ + Runs hyperparameter tuning using Ray Tune. + + Args: + model (YOLO): Model to run the tuner on. + space (dict, optional): The hyperparameter search space. Defaults to None. + grace_period (int, optional): The grace period in epochs of the ASHA scheduler. Defaults to 10. + gpu_per_trial (int, optional): The number of GPUs to allocate per trial. Defaults to None. + max_samples (int, optional): The maximum number of trials to run. Defaults to 10. + train_args (dict, optional): Additional arguments to pass to the `train()` method. Defaults to {}. + + Returns: + (dict): A dictionary containing the results of the hyperparameter search. + + Example: + ```python + from ultralytics import YOLO + + # Load a YOLOv8n model + model = YOLO("yolo11n.pt") + + # Start tuning hyperparameters for YOLOv8n training on the COCO8 dataset + result_grid = model.tune(data="coco8.yaml", use_ray=True) + ``` + """ + LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune") + if train_args is None: + train_args = {} + + try: + checks.check_requirements("ray[tune]") + + import ray + from ray import tune + from ray.air import RunConfig + from ray.air.integrations.wandb import WandbLoggerCallback + from ray.tune.schedulers import ASHAScheduler + except ImportError: + raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"') + + try: + import wandb + + assert hasattr(wandb, "__version__") + except (ImportError, AssertionError): + wandb = False + + checks.check_version(ray.__version__, ">=2.0.0", "ray") + default_space = { + # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), + "lr0": tune.uniform(1e-5, 1e-1), + "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) + "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 + "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 + "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) + "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum + "box": tune.uniform(0.02, 0.2), # box loss gain + "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) + "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) + "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) + "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) + "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg) + "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction) + "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain) + "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg) + "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 + "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability) + "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability) + "bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability) + "mosaic": tune.uniform(0.0, 1.0), # image mixup (probability) + "mixup": tune.uniform(0.0, 1.0), # image mixup (probability) + "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability) + } + + # Put the model in ray store + task = model.task + model_in_store = ray.put(model) + + def _tune(config): + """ + Trains the YOLO model with the specified hyperparameters and additional arguments. + + Args: + config (dict): A dictionary of hyperparameters to use for training. + + Returns: + None + """ + model_to_train = ray.get(model_in_store) # get the model from ray store for tuning + model_to_train.reset_callbacks() + config.update(train_args) + results = model_to_train.train(**config) + return results.results_dict + + # Get search space + if not space: + space = default_space + LOGGER.warning("WARNING ⚠️ search space not provided, using default search space.") + + # Get dataset + data = train_args.get("data", TASK2DATA[task]) + space["data"] = data + if "data" not in train_args: + LOGGER.warning(f'WARNING ⚠️ data not provided, using default "data={data}".') + + # Define the trainable function with allocated resources + trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0}) + + # Define the ASHA scheduler for hyperparameter search + asha_scheduler = ASHAScheduler( + time_attr="epoch", + metric=TASK2METRIC[task], + mode="max", + max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100, + grace_period=grace_period, + reduction_factor=3, + ) + + # Define the callbacks for the hyperparameter search + tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else [] + + # Create the Ray Tune hyperparameter search tuner + tune_dir = get_save_dir( + get_cfg(DEFAULT_CFG, train_args), name=train_args.pop("name", "tune") + ).resolve() # must be absolute dir + tune_dir.mkdir(parents=True, exist_ok=True) + tuner = tune.Tuner( + trainable_with_resources, + param_space=space, + tune_config=tune.TuneConfig(scheduler=asha_scheduler, num_samples=max_samples), + run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir), + ) + + # Run the hyperparameter search + tuner.fit() + + # Get the results of the hyperparameter search + results = tuner.get_results() + + # Shut down Ray to clean up workers + ray.shutdown() + + return results