Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +4 -0
- __pycache__/SegTracker.cpython-310.pyc +0 -0
- __pycache__/aot_tracker.cpython-310.pyc +0 -0
- __pycache__/model_args.cpython-310.pyc +0 -0
- __pycache__/seg_track_anything.cpython-310.pyc +0 -0
- aot/.DS_Store +0 -0
- aot/LICENSE +29 -0
- aot/MODEL_ZOO.md +115 -0
- aot/Pytorch-Correlation-extension/.gitignore +1 -0
- aot/Pytorch-Correlation-extension/Correlation_Module/correlation.cpp +178 -0
- aot/Pytorch-Correlation-extension/Correlation_Module/correlation_cuda_kernel.cu +327 -0
- aot/Pytorch-Correlation-extension/Correlation_Module/correlation_sampler.cpp +138 -0
- aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/__init__.py +1 -0
- aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/spatial_correlation_sampler.py +107 -0
- aot/Pytorch-Correlation-extension/LICENSE +21 -0
- aot/Pytorch-Correlation-extension/README.md +155 -0
- aot/Pytorch-Correlation-extension/benchmark.py +90 -0
- aot/Pytorch-Correlation-extension/check.py +119 -0
- aot/Pytorch-Correlation-extension/grad_check.py +47 -0
- aot/Pytorch-Correlation-extension/requirements.txt +2 -0
- aot/Pytorch-Correlation-extension/setup.py +69 -0
- aot/Pytorch-Correlation-extension/setup_cpu.py +4 -0
- aot/README.md +152 -0
- aot/__init__.py +0 -0
- aot/__pycache__/__init__.cpython-310.pyc +0 -0
- aot/configs/__pycache__/default.cpython-310.pyc +0 -0
- aot/configs/__pycache__/pre_ytb_dav.cpython-310.pyc +0 -0
- aot/configs/default.py +138 -0
- aot/configs/models/__pycache__/default.cpython-310.pyc +0 -0
- aot/configs/models/__pycache__/r50_aotl.cpython-310.pyc +0 -0
- aot/configs/models/aotb.py +9 -0
- aot/configs/models/aotl.py +13 -0
- aot/configs/models/aots.py +9 -0
- aot/configs/models/aott.py +7 -0
- aot/configs/models/deaotb.py +9 -0
- aot/configs/models/deaotl.py +13 -0
- aot/configs/models/deaots.py +9 -0
- aot/configs/models/deaott.py +7 -0
- aot/configs/models/default.py +27 -0
- aot/configs/models/default_deaot.py +17 -0
- aot/configs/models/r101_aotl.py +16 -0
- aot/configs/models/r50_aotl.py +16 -0
- aot/configs/models/r50_deaotl.py +16 -0
- aot/configs/models/rs101_aotl.py +16 -0
- aot/configs/models/swinb_aotl.py +17 -0
- aot/configs/models/swinb_deaotl.py +17 -0
- aot/configs/pre.py +19 -0
- aot/configs/pre_dav.py +21 -0
- aot/configs/pre_ytb.py +17 -0
.DS_Store
ADDED
Binary file (10.2 kB). View file
|
|
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/cars.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/cell.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/demo_3x2.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/top.gif filter=lfs diff=lfs merge=lfs -text
|
__pycache__/SegTracker.cpython-310.pyc
ADDED
Binary file (8.05 kB). View file
|
|
__pycache__/aot_tracker.cpython-310.pyc
ADDED
Binary file (5.71 kB). View file
|
|
__pycache__/model_args.cpython-310.pyc
ADDED
Binary file (772 Bytes). View file
|
|
__pycache__/seg_track_anything.cpython-310.pyc
ADDED
Binary file (6.5 kB). View file
|
|
aot/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
aot/LICENSE
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2020, z-x-yang
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without
|
7 |
+
modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
this list of conditions and the following disclaimer in the documentation
|
14 |
+
and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
3. Neither the name of the copyright holder nor the names of its
|
17 |
+
contributors may be used to endorse or promote products derived from
|
18 |
+
this software without specific prior written permission.
|
19 |
+
|
20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
aot/MODEL_ZOO.md
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Model Zoo and Results
|
2 |
+
|
3 |
+
### Environment and Settings
|
4 |
+
- 4/1 NVIDIA V100 GPUs for training/evaluation.
|
5 |
+
- Auto-mixed precision was enabled in training but disabled in evaluation.
|
6 |
+
- Test-time augmentations were not used.
|
7 |
+
- The inference resolution of DAVIS/YouTube-VOS was 480p/1.3x480p as [CFBI](https://github.com/z-x-yang/CFBI).
|
8 |
+
- Fully online inference. We passed all the modules frame by frame.
|
9 |
+
- Multi-object FPS was recorded instead of single-object one.
|
10 |
+
|
11 |
+
### Pre-trained Models
|
12 |
+
Stages:
|
13 |
+
|
14 |
+
- `PRE`: the pre-training stage with static images.
|
15 |
+
|
16 |
+
- `PRE_YTB_DAV`: the main-training stage with YouTube-VOS and DAVIS. All the kinds of evaluation share an **identical** model and the **same** parameters.
|
17 |
+
|
18 |
+
|
19 |
+
| Model | Param (M) | PRE | PRE_YTB_DAV |
|
20 |
+
|:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
|
21 |
+
| AOTT | 5.7 | [gdrive](https://drive.google.com/file/d/1_513h8Hok9ySQPMs_dHgX5sPexUhyCmy/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1owPmwV4owd_ll6GuilzklqTyAd0ZvbCu/view?usp=sharing) |
|
22 |
+
| AOTS | 7.0 | [gdrive](https://drive.google.com/file/d/1QUP0-VED-lOF1oX_ppYWnXyBjvUzJJB7/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1beU5E6Mdnr_pPrgjWvdWurKAIwJSz1xf/view?usp=sharing) |
|
23 |
+
| AOTB | 8.3 | [gdrive](https://drive.google.com/file/d/11Bx8n_INAha1IdpHjueGpf7BrKmCJDvK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1hH-GOn4GAxHkV8ARcQzsUy8Ax6ndot-A/view?usp=sharing) |
|
24 |
+
| AOTL | 8.3 | [gdrive](https://drive.google.com/file/d/1WL6QCsYeT7Bt-Gain9ZIrNNXpR2Hgh29/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1L1N2hkSPqrwGgnW9GyFHuG59_EYYfTG4/view?usp=sharing) |
|
25 |
+
| R50-AOTL | 14.9 | [gdrive](https://drive.google.com/file/d/1hS4JIvOXeqvbs-CokwV6PwZV-EvzE6x8/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) |
|
26 |
+
| SwinB-AOTL | 65.4 | [gdrive](https://drive.google.com/file/d/1LlhKQiXD8JyZGGs3hZiNzcaCLqyvL9tj/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/192jCGQZdnuTsvX-CVra-KVZl2q1ZR0vW/view?usp=sharing) |
|
27 |
+
|
28 |
+
| Model | Param (M) | PRE | PRE_YTB_DAV |
|
29 |
+
|:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
|
30 |
+
| DeAOTT | 7.2 | [gdrive](https://drive.google.com/file/d/11C1ZBoFpL3ztKtINS8qqwPSldfYXexFK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1ThWIZQS03cYWx1EKNN8MIMnJS5eRowzr/view?usp=sharing) |
|
31 |
+
| DeAOTS | 10.2 | [gdrive](https://drive.google.com/file/d/1uUidrWVoaP9A5B5-EzQLbielUnRLRF3j/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1YwIAV5tBtn5spSFxKLBQBEQGwPHyQlHi/view?usp=sharing) |
|
32 |
+
| DeAOTB | 13.2 | [gdrive](https://drive.google.com/file/d/1bEQr6vIgQMVITrSOtxWTMgycKpS0cor9/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1BHxsonnvJXylqHlZ1zJHHc-ymKyq-CFf/view?usp=sharing) |
|
33 |
+
| DeAOTL | 13.2 | [gdrive](https://drive.google.com/file/d/1_vBL4KJlmBy0oBE4YFDOvsYL1ZtpEL32/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/18elNz_wi9JyVBcIUYKhRdL08MA-FqHD5/view?usp=sharing) |
|
34 |
+
| R50-DeAOTL | 19.8 | [gdrive](https://drive.google.com/file/d/1sTRQ1g0WCpqVCdavv7uJiZNkXunBt3-R/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view?usp=sharing) |
|
35 |
+
| SwinB-DeAOTL | 70.3 | [gdrive](https://drive.google.com/file/d/16BZEE53no8CxT-pPLDC2q1d6Xlg8mWPU/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1g4E-F0RPOx9Nd6J7tU9AE1TjsouL4oZq/view?usp=sharing) |
|
36 |
+
|
37 |
+
To use our pre-trained model to infer, a simple way is to set `--model` and `--ckpt_path` to your downloaded checkpoint's model type and file path when running `eval.py`.
|
38 |
+
|
39 |
+
### YouTube-VOS 2018 val
|
40 |
+
`ALL-F`: all frames. The default evaluation setting of YouTube-VOS is 6fps, but 30fps sequences (all the frames) are also supplied by the dataset organizers. We noticed that many VOS methods prefer to evaluate with 30fps videos. Thus, we also supply our results here. Denser video sequences can significantly improve VOS performance when using the memory reading strategy (like AOTL, R50-AOTL, and SwinB-AOTL), but the efficiency will be influenced since more memorized frames are stored for object matching.
|
41 |
+
| Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions |
|
42 |
+
|:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:|
|
43 |
+
| AOTT | PRE_YTB_DAV | 41.0 | | 80.2 | 80.4 | 85.0 | 73.6 | 81.7 | [gdrive](https://drive.google.com/file/d/1u8mvPRT08ENZHsw9Xf_4C6Sv9BoCzENR/view?usp=sharing) |
|
44 |
+
| AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 80.0 | 84.7 | 75.2 | 83.5 | [gdrive](https://drive.google.com/file/d/1RGMI5-29Z0odq73rt26eCxOUYUd-fvVv/view?usp=sharing) |
|
45 |
+
| DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.6** | **86.3** | **75.8** | **84.2** | - |
|
46 |
+
| AOTS | PRE_YTB_DAV | 27.1 | | 82.9 | 82.3 | 87.0 | 77.1 | 85.1 | [gdrive](https://drive.google.com/file/d/1a4-rNnxjMuPBq21IKo31WDYZXMPgS7r2/view?usp=sharing) |
|
47 |
+
| AOTS | PRE_YTB_DAV | 27.1 | √ | 83.0 | 82.2 | 87.0 | 77.3 | 85.7 | [gdrive](https://drive.google.com/file/d/1Z0cndyoCw5Na6u-VFRE8CyiIG2RbMIUO/view?usp=sharing) |
|
48 |
+
| DeAOTS | PRE_YTB_DAV | **38.7** | | **84.0** | **83.3** | **88.3** | **77.9** | **86.6** | - |
|
49 |
+
| AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.2 | 88.1 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1J5nhuQbbjVLYNXViBIgo21ddQy-MiOLG/view?usp=sharing) |
|
50 |
+
| AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.6 | 88.5 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1gFaweB_GTJjHzSD61v_ZsY9K7UEND30O/view?usp=sharing) |
|
51 |
+
| DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.9** | **88.9** | **78.5** | **87.0** | - |
|
52 |
+
| AOTL | PRE_YTB_DAV | 16.0 | | 84.1 | 83.2 | 88.2 | 78.2 | 86.8 | [gdrive](https://drive.google.com/file/d/1kS8KWQ2L3wzxt44ROLTxwZOT7ZpT8Igc/view?usp=sharing) |
|
53 |
+
| AOTL | PRE_YTB_DAV | 6.5 | √ | 84.5 | 83.7 | 88.8 | 78.4 | **87.1** | [gdrive](https://drive.google.com/file/d/1Rpm3e215kJOUvb562lJ2kYg2I3hkrxiM/view?usp=sharing) |
|
54 |
+
| DeAOTL | PRE_YTB_DAV | **24.7** | | **84.8** | **84.2** | **89.4** | **78.6** | 87.0 | - |
|
55 |
+
| R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.6 | 83.7 | 88.5 | 78.8 | 87.3 | [gdrive](https://drive.google.com/file/d/1nbJZ1bbmEgyK-bg6HQ8LwCz5gVJ6wzIZ/view?usp=sharing) |
|
56 |
+
| R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.5 | 84.5 | 89.5 | 79.6 | 88.2 | [gdrive](https://drive.google.com/file/d/1NbB54ZhYvfJh38KFOgovYYPjWopd-2TE/view?usp=sharing) |
|
57 |
+
| R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **86.0** | **84.9** | **89.9** | **80.4** | **88.7** | - |
|
58 |
+
| SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.5 | 89.5 | 78.1 | 86.7 | [gdrive](https://drive.google.com/file/d/1QFowulSY0LHfpsjUV8ZE9rYc55L9DOC7/view?usp=sharing) |
|
59 |
+
| SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.1 | 85.1 | 90.1 | 78.4 | 86.9 | [gdrive](https://drive.google.com/file/d/1TulhVOhh01rkssNYbOQASeWKu7CQ5Azx/view?usp=sharing) |
|
60 |
+
| SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.2** | **85.6** | **90.6** | **80.0** | **88.4** | - |
|
61 |
+
|
62 |
+
### YouTube-VOS 2019 val
|
63 |
+
| Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions |
|
64 |
+
|:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:|
|
65 |
+
| AOTT | PRE_YTB_DAV | 41.0 | | 80.0 | 79.8 | 84.2 | 74.1 | 82.1 | [gdrive](https://drive.google.com/file/d/1zzyhN1XYtajte5nbZ7opOdfXeDJgCxC5/view?usp=sharing) |
|
66 |
+
| AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 79.9 | 84.4 | 75.6 | 83.8 | [gdrive](https://drive.google.com/file/d/1V_5vi9dAXOis_WrDieacSESm7OX20Bv-/view?usp=sharing) |
|
67 |
+
| DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.2** | **85.6** | **76.4** | **84.7** | - |
|
68 |
+
| AOTS | PRE_YTB_DAV | 27.1 | | 82.7 | 81.9 | 86.5 | 77.3 | 85.2 | [gdrive](https://drive.google.com/file/d/11YdkUeyjkTv8Uw7xMgPCBzJs6v5SDt6n/view?usp=sharing) |
|
69 |
+
| AOTS | PRE_YTB_DAV | 27.1 | √ | 82.8 | 81.9 | 86.5 | 77.3 | 85.6 | [gdrive](https://drive.google.com/file/d/1UhyurGTJeAw412czU3_ebzNwF8xQ4QG_/view?usp=sharing) |
|
70 |
+
| DeAOTS | PRE_YTB_DAV | **38.7** | | **83.8** | **82.8** | **87.5** | **78.1** | **86.8** | - |
|
71 |
+
| AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.1 | 87.7 | 78.5 | 86.8 | [gdrive](https://drive.google.com/file/d/1NeI8cT4kVqTqVWAwtwiga1rkrvksNWaO/view?usp=sharing) |
|
72 |
+
| AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.3 | 88.0 | 78.2 | 86.7 | [gdrive](https://drive.google.com/file/d/1kpYV2XFR0sOfLWD-wMhd-nUO6CFiLjlL/view?usp=sharing) |
|
73 |
+
| DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.5** | **88.3** | **79.1** | **87.5** | - |
|
74 |
+
| AOTL | PRE_YTB_DAV | 16.0 | | 84.0 | 82.8 | 87.6 | 78.6 | 87.1 | [gdrive](https://drive.google.com/file/d/1qKLlNXxmT31bW0weEHI_zAf4QwU8Lhou/view?usp=sharing) |
|
75 |
+
| AOTL | PRE_YTB_DAV | 6.5 | √ | 84.2 | 83.0 | 87.8 | 78.7 | 87.3 | [gdrive](https://drive.google.com/file/d/1o3fwZ0cH71bqHSA3bYNjhP4GGv9Vyuwa/view?usp=sharing) |
|
76 |
+
| DeAOTL | PRE_YTB_DAV | **24.7** | | **84.7** | **83.8** | **88.8** | **79.0** | **87.2** | - |
|
77 |
+
| R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.4 | 83.4 | 88.1 | 78.7 | 87.2 | [gdrive](https://drive.google.com/file/d/1I7ooSp8EYfU6fvkP6QcCMaxeencA68AH/view?usp=sharing) |
|
78 |
+
| R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.3 | 83.9 | 88.8 | 79.9 | 88.5 | [gdrive](https://drive.google.com/file/d/1OGqlkEu0uXa8QVWIVz_M5pmXXiYR2sh3/view?usp=sharing) |
|
79 |
+
| R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **85.9** | **84.6** | **89.4** | **80.8** | **88.9** | - |
|
80 |
+
| SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.0 | 88.8 | 78.7 | 87.1 | [gdrive](https://drive.google.com/file/d/1fPzCxi5GM7N2sLKkhoTC2yoY_oTQCHp1/view?usp=sharing) |
|
81 |
+
| SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.3 | 84.6 | 89.5 | 79.3 | 87.7 | [gdrive](https://drive.google.com/file/d/1e3D22s_rJ7Y2X2MHo7x5lcNtwmHFlwYB/view?usp=sharing) |
|
82 |
+
| SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.1** | **85.3** | **90.2** | **80.4** | **88.6** | - |
|
83 |
+
|
84 |
+
### DAVIS-2017 test
|
85 |
+
|
86 |
+
| Model | Stage | FPS | Mean | J Score | F Score | Predictions |
|
87 |
+
| ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:|
|
88 |
+
| AOTT | PRE_YTB_DAV | **51.4** | 73.7 | 70.0 | 77.3 | [gdrive](https://drive.google.com/file/d/14Pu-6Uz4rfmJ_WyL2yl57KTx_pSSUNAf/view?usp=sharing) |
|
89 |
+
| AOTS | PRE_YTB_DAV | 40.0 | 75.2 | 71.4 | 78.9 | [gdrive](https://drive.google.com/file/d/1zzAPZCRLgnBWuAXqejPPEYLqBxu67Rj1/view?usp=sharing) |
|
90 |
+
| AOTB | PRE_YTB_DAV | 29.6 | 77.4 | 73.7 | 81.1 | [gdrive](https://drive.google.com/file/d/1WpQ-_Jrs7Ssfw0oekrejM2OVWEx_tBN1/view?usp=sharing) |
|
91 |
+
| AOTL | PRE_YTB_DAV | 18.7 | 79.3 | 75.5 | 83.2 | [gdrive](https://drive.google.com/file/d/1rP1Zdgc0N1d8RR2EaXMz3F-o5zqcNVe8/view?usp=sharing) |
|
92 |
+
| R50-AOTL | PRE_YTB_DAV | 18.0 | 79.5 | 76.0 | 83.0 | [gdrive](https://drive.google.com/file/d/1iQ5iNlvlS-In586ZNc4LIZMSdNIWDvle/view?usp=sharing) |
|
93 |
+
| SwinB-AOTL | PRE_YTB_DAV | 12.1 | **82.1** | **78.2** | **85.9** | [gdrive](https://drive.google.com/file/d/1oVt4FPcZdfVHiOxjYYKef0q7Ovy4f5Q_/view?usp=sharing) |
|
94 |
+
|
95 |
+
### DAVIS-2017 val
|
96 |
+
|
97 |
+
| Model | Stage | FPS | Mean | J Score | F Score | Predictions |
|
98 |
+
| ---------- |:-----------:|:----:|:--------:|:--------:|:---------:|:----:|
|
99 |
+
| AOTT | PRE_YTB_DAV | **51.4** | 79.2 | 76.5 | 81.9 | [gdrive](https://drive.google.com/file/d/10OUFhK2Sz-hOJrTDoTI0mA45KO1qodZt/view?usp=sharing) |
|
100 |
+
| AOTS | PRE_YTB_DAV | 40.0 | 82.1 | 79.3 | 84.8 | [gdrive](https://drive.google.com/file/d/1T-JTYyksWlq45jxcLjnRaBvvYUhWgHFH/view?usp=sharing) |
|
101 |
+
| AOTB | PRE_YTB_DAV | 29.6 | 83.3 | 80.6 | 85.9 | [gdrive](https://drive.google.com/file/d/1EVUnxQm9TLBTuwK82QyiSKk9R9V8NwRL/view?usp=sharing) |
|
102 |
+
| AOTL | PRE_YTB_DAV | 18.7 | 83.6 | 80.8 | 86.3 | [gdrive](https://drive.google.com/file/d/1CFauSni2BxAe_fcl8W_6bFByuwJRbDYm/view?usp=sharing) |
|
103 |
+
| R50-AOTL | PRE_YTB_DAV | 18.0 | 85.2 | 82.5 | 87.9 | [gdrive](https://drive.google.com/file/d/1vjloxnP8R4PZdsH2DDizfU2CrkdRHHyo/view?usp=sharing) |
|
104 |
+
| SwinB-AOTL | PRE_YTB_DAV | 12.1 | **85.9** | **82.9** | **88.9** | [gdrive](https://drive.google.com/file/d/1tYCbKOas0i7Et2iyUAyDwaXnaD9YWxLr/view?usp=sharing) |
|
105 |
+
|
106 |
+
### DAVIS-2016 val
|
107 |
+
|
108 |
+
| Model | Stage | FPS | Mean | J Score | F Score | Predictions |
|
109 |
+
| ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:|
|
110 |
+
| AOTT | PRE_YTB_DAV | **51.4** | 87.5 | 86.5 | 88.4 | [gdrive](https://drive.google.com/file/d/1LeW8WQhnylZ3umT7E379KdII92uUsGA9/view?usp=sharing) |
|
111 |
+
| AOTS | PRE_YTB_DAV | 40.0 | 89.6 | 88.6 | 90.5 | [gdrive](https://drive.google.com/file/d/1vqGei5tLu1FPVrTi5bwRAsaGy3Upf7B1/view?usp=sharing) |
|
112 |
+
| AOTB | PRE_YTB_DAV | 29.6 | 90.9 | 89.6 | 92.1 | [gdrive](https://drive.google.com/file/d/1qAppo2uOVu0FbE9t1FBUpymC3yWgw1LM/view?usp=sharing) |
|
113 |
+
| AOTL | PRE_YTB_DAV | 18.7 | 91.1 | 89.5 | 92.7 | [gdrive](https://drive.google.com/file/d/1g6cjYhgBWjMaY3RGAm31qm3SPEF3QcKV/view?usp=sharing) |
|
114 |
+
| R50-AOTL | PRE_YTB_DAV | 18.0 | 91.7 | 90.4 | 93.0 | [gdrive](https://drive.google.com/file/d/1QzxojqWKsvRf53K2AgKsK523ZVuYU4O-/view?usp=sharing) |
|
115 |
+
| SwinB-AOTL | PRE_YTB_DAV | 12.1 | **92.2** | **90.6** | **93.8** | [gdrive](https://drive.google.com/file/d/1RIqUtAyVnopeogfT520d7a0yiULg1obp/view?usp=sharing) |
|
aot/Pytorch-Correlation-extension/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.egg*
|
aot/Pytorch-Correlation-extension/Correlation_Module/correlation.cpp
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
using namespace torch;
|
3 |
+
|
4 |
+
#include <vector>
|
5 |
+
|
6 |
+
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
|
7 |
+
|
8 |
+
template <typename scalar_t>
|
9 |
+
static void correlate_patch(
|
10 |
+
TensorAccessor<scalar_t,3> input1,
|
11 |
+
TensorAccessor<scalar_t,3> input2,
|
12 |
+
scalar_t *dst,
|
13 |
+
int kH, int kW,
|
14 |
+
int dilationH, int dilationW,
|
15 |
+
int u, int v,
|
16 |
+
int shiftU, int shiftV){
|
17 |
+
const int C = input1.size(0);
|
18 |
+
const int iH = input1.size(1);
|
19 |
+
const int iW = input1.size(2);
|
20 |
+
for (int c=0; c<C; ++c){
|
21 |
+
for (int i=0; i<kH; ++i){
|
22 |
+
int i1 = u + i * dilationH;
|
23 |
+
int i2 = i1 + shiftU;
|
24 |
+
if WITHIN_BOUNDS(i1, i2, iH, iH){
|
25 |
+
for (int j=0; j<kW; ++j){
|
26 |
+
int j1 = v + j * dilationW;
|
27 |
+
int j2 = j1 + shiftV;
|
28 |
+
if WITHIN_BOUNDS(j1, j2, iW, iW){
|
29 |
+
scalar_t v1 = input1[c][i1][j1];
|
30 |
+
scalar_t v2 = input2[c][i2][j2];
|
31 |
+
*dst += v1 * v2;
|
32 |
+
}
|
33 |
+
}
|
34 |
+
}
|
35 |
+
}
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
template <typename scalar_t>
|
40 |
+
static void correlate_patch_grad(
|
41 |
+
TensorAccessor<scalar_t,3> input1,
|
42 |
+
TensorAccessor<scalar_t,3> gradInput1,
|
43 |
+
TensorAccessor<scalar_t,3> input2,
|
44 |
+
TensorAccessor<scalar_t,3> gradInput2,
|
45 |
+
scalar_t gradOutput,
|
46 |
+
int kH, int kW,
|
47 |
+
int dilationH, int dilationW,
|
48 |
+
int u, int v,
|
49 |
+
int shiftU, int shiftV){
|
50 |
+
|
51 |
+
const int C = input1.size(0);
|
52 |
+
const int iH = input1.size(1);
|
53 |
+
const int iW = input1.size(2);
|
54 |
+
|
55 |
+
for (int c=0; c<C; ++c){
|
56 |
+
for (int i=0; i<kH; ++i){
|
57 |
+
int i1 = u + i * dilationH;
|
58 |
+
int i2 = i1 + shiftU;
|
59 |
+
if WITHIN_BOUNDS(i1, i2, iH, iH){
|
60 |
+
for (int j=0; j<kW; ++j){
|
61 |
+
int j1 = v + j * dilationW;
|
62 |
+
int j2 = j1 + shiftV;
|
63 |
+
if WITHIN_BOUNDS(j1, j2, iW, iW){
|
64 |
+
scalar_t v1 = input1[c][i1][j1];
|
65 |
+
scalar_t v2 = input2[c][i2][j2];
|
66 |
+
gradInput2[c][i2][j2] += gradOutput * v1;
|
67 |
+
gradInput1[c][i1][j1] += gradOutput * v2;
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
torch::Tensor correlation_cpp_forward(
|
76 |
+
torch::Tensor input1,
|
77 |
+
torch::Tensor input2,
|
78 |
+
int kH, int kW,
|
79 |
+
int patchH, int patchW,
|
80 |
+
int padH, int padW,
|
81 |
+
int dilationH, int dilationW,
|
82 |
+
int dilation_patchH, int dilation_patchW,
|
83 |
+
int dH, int dW) {
|
84 |
+
|
85 |
+
const auto batch_size = input1.size(0);
|
86 |
+
const auto iH = input1.size(2);
|
87 |
+
const auto iW = input1.size(3);
|
88 |
+
const int patchRadH = (patchH - 1) / 2;
|
89 |
+
const int patchRadW = (patchW - 1) / 2;
|
90 |
+
const int dilatedKH = (kH - 1) * dilationH + 1;
|
91 |
+
const int dilatedKW = (kW - 1) * dilationW + 1;
|
92 |
+
|
93 |
+
const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
|
94 |
+
const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
|
95 |
+
auto output = at::zeros({batch_size, patchH, patchW, oH, oW}, input1.options());
|
96 |
+
|
97 |
+
int n, ph, pw, h, w;
|
98 |
+
#pragma omp parallel for private(n, ph, pw, h, w) collapse(2)
|
99 |
+
for (n = 0; n < batch_size; ++n) {
|
100 |
+
for(ph = 0; ph < patchH; ++ph){
|
101 |
+
for(pw = 0; pw < patchW; ++pw){
|
102 |
+
AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "correlation_forward_cpp", ([&] {
|
103 |
+
auto input1_acc = input1.accessor<scalar_t, 4>();
|
104 |
+
auto input2_acc = input2.accessor<scalar_t, 4>();
|
105 |
+
auto output_acc = output.accessor<scalar_t, 5>();
|
106 |
+
for (h = 0; h < oH; ++h) {
|
107 |
+
for (w = 0; w < oW; ++w) {
|
108 |
+
correlate_patch(input1_acc[n],
|
109 |
+
input2_acc[n],
|
110 |
+
&output_acc[n][ph][pw][h][w],
|
111 |
+
kH, kW,
|
112 |
+
dilationH, dilationW,
|
113 |
+
-padH + h * dH,
|
114 |
+
-padW + w * dW,
|
115 |
+
(ph - patchRadH) * dilation_patchH,
|
116 |
+
(pw - patchRadW) * dilation_patchW);
|
117 |
+
}
|
118 |
+
}
|
119 |
+
}));
|
120 |
+
}
|
121 |
+
}
|
122 |
+
}
|
123 |
+
return output;
|
124 |
+
}
|
125 |
+
|
126 |
+
std::vector<torch::Tensor> correlation_cpp_backward(
|
127 |
+
torch::Tensor input1,
|
128 |
+
torch::Tensor input2,
|
129 |
+
torch::Tensor gradOutput,
|
130 |
+
int kH, int kW,
|
131 |
+
int patchH, int patchW,
|
132 |
+
int padH, int padW,
|
133 |
+
int dilationH, int dilationW,
|
134 |
+
int dilation_patchH, int dilation_patchW,
|
135 |
+
int dH, int dW) {
|
136 |
+
|
137 |
+
const int batch_size = input1.size(0);
|
138 |
+
const int patchRadH = (patchH - 1) / 2;
|
139 |
+
const int patchRadW = (patchW - 1) / 2;
|
140 |
+
const int oH = gradOutput.size(3);
|
141 |
+
const int oW = gradOutput.size(4);
|
142 |
+
|
143 |
+
auto gradInput1 = torch::zeros_like(input1);
|
144 |
+
|
145 |
+
auto gradInput2 = torch::zeros_like(input2);
|
146 |
+
|
147 |
+
int n, ph, pw, h, w;
|
148 |
+
#pragma omp parallel for private(n, ph, pw, h, w)
|
149 |
+
for (n = 0; n < batch_size; ++n) {
|
150 |
+
AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "correlation_backward_cpp", ([&] {
|
151 |
+
auto input1_acc = input1.accessor<scalar_t, 4>();
|
152 |
+
auto gradInput1_acc = gradInput1.accessor<scalar_t, 4>();
|
153 |
+
auto input2_acc = input2.accessor<scalar_t, 4>();
|
154 |
+
auto gradInput2_acc = gradInput2.accessor<scalar_t, 4>();
|
155 |
+
auto gradOutput_acc = gradOutput.accessor<scalar_t, 5>();
|
156 |
+
|
157 |
+
for(ph = 0; ph < patchH; ++ph){
|
158 |
+
for(pw = 0; pw < patchW; ++pw){
|
159 |
+
for (h = 0; h < oH; ++h) {
|
160 |
+
for (w = 0; w < oW; ++w) {
|
161 |
+
correlate_patch_grad(input1_acc[n], gradInput1_acc[n],
|
162 |
+
input2_acc[n], gradInput2_acc[n],
|
163 |
+
gradOutput_acc[n][ph][pw][h][w],
|
164 |
+
kH, kW,
|
165 |
+
dilationH, dilationW,
|
166 |
+
-padH + h * dH,
|
167 |
+
-padW + w * dW,
|
168 |
+
(ph - patchRadH) * dilation_patchH,
|
169 |
+
(pw - patchRadW) * dilation_patchW);
|
170 |
+
}
|
171 |
+
}
|
172 |
+
}
|
173 |
+
}
|
174 |
+
}));
|
175 |
+
}
|
176 |
+
|
177 |
+
return {gradInput1, gradInput2};
|
178 |
+
}
|
aot/Pytorch-Correlation-extension/Correlation_Module/correlation_cuda_kernel.cu
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/types.h>
|
2 |
+
using namespace torch;
|
3 |
+
|
4 |
+
#include <cuda.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
|
7 |
+
#include <vector>
|
8 |
+
#include <iostream>
|
9 |
+
|
10 |
+
// Cuda tensor accessor definitions
|
11 |
+
// restrict pointer traits piroritize speed over memory consumption
|
12 |
+
#define TensorAcc4R PackedTensorAccessor32<scalar_t,4,RestrictPtrTraits>
|
13 |
+
#define TensorAcc5R PackedTensorAccessor32<scalar_t,5,RestrictPtrTraits>
|
14 |
+
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
|
15 |
+
|
16 |
+
#define THREADS_FORWARD 32
|
17 |
+
#define THREADS_BACKWARD 5
|
18 |
+
|
19 |
+
|
20 |
+
namespace corr {
|
21 |
+
template <typename scalar_t>
|
22 |
+
__global__ void correlation_cuda_forward_kernel(
|
23 |
+
const TensorAcc4R rInput1,
|
24 |
+
const TensorAcc4R rInput2,
|
25 |
+
TensorAcc5R output,
|
26 |
+
int kH, int kW,
|
27 |
+
int patchH, int patchW,
|
28 |
+
int padH, int padW,
|
29 |
+
int dilationH, int dilationW,
|
30 |
+
int dilation_patchH, int dilation_patchW,
|
31 |
+
int dH, int dW) {
|
32 |
+
|
33 |
+
const int iH = rInput1.size(1);
|
34 |
+
const int iW = rInput1.size(2);
|
35 |
+
const int C = rInput1.size(3);
|
36 |
+
|
37 |
+
const int n = blockIdx.x;
|
38 |
+
const int h = blockIdx.y;
|
39 |
+
const int w = blockIdx.z;
|
40 |
+
const int thread = threadIdx.x;
|
41 |
+
|
42 |
+
const int start_i = -padH + h * dH;
|
43 |
+
const int start_j = -padW + w * dW;
|
44 |
+
|
45 |
+
const int patchRadH = dilation_patchH * (patchH - 1) / 2;
|
46 |
+
const int patchRadW = dilation_patchW * (patchW - 1) / 2;
|
47 |
+
|
48 |
+
__shared__ scalar_t prod_sum[THREADS_FORWARD];
|
49 |
+
|
50 |
+
for(int ph = 0; ph < patchH; ++ph){
|
51 |
+
int ph_dilated = ph * dilation_patchH - patchRadH;
|
52 |
+
for(int pw = 0; pw < patchW; ++pw){
|
53 |
+
int pw_dilated = pw * dilation_patchW - patchRadW;
|
54 |
+
prod_sum[thread] = 0;
|
55 |
+
for (int i=0; i<kH; ++i){
|
56 |
+
int i1 = start_i + i * dilationH;
|
57 |
+
int i2 = i1 + ph_dilated;
|
58 |
+
if WITHIN_BOUNDS(i1, i2, iH, iH){
|
59 |
+
for (int j=0; j<kW; ++j){
|
60 |
+
int j1 = start_j + j * dilationW;
|
61 |
+
int j2 = j1 + pw_dilated;
|
62 |
+
if WITHIN_BOUNDS(j1, j2, iW, iW){
|
63 |
+
for (int c=thread; c<C; c += THREADS_FORWARD){
|
64 |
+
scalar_t v1 = rInput1[n][i1][j1][c];
|
65 |
+
scalar_t v2 = rInput2[n][i2][j2][c];
|
66 |
+
prod_sum[thread] += v1 * v2;
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
70 |
+
}
|
71 |
+
}
|
72 |
+
// accumulate
|
73 |
+
__syncthreads();
|
74 |
+
if (thread == 0) {
|
75 |
+
scalar_t reduce_sum = 0;
|
76 |
+
for (int index = 0; index < THREADS_FORWARD; ++index) {
|
77 |
+
reduce_sum += prod_sum[index];
|
78 |
+
}
|
79 |
+
output[n][ph][pw][h][w] = reduce_sum;
|
80 |
+
}
|
81 |
+
}
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
|
86 |
+
template <typename scalar_t>
|
87 |
+
__global__ void correlation_cuda_backward_kernel_input1(
|
88 |
+
const TensorAcc5R gradOutput,
|
89 |
+
const TensorAcc4R input2,
|
90 |
+
TensorAcc4R gradInput1,
|
91 |
+
const int kH, const int kW,
|
92 |
+
const int patchH, const int patchW,
|
93 |
+
const int padH, const int padW,
|
94 |
+
const int dilationH, const int dilationW,
|
95 |
+
const int dilation_patchH, const int dilation_patchW,
|
96 |
+
const int dH, const int dW,
|
97 |
+
const int batch) {
|
98 |
+
const int iH = input2.size(2);
|
99 |
+
const int iW = input2.size(3);
|
100 |
+
|
101 |
+
const int H = gradOutput.size(3);
|
102 |
+
const int W = gradOutput.size(4);
|
103 |
+
|
104 |
+
const int patchRadH = (patchH - 1) / 2;
|
105 |
+
const int patchRadW = (patchW - 1) / 2;
|
106 |
+
|
107 |
+
const int n = batch;
|
108 |
+
const int c = blockIdx.x;
|
109 |
+
const int h = blockIdx.y;
|
110 |
+
const int w = blockIdx.z;
|
111 |
+
const int ph_off = threadIdx.x;
|
112 |
+
const int pw_off = threadIdx.y;
|
113 |
+
|
114 |
+
const int h_2 = h + padH;
|
115 |
+
const int w_2 = w + padW;
|
116 |
+
const int min_h = h_2 - kH * dilationH;
|
117 |
+
const int min_w = w_2 - kW * dilationW;
|
118 |
+
|
119 |
+
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
|
120 |
+
prod_sum[ph_off][pw_off] = 0;
|
121 |
+
|
122 |
+
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
|
123 |
+
int i1 = h + dilation_patchH * (ph - patchRadH);
|
124 |
+
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
|
125 |
+
int j1 = w + dilation_patchW * (pw - patchRadW);
|
126 |
+
if (WITHIN_BOUNDS(i1, j1, iH, iW)){
|
127 |
+
scalar_t val = input2[n][c][i1][j1];
|
128 |
+
for(int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
|
129 |
+
int i2 = (h_3)/dH;
|
130 |
+
if (i2 * dH != h_3)
|
131 |
+
continue;
|
132 |
+
for(int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
|
133 |
+
int j2 = (w_3) / dW;
|
134 |
+
if(j2 * dW != w_3)
|
135 |
+
continue;
|
136 |
+
if WITHIN_BOUNDS(i2, j2, H, W) {
|
137 |
+
prod_sum[ph_off][pw_off] += gradOutput[n][ph][pw][i2][j2] * val;
|
138 |
+
}
|
139 |
+
}
|
140 |
+
}
|
141 |
+
}
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
__syncthreads();
|
146 |
+
|
147 |
+
if (ph_off == 0 && pw_off == 0){
|
148 |
+
scalar_t reduce_sum =0;
|
149 |
+
for (int ph = 0; ph < THREADS_BACKWARD; ++ph){
|
150 |
+
for (int pw = 0; pw < THREADS_BACKWARD; ++pw){
|
151 |
+
reduce_sum += prod_sum[ph][pw];
|
152 |
+
}
|
153 |
+
}
|
154 |
+
gradInput1[n][c][h][w] = reduce_sum;
|
155 |
+
}
|
156 |
+
}
|
157 |
+
|
158 |
+
|
159 |
+
template <typename scalar_t>
|
160 |
+
__global__ void correlation_cuda_backward_kernel_input2(
|
161 |
+
const TensorAcc5R gradOutput,
|
162 |
+
const TensorAcc4R input1,
|
163 |
+
TensorAcc4R gradInput2,
|
164 |
+
int kH, int kW,
|
165 |
+
int patchH, int patchW,
|
166 |
+
int padH, int padW,
|
167 |
+
int dilationH, int dilationW,
|
168 |
+
int dilation_patchH, int dilation_patchW,
|
169 |
+
int dH, int dW,
|
170 |
+
int batch) {
|
171 |
+
const int iH = input1.size(2);
|
172 |
+
const int iW = input1.size(3);
|
173 |
+
|
174 |
+
const int patchRadH = (patchH - 1) / 2;
|
175 |
+
const int patchRadW = (patchW - 1) / 2;
|
176 |
+
|
177 |
+
const int H = gradOutput.size(3);
|
178 |
+
const int W = gradOutput.size(4);
|
179 |
+
|
180 |
+
const int dilatedKH = kH * dilationH;
|
181 |
+
const int dilatedKW = kW * dilationW;
|
182 |
+
|
183 |
+
const int n = batch;
|
184 |
+
const int c = blockIdx.x;
|
185 |
+
const int h = blockIdx.y;
|
186 |
+
const int w = blockIdx.z;
|
187 |
+
const int ph_off = threadIdx.x;
|
188 |
+
const int pw_off = threadIdx.y;
|
189 |
+
|
190 |
+
__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
|
191 |
+
prod_sum[ph_off][pw_off] = 0;
|
192 |
+
|
193 |
+
for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
|
194 |
+
int i1 = h - dilation_patchH * (ph - patchRadH);
|
195 |
+
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
|
196 |
+
int j1 = w - dilation_patchW * (pw - patchRadW);
|
197 |
+
if WITHIN_BOUNDS(i1, j1, iH, iW) {
|
198 |
+
scalar_t val = input1[n][c][i1][j1];
|
199 |
+
|
200 |
+
const int h_2 = i1 + padH;
|
201 |
+
const int w_2 = j1 + padW;
|
202 |
+
const int min_h = h_2 - dilatedKH;
|
203 |
+
const int min_w = w_2 - dilatedKW;
|
204 |
+
|
205 |
+
for(int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
|
206 |
+
int i2 = (h_3)/dH;
|
207 |
+
if (i2 * dH != h_3)
|
208 |
+
continue;
|
209 |
+
for(int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
|
210 |
+
int j2 = (w_3) / dW;
|
211 |
+
if(j2 * dW != w_3)
|
212 |
+
continue;
|
213 |
+
if WITHIN_BOUNDS(i2, j2, H, W) {
|
214 |
+
prod_sum[ph_off][pw_off] += gradOutput[n][ph][pw][i2][j2] * val;
|
215 |
+
}
|
216 |
+
}
|
217 |
+
}
|
218 |
+
}
|
219 |
+
}
|
220 |
+
}
|
221 |
+
|
222 |
+
__syncthreads();
|
223 |
+
|
224 |
+
if (ph_off == 0 && pw_off == 0){
|
225 |
+
scalar_t reduce_sum =0;
|
226 |
+
for (int ph = 0; ph < THREADS_BACKWARD; ++ph){
|
227 |
+
for (int pw = 0; pw < THREADS_BACKWARD; ++pw){
|
228 |
+
reduce_sum += prod_sum[ph][pw];
|
229 |
+
}
|
230 |
+
}
|
231 |
+
gradInput2[n][c][h][w] = reduce_sum;
|
232 |
+
}
|
233 |
+
}
|
234 |
+
} // namsepace corr
|
235 |
+
|
236 |
+
torch::Tensor correlation_cuda_forward(
|
237 |
+
torch::Tensor input1,
|
238 |
+
torch::Tensor input2,
|
239 |
+
int kH, int kW,
|
240 |
+
int patchH, int patchW,
|
241 |
+
int padH, int padW,
|
242 |
+
int dilationH, int dilationW,
|
243 |
+
int dilation_patchH, int dilation_patchW,
|
244 |
+
int dH, int dW) {
|
245 |
+
|
246 |
+
const int batch_size = input1.size(0);
|
247 |
+
const int iH = input1.size(2);
|
248 |
+
const int iW = input1.size(3);
|
249 |
+
const int dilatedKH = (kH - 1) * dilationH + 1;
|
250 |
+
const int dilatedKW = (kW - 1) * dilationW + 1;
|
251 |
+
|
252 |
+
const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
|
253 |
+
const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
|
254 |
+
auto output = torch::zeros({batch_size, patchH, patchW, oH, oW}, input1.options());
|
255 |
+
|
256 |
+
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
|
257 |
+
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
|
258 |
+
|
259 |
+
const int threads = THREADS_FORWARD;
|
260 |
+
const dim3 blocks(batch_size, oH, oW);
|
261 |
+
|
262 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), "correlation_forward_cuda", ([&] {
|
263 |
+
TensorAcc4R trInput1_acc = trInput1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
264 |
+
TensorAcc4R trInput2_acc = trInput2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
265 |
+
TensorAcc5R output_acc = output.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
|
266 |
+
corr::correlation_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
|
267 |
+
trInput1_acc, trInput2_acc, output_acc,
|
268 |
+
kH, kW, patchH, patchW, padH, padW, dilationH, dilationW,
|
269 |
+
dilation_patchH, dilation_patchW, dH, dW);
|
270 |
+
}));
|
271 |
+
|
272 |
+
return output;
|
273 |
+
}
|
274 |
+
|
275 |
+
std::vector<torch::Tensor> correlation_cuda_backward(
|
276 |
+
torch::Tensor input1,
|
277 |
+
torch::Tensor input2,
|
278 |
+
torch::Tensor gradOutput,
|
279 |
+
int kH, int kW,
|
280 |
+
int patchH, int patchW,
|
281 |
+
int padH, int padW,
|
282 |
+
int dilationH, int dilationW,
|
283 |
+
int dilation_patchH, int dilation_patchW,
|
284 |
+
int dH, int dW) {
|
285 |
+
|
286 |
+
auto gradInput1 = torch::zeros_like(input1);
|
287 |
+
auto gradInput2 = torch::zeros_like(input2);
|
288 |
+
|
289 |
+
const int batch_size = input1.size(0);
|
290 |
+
const int iH = input1.size(2);
|
291 |
+
const int iW = input1.size(3);
|
292 |
+
const int C = input1.size(1);
|
293 |
+
|
294 |
+
const dim3 blocks(C, iH, iW);
|
295 |
+
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
|
296 |
+
|
297 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), "correlation_backward_cuda", ([&] {
|
298 |
+
TensorAcc4R input1_acc = input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
299 |
+
TensorAcc4R input2_acc = input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
300 |
+
TensorAcc4R gradInput1_acc = gradInput1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
301 |
+
TensorAcc4R gradInput2_acc = gradInput2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
|
302 |
+
TensorAcc5R gradOutput_acc = gradOutput.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
|
303 |
+
|
304 |
+
|
305 |
+
for (int n = 0; n < batch_size; ++n){
|
306 |
+
corr::correlation_cuda_backward_kernel_input1<scalar_t><<<blocks, threads>>>(
|
307 |
+
gradOutput_acc, input2_acc, gradInput1_acc,
|
308 |
+
kH, kW, patchH, patchW, padH, padW,
|
309 |
+
dilationH, dilationW,
|
310 |
+
dilation_patchH, dilation_patchW,
|
311 |
+
dH, dW,
|
312 |
+
n);
|
313 |
+
}
|
314 |
+
|
315 |
+
for (int n = 0; n < batch_size; ++n){
|
316 |
+
corr::correlation_cuda_backward_kernel_input2<scalar_t><<<blocks, threads>>>(
|
317 |
+
gradOutput_acc, input1_acc, gradInput2_acc,
|
318 |
+
kH, kW, patchH, patchW, padH, padW,
|
319 |
+
dilationH, dilationW,
|
320 |
+
dilation_patchH, dilation_patchW,
|
321 |
+
dH, dW,
|
322 |
+
n);
|
323 |
+
}
|
324 |
+
}));
|
325 |
+
|
326 |
+
return {gradInput1, gradInput2};
|
327 |
+
}
|
aot/Pytorch-Correlation-extension/Correlation_Module/correlation_sampler.cpp
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <c10/cuda/CUDAGuard.h>
|
3 |
+
#include <vector>
|
4 |
+
#include <iostream>
|
5 |
+
|
6 |
+
// declarations
|
7 |
+
|
8 |
+
torch::Tensor correlation_cpp_forward(
|
9 |
+
torch::Tensor input1,
|
10 |
+
torch::Tensor input2,
|
11 |
+
int kH, int kW,
|
12 |
+
int patchH, int patchW,
|
13 |
+
int padH, int padW,
|
14 |
+
int dilationH, int dilationW,
|
15 |
+
int dilation_patchH, int dilation_patchW,
|
16 |
+
int dH, int dW);
|
17 |
+
|
18 |
+
std::vector<torch::Tensor> correlation_cpp_backward(
|
19 |
+
torch::Tensor grad_output,
|
20 |
+
torch::Tensor input1,
|
21 |
+
torch::Tensor input2,
|
22 |
+
int kH, int kW,
|
23 |
+
int patchH, int patchW,
|
24 |
+
int padH, int padW,
|
25 |
+
int dilationH, int dilationW,
|
26 |
+
int dilation_patchH, int dilation_patchW,
|
27 |
+
int dH, int dW);
|
28 |
+
|
29 |
+
#ifdef USE_CUDA
|
30 |
+
|
31 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor")
|
32 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous")
|
33 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
34 |
+
#define CHECK_SAME_DEVICE(x, y) TORCH_CHECK(x.device() == y.device(), #x " is not on same device as " #y)
|
35 |
+
|
36 |
+
torch::Tensor correlation_cuda_forward(
|
37 |
+
torch::Tensor input1,
|
38 |
+
torch::Tensor input2,
|
39 |
+
int kH, int kW,
|
40 |
+
int patchH, int patchW,
|
41 |
+
int padH, int padW,
|
42 |
+
int dilationH, int dilationW,
|
43 |
+
int dilation_patchH, int dilation_patchW,
|
44 |
+
int dH, int dW);
|
45 |
+
|
46 |
+
std::vector<torch::Tensor> correlation_cuda_backward(
|
47 |
+
torch::Tensor grad_output,
|
48 |
+
torch::Tensor input1,
|
49 |
+
torch::Tensor input2,
|
50 |
+
int kH, int kW,
|
51 |
+
int patchH, int patchW,
|
52 |
+
int padH, int padW,
|
53 |
+
int dilationH, int dilationW,
|
54 |
+
int dilation_patchH, int dilation_patchW,
|
55 |
+
int dH, int dW);
|
56 |
+
|
57 |
+
// C++ interface
|
58 |
+
|
59 |
+
torch::Tensor correlation_sample_forward(
|
60 |
+
torch::Tensor input1,
|
61 |
+
torch::Tensor input2,
|
62 |
+
int kH, int kW,
|
63 |
+
int patchH, int patchW,
|
64 |
+
int padH, int padW,
|
65 |
+
int dilationH, int dilationW,
|
66 |
+
int dilation_patchH, int dilation_patchW,
|
67 |
+
int dH, int dW) {
|
68 |
+
if (input1.device().is_cuda()){
|
69 |
+
CHECK_INPUT(input1);
|
70 |
+
CHECK_INPUT(input2);
|
71 |
+
|
72 |
+
// set device of input1 as default CUDA device
|
73 |
+
// https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_optional_c_u_d_a_guard.html
|
74 |
+
const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
|
75 |
+
CHECK_SAME_DEVICE(input1, input2);
|
76 |
+
|
77 |
+
return correlation_cuda_forward(input1, input2, kH, kW, patchH, patchW,
|
78 |
+
padH, padW, dilationH, dilationW,
|
79 |
+
dilation_patchH, dilation_patchW,
|
80 |
+
dH, dW);
|
81 |
+
}else{
|
82 |
+
return correlation_cpp_forward(input1, input2, kH, kW, patchH, patchW,
|
83 |
+
padH, padW, dilationH, dilationW,
|
84 |
+
dilation_patchH, dilation_patchW,
|
85 |
+
dH, dW);
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
std::vector<torch::Tensor> correlation_sample_backward(
|
90 |
+
torch::Tensor input1,
|
91 |
+
torch::Tensor input2,
|
92 |
+
torch::Tensor grad_output,
|
93 |
+
int kH, int kW,
|
94 |
+
int patchH, int patchW,
|
95 |
+
int padH, int padW,
|
96 |
+
int dilationH, int dilationW,
|
97 |
+
int dilation_patchH, int dilation_patchW,
|
98 |
+
int dH, int dW) {
|
99 |
+
|
100 |
+
if(grad_output.device().is_cuda()){
|
101 |
+
CHECK_INPUT(input1);
|
102 |
+
CHECK_INPUT(input2);
|
103 |
+
|
104 |
+
// set device of input1 as default CUDA device
|
105 |
+
const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
|
106 |
+
CHECK_SAME_DEVICE(input1, input2);
|
107 |
+
CHECK_SAME_DEVICE(input1, grad_output);
|
108 |
+
|
109 |
+
return correlation_cuda_backward(input1, input2, grad_output,
|
110 |
+
kH, kW, patchH, patchW,
|
111 |
+
padH, padW,
|
112 |
+
dilationH, dilationW,
|
113 |
+
dilation_patchH, dilation_patchW,
|
114 |
+
dH, dW);
|
115 |
+
}else{
|
116 |
+
return correlation_cpp_backward(
|
117 |
+
input1, input2, grad_output,
|
118 |
+
kH, kW, patchH, patchW,
|
119 |
+
padH, padW,
|
120 |
+
dilationH, dilationW,
|
121 |
+
dilation_patchH, dilation_patchW,
|
122 |
+
dH, dW);
|
123 |
+
}
|
124 |
+
}
|
125 |
+
|
126 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
127 |
+
m.def("forward", &correlation_sample_forward, "Spatial Correlation Sampler Forward");
|
128 |
+
m.def("backward", &correlation_sample_backward, "Spatial Correlation Sampler backward");
|
129 |
+
}
|
130 |
+
|
131 |
+
#else
|
132 |
+
|
133 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
134 |
+
m.def("forward", &correlation_cpp_forward, "Spatial Correlation Sampler Forward");
|
135 |
+
m.def("backward", &correlation_cpp_backward, "Spatial Correlation Sampler backward");
|
136 |
+
}
|
137 |
+
|
138 |
+
#endif
|
aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .spatial_correlation_sampler import SpatialCorrelationSampler, spatial_correlation_sample
|
aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/spatial_correlation_sampler.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from torch.autograd import Function
|
3 |
+
from torch.autograd.function import once_differentiable
|
4 |
+
from torch.nn.modules.utils import _pair
|
5 |
+
|
6 |
+
import spatial_correlation_sampler_backend as correlation
|
7 |
+
|
8 |
+
|
9 |
+
def spatial_correlation_sample(input1,
|
10 |
+
input2,
|
11 |
+
kernel_size=1,
|
12 |
+
patch_size=1,
|
13 |
+
stride=1,
|
14 |
+
padding=0,
|
15 |
+
dilation=1,
|
16 |
+
dilation_patch=1):
|
17 |
+
"""Apply spatial correlation sampling on from input1 to input2,
|
18 |
+
|
19 |
+
Every parameter except input1 and input2 can be either single int
|
20 |
+
or a pair of int. For more information about Spatial Correlation
|
21 |
+
Sampling, see this page.
|
22 |
+
https://lmb.informatik.uni-freiburg.de/Publications/2015/DFIB15/
|
23 |
+
|
24 |
+
Args:
|
25 |
+
input1 : The first parameter.
|
26 |
+
input2 : The second parameter.
|
27 |
+
kernel_size : total size of your correlation kernel, in pixels
|
28 |
+
patch_size : total size of your patch, determining how many
|
29 |
+
different shifts will be applied
|
30 |
+
stride : stride of the spatial sampler, will modify output
|
31 |
+
height and width
|
32 |
+
padding : padding applied to input1 and input2 before applying
|
33 |
+
the correlation sampling, will modify output height and width
|
34 |
+
dilation_patch : step for every shift in patch
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Tensor: Result of correlation sampling
|
38 |
+
|
39 |
+
"""
|
40 |
+
return SpatialCorrelationSamplerFunction.apply(input1, input2,
|
41 |
+
kernel_size, patch_size,
|
42 |
+
stride, padding, dilation, dilation_patch)
|
43 |
+
|
44 |
+
|
45 |
+
class SpatialCorrelationSamplerFunction(Function):
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def forward(ctx,
|
49 |
+
input1,
|
50 |
+
input2,
|
51 |
+
kernel_size=1,
|
52 |
+
patch_size=1,
|
53 |
+
stride=1,
|
54 |
+
padding=0,
|
55 |
+
dilation=1,
|
56 |
+
dilation_patch=1):
|
57 |
+
|
58 |
+
ctx.save_for_backward(input1, input2)
|
59 |
+
kH, kW = ctx.kernel_size = _pair(kernel_size)
|
60 |
+
patchH, patchW = ctx.patch_size = _pair(patch_size)
|
61 |
+
padH, padW = ctx.padding = _pair(padding)
|
62 |
+
dilationH, dilationW = ctx.dilation = _pair(dilation)
|
63 |
+
dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(dilation_patch)
|
64 |
+
dH, dW = ctx.stride = _pair(stride)
|
65 |
+
|
66 |
+
output = correlation.forward(input1, input2,
|
67 |
+
kH, kW, patchH, patchW,
|
68 |
+
padH, padW, dilationH, dilationW,
|
69 |
+
dilation_patchH, dilation_patchW,
|
70 |
+
dH, dW)
|
71 |
+
|
72 |
+
return output
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
@once_differentiable
|
76 |
+
def backward(ctx, grad_output):
|
77 |
+
input1, input2 = ctx.saved_variables
|
78 |
+
|
79 |
+
kH, kW = ctx.kernel_size
|
80 |
+
patchH, patchW = ctx.patch_size
|
81 |
+
padH, padW = ctx.padding
|
82 |
+
dilationH, dilationW = ctx.dilation
|
83 |
+
dilation_patchH, dilation_patchW = ctx.dilation_patch
|
84 |
+
dH, dW = ctx.stride
|
85 |
+
|
86 |
+
grad_input1, grad_input2 = correlation.backward(input1, input2, grad_output,
|
87 |
+
kH, kW, patchH, patchW,
|
88 |
+
padH, padW, dilationH, dilationW,
|
89 |
+
dilation_patchH, dilation_patchW,
|
90 |
+
dH, dW)
|
91 |
+
return grad_input1, grad_input2, None, None, None, None, None, None
|
92 |
+
|
93 |
+
|
94 |
+
class SpatialCorrelationSampler(nn.Module):
|
95 |
+
def __init__(self, kernel_size=1, patch_size=1, stride=1, padding=0, dilation=1, dilation_patch=1):
|
96 |
+
super(SpatialCorrelationSampler, self).__init__()
|
97 |
+
self.kernel_size = kernel_size
|
98 |
+
self.patch_size = patch_size
|
99 |
+
self.stride = stride
|
100 |
+
self.padding = padding
|
101 |
+
self.dilation = dilation
|
102 |
+
self.dilation_patch = dilation_patch
|
103 |
+
|
104 |
+
def forward(self, input1, input2):
|
105 |
+
return SpatialCorrelationSamplerFunction.apply(input1, input2, self.kernel_size,
|
106 |
+
self.patch_size, self.stride,
|
107 |
+
self.padding, self.dilation, self.dilation_patch)
|
aot/Pytorch-Correlation-extension/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) [year] [fullname]
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
aot/Pytorch-Correlation-extension/README.md
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[![PyPI](https://img.shields.io/pypi/v/spatial-correlation-sampler.svg)](https://pypi.org/project/spatial-correlation-sampler/)
|
3 |
+
|
4 |
+
|
5 |
+
# Pytorch Correlation module
|
6 |
+
|
7 |
+
this is a custom C++/Cuda implementation of Correlation module, used e.g. in [FlowNetC](https://arxiv.org/abs/1504.06852)
|
8 |
+
|
9 |
+
This [tutorial](http://pytorch.org/tutorials/advanced/cpp_extension.html) was used as a basis for implementation, as well as
|
10 |
+
[NVIDIA's cuda code](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package)
|
11 |
+
|
12 |
+
- Build and Install C++ and CUDA extensions by executing `python setup.py install`,
|
13 |
+
- Benchmark C++ vs. CUDA by running `python benchmark.py {cpu, cuda}`,
|
14 |
+
- Run gradient checks on the code by running `python grad_check.py --backend {cpu, cuda}`.
|
15 |
+
|
16 |
+
# Requirements
|
17 |
+
|
18 |
+
This module is expected to compile for Pytorch `2.1.0`.
|
19 |
+
|
20 |
+
Before installation please check compatibility of your GPU and CUDA (_Compute Capability_) [nvidia docs](https://developer.nvidia.com/cuda-gpus).
|
21 |
+
e.g RTX 6000 is using CC=8.9 so we are setting the environment variable to
|
22 |
+
|
23 |
+
`export TORCH_CUDA_ARCH_LIST="8.9+PTX"`
|
24 |
+
|
25 |
+
# Installation
|
26 |
+
|
27 |
+
be reminded this module requires `python3-dev` to compile C++ code, e.g. on Ubuntu run:
|
28 |
+
|
29 |
+
`apt install python3-dev`
|
30 |
+
|
31 |
+
this module is available on pip
|
32 |
+
|
33 |
+
`pip install spatial-correlation-sampler`
|
34 |
+
|
35 |
+
For a cpu-only version, you can install from source with
|
36 |
+
|
37 |
+
`python setup_cpu.py install`
|
38 |
+
|
39 |
+
# Known Problems
|
40 |
+
|
41 |
+
This module needs compatible gcc version and CUDA to be compiled.
|
42 |
+
Namely, CUDA 9.1 and below will need gcc5, while CUDA 9.2 and 10.0 will need gcc7
|
43 |
+
See [this issue](https://github.com/ClementPinard/Pytorch-Correlation-extension/issues/1) for more information
|
44 |
+
|
45 |
+
# Usage
|
46 |
+
|
47 |
+
API has a few difference with NVIDIA's module
|
48 |
+
* output is now a 5D tensor, which reflects the shifts horizontal and vertical.
|
49 |
+
```
|
50 |
+
input (B x C x H x W) -> output (B x PatchH x PatchW x oH x oW)
|
51 |
+
```
|
52 |
+
* Output sizes `oH` and `oW` are no longer dependant of patch size, but only of kernel size and padding
|
53 |
+
* Patch size `patch_size` is now the whole patch, and not only the radii.
|
54 |
+
* `stride1` is now `stride` and`stride2` is `dilation_patch`, which behave like dilated convolutions
|
55 |
+
* equivalent `max_displacement` is then `dilation_patch * (patch_size - 1) / 2`.
|
56 |
+
* `dilation` is a new parameter, it acts the same way as dilated convolution regarding the correlation kernel
|
57 |
+
* to get the right parameters for FlowNetC, you would have
|
58 |
+
```
|
59 |
+
kernel_size=1
|
60 |
+
patch_size=21,
|
61 |
+
stride=1,
|
62 |
+
padding=0,
|
63 |
+
dilation=1
|
64 |
+
dilation_patch=2
|
65 |
+
```
|
66 |
+
|
67 |
+
|
68 |
+
## Example
|
69 |
+
```python
|
70 |
+
import torch
|
71 |
+
from spatial_correlation_sampler import SpatialCorrelationSampler,
|
72 |
+
|
73 |
+
device = "cuda"
|
74 |
+
batch_size = 1
|
75 |
+
channel = 1
|
76 |
+
H = 10
|
77 |
+
W = 10
|
78 |
+
dtype = torch.float32
|
79 |
+
|
80 |
+
input1 = torch.randint(1, 4, (batch_size, channel, H, W), dtype=dtype, device=device, requires_grad=True)
|
81 |
+
input2 = torch.randint_like(input1, 1, 4).requires_grad_(True)
|
82 |
+
|
83 |
+
#You can either use the function or the module. Note that the module doesn't contain any parameter tensor.
|
84 |
+
|
85 |
+
#function
|
86 |
+
|
87 |
+
out = spatial_correlation_sample(input1,
|
88 |
+
input2,
|
89 |
+
kernel_size=3,
|
90 |
+
patch_size=1,
|
91 |
+
stride=2,
|
92 |
+
padding=0,
|
93 |
+
dilation=2,
|
94 |
+
dilation_patch=1)
|
95 |
+
|
96 |
+
#module
|
97 |
+
|
98 |
+
correlation_sampler = SpatialCorrelationSampler(
|
99 |
+
kernel_size=3,
|
100 |
+
patch_size=1,
|
101 |
+
stride=2,
|
102 |
+
padding=0,
|
103 |
+
dilation=2,
|
104 |
+
dilation_patch=1)
|
105 |
+
out = correlation_sampler(input1, input2)
|
106 |
+
|
107 |
+
```
|
108 |
+
|
109 |
+
# Benchmark
|
110 |
+
|
111 |
+
* default parameters are from `benchmark.py`, FlowNetC parameters are same as use in `FlowNetC` with a batch size of 4, described in [this paper](https://arxiv.org/abs/1504.06852), implemented [here](https://github.com/lmb-freiburg/flownet2) and [here](https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/FlowNetC.py).
|
112 |
+
* Feel free to file an issue to add entries to this with your hardware !
|
113 |
+
|
114 |
+
## CUDA Benchmark
|
115 |
+
|
116 |
+
* See [here](https://gist.github.com/ClementPinard/270e910147119831014932f67fb1b5ea) for a benchmark script working with [NVIDIA](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package)'s code, and Pytorch.
|
117 |
+
* Benchmark are launched with environment variable `CUDA_LAUNCH_BLOCKING` set to `1`.
|
118 |
+
* Only `float32` is benchmarked.
|
119 |
+
* FlowNetC correlation parameters where launched with the following command:
|
120 |
+
|
121 |
+
```bash
|
122 |
+
CUDA_LAUNCH_BLOCKING=1 python benchmark.py --scale ms -k1 --patch 21 -s1 -p0 --patch_dilation 2 -b4 --height 48 --width 64 -c256 cuda -d float
|
123 |
+
|
124 |
+
CUDA_LAUNCH_BLOCKING=1 python NV_correlation_benchmark.py --scale ms -k1 --patch 21 -s1 -p0 --patch_dilation 2 -b4 --height 48 --width 64 -c256
|
125 |
+
```
|
126 |
+
|
127 |
+
| implementation | Correlation parameters | device | pass | min time | avg time |
|
128 |
+
| -------------- | ---------------------- | ------- | -------- | ------------: | ------------: |
|
129 |
+
| ours | default | 980 GTX | forward | **5.745 ms** | **5.851 ms** |
|
130 |
+
| ours | default | 980 GTX | backward | 77.694 ms | 77.957 ms |
|
131 |
+
| NVIDIA | default | 980 GTX | forward | 13.779 ms | 13.853 ms |
|
132 |
+
| NVIDIA | default | 980 GTX | backward | **73.383 ms** | **73.708 ms** |
|
133 |
+
| | | | | | |
|
134 |
+
| ours | FlowNetC | 980 GTX | forward | **26.102 ms** | **26.179 ms** |
|
135 |
+
| ours | FlowNetC | 980 GTX | backward | **208.091 ms** | **208.510 ms** |
|
136 |
+
| NVIDIA | FlowNetC | 980 GTX | forward | 35.363 ms | 35.550 ms |
|
137 |
+
| NVIDIA | FlowNetC | 980 GTX | backward | 283.748 ms | 284.346 ms |
|
138 |
+
|
139 |
+
### Notes
|
140 |
+
* The overhead of our implementation regarding `kernel_size` > 1 during backward needs some investigation, feel free to
|
141 |
+
dive in the code to improve it !
|
142 |
+
* The backward pass of NVIDIA is not entirely correct when stride1 > 1 and kernel_size > 1, because not everything
|
143 |
+
is computed, see [here](https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/correlation_package/src/correlation_cuda_kernel.cu#L120).
|
144 |
+
|
145 |
+
## CPU Benchmark
|
146 |
+
|
147 |
+
* No other implementation is avalaible on CPU.
|
148 |
+
* It is obviously not recommended to run it on CPU if you have a GPU.
|
149 |
+
|
150 |
+
| Correlation parameters | device | pass | min time | avg time |
|
151 |
+
| ---------------------- | -------------------- | -------- | ----------: | ----------: |
|
152 |
+
| default | E5-2630 v3 @ 2.40GHz | forward | 159.616 ms | 188.727 ms |
|
153 |
+
| default | E5-2630 v3 @ 2.40GHz | backward | 282.641 ms | 294.194 ms |
|
154 |
+
| FlowNetC | E5-2630 v3 @ 2.40GHz | forward | 2.138 s | 2.144 s |
|
155 |
+
| FlowNetC | E5-2630 v3 @ 2.40GHz | backward | 7.006 s | 7.075 s |
|
aot/Pytorch-Correlation-extension/benchmark.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
from __future__ import print_function
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import time
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from spatial_correlation_sampler import SpatialCorrelationSampler
|
9 |
+
from tqdm import trange
|
10 |
+
|
11 |
+
TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000}
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
|
15 |
+
parser.add_argument('-b', '--batch-size', type=int, default=16)
|
16 |
+
parser.add_argument('-k', '--kernel-size', type=int, default=3)
|
17 |
+
parser.add_argument('--patch', type=int, default=3)
|
18 |
+
parser.add_argument('--patch_dilation', type=int, default=2)
|
19 |
+
parser.add_argument('-c', '--channel', type=int, default=64)
|
20 |
+
parser.add_argument('--height', type=int, default=100)
|
21 |
+
parser.add_argument('-w', '--width', type=int, default=100)
|
22 |
+
parser.add_argument('-s', '--stride', type=int, default=2)
|
23 |
+
parser.add_argument('-p', '--pad', type=int, default=1)
|
24 |
+
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
|
25 |
+
parser.add_argument('-r', '--runs', type=int, default=100)
|
26 |
+
parser.add_argument('--dilation', type=int, default=2)
|
27 |
+
parser.add_argument('-d', '--dtype', choices=['half', 'float', 'double'])
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
device = torch.device(args.backend)
|
32 |
+
|
33 |
+
if args.dtype == 'half':
|
34 |
+
dtype = torch.float16
|
35 |
+
elif args.dtype == 'float':
|
36 |
+
dtype = torch.float32
|
37 |
+
else:
|
38 |
+
dtype = torch.float64
|
39 |
+
|
40 |
+
|
41 |
+
input1 = torch.randn(args.batch_size,
|
42 |
+
args.channel,
|
43 |
+
args.height,
|
44 |
+
args.width,
|
45 |
+
dtype=dtype,
|
46 |
+
device=device,
|
47 |
+
requires_grad=True)
|
48 |
+
input2 = torch.randn_like(input1)
|
49 |
+
|
50 |
+
correlation_sampler = SpatialCorrelationSampler(
|
51 |
+
args.kernel_size,
|
52 |
+
args.patch,
|
53 |
+
args.stride,
|
54 |
+
args.pad,
|
55 |
+
args.dilation,
|
56 |
+
args.patch_dilation)
|
57 |
+
|
58 |
+
# Force CUDA initialization
|
59 |
+
output = correlation_sampler(input1, input2)
|
60 |
+
print(output.size())
|
61 |
+
output.mean().backward()
|
62 |
+
forward_min = float('inf')
|
63 |
+
forward_time = 0
|
64 |
+
backward_min = float('inf')
|
65 |
+
backward_time = 0
|
66 |
+
for _ in trange(args.runs):
|
67 |
+
correlation_sampler.zero_grad()
|
68 |
+
|
69 |
+
start = time.time()
|
70 |
+
output = correlation_sampler(input1, input2)
|
71 |
+
elapsed = time.time() - start
|
72 |
+
forward_min = min(forward_min, elapsed)
|
73 |
+
forward_time += elapsed
|
74 |
+
output = output.mean()
|
75 |
+
|
76 |
+
start = time.time()
|
77 |
+
(output.mean()).backward()
|
78 |
+
elapsed = time.time() - start
|
79 |
+
backward_min = min(backward_min, elapsed)
|
80 |
+
backward_time += elapsed
|
81 |
+
|
82 |
+
scale = TIME_SCALES[args.scale]
|
83 |
+
forward_min *= scale
|
84 |
+
backward_min *= scale
|
85 |
+
forward_average = forward_time / args.runs * scale
|
86 |
+
backward_average = backward_time / args.runs * scale
|
87 |
+
|
88 |
+
print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format(
|
89 |
+
forward_min, forward_average, backward_min, backward_average,
|
90 |
+
args.scale))
|
aot/Pytorch-Correlation-extension/check.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import division
|
2 |
+
from __future__ import print_function
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from spatial_correlation_sampler import SpatialCorrelationSampler
|
9 |
+
|
10 |
+
|
11 |
+
def check_equal(first, second, verbose):
|
12 |
+
if verbose:
|
13 |
+
print()
|
14 |
+
for i, (x, y) in enumerate(zip(first, second)):
|
15 |
+
x = x.cpu().detach().numpy()
|
16 |
+
y = y.cpu().detach().numpy()
|
17 |
+
if verbose:
|
18 |
+
print("x = {}".format(x.flatten()))
|
19 |
+
print("y = {}".format(y.flatten()))
|
20 |
+
print('-' * 80)
|
21 |
+
np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i))
|
22 |
+
|
23 |
+
|
24 |
+
def zero_grad(variables):
|
25 |
+
for variable in variables:
|
26 |
+
if variable.grad is not None: variable.grad.zero_()
|
27 |
+
|
28 |
+
|
29 |
+
def get_grads(variables):
|
30 |
+
return [var.grad.clone() for var in variables]
|
31 |
+
|
32 |
+
|
33 |
+
def check_forward(input1, input2, correlation_sampler, verbose, gpu_index=0):
|
34 |
+
device = torch.device(f"cuda:{gpu_index}")
|
35 |
+
|
36 |
+
cpu_values = correlation_sampler(input1, input2)
|
37 |
+
cuda_values = correlation_sampler(input1.to(device), input2.to(device))
|
38 |
+
|
39 |
+
print(f"Forward: CPU vs. CUDA device:{gpu_index} ... ", end='')
|
40 |
+
check_equal(cpu_values, cuda_values, verbose)
|
41 |
+
print('Ok')
|
42 |
+
|
43 |
+
|
44 |
+
def check_backward(input1, input2, correlation_sampler, verbose, gpu_index=0):
|
45 |
+
device = torch.device(f"cuda:{gpu_index}")
|
46 |
+
|
47 |
+
zero_grad([input1, input2])
|
48 |
+
|
49 |
+
cpu_values = correlation_sampler(input1, input2)
|
50 |
+
cpu_values.sum().backward()
|
51 |
+
grad_cpu = get_grads([input1, input2])
|
52 |
+
|
53 |
+
zero_grad([input1, input2])
|
54 |
+
|
55 |
+
cuda_values = correlation_sampler(input1.to(device), input2.to(device))
|
56 |
+
cuda_values.sum().backward()
|
57 |
+
grad_cuda = get_grads([input1, input2])
|
58 |
+
|
59 |
+
print(f"Backward: CPU vs. CUDA device:{gpu_index} ... ", end='')
|
60 |
+
check_equal(grad_cpu, grad_cuda, verbose)
|
61 |
+
print('Ok')
|
62 |
+
|
63 |
+
|
64 |
+
def check_multi_gpu_forward(correlation_sampler, verbose):
|
65 |
+
print("Multi-GPU forward")
|
66 |
+
total_gpus = torch.cuda.device_count()
|
67 |
+
for gpu in range(total_gpus):
|
68 |
+
check_forward(input1, input2, correlation_sampler, verbose, gpu_index=gpu)
|
69 |
+
|
70 |
+
def check_multi_gpu_backward(correlation_sampler, verbose):
|
71 |
+
print("Multi-GPU backward")
|
72 |
+
total_gpus = torch.cuda.device_count()
|
73 |
+
for gpu in range(total_gpus):
|
74 |
+
check_backward(input1, input2, correlation_sampler, verbose, gpu_index=gpu)
|
75 |
+
|
76 |
+
|
77 |
+
parser = argparse.ArgumentParser()
|
78 |
+
parser.add_argument('direction', choices=['forward', 'backward'], nargs='+')
|
79 |
+
parser.add_argument('-b', '--batch-size', type=int, default=1)
|
80 |
+
parser.add_argument('-k', '--kernel-size', type=int, default=3)
|
81 |
+
parser.add_argument('--patch', type=int, default=3)
|
82 |
+
parser.add_argument('--patch_dilation', type=int, default=2)
|
83 |
+
parser.add_argument('-c', '--channel', type=int, default=10)
|
84 |
+
parser.add_argument('--height', type=int, default=10)
|
85 |
+
parser.add_argument('-w', '--width', type=int, default=10)
|
86 |
+
parser.add_argument('-s', '--stride', type=int, default=2)
|
87 |
+
parser.add_argument('-p', '--pad', type=int, default=5)
|
88 |
+
parser.add_argument('-v', '--verbose', action='store_true', default=False)
|
89 |
+
parser.add_argument('-d', '--dilation', type=int, default=2)
|
90 |
+
args = parser.parse_args()
|
91 |
+
print(args)
|
92 |
+
|
93 |
+
assert(torch.cuda.is_available()), "no comparison to make"
|
94 |
+
input1 = torch.randn(args.batch_size,
|
95 |
+
args.channel,
|
96 |
+
args.height,
|
97 |
+
args.width).double()
|
98 |
+
input2 = torch.randn(args.batch_size,
|
99 |
+
args.channel,
|
100 |
+
args.height,
|
101 |
+
args.width).double()
|
102 |
+
input1.requires_grad = True
|
103 |
+
input2.requires_grad = True
|
104 |
+
|
105 |
+
correlation_sampler = SpatialCorrelationSampler(
|
106 |
+
args.kernel_size,
|
107 |
+
args.patch,
|
108 |
+
args.stride,
|
109 |
+
args.pad,
|
110 |
+
args.dilation,
|
111 |
+
args.patch_dilation)
|
112 |
+
|
113 |
+
if 'forward' in args.direction:
|
114 |
+
check_forward(input1, input2, correlation_sampler, args.verbose)
|
115 |
+
if torch.cuda.device_count() > 1: check_multi_gpu_forward(correlation_sampler, args.verbose)
|
116 |
+
|
117 |
+
if 'backward' in args.direction:
|
118 |
+
check_backward(input1, input2, correlation_sampler, args.verbose)
|
119 |
+
if torch.cuda.device_count() > 1: check_multi_gpu_backward(correlation_sampler, args.verbose)
|
aot/Pytorch-Correlation-extension/grad_check.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
# torch.set_printoptions(precision=1, threshold=10000)
|
4 |
+
from torch.autograd import gradcheck
|
5 |
+
from spatial_correlation_sampler import SpatialCorrelationSampler
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
|
9 |
+
parser.add_argument('-b', '--batch-size', type=int, default=2)
|
10 |
+
parser.add_argument('-k', '--kernel-size', type=int, default=3)
|
11 |
+
parser.add_argument('--patch', type=int, default=3)
|
12 |
+
parser.add_argument('--patch_dilation', type=int, default=2)
|
13 |
+
parser.add_argument('-c', '--channel', type=int, default=2)
|
14 |
+
parser.add_argument('--height', type=int, default=10)
|
15 |
+
parser.add_argument('-w', '--width', type=int, default=10)
|
16 |
+
parser.add_argument('-s', '--stride', type=int, default=2)
|
17 |
+
parser.add_argument('-p', '--pad', type=int, default=1)
|
18 |
+
parser.add_argument('-d', '--dilation', type=int, default=2)
|
19 |
+
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
input1 = torch.randn(args.batch_size,
|
23 |
+
args.channel,
|
24 |
+
args.height,
|
25 |
+
args.width,
|
26 |
+
dtype=torch.float64,
|
27 |
+
device=torch.device(args.backend))
|
28 |
+
input2 = torch.randn(args.batch_size,
|
29 |
+
args.channel,
|
30 |
+
args.height,
|
31 |
+
args.width,
|
32 |
+
dtype=torch.float64,
|
33 |
+
device=torch.device(args.backend))
|
34 |
+
|
35 |
+
input1.requires_grad = True
|
36 |
+
input2.requires_grad = True
|
37 |
+
|
38 |
+
correlation_sampler = SpatialCorrelationSampler(args.kernel_size,
|
39 |
+
args.patch,
|
40 |
+
args.stride,
|
41 |
+
args.pad,
|
42 |
+
args.dilation,
|
43 |
+
args.patch_dilation)
|
44 |
+
|
45 |
+
|
46 |
+
if gradcheck(correlation_sampler, [input1, input2]):
|
47 |
+
print('Ok')
|
aot/Pytorch-Correlation-extension/requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch>=1.0.1
|
2 |
+
numpy
|
aot/Pytorch-Correlation-extension/setup.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from setuptools import setup
|
3 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
|
4 |
+
from os.path import join
|
5 |
+
|
6 |
+
CPU_ONLY = False
|
7 |
+
project_root = 'Correlation_Module'
|
8 |
+
|
9 |
+
source_files = ['correlation.cpp', 'correlation_sampler.cpp']
|
10 |
+
|
11 |
+
cxx_args = ['-std=c++17', '-fopenmp']
|
12 |
+
|
13 |
+
def generate_nvcc_args(gpu_archs):
|
14 |
+
nvcc_args = []
|
15 |
+
for arch in gpu_archs:
|
16 |
+
nvcc_args.extend(['-gencode', f'arch=compute_{arch},code=sm_{arch}'])
|
17 |
+
return nvcc_args
|
18 |
+
|
19 |
+
gpu_arch = os.environ.get('GPU_ARCH', '').split()
|
20 |
+
nvcc_args = generate_nvcc_args(gpu_arch)
|
21 |
+
|
22 |
+
with open("README.md", "r") as fh:
|
23 |
+
long_description = fh.read()
|
24 |
+
|
25 |
+
|
26 |
+
def launch_setup():
|
27 |
+
if CPU_ONLY:
|
28 |
+
Extension = CppExtension
|
29 |
+
macro = []
|
30 |
+
else:
|
31 |
+
Extension = CUDAExtension
|
32 |
+
source_files.append('correlation_cuda_kernel.cu')
|
33 |
+
macro = [("USE_CUDA", None)]
|
34 |
+
|
35 |
+
sources = [join(project_root, file) for file in source_files]
|
36 |
+
|
37 |
+
setup(
|
38 |
+
name='spatial_correlation_sampler',
|
39 |
+
version="0.4.0",
|
40 |
+
author="Clément Pinard",
|
41 |
+
author_email="[email protected]",
|
42 |
+
description="Correlation module for pytorch",
|
43 |
+
long_description=long_description,
|
44 |
+
long_description_content_type="text/markdown",
|
45 |
+
url="https://github.com/ClementPinard/Pytorch-Correlation-extension",
|
46 |
+
install_requires=['torch>=1.1', 'numpy'],
|
47 |
+
ext_modules=[
|
48 |
+
Extension('spatial_correlation_sampler_backend',
|
49 |
+
sources,
|
50 |
+
define_macros=macro,
|
51 |
+
extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args},
|
52 |
+
extra_link_args=['-lgomp'])
|
53 |
+
],
|
54 |
+
package_dir={'': project_root},
|
55 |
+
packages=['spatial_correlation_sampler'],
|
56 |
+
cmdclass={
|
57 |
+
'build_ext': BuildExtension
|
58 |
+
},
|
59 |
+
classifiers=[
|
60 |
+
"Programming Language :: Python :: 3",
|
61 |
+
"License :: OSI Approved :: MIT License",
|
62 |
+
"Operating System :: POSIX :: Linux",
|
63 |
+
"Intended Audience :: Science/Research",
|
64 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
65 |
+
])
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == '__main__':
|
69 |
+
launch_setup()
|
aot/Pytorch-Correlation-extension/setup_cpu.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import setup
|
2 |
+
|
3 |
+
setup.CPU_ONLY = True
|
4 |
+
setup.launch_setup()
|
aot/README.md
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AOT Series Frameworks in PyTorch
|
2 |
+
|
3 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/decoupling-features-in-hierarchical/semi-supervised-video-object-segmentation-on-15)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-15?p=decoupling-features-in-hierarchical)
|
4 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/video-object-segmentation-on-youtube-vos)](https://paperswithcode.com/sota/video-object-segmentation-on-youtube-vos?p=associating-objects-with-scalable)
|
5 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-18)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-18?p=associating-objects-with-scalable)
|
6 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-1)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-1?p=associating-objects-with-scalable)
|
7 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2017)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2017?p=associating-objects-with-scalable)
|
8 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2016)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2016?p=associating-objects-with-scalable)
|
9 |
+
|
10 |
+
A modular reference PyTorch implementation of AOT series frameworks:
|
11 |
+
- **DeAOT**: Decoupling Features in Hierachical Propagation for Video Object Segmentation (NeurIPS 2022, Spotlight) [[OpenReview](https://openreview.net/forum?id=DgM7-7eMkq0)][[PDF](https://arxiv.org/pdf/2210.09782.pdf)]
|
12 |
+
<img src="source/overview_deaot.png" width="90%"/>
|
13 |
+
|
14 |
+
- **AOT**: Associating Objects with Transformers for Video Object Segmentation (NeurIPS 2021, Score 8/8/7/8) [[OpenReview](https://openreview.net/forum?id=hl3v8io3ZYt)][[PDF](https://arxiv.org/abs/2106.02638)]
|
15 |
+
<img src="source/overview.png" width="90%"/>
|
16 |
+
|
17 |
+
An extension of AOT, [AOST](https://arxiv.org/abs/2203.11442) (under review), is available now. AOST is a more robust and flexible framework, supporting run-time speed-accuracy trade-offs.
|
18 |
+
|
19 |
+
## Examples
|
20 |
+
Benchmark examples:
|
21 |
+
|
22 |
+
<img src="source/some_results.png" width="81%"/>
|
23 |
+
|
24 |
+
General examples (Messi and Kobe):
|
25 |
+
|
26 |
+
<img src="source/messi.gif" width="45%"/> <img src="source/kobe.gif" width="45%"/>
|
27 |
+
|
28 |
+
## Highlights
|
29 |
+
- **High performance:** up to **85.5%** ([R50-AOTL](MODEL_ZOO.md#youtube-vos-2018-val)) on YouTube-VOS 2018 and **82.1%** ([SwinB-AOTL]((MODEL_ZOO.md#youtube-vos-2018-val))) on DAVIS-2017 Test-dev under standard settings (without any test-time augmentation and post processing).
|
30 |
+
- **High efficiency:** up to **51fps** ([AOTT](MODEL_ZOO.md#davis-2017-test)) on DAVIS-2017 (480p) even with **10** objects and **41fps** on YouTube-VOS (1.3x480p). AOT can process multiple objects (less than a pre-defined number, 10 is the default) as efficiently as processing a single object. This project also supports inferring any number of objects together within a video by automatic separation and aggregation.
|
31 |
+
- **Multi-GPU training and inference**
|
32 |
+
- **Mixed precision training and inference**
|
33 |
+
- **Test-time augmentation:** multi-scale and flipping augmentations are supported.
|
34 |
+
|
35 |
+
## Requirements
|
36 |
+
* Python3
|
37 |
+
* pytorch >= 1.7.0 and torchvision
|
38 |
+
* opencv-python
|
39 |
+
* Pillow
|
40 |
+
* Pytorch Correlation (Recommend to install from [source](https://github.com/ClementPinard/Pytorch-Correlation-extension) instead of using `pip`. **The project can also work without this module but will lose some efficiency of the short-term attention**.)
|
41 |
+
|
42 |
+
Optional:
|
43 |
+
* scikit-image (if you want to run our **Demo**, please install)
|
44 |
+
|
45 |
+
## Model Zoo and Results
|
46 |
+
Pre-trained models, benckmark scores, and pre-computed results reproduced by this project can be found in [MODEL_ZOO.md](MODEL_ZOO.md).
|
47 |
+
|
48 |
+
## Demo - Panoptic Propagation
|
49 |
+
We provide a simple demo to demonstrate AOT's effectiveness. The demo will propagate more than **40** objects, including semantic regions (like sky) and instances (like person), together within a single complex scenario and predict its video panoptic segmentation.
|
50 |
+
|
51 |
+
To run the demo, download the [checkpoint](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) of R50-AOTL into [pretrain_models](pretrain_models), and then run:
|
52 |
+
```bash
|
53 |
+
python tools/demo.py
|
54 |
+
```
|
55 |
+
which will predict the given scenarios in the resolution of 1.3x480p. You can also run this demo with other AOTs ([MODEL_ZOO.md](MODEL_ZOO.md)) by setting `--model` (model type) and `--ckpt_path` (checkpoint path).
|
56 |
+
|
57 |
+
Two scenarios from [VSPW](https://www.vspwdataset.com/home) are supplied in [datasets/Demo](datasets/Demo):
|
58 |
+
|
59 |
+
- 1001_3iEIq5HBY1s: 44 objects. 1080P.
|
60 |
+
- 1007_YCTBBdbKSSg: 43 objects. 1080P.
|
61 |
+
|
62 |
+
Results:
|
63 |
+
|
64 |
+
<img src="source/1001_3iEIq5HBY1s.gif" width="45%"/> <img src="source/1007_YCTBBdbKSSg.gif" width="45%"/>
|
65 |
+
|
66 |
+
|
67 |
+
## Getting Started
|
68 |
+
0. Prepare a valid environment follow the [requirements](#requirements).
|
69 |
+
|
70 |
+
1. Prepare datasets:
|
71 |
+
|
72 |
+
Please follow the below instruction to prepare datasets in each corresponding folder.
|
73 |
+
* **Static**
|
74 |
+
|
75 |
+
[datasets/Static](datasets/Static): pre-training dataset with static images. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training.
|
76 |
+
* **YouTube-VOS**
|
77 |
+
|
78 |
+
A commonly-used large-scale VOS dataset.
|
79 |
+
|
80 |
+
[datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation.
|
81 |
+
|
82 |
+
[datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation.
|
83 |
+
|
84 |
+
* **DAVIS**
|
85 |
+
|
86 |
+
A commonly-used small-scale VOS dataset.
|
87 |
+
|
88 |
+
[datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required.
|
89 |
+
|
90 |
+
|
91 |
+
2. Prepare ImageNet pre-trained encoders
|
92 |
+
|
93 |
+
Select and download below checkpoints into [pretrain_models](pretrain_models):
|
94 |
+
|
95 |
+
- [MobileNet-V2](https://download.pytorch.org/models/mobilenet_v2-b0353104.pth) (default encoder)
|
96 |
+
- [MobileNet-V3](https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth)
|
97 |
+
- [ResNet-50](https://download.pytorch.org/models/resnet50-0676ba61.pth)
|
98 |
+
- [ResNet-101](https://download.pytorch.org/models/resnet101-63fe2227.pth)
|
99 |
+
- [ResNeSt-50](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest50-528c19ca.pth)
|
100 |
+
- [ResNeSt-101](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth)
|
101 |
+
- [Swin-Base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)
|
102 |
+
|
103 |
+
The current default training configs are not optimized for encoders larger than ResNet-50. If you want to use larger encoders, we recommend early stopping the main-training stage at 80,000 iterations (100,000 in default) to avoid over-fitting on the seen classes of YouTube-VOS.
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
3. Training and Evaluation
|
108 |
+
|
109 |
+
The [example script](train_eval.sh) will train AOTT with 2 stages using 4 GPUs and auto-mixed precision (`--amp`). The first stage is a pre-training stage using `Static` dataset, and the second stage is a main-training stage, which uses both `YouTube-VOS 2019 train` and `DAVIS-2017 train` for training, resulting in a model that can generalize to different domains (YouTube-VOS and DAVIS) and different frame rates (6fps, 24fps, and 30fps).
|
110 |
+
|
111 |
+
Notably, you can use only the `YouTube-VOS 2019 train` split in the second stage by changing `pre_ytb_dav` to `pre_ytb`, which leads to better YouTube-VOS performance on unseen classes. Besides, if you don't want to do the first stage, you can start the training from stage `ytb`, but the performance will drop about 1~2% absolutely.
|
112 |
+
|
113 |
+
After the training is finished (about 0.6 days for each stage with 4 Tesla V100 GPUs), the [example script](train_eval.sh) will evaluate the model on YouTube-VOS and DAVIS, and the results will be packed into Zip files. For calculating scores, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev).
|
114 |
+
|
115 |
+
## Adding your own dataset
|
116 |
+
Coming
|
117 |
+
|
118 |
+
## Troubleshooting
|
119 |
+
Waiting
|
120 |
+
|
121 |
+
## TODO
|
122 |
+
- [ ] Code documentation
|
123 |
+
- [ ] Adding your own dataset
|
124 |
+
- [ ] Results with test-time augmentations in Model Zoo
|
125 |
+
- [ ] Support gradient accumulation
|
126 |
+
- [x] Demo tool
|
127 |
+
|
128 |
+
## Citations
|
129 |
+
Please consider citing the related paper(s) in your publications if it helps your research.
|
130 |
+
```
|
131 |
+
@inproceedings{yang2022deaot,
|
132 |
+
title={Decoupling Features in Hierarchical Propagation for Video Object Segmentation},
|
133 |
+
author={Yang, Zongxin and Yang, Yi},
|
134 |
+
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
|
135 |
+
year={2022}
|
136 |
+
}
|
137 |
+
@article{yang2021aost,
|
138 |
+
title={Scalable Multi-object Identification for Video Object Segmentation},
|
139 |
+
author={Yang, Zongxin and Miao, Jiaxu and Wang, Xiaohan and Wei, Yunchao and Yang, Yi},
|
140 |
+
journal={arXiv preprint arXiv:2203.11442},
|
141 |
+
year={2022}
|
142 |
+
}
|
143 |
+
@inproceedings{yang2021aot,
|
144 |
+
title={Associating Objects with Transformers for Video Object Segmentation},
|
145 |
+
author={Yang, Zongxin and Wei, Yunchao and Yang, Yi},
|
146 |
+
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
|
147 |
+
year={2021}
|
148 |
+
}
|
149 |
+
```
|
150 |
+
|
151 |
+
## License
|
152 |
+
This project is released under the BSD-3-Clause license. See [LICENSE](LICENSE) for additional details.
|
aot/__init__.py
ADDED
File without changes
|
aot/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (181 Bytes). View file
|
|
aot/configs/__pycache__/default.cpython-310.pyc
ADDED
Binary file (4.29 kB). View file
|
|
aot/configs/__pycache__/pre_ytb_dav.cpython-310.pyc
ADDED
Binary file (943 Bytes). View file
|
|
aot/configs/default.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import importlib
|
3 |
+
|
4 |
+
|
5 |
+
class DefaultEngineConfig():
|
6 |
+
def __init__(self, exp_name='default', model='aott'):
|
7 |
+
model_cfg = importlib.import_module('configs.models.' +
|
8 |
+
model).ModelConfig()
|
9 |
+
self.__dict__.update(model_cfg.__dict__) # add model config
|
10 |
+
|
11 |
+
self.EXP_NAME = exp_name + '_' + self.MODEL_NAME
|
12 |
+
|
13 |
+
self.STAGE_NAME = 'YTB'
|
14 |
+
|
15 |
+
self.DATASETS = ['youtubevos']
|
16 |
+
self.DATA_WORKERS = 8
|
17 |
+
self.DATA_RANDOMCROP = (465,
|
18 |
+
465) if self.MODEL_ALIGN_CORNERS else (464,
|
19 |
+
464)
|
20 |
+
self.DATA_RANDOMFLIP = 0.5
|
21 |
+
self.DATA_MAX_CROP_STEPS = 10
|
22 |
+
self.DATA_SHORT_EDGE_LEN = 480
|
23 |
+
self.DATA_MIN_SCALE_FACTOR = 0.7
|
24 |
+
self.DATA_MAX_SCALE_FACTOR = 1.3
|
25 |
+
self.DATA_RANDOM_REVERSE_SEQ = True
|
26 |
+
self.DATA_SEQ_LEN = 5
|
27 |
+
self.DATA_DAVIS_REPEAT = 5
|
28 |
+
self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps)
|
29 |
+
self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps)
|
30 |
+
self.DATA_DYNAMIC_MERGE_PROB = 0.3
|
31 |
+
|
32 |
+
self.PRETRAIN = True
|
33 |
+
self.PRETRAIN_FULL = False # if False, load encoder only
|
34 |
+
self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth'
|
35 |
+
# self.PRETRAIN_MODEL = './pretrain_models/mobilenet_v2-b0353104.pth'
|
36 |
+
|
37 |
+
self.TRAIN_TOTAL_STEPS = 100000
|
38 |
+
self.TRAIN_START_STEP = 0
|
39 |
+
self.TRAIN_WEIGHT_DECAY = 0.07
|
40 |
+
self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = {
|
41 |
+
# 'encoder.': 0.01
|
42 |
+
}
|
43 |
+
self.TRAIN_WEIGHT_DECAY_EXEMPTION = [
|
44 |
+
'absolute_pos_embed', 'relative_position_bias_table',
|
45 |
+
'relative_emb_v', 'conv_out'
|
46 |
+
]
|
47 |
+
self.TRAIN_LR = 2e-4
|
48 |
+
self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5
|
49 |
+
self.TRAIN_LR_POWER = 0.9
|
50 |
+
self.TRAIN_LR_ENCODER_RATIO = 0.1
|
51 |
+
self.TRAIN_LR_WARM_UP_RATIO = 0.05
|
52 |
+
self.TRAIN_LR_COSINE_DECAY = False
|
53 |
+
self.TRAIN_LR_RESTART = 1
|
54 |
+
self.TRAIN_LR_UPDATE_STEP = 1
|
55 |
+
self.TRAIN_AUX_LOSS_WEIGHT = 1.0
|
56 |
+
self.TRAIN_AUX_LOSS_RATIO = 1.0
|
57 |
+
self.TRAIN_OPT = 'adamw'
|
58 |
+
self.TRAIN_SGD_MOMENTUM = 0.9
|
59 |
+
self.TRAIN_GPUS = 4
|
60 |
+
self.TRAIN_BATCH_SIZE = 16
|
61 |
+
self.TRAIN_TBLOG = False
|
62 |
+
self.TRAIN_TBLOG_STEP = 50
|
63 |
+
self.TRAIN_LOG_STEP = 20
|
64 |
+
self.TRAIN_IMG_LOG = True
|
65 |
+
self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15
|
66 |
+
self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank']
|
67 |
+
self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5
|
68 |
+
self.TRAIN_HARD_MINING_RATIO = 0.5
|
69 |
+
self.TRAIN_EMA_RATIO = 0.1
|
70 |
+
self.TRAIN_CLIP_GRAD_NORM = 5.
|
71 |
+
self.TRAIN_SAVE_STEP = 5000
|
72 |
+
self.TRAIN_MAX_KEEP_CKPT = 8
|
73 |
+
self.TRAIN_RESUME = False
|
74 |
+
self.TRAIN_RESUME_CKPT = None
|
75 |
+
self.TRAIN_RESUME_STEP = 0
|
76 |
+
self.TRAIN_AUTO_RESUME = True
|
77 |
+
self.TRAIN_DATASET_FULL_RESOLUTION = False
|
78 |
+
self.TRAIN_ENABLE_PREV_FRAME = False
|
79 |
+
self.TRAIN_ENCODER_FREEZE_AT = 2
|
80 |
+
self.TRAIN_LSTT_EMB_DROPOUT = 0.
|
81 |
+
self.TRAIN_LSTT_ID_DROPOUT = 0.
|
82 |
+
self.TRAIN_LSTT_DROPPATH = 0.1
|
83 |
+
self.TRAIN_LSTT_DROPPATH_SCALING = False
|
84 |
+
self.TRAIN_LSTT_DROPPATH_LST = False
|
85 |
+
self.TRAIN_LSTT_LT_DROPOUT = 0.
|
86 |
+
self.TRAIN_LSTT_ST_DROPOUT = 0.
|
87 |
+
|
88 |
+
self.TEST_GPU_ID = 0
|
89 |
+
self.TEST_GPU_NUM = 1
|
90 |
+
self.TEST_FRAME_LOG = False
|
91 |
+
self.TEST_DATASET = 'youtubevos'
|
92 |
+
self.TEST_DATASET_FULL_RESOLUTION = False
|
93 |
+
self.TEST_DATASET_SPLIT = 'val'
|
94 |
+
self.TEST_CKPT_PATH = None
|
95 |
+
# if "None", evaluate the latest checkpoint.
|
96 |
+
self.TEST_CKPT_STEP = None
|
97 |
+
self.TEST_FLIP = False
|
98 |
+
self.TEST_MULTISCALE = [1]
|
99 |
+
self.TEST_MAX_SHORT_EDGE = None
|
100 |
+
self.TEST_MAX_LONG_EDGE = 800 * 1.3
|
101 |
+
self.TEST_WORKERS = 4
|
102 |
+
|
103 |
+
# GPU distribution
|
104 |
+
self.DIST_ENABLE = True
|
105 |
+
self.DIST_BACKEND = "nccl" # "gloo"
|
106 |
+
self.DIST_URL = "tcp://127.0.0.1:13241"
|
107 |
+
self.DIST_START_GPU = 0
|
108 |
+
|
109 |
+
def init_dir(self):
|
110 |
+
self.DIR_DATA = '../VOS02/datasets'#'./datasets'
|
111 |
+
self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS')
|
112 |
+
self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB')
|
113 |
+
self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static')
|
114 |
+
|
115 |
+
self.DIR_ROOT = './'#'./data_wd/youtube_vos_jobs'
|
116 |
+
|
117 |
+
self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME,
|
118 |
+
self.STAGE_NAME)
|
119 |
+
self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt')
|
120 |
+
self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt')
|
121 |
+
self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log')
|
122 |
+
self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard')
|
123 |
+
# self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img')
|
124 |
+
# self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval')
|
125 |
+
self.DIR_IMG_LOG = './img_logs'
|
126 |
+
self.DIR_EVALUATION = './results'
|
127 |
+
|
128 |
+
for path in [
|
129 |
+
self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT,
|
130 |
+
self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG,
|
131 |
+
self.DIR_TB_LOG
|
132 |
+
]:
|
133 |
+
if not os.path.isdir(path):
|
134 |
+
try:
|
135 |
+
os.makedirs(path)
|
136 |
+
except Exception as inst:
|
137 |
+
print(inst)
|
138 |
+
print('Failed to make dir: {}.'.format(path))
|
aot/configs/models/__pycache__/default.cpython-310.pyc
ADDED
Binary file (1.22 kB). View file
|
|
aot/configs/models/__pycache__/r50_aotl.cpython-310.pyc
ADDED
Binary file (873 Bytes). View file
|
|
aot/configs/models/aotb.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .default import DefaultModelConfig
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'AOTB'
|
8 |
+
|
9 |
+
self.MODEL_LSTT_NUM = 3
|
aot/configs/models/aotl.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .default import DefaultModelConfig
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'AOTL'
|
8 |
+
|
9 |
+
self.MODEL_LSTT_NUM = 3
|
10 |
+
|
11 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
12 |
+
|
13 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/models/aots.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .default import DefaultModelConfig
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'AOTS'
|
8 |
+
|
9 |
+
self.MODEL_LSTT_NUM = 2
|
aot/configs/models/aott.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .default import DefaultModelConfig
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'AOTT'
|
aot/configs/models/deaotb.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_deaot import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'DeAOTB'
|
8 |
+
|
9 |
+
self.MODEL_LSTT_NUM = 3
|
aot/configs/models/deaotl.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_deaot import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'DeAOTL'
|
8 |
+
|
9 |
+
self.MODEL_LSTT_NUM = 3
|
10 |
+
|
11 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
12 |
+
|
13 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/models/deaots.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_deaot import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'DeAOTS'
|
8 |
+
|
9 |
+
self.MODEL_LSTT_NUM = 2
|
aot/configs/models/deaott.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_deaot import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'DeAOTT'
|
aot/configs/models/default.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class DefaultModelConfig():
|
2 |
+
def __init__(self):
|
3 |
+
self.MODEL_NAME = 'AOTDefault'
|
4 |
+
|
5 |
+
self.MODEL_VOS = 'aot'
|
6 |
+
self.MODEL_ENGINE = 'aotengine'
|
7 |
+
self.MODEL_ALIGN_CORNERS = True
|
8 |
+
self.MODEL_ENCODER = 'mobilenetv2'
|
9 |
+
self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth'
|
10 |
+
self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x
|
11 |
+
self.MODEL_ENCODER_EMBEDDING_DIM = 256
|
12 |
+
self.MODEL_DECODER_INTERMEDIATE_LSTT = True
|
13 |
+
self.MODEL_FREEZE_BN = True
|
14 |
+
self.MODEL_FREEZE_BACKBONE = False
|
15 |
+
self.MODEL_MAX_OBJ_NUM = 10
|
16 |
+
self.MODEL_SELF_HEADS = 8
|
17 |
+
self.MODEL_ATT_HEADS = 8
|
18 |
+
self.MODEL_LSTT_NUM = 1
|
19 |
+
self.MODEL_EPSILON = 1e-5
|
20 |
+
self.MODEL_USE_PREV_PROB = False
|
21 |
+
|
22 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 9999
|
23 |
+
self.TRAIN_AUG_TYPE = 'v1'
|
24 |
+
|
25 |
+
self.TEST_LONG_TERM_MEM_GAP = 9999
|
26 |
+
|
27 |
+
self.TEST_SHORT_TERM_MEM_SKIP = 1
|
aot/configs/models/default_deaot.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default import DefaultModelConfig as BaseConfig
|
2 |
+
|
3 |
+
|
4 |
+
class DefaultModelConfig(BaseConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'DeAOTDefault'
|
8 |
+
|
9 |
+
self.MODEL_VOS = 'deaot'
|
10 |
+
self.MODEL_ENGINE = 'deaotengine'
|
11 |
+
|
12 |
+
self.MODEL_DECODER_INTERMEDIATE_LSTT = False
|
13 |
+
|
14 |
+
self.MODEL_SELF_HEADS = 1
|
15 |
+
self.MODEL_ATT_HEADS = 1
|
16 |
+
|
17 |
+
self.TRAIN_AUG_TYPE = 'v2'
|
aot/configs/models/r101_aotl.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'R101_AOTL'
|
8 |
+
|
9 |
+
self.MODEL_ENCODER = 'resnet101'
|
10 |
+
self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth
|
11 |
+
self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
|
12 |
+
self.MODEL_LSTT_NUM = 3
|
13 |
+
|
14 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
15 |
+
|
16 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/models/r50_aotl.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'R50_AOTL'
|
8 |
+
|
9 |
+
self.MODEL_ENCODER = 'resnet50'
|
10 |
+
self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth
|
11 |
+
self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
|
12 |
+
self.MODEL_LSTT_NUM = 3
|
13 |
+
|
14 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
15 |
+
|
16 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/models/r50_deaotl.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_deaot import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'R50_DeAOTL'
|
8 |
+
|
9 |
+
self.MODEL_ENCODER = 'resnet50'
|
10 |
+
self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
|
11 |
+
|
12 |
+
self.MODEL_LSTT_NUM = 3
|
13 |
+
|
14 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
15 |
+
|
16 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/models/rs101_aotl.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'R101_AOTL'
|
8 |
+
|
9 |
+
self.MODEL_ENCODER = 'resnest101'
|
10 |
+
self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth
|
11 |
+
self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
|
12 |
+
self.MODEL_LSTT_NUM = 3
|
13 |
+
|
14 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
15 |
+
|
16 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/models/swinb_aotl.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'SwinB_AOTL'
|
8 |
+
|
9 |
+
self.MODEL_ENCODER = 'swin_base'
|
10 |
+
self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
|
11 |
+
self.MODEL_ALIGN_CORNERS = False
|
12 |
+
self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x
|
13 |
+
self.MODEL_LSTT_NUM = 3
|
14 |
+
|
15 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
16 |
+
|
17 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/models/swinb_deaotl.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_deaot import DefaultModelConfig
|
2 |
+
|
3 |
+
|
4 |
+
class ModelConfig(DefaultModelConfig):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
self.MODEL_NAME = 'SwinB_DeAOTL'
|
8 |
+
|
9 |
+
self.MODEL_ENCODER = 'swin_base'
|
10 |
+
self.MODEL_ALIGN_CORNERS = False
|
11 |
+
self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x
|
12 |
+
|
13 |
+
self.MODEL_LSTT_NUM = 3
|
14 |
+
|
15 |
+
self.TRAIN_LONG_TERM_MEM_GAP = 2
|
16 |
+
|
17 |
+
self.TEST_LONG_TERM_MEM_GAP = 5
|
aot/configs/pre.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default import DefaultEngineConfig
|
2 |
+
|
3 |
+
|
4 |
+
class EngineConfig(DefaultEngineConfig):
|
5 |
+
def __init__(self, exp_name='default', model='AOTT'):
|
6 |
+
super().__init__(exp_name, model)
|
7 |
+
self.STAGE_NAME = 'PRE'
|
8 |
+
|
9 |
+
self.init_dir()
|
10 |
+
|
11 |
+
self.DATASETS = ['static']
|
12 |
+
|
13 |
+
self.DATA_DYNAMIC_MERGE_PROB = 1.0
|
14 |
+
|
15 |
+
self.TRAIN_LR = 4e-4
|
16 |
+
self.TRAIN_LR_MIN = 2e-5
|
17 |
+
self.TRAIN_WEIGHT_DECAY = 0.03
|
18 |
+
self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0
|
19 |
+
self.TRAIN_AUX_LOSS_RATIO = 0.1
|
aot/configs/pre_dav.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .default import DefaultEngineConfig
|
3 |
+
|
4 |
+
|
5 |
+
class EngineConfig(DefaultEngineConfig):
|
6 |
+
def __init__(self, exp_name='default', model='AOTT'):
|
7 |
+
super().__init__(exp_name, model)
|
8 |
+
self.STAGE_NAME = 'PRE_DAV'
|
9 |
+
|
10 |
+
self.init_dir()
|
11 |
+
|
12 |
+
self.DATASETS = ['davis2017']
|
13 |
+
|
14 |
+
self.TRAIN_TOTAL_STEPS = 50000
|
15 |
+
|
16 |
+
pretrain_stage = 'PRE'
|
17 |
+
pretrain_ckpt = 'save_step_100000.pth'
|
18 |
+
self.PRETRAIN_FULL = True # if False, load encoder only
|
19 |
+
self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result',
|
20 |
+
self.EXP_NAME, pretrain_stage,
|
21 |
+
'ema_ckpt', pretrain_ckpt)
|
aot/configs/pre_ytb.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .default import DefaultEngineConfig
|
3 |
+
|
4 |
+
|
5 |
+
class EngineConfig(DefaultEngineConfig):
|
6 |
+
def __init__(self, exp_name='default', model='AOTT'):
|
7 |
+
super().__init__(exp_name, model)
|
8 |
+
self.STAGE_NAME = 'PRE_YTB'
|
9 |
+
|
10 |
+
self.init_dir()
|
11 |
+
|
12 |
+
pretrain_stage = 'PRE'
|
13 |
+
pretrain_ckpt = 'save_step_100000.pth'
|
14 |
+
self.PRETRAIN_FULL = True # if False, load encoder only
|
15 |
+
self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result',
|
16 |
+
self.EXP_NAME, pretrain_stage,
|
17 |
+
'ema_ckpt', pretrain_ckpt)
|