aikenml commited on
Commit
a69d385
·
1 Parent(s): e4ff10b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +4 -0
  3. __pycache__/SegTracker.cpython-310.pyc +0 -0
  4. __pycache__/aot_tracker.cpython-310.pyc +0 -0
  5. __pycache__/model_args.cpython-310.pyc +0 -0
  6. __pycache__/seg_track_anything.cpython-310.pyc +0 -0
  7. aot/.DS_Store +0 -0
  8. aot/LICENSE +29 -0
  9. aot/MODEL_ZOO.md +115 -0
  10. aot/Pytorch-Correlation-extension/.gitignore +1 -0
  11. aot/Pytorch-Correlation-extension/Correlation_Module/correlation.cpp +178 -0
  12. aot/Pytorch-Correlation-extension/Correlation_Module/correlation_cuda_kernel.cu +327 -0
  13. aot/Pytorch-Correlation-extension/Correlation_Module/correlation_sampler.cpp +138 -0
  14. aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/__init__.py +1 -0
  15. aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/spatial_correlation_sampler.py +107 -0
  16. aot/Pytorch-Correlation-extension/LICENSE +21 -0
  17. aot/Pytorch-Correlation-extension/README.md +155 -0
  18. aot/Pytorch-Correlation-extension/benchmark.py +90 -0
  19. aot/Pytorch-Correlation-extension/check.py +119 -0
  20. aot/Pytorch-Correlation-extension/grad_check.py +47 -0
  21. aot/Pytorch-Correlation-extension/requirements.txt +2 -0
  22. aot/Pytorch-Correlation-extension/setup.py +69 -0
  23. aot/Pytorch-Correlation-extension/setup_cpu.py +4 -0
  24. aot/README.md +152 -0
  25. aot/__init__.py +0 -0
  26. aot/__pycache__/__init__.cpython-310.pyc +0 -0
  27. aot/configs/__pycache__/default.cpython-310.pyc +0 -0
  28. aot/configs/__pycache__/pre_ytb_dav.cpython-310.pyc +0 -0
  29. aot/configs/default.py +138 -0
  30. aot/configs/models/__pycache__/default.cpython-310.pyc +0 -0
  31. aot/configs/models/__pycache__/r50_aotl.cpython-310.pyc +0 -0
  32. aot/configs/models/aotb.py +9 -0
  33. aot/configs/models/aotl.py +13 -0
  34. aot/configs/models/aots.py +9 -0
  35. aot/configs/models/aott.py +7 -0
  36. aot/configs/models/deaotb.py +9 -0
  37. aot/configs/models/deaotl.py +13 -0
  38. aot/configs/models/deaots.py +9 -0
  39. aot/configs/models/deaott.py +7 -0
  40. aot/configs/models/default.py +27 -0
  41. aot/configs/models/default_deaot.py +17 -0
  42. aot/configs/models/r101_aotl.py +16 -0
  43. aot/configs/models/r50_aotl.py +16 -0
  44. aot/configs/models/r50_deaotl.py +16 -0
  45. aot/configs/models/rs101_aotl.py +16 -0
  46. aot/configs/models/swinb_aotl.py +17 -0
  47. aot/configs/models/swinb_deaotl.py +17 -0
  48. aot/configs/pre.py +19 -0
  49. aot/configs/pre_dav.py +21 -0
  50. aot/configs/pre_ytb.py +17 -0
.DS_Store ADDED
Binary file (10.2 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cars.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ assets/cell.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ assets/demo_3x2.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/top.gif filter=lfs diff=lfs merge=lfs -text
__pycache__/SegTracker.cpython-310.pyc ADDED
Binary file (8.05 kB). View file
 
__pycache__/aot_tracker.cpython-310.pyc ADDED
Binary file (5.71 kB). View file
 
__pycache__/model_args.cpython-310.pyc ADDED
Binary file (772 Bytes). View file
 
__pycache__/seg_track_anything.cpython-310.pyc ADDED
Binary file (6.5 kB). View file
 
aot/.DS_Store ADDED
Binary file (8.2 kB). View file
 
aot/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2020, z-x-yang
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
aot/MODEL_ZOO.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Model Zoo and Results
2
+
3
+ ### Environment and Settings
4
+ - 4/1 NVIDIA V100 GPUs for training/evaluation.
5
+ - Auto-mixed precision was enabled in training but disabled in evaluation.
6
+ - Test-time augmentations were not used.
7
+ - The inference resolution of DAVIS/YouTube-VOS was 480p/1.3x480p as [CFBI](https://github.com/z-x-yang/CFBI).
8
+ - Fully online inference. We passed all the modules frame by frame.
9
+ - Multi-object FPS was recorded instead of single-object one.
10
+
11
+ ### Pre-trained Models
12
+ Stages:
13
+
14
+ - `PRE`: the pre-training stage with static images.
15
+
16
+ - `PRE_YTB_DAV`: the main-training stage with YouTube-VOS and DAVIS. All the kinds of evaluation share an **identical** model and the **same** parameters.
17
+
18
+
19
+ | Model | Param (M) | PRE | PRE_YTB_DAV |
20
+ |:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
21
+ | AOTT | 5.7 | [gdrive](https://drive.google.com/file/d/1_513h8Hok9ySQPMs_dHgX5sPexUhyCmy/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1owPmwV4owd_ll6GuilzklqTyAd0ZvbCu/view?usp=sharing) |
22
+ | AOTS | 7.0 | [gdrive](https://drive.google.com/file/d/1QUP0-VED-lOF1oX_ppYWnXyBjvUzJJB7/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1beU5E6Mdnr_pPrgjWvdWurKAIwJSz1xf/view?usp=sharing) |
23
+ | AOTB | 8.3 | [gdrive](https://drive.google.com/file/d/11Bx8n_INAha1IdpHjueGpf7BrKmCJDvK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1hH-GOn4GAxHkV8ARcQzsUy8Ax6ndot-A/view?usp=sharing) |
24
+ | AOTL | 8.3 | [gdrive](https://drive.google.com/file/d/1WL6QCsYeT7Bt-Gain9ZIrNNXpR2Hgh29/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1L1N2hkSPqrwGgnW9GyFHuG59_EYYfTG4/view?usp=sharing) |
25
+ | R50-AOTL | 14.9 | [gdrive](https://drive.google.com/file/d/1hS4JIvOXeqvbs-CokwV6PwZV-EvzE6x8/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) |
26
+ | SwinB-AOTL | 65.4 | [gdrive](https://drive.google.com/file/d/1LlhKQiXD8JyZGGs3hZiNzcaCLqyvL9tj/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/192jCGQZdnuTsvX-CVra-KVZl2q1ZR0vW/view?usp=sharing) |
27
+
28
+ | Model | Param (M) | PRE | PRE_YTB_DAV |
29
+ |:---------- |:---------:|:--------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|
30
+ | DeAOTT | 7.2 | [gdrive](https://drive.google.com/file/d/11C1ZBoFpL3ztKtINS8qqwPSldfYXexFK/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1ThWIZQS03cYWx1EKNN8MIMnJS5eRowzr/view?usp=sharing) |
31
+ | DeAOTS | 10.2 | [gdrive](https://drive.google.com/file/d/1uUidrWVoaP9A5B5-EzQLbielUnRLRF3j/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1YwIAV5tBtn5spSFxKLBQBEQGwPHyQlHi/view?usp=sharing) |
32
+ | DeAOTB | 13.2 | [gdrive](https://drive.google.com/file/d/1bEQr6vIgQMVITrSOtxWTMgycKpS0cor9/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1BHxsonnvJXylqHlZ1zJHHc-ymKyq-CFf/view?usp=sharing) |
33
+ | DeAOTL | 13.2 | [gdrive](https://drive.google.com/file/d/1_vBL4KJlmBy0oBE4YFDOvsYL1ZtpEL32/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/18elNz_wi9JyVBcIUYKhRdL08MA-FqHD5/view?usp=sharing) |
34
+ | R50-DeAOTL | 19.8 | [gdrive](https://drive.google.com/file/d/1sTRQ1g0WCpqVCdavv7uJiZNkXunBt3-R/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ/view?usp=sharing) |
35
+ | SwinB-DeAOTL | 70.3 | [gdrive](https://drive.google.com/file/d/16BZEE53no8CxT-pPLDC2q1d6Xlg8mWPU/view?usp=sharing) | [gdrive](https://drive.google.com/file/d/1g4E-F0RPOx9Nd6J7tU9AE1TjsouL4oZq/view?usp=sharing) |
36
+
37
+ To use our pre-trained model to infer, a simple way is to set `--model` and `--ckpt_path` to your downloaded checkpoint's model type and file path when running `eval.py`.
38
+
39
+ ### YouTube-VOS 2018 val
40
+ `ALL-F`: all frames. The default evaluation setting of YouTube-VOS is 6fps, but 30fps sequences (all the frames) are also supplied by the dataset organizers. We noticed that many VOS methods prefer to evaluate with 30fps videos. Thus, we also supply our results here. Denser video sequences can significantly improve VOS performance when using the memory reading strategy (like AOTL, R50-AOTL, and SwinB-AOTL), but the efficiency will be influenced since more memorized frames are stored for object matching.
41
+ | Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions |
42
+ |:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:|
43
+ | AOTT | PRE_YTB_DAV | 41.0 | | 80.2 | 80.4 | 85.0 | 73.6 | 81.7 | [gdrive](https://drive.google.com/file/d/1u8mvPRT08ENZHsw9Xf_4C6Sv9BoCzENR/view?usp=sharing) |
44
+ | AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 80.0 | 84.7 | 75.2 | 83.5 | [gdrive](https://drive.google.com/file/d/1RGMI5-29Z0odq73rt26eCxOUYUd-fvVv/view?usp=sharing) |
45
+ | DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.6** | **86.3** | **75.8** | **84.2** | - |
46
+ | AOTS | PRE_YTB_DAV | 27.1 | | 82.9 | 82.3 | 87.0 | 77.1 | 85.1 | [gdrive](https://drive.google.com/file/d/1a4-rNnxjMuPBq21IKo31WDYZXMPgS7r2/view?usp=sharing) |
47
+ | AOTS | PRE_YTB_DAV | 27.1 | √ | 83.0 | 82.2 | 87.0 | 77.3 | 85.7 | [gdrive](https://drive.google.com/file/d/1Z0cndyoCw5Na6u-VFRE8CyiIG2RbMIUO/view?usp=sharing) |
48
+ | DeAOTS | PRE_YTB_DAV | **38.7** | | **84.0** | **83.3** | **88.3** | **77.9** | **86.6** | - |
49
+ | AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.2 | 88.1 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1J5nhuQbbjVLYNXViBIgo21ddQy-MiOLG/view?usp=sharing) |
50
+ | AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.6 | 88.5 | 78.0 | 86.5 | [gdrive](https://drive.google.com/file/d/1gFaweB_GTJjHzSD61v_ZsY9K7UEND30O/view?usp=sharing) |
51
+ | DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.9** | **88.9** | **78.5** | **87.0** | - |
52
+ | AOTL | PRE_YTB_DAV | 16.0 | | 84.1 | 83.2 | 88.2 | 78.2 | 86.8 | [gdrive](https://drive.google.com/file/d/1kS8KWQ2L3wzxt44ROLTxwZOT7ZpT8Igc/view?usp=sharing) |
53
+ | AOTL | PRE_YTB_DAV | 6.5 | √ | 84.5 | 83.7 | 88.8 | 78.4 | **87.1** | [gdrive](https://drive.google.com/file/d/1Rpm3e215kJOUvb562lJ2kYg2I3hkrxiM/view?usp=sharing) |
54
+ | DeAOTL | PRE_YTB_DAV | **24.7** | | **84.8** | **84.2** | **89.4** | **78.6** | 87.0 | - |
55
+ | R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.6 | 83.7 | 88.5 | 78.8 | 87.3 | [gdrive](https://drive.google.com/file/d/1nbJZ1bbmEgyK-bg6HQ8LwCz5gVJ6wzIZ/view?usp=sharing) |
56
+ | R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.5 | 84.5 | 89.5 | 79.6 | 88.2 | [gdrive](https://drive.google.com/file/d/1NbB54ZhYvfJh38KFOgovYYPjWopd-2TE/view?usp=sharing) |
57
+ | R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **86.0** | **84.9** | **89.9** | **80.4** | **88.7** | - |
58
+ | SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.5 | 89.5 | 78.1 | 86.7 | [gdrive](https://drive.google.com/file/d/1QFowulSY0LHfpsjUV8ZE9rYc55L9DOC7/view?usp=sharing) |
59
+ | SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.1 | 85.1 | 90.1 | 78.4 | 86.9 | [gdrive](https://drive.google.com/file/d/1TulhVOhh01rkssNYbOQASeWKu7CQ5Azx/view?usp=sharing) |
60
+ | SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.2** | **85.6** | **90.6** | **80.0** | **88.4** | - |
61
+
62
+ ### YouTube-VOS 2019 val
63
+ | Model | Stage | FPS | All-F | Mean | J Seen | F Seen | J Unseen | F Unseen | Predictions |
64
+ |:------------ |:-----------:|:--------:|:-----:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------------------------------------------------------------------------------------------:|
65
+ | AOTT | PRE_YTB_DAV | 41.0 | | 80.0 | 79.8 | 84.2 | 74.1 | 82.1 | [gdrive](https://drive.google.com/file/d/1zzyhN1XYtajte5nbZ7opOdfXeDJgCxC5/view?usp=sharing) |
66
+ | AOTT | PRE_YTB_DAV | 41.0 | √ | 80.9 | 79.9 | 84.4 | 75.6 | 83.8 | [gdrive](https://drive.google.com/file/d/1V_5vi9dAXOis_WrDieacSESm7OX20Bv-/view?usp=sharing) |
67
+ | DeAOTT | PRE_YTB_DAV | **53.4** | | **82.0** | **81.2** | **85.6** | **76.4** | **84.7** | - |
68
+ | AOTS | PRE_YTB_DAV | 27.1 | | 82.7 | 81.9 | 86.5 | 77.3 | 85.2 | [gdrive](https://drive.google.com/file/d/11YdkUeyjkTv8Uw7xMgPCBzJs6v5SDt6n/view?usp=sharing) |
69
+ | AOTS | PRE_YTB_DAV | 27.1 | √ | 82.8 | 81.9 | 86.5 | 77.3 | 85.6 | [gdrive](https://drive.google.com/file/d/1UhyurGTJeAw412czU3_ebzNwF8xQ4QG_/view?usp=sharing) |
70
+ | DeAOTS | PRE_YTB_DAV | **38.7** | | **83.8** | **82.8** | **87.5** | **78.1** | **86.8** | - |
71
+ | AOTB | PRE_YTB_DAV | 20.5 | | 84.0 | 83.1 | 87.7 | 78.5 | 86.8 | [gdrive](https://drive.google.com/file/d/1NeI8cT4kVqTqVWAwtwiga1rkrvksNWaO/view?usp=sharing) |
72
+ | AOTB | PRE_YTB_DAV | 20.5 | √ | 84.1 | 83.3 | 88.0 | 78.2 | 86.7 | [gdrive](https://drive.google.com/file/d/1kpYV2XFR0sOfLWD-wMhd-nUO6CFiLjlL/view?usp=sharing) |
73
+ | DeAOTB | PRE_YTB_DAV | **30.4** | | **84.6** | **83.5** | **88.3** | **79.1** | **87.5** | - |
74
+ | AOTL | PRE_YTB_DAV | 16.0 | | 84.0 | 82.8 | 87.6 | 78.6 | 87.1 | [gdrive](https://drive.google.com/file/d/1qKLlNXxmT31bW0weEHI_zAf4QwU8Lhou/view?usp=sharing) |
75
+ | AOTL | PRE_YTB_DAV | 6.5 | √ | 84.2 | 83.0 | 87.8 | 78.7 | 87.3 | [gdrive](https://drive.google.com/file/d/1o3fwZ0cH71bqHSA3bYNjhP4GGv9Vyuwa/view?usp=sharing) |
76
+ | DeAOTL | PRE_YTB_DAV | **24.7** | | **84.7** | **83.8** | **88.8** | **79.0** | **87.2** | - |
77
+ | R50-AOTL | PRE_YTB_DAV | 14.9 | | 84.4 | 83.4 | 88.1 | 78.7 | 87.2 | [gdrive](https://drive.google.com/file/d/1I7ooSp8EYfU6fvkP6QcCMaxeencA68AH/view?usp=sharing) |
78
+ | R50-AOTL | PRE_YTB_DAV | 6.4 | √ | 85.3 | 83.9 | 88.8 | 79.9 | 88.5 | [gdrive](https://drive.google.com/file/d/1OGqlkEu0uXa8QVWIVz_M5pmXXiYR2sh3/view?usp=sharing) |
79
+ | R50-DeAOTL | PRE_YTB_DAV | **22.4** | | **85.9** | **84.6** | **89.4** | **80.8** | **88.9** | - |
80
+ | SwinB-AOTL | PRE_YTB_DAV | 9.3 | | 84.7 | 84.0 | 88.8 | 78.7 | 87.1 | [gdrive](https://drive.google.com/file/d/1fPzCxi5GM7N2sLKkhoTC2yoY_oTQCHp1/view?usp=sharing) |
81
+ | SwinB-AOTL | PRE_YTB_DAV | 5.2 | √ | 85.3 | 84.6 | 89.5 | 79.3 | 87.7 | [gdrive](https://drive.google.com/file/d/1e3D22s_rJ7Y2X2MHo7x5lcNtwmHFlwYB/view?usp=sharing) |
82
+ | SwinB-DeAOTL | PRE_YTB_DAV | **11.9** | | **86.1** | **85.3** | **90.2** | **80.4** | **88.6** | - |
83
+
84
+ ### DAVIS-2017 test
85
+
86
+ | Model | Stage | FPS | Mean | J Score | F Score | Predictions |
87
+ | ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:|
88
+ | AOTT | PRE_YTB_DAV | **51.4** | 73.7 | 70.0 | 77.3 | [gdrive](https://drive.google.com/file/d/14Pu-6Uz4rfmJ_WyL2yl57KTx_pSSUNAf/view?usp=sharing) |
89
+ | AOTS | PRE_YTB_DAV | 40.0 | 75.2 | 71.4 | 78.9 | [gdrive](https://drive.google.com/file/d/1zzAPZCRLgnBWuAXqejPPEYLqBxu67Rj1/view?usp=sharing) |
90
+ | AOTB | PRE_YTB_DAV | 29.6 | 77.4 | 73.7 | 81.1 | [gdrive](https://drive.google.com/file/d/1WpQ-_Jrs7Ssfw0oekrejM2OVWEx_tBN1/view?usp=sharing) |
91
+ | AOTL | PRE_YTB_DAV | 18.7 | 79.3 | 75.5 | 83.2 | [gdrive](https://drive.google.com/file/d/1rP1Zdgc0N1d8RR2EaXMz3F-o5zqcNVe8/view?usp=sharing) |
92
+ | R50-AOTL | PRE_YTB_DAV | 18.0 | 79.5 | 76.0 | 83.0 | [gdrive](https://drive.google.com/file/d/1iQ5iNlvlS-In586ZNc4LIZMSdNIWDvle/view?usp=sharing) |
93
+ | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **82.1** | **78.2** | **85.9** | [gdrive](https://drive.google.com/file/d/1oVt4FPcZdfVHiOxjYYKef0q7Ovy4f5Q_/view?usp=sharing) |
94
+
95
+ ### DAVIS-2017 val
96
+
97
+ | Model | Stage | FPS | Mean | J Score | F Score | Predictions |
98
+ | ---------- |:-----------:|:----:|:--------:|:--------:|:---------:|:----:|
99
+ | AOTT | PRE_YTB_DAV | **51.4** | 79.2 | 76.5 | 81.9 | [gdrive](https://drive.google.com/file/d/10OUFhK2Sz-hOJrTDoTI0mA45KO1qodZt/view?usp=sharing) |
100
+ | AOTS | PRE_YTB_DAV | 40.0 | 82.1 | 79.3 | 84.8 | [gdrive](https://drive.google.com/file/d/1T-JTYyksWlq45jxcLjnRaBvvYUhWgHFH/view?usp=sharing) |
101
+ | AOTB | PRE_YTB_DAV | 29.6 | 83.3 | 80.6 | 85.9 | [gdrive](https://drive.google.com/file/d/1EVUnxQm9TLBTuwK82QyiSKk9R9V8NwRL/view?usp=sharing) |
102
+ | AOTL | PRE_YTB_DAV | 18.7 | 83.6 | 80.8 | 86.3 | [gdrive](https://drive.google.com/file/d/1CFauSni2BxAe_fcl8W_6bFByuwJRbDYm/view?usp=sharing) |
103
+ | R50-AOTL | PRE_YTB_DAV | 18.0 | 85.2 | 82.5 | 87.9 | [gdrive](https://drive.google.com/file/d/1vjloxnP8R4PZdsH2DDizfU2CrkdRHHyo/view?usp=sharing) |
104
+ | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **85.9** | **82.9** | **88.9** | [gdrive](https://drive.google.com/file/d/1tYCbKOas0i7Et2iyUAyDwaXnaD9YWxLr/view?usp=sharing) |
105
+
106
+ ### DAVIS-2016 val
107
+
108
+ | Model | Stage | FPS | Mean | J Score | F Score | Predictions |
109
+ | ---------- |:-----------:|:----:|:--------:|:--------:|:--------:|:----:|
110
+ | AOTT | PRE_YTB_DAV | **51.4** | 87.5 | 86.5 | 88.4 | [gdrive](https://drive.google.com/file/d/1LeW8WQhnylZ3umT7E379KdII92uUsGA9/view?usp=sharing) |
111
+ | AOTS | PRE_YTB_DAV | 40.0 | 89.6 | 88.6 | 90.5 | [gdrive](https://drive.google.com/file/d/1vqGei5tLu1FPVrTi5bwRAsaGy3Upf7B1/view?usp=sharing) |
112
+ | AOTB | PRE_YTB_DAV | 29.6 | 90.9 | 89.6 | 92.1 | [gdrive](https://drive.google.com/file/d/1qAppo2uOVu0FbE9t1FBUpymC3yWgw1LM/view?usp=sharing) |
113
+ | AOTL | PRE_YTB_DAV | 18.7 | 91.1 | 89.5 | 92.7 | [gdrive](https://drive.google.com/file/d/1g6cjYhgBWjMaY3RGAm31qm3SPEF3QcKV/view?usp=sharing) |
114
+ | R50-AOTL | PRE_YTB_DAV | 18.0 | 91.7 | 90.4 | 93.0 | [gdrive](https://drive.google.com/file/d/1QzxojqWKsvRf53K2AgKsK523ZVuYU4O-/view?usp=sharing) |
115
+ | SwinB-AOTL | PRE_YTB_DAV | 12.1 | **92.2** | **90.6** | **93.8** | [gdrive](https://drive.google.com/file/d/1RIqUtAyVnopeogfT520d7a0yiULg1obp/view?usp=sharing) |
aot/Pytorch-Correlation-extension/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.egg*
aot/Pytorch-Correlation-extension/Correlation_Module/correlation.cpp ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ using namespace torch;
3
+
4
+ #include <vector>
5
+
6
+ #define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
7
+
8
+ template <typename scalar_t>
9
+ static void correlate_patch(
10
+ TensorAccessor<scalar_t,3> input1,
11
+ TensorAccessor<scalar_t,3> input2,
12
+ scalar_t *dst,
13
+ int kH, int kW,
14
+ int dilationH, int dilationW,
15
+ int u, int v,
16
+ int shiftU, int shiftV){
17
+ const int C = input1.size(0);
18
+ const int iH = input1.size(1);
19
+ const int iW = input1.size(2);
20
+ for (int c=0; c<C; ++c){
21
+ for (int i=0; i<kH; ++i){
22
+ int i1 = u + i * dilationH;
23
+ int i2 = i1 + shiftU;
24
+ if WITHIN_BOUNDS(i1, i2, iH, iH){
25
+ for (int j=0; j<kW; ++j){
26
+ int j1 = v + j * dilationW;
27
+ int j2 = j1 + shiftV;
28
+ if WITHIN_BOUNDS(j1, j2, iW, iW){
29
+ scalar_t v1 = input1[c][i1][j1];
30
+ scalar_t v2 = input2[c][i2][j2];
31
+ *dst += v1 * v2;
32
+ }
33
+ }
34
+ }
35
+ }
36
+ }
37
+ }
38
+
39
+ template <typename scalar_t>
40
+ static void correlate_patch_grad(
41
+ TensorAccessor<scalar_t,3> input1,
42
+ TensorAccessor<scalar_t,3> gradInput1,
43
+ TensorAccessor<scalar_t,3> input2,
44
+ TensorAccessor<scalar_t,3> gradInput2,
45
+ scalar_t gradOutput,
46
+ int kH, int kW,
47
+ int dilationH, int dilationW,
48
+ int u, int v,
49
+ int shiftU, int shiftV){
50
+
51
+ const int C = input1.size(0);
52
+ const int iH = input1.size(1);
53
+ const int iW = input1.size(2);
54
+
55
+ for (int c=0; c<C; ++c){
56
+ for (int i=0; i<kH; ++i){
57
+ int i1 = u + i * dilationH;
58
+ int i2 = i1 + shiftU;
59
+ if WITHIN_BOUNDS(i1, i2, iH, iH){
60
+ for (int j=0; j<kW; ++j){
61
+ int j1 = v + j * dilationW;
62
+ int j2 = j1 + shiftV;
63
+ if WITHIN_BOUNDS(j1, j2, iW, iW){
64
+ scalar_t v1 = input1[c][i1][j1];
65
+ scalar_t v2 = input2[c][i2][j2];
66
+ gradInput2[c][i2][j2] += gradOutput * v1;
67
+ gradInput1[c][i1][j1] += gradOutput * v2;
68
+ }
69
+ }
70
+ }
71
+ }
72
+ }
73
+ }
74
+
75
+ torch::Tensor correlation_cpp_forward(
76
+ torch::Tensor input1,
77
+ torch::Tensor input2,
78
+ int kH, int kW,
79
+ int patchH, int patchW,
80
+ int padH, int padW,
81
+ int dilationH, int dilationW,
82
+ int dilation_patchH, int dilation_patchW,
83
+ int dH, int dW) {
84
+
85
+ const auto batch_size = input1.size(0);
86
+ const auto iH = input1.size(2);
87
+ const auto iW = input1.size(3);
88
+ const int patchRadH = (patchH - 1) / 2;
89
+ const int patchRadW = (patchW - 1) / 2;
90
+ const int dilatedKH = (kH - 1) * dilationH + 1;
91
+ const int dilatedKW = (kW - 1) * dilationW + 1;
92
+
93
+ const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
94
+ const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
95
+ auto output = at::zeros({batch_size, patchH, patchW, oH, oW}, input1.options());
96
+
97
+ int n, ph, pw, h, w;
98
+ #pragma omp parallel for private(n, ph, pw, h, w) collapse(2)
99
+ for (n = 0; n < batch_size; ++n) {
100
+ for(ph = 0; ph < patchH; ++ph){
101
+ for(pw = 0; pw < patchW; ++pw){
102
+ AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "correlation_forward_cpp", ([&] {
103
+ auto input1_acc = input1.accessor<scalar_t, 4>();
104
+ auto input2_acc = input2.accessor<scalar_t, 4>();
105
+ auto output_acc = output.accessor<scalar_t, 5>();
106
+ for (h = 0; h < oH; ++h) {
107
+ for (w = 0; w < oW; ++w) {
108
+ correlate_patch(input1_acc[n],
109
+ input2_acc[n],
110
+ &output_acc[n][ph][pw][h][w],
111
+ kH, kW,
112
+ dilationH, dilationW,
113
+ -padH + h * dH,
114
+ -padW + w * dW,
115
+ (ph - patchRadH) * dilation_patchH,
116
+ (pw - patchRadW) * dilation_patchW);
117
+ }
118
+ }
119
+ }));
120
+ }
121
+ }
122
+ }
123
+ return output;
124
+ }
125
+
126
+ std::vector<torch::Tensor> correlation_cpp_backward(
127
+ torch::Tensor input1,
128
+ torch::Tensor input2,
129
+ torch::Tensor gradOutput,
130
+ int kH, int kW,
131
+ int patchH, int patchW,
132
+ int padH, int padW,
133
+ int dilationH, int dilationW,
134
+ int dilation_patchH, int dilation_patchW,
135
+ int dH, int dW) {
136
+
137
+ const int batch_size = input1.size(0);
138
+ const int patchRadH = (patchH - 1) / 2;
139
+ const int patchRadW = (patchW - 1) / 2;
140
+ const int oH = gradOutput.size(3);
141
+ const int oW = gradOutput.size(4);
142
+
143
+ auto gradInput1 = torch::zeros_like(input1);
144
+
145
+ auto gradInput2 = torch::zeros_like(input2);
146
+
147
+ int n, ph, pw, h, w;
148
+ #pragma omp parallel for private(n, ph, pw, h, w)
149
+ for (n = 0; n < batch_size; ++n) {
150
+ AT_DISPATCH_FLOATING_TYPES(input1.scalar_type(), "correlation_backward_cpp", ([&] {
151
+ auto input1_acc = input1.accessor<scalar_t, 4>();
152
+ auto gradInput1_acc = gradInput1.accessor<scalar_t, 4>();
153
+ auto input2_acc = input2.accessor<scalar_t, 4>();
154
+ auto gradInput2_acc = gradInput2.accessor<scalar_t, 4>();
155
+ auto gradOutput_acc = gradOutput.accessor<scalar_t, 5>();
156
+
157
+ for(ph = 0; ph < patchH; ++ph){
158
+ for(pw = 0; pw < patchW; ++pw){
159
+ for (h = 0; h < oH; ++h) {
160
+ for (w = 0; w < oW; ++w) {
161
+ correlate_patch_grad(input1_acc[n], gradInput1_acc[n],
162
+ input2_acc[n], gradInput2_acc[n],
163
+ gradOutput_acc[n][ph][pw][h][w],
164
+ kH, kW,
165
+ dilationH, dilationW,
166
+ -padH + h * dH,
167
+ -padW + w * dW,
168
+ (ph - patchRadH) * dilation_patchH,
169
+ (pw - patchRadW) * dilation_patchW);
170
+ }
171
+ }
172
+ }
173
+ }
174
+ }));
175
+ }
176
+
177
+ return {gradInput1, gradInput2};
178
+ }
aot/Pytorch-Correlation-extension/Correlation_Module/correlation_cuda_kernel.cu ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/types.h>
2
+ using namespace torch;
3
+
4
+ #include <cuda.h>
5
+ #include <cuda_runtime.h>
6
+
7
+ #include <vector>
8
+ #include <iostream>
9
+
10
+ // Cuda tensor accessor definitions
11
+ // restrict pointer traits piroritize speed over memory consumption
12
+ #define TensorAcc4R PackedTensorAccessor32<scalar_t,4,RestrictPtrTraits>
13
+ #define TensorAcc5R PackedTensorAccessor32<scalar_t,5,RestrictPtrTraits>
14
+ #define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)
15
+
16
+ #define THREADS_FORWARD 32
17
+ #define THREADS_BACKWARD 5
18
+
19
+
20
+ namespace corr {
21
+ template <typename scalar_t>
22
+ __global__ void correlation_cuda_forward_kernel(
23
+ const TensorAcc4R rInput1,
24
+ const TensorAcc4R rInput2,
25
+ TensorAcc5R output,
26
+ int kH, int kW,
27
+ int patchH, int patchW,
28
+ int padH, int padW,
29
+ int dilationH, int dilationW,
30
+ int dilation_patchH, int dilation_patchW,
31
+ int dH, int dW) {
32
+
33
+ const int iH = rInput1.size(1);
34
+ const int iW = rInput1.size(2);
35
+ const int C = rInput1.size(3);
36
+
37
+ const int n = blockIdx.x;
38
+ const int h = blockIdx.y;
39
+ const int w = blockIdx.z;
40
+ const int thread = threadIdx.x;
41
+
42
+ const int start_i = -padH + h * dH;
43
+ const int start_j = -padW + w * dW;
44
+
45
+ const int patchRadH = dilation_patchH * (patchH - 1) / 2;
46
+ const int patchRadW = dilation_patchW * (patchW - 1) / 2;
47
+
48
+ __shared__ scalar_t prod_sum[THREADS_FORWARD];
49
+
50
+ for(int ph = 0; ph < patchH; ++ph){
51
+ int ph_dilated = ph * dilation_patchH - patchRadH;
52
+ for(int pw = 0; pw < patchW; ++pw){
53
+ int pw_dilated = pw * dilation_patchW - patchRadW;
54
+ prod_sum[thread] = 0;
55
+ for (int i=0; i<kH; ++i){
56
+ int i1 = start_i + i * dilationH;
57
+ int i2 = i1 + ph_dilated;
58
+ if WITHIN_BOUNDS(i1, i2, iH, iH){
59
+ for (int j=0; j<kW; ++j){
60
+ int j1 = start_j + j * dilationW;
61
+ int j2 = j1 + pw_dilated;
62
+ if WITHIN_BOUNDS(j1, j2, iW, iW){
63
+ for (int c=thread; c<C; c += THREADS_FORWARD){
64
+ scalar_t v1 = rInput1[n][i1][j1][c];
65
+ scalar_t v2 = rInput2[n][i2][j2][c];
66
+ prod_sum[thread] += v1 * v2;
67
+ }
68
+ }
69
+ }
70
+ }
71
+ }
72
+ // accumulate
73
+ __syncthreads();
74
+ if (thread == 0) {
75
+ scalar_t reduce_sum = 0;
76
+ for (int index = 0; index < THREADS_FORWARD; ++index) {
77
+ reduce_sum += prod_sum[index];
78
+ }
79
+ output[n][ph][pw][h][w] = reduce_sum;
80
+ }
81
+ }
82
+ }
83
+ }
84
+
85
+
86
+ template <typename scalar_t>
87
+ __global__ void correlation_cuda_backward_kernel_input1(
88
+ const TensorAcc5R gradOutput,
89
+ const TensorAcc4R input2,
90
+ TensorAcc4R gradInput1,
91
+ const int kH, const int kW,
92
+ const int patchH, const int patchW,
93
+ const int padH, const int padW,
94
+ const int dilationH, const int dilationW,
95
+ const int dilation_patchH, const int dilation_patchW,
96
+ const int dH, const int dW,
97
+ const int batch) {
98
+ const int iH = input2.size(2);
99
+ const int iW = input2.size(3);
100
+
101
+ const int H = gradOutput.size(3);
102
+ const int W = gradOutput.size(4);
103
+
104
+ const int patchRadH = (patchH - 1) / 2;
105
+ const int patchRadW = (patchW - 1) / 2;
106
+
107
+ const int n = batch;
108
+ const int c = blockIdx.x;
109
+ const int h = blockIdx.y;
110
+ const int w = blockIdx.z;
111
+ const int ph_off = threadIdx.x;
112
+ const int pw_off = threadIdx.y;
113
+
114
+ const int h_2 = h + padH;
115
+ const int w_2 = w + padW;
116
+ const int min_h = h_2 - kH * dilationH;
117
+ const int min_w = w_2 - kW * dilationW;
118
+
119
+ __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
120
+ prod_sum[ph_off][pw_off] = 0;
121
+
122
+ for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
123
+ int i1 = h + dilation_patchH * (ph - patchRadH);
124
+ for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
125
+ int j1 = w + dilation_patchW * (pw - patchRadW);
126
+ if (WITHIN_BOUNDS(i1, j1, iH, iW)){
127
+ scalar_t val = input2[n][c][i1][j1];
128
+ for(int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
129
+ int i2 = (h_3)/dH;
130
+ if (i2 * dH != h_3)
131
+ continue;
132
+ for(int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
133
+ int j2 = (w_3) / dW;
134
+ if(j2 * dW != w_3)
135
+ continue;
136
+ if WITHIN_BOUNDS(i2, j2, H, W) {
137
+ prod_sum[ph_off][pw_off] += gradOutput[n][ph][pw][i2][j2] * val;
138
+ }
139
+ }
140
+ }
141
+ }
142
+ }
143
+ }
144
+
145
+ __syncthreads();
146
+
147
+ if (ph_off == 0 && pw_off == 0){
148
+ scalar_t reduce_sum =0;
149
+ for (int ph = 0; ph < THREADS_BACKWARD; ++ph){
150
+ for (int pw = 0; pw < THREADS_BACKWARD; ++pw){
151
+ reduce_sum += prod_sum[ph][pw];
152
+ }
153
+ }
154
+ gradInput1[n][c][h][w] = reduce_sum;
155
+ }
156
+ }
157
+
158
+
159
+ template <typename scalar_t>
160
+ __global__ void correlation_cuda_backward_kernel_input2(
161
+ const TensorAcc5R gradOutput,
162
+ const TensorAcc4R input1,
163
+ TensorAcc4R gradInput2,
164
+ int kH, int kW,
165
+ int patchH, int patchW,
166
+ int padH, int padW,
167
+ int dilationH, int dilationW,
168
+ int dilation_patchH, int dilation_patchW,
169
+ int dH, int dW,
170
+ int batch) {
171
+ const int iH = input1.size(2);
172
+ const int iW = input1.size(3);
173
+
174
+ const int patchRadH = (patchH - 1) / 2;
175
+ const int patchRadW = (patchW - 1) / 2;
176
+
177
+ const int H = gradOutput.size(3);
178
+ const int W = gradOutput.size(4);
179
+
180
+ const int dilatedKH = kH * dilationH;
181
+ const int dilatedKW = kW * dilationW;
182
+
183
+ const int n = batch;
184
+ const int c = blockIdx.x;
185
+ const int h = blockIdx.y;
186
+ const int w = blockIdx.z;
187
+ const int ph_off = threadIdx.x;
188
+ const int pw_off = threadIdx.y;
189
+
190
+ __shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
191
+ prod_sum[ph_off][pw_off] = 0;
192
+
193
+ for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
194
+ int i1 = h - dilation_patchH * (ph - patchRadH);
195
+ for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
196
+ int j1 = w - dilation_patchW * (pw - patchRadW);
197
+ if WITHIN_BOUNDS(i1, j1, iH, iW) {
198
+ scalar_t val = input1[n][c][i1][j1];
199
+
200
+ const int h_2 = i1 + padH;
201
+ const int w_2 = j1 + padW;
202
+ const int min_h = h_2 - dilatedKH;
203
+ const int min_w = w_2 - dilatedKW;
204
+
205
+ for(int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
206
+ int i2 = (h_3)/dH;
207
+ if (i2 * dH != h_3)
208
+ continue;
209
+ for(int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
210
+ int j2 = (w_3) / dW;
211
+ if(j2 * dW != w_3)
212
+ continue;
213
+ if WITHIN_BOUNDS(i2, j2, H, W) {
214
+ prod_sum[ph_off][pw_off] += gradOutput[n][ph][pw][i2][j2] * val;
215
+ }
216
+ }
217
+ }
218
+ }
219
+ }
220
+ }
221
+
222
+ __syncthreads();
223
+
224
+ if (ph_off == 0 && pw_off == 0){
225
+ scalar_t reduce_sum =0;
226
+ for (int ph = 0; ph < THREADS_BACKWARD; ++ph){
227
+ for (int pw = 0; pw < THREADS_BACKWARD; ++pw){
228
+ reduce_sum += prod_sum[ph][pw];
229
+ }
230
+ }
231
+ gradInput2[n][c][h][w] = reduce_sum;
232
+ }
233
+ }
234
+ } // namsepace corr
235
+
236
+ torch::Tensor correlation_cuda_forward(
237
+ torch::Tensor input1,
238
+ torch::Tensor input2,
239
+ int kH, int kW,
240
+ int patchH, int patchW,
241
+ int padH, int padW,
242
+ int dilationH, int dilationW,
243
+ int dilation_patchH, int dilation_patchW,
244
+ int dH, int dW) {
245
+
246
+ const int batch_size = input1.size(0);
247
+ const int iH = input1.size(2);
248
+ const int iW = input1.size(3);
249
+ const int dilatedKH = (kH - 1) * dilationH + 1;
250
+ const int dilatedKW = (kW - 1) * dilationW + 1;
251
+
252
+ const auto oH = (iH + 2 * padH - dilatedKH) / dH + 1;
253
+ const auto oW = (iW + 2 * padW - dilatedKW) / dW + 1;
254
+ auto output = torch::zeros({batch_size, patchH, patchW, oH, oW}, input1.options());
255
+
256
+ auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
257
+ auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
258
+
259
+ const int threads = THREADS_FORWARD;
260
+ const dim3 blocks(batch_size, oH, oW);
261
+
262
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), "correlation_forward_cuda", ([&] {
263
+ TensorAcc4R trInput1_acc = trInput1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
264
+ TensorAcc4R trInput2_acc = trInput2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
265
+ TensorAcc5R output_acc = output.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
266
+ corr::correlation_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
267
+ trInput1_acc, trInput2_acc, output_acc,
268
+ kH, kW, patchH, patchW, padH, padW, dilationH, dilationW,
269
+ dilation_patchH, dilation_patchW, dH, dW);
270
+ }));
271
+
272
+ return output;
273
+ }
274
+
275
+ std::vector<torch::Tensor> correlation_cuda_backward(
276
+ torch::Tensor input1,
277
+ torch::Tensor input2,
278
+ torch::Tensor gradOutput,
279
+ int kH, int kW,
280
+ int patchH, int patchW,
281
+ int padH, int padW,
282
+ int dilationH, int dilationW,
283
+ int dilation_patchH, int dilation_patchW,
284
+ int dH, int dW) {
285
+
286
+ auto gradInput1 = torch::zeros_like(input1);
287
+ auto gradInput2 = torch::zeros_like(input2);
288
+
289
+ const int batch_size = input1.size(0);
290
+ const int iH = input1.size(2);
291
+ const int iW = input1.size(3);
292
+ const int C = input1.size(1);
293
+
294
+ const dim3 blocks(C, iH, iW);
295
+ const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
296
+
297
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.scalar_type(), "correlation_backward_cuda", ([&] {
298
+ TensorAcc4R input1_acc = input1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
299
+ TensorAcc4R input2_acc = input2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
300
+ TensorAcc4R gradInput1_acc = gradInput1.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
301
+ TensorAcc4R gradInput2_acc = gradInput2.packed_accessor32<scalar_t,4,RestrictPtrTraits>();
302
+ TensorAcc5R gradOutput_acc = gradOutput.packed_accessor32<scalar_t,5,RestrictPtrTraits>();
303
+
304
+
305
+ for (int n = 0; n < batch_size; ++n){
306
+ corr::correlation_cuda_backward_kernel_input1<scalar_t><<<blocks, threads>>>(
307
+ gradOutput_acc, input2_acc, gradInput1_acc,
308
+ kH, kW, patchH, patchW, padH, padW,
309
+ dilationH, dilationW,
310
+ dilation_patchH, dilation_patchW,
311
+ dH, dW,
312
+ n);
313
+ }
314
+
315
+ for (int n = 0; n < batch_size; ++n){
316
+ corr::correlation_cuda_backward_kernel_input2<scalar_t><<<blocks, threads>>>(
317
+ gradOutput_acc, input1_acc, gradInput2_acc,
318
+ kH, kW, patchH, patchW, padH, padW,
319
+ dilationH, dilationW,
320
+ dilation_patchH, dilation_patchW,
321
+ dH, dW,
322
+ n);
323
+ }
324
+ }));
325
+
326
+ return {gradInput1, gradInput2};
327
+ }
aot/Pytorch-Correlation-extension/Correlation_Module/correlation_sampler.cpp ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <c10/cuda/CUDAGuard.h>
3
+ #include <vector>
4
+ #include <iostream>
5
+
6
+ // declarations
7
+
8
+ torch::Tensor correlation_cpp_forward(
9
+ torch::Tensor input1,
10
+ torch::Tensor input2,
11
+ int kH, int kW,
12
+ int patchH, int patchW,
13
+ int padH, int padW,
14
+ int dilationH, int dilationW,
15
+ int dilation_patchH, int dilation_patchW,
16
+ int dH, int dW);
17
+
18
+ std::vector<torch::Tensor> correlation_cpp_backward(
19
+ torch::Tensor grad_output,
20
+ torch::Tensor input1,
21
+ torch::Tensor input2,
22
+ int kH, int kW,
23
+ int patchH, int patchW,
24
+ int padH, int padW,
25
+ int dilationH, int dilationW,
26
+ int dilation_patchH, int dilation_patchW,
27
+ int dH, int dW);
28
+
29
+ #ifdef USE_CUDA
30
+
31
+ #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor")
32
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous")
33
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
34
+ #define CHECK_SAME_DEVICE(x, y) TORCH_CHECK(x.device() == y.device(), #x " is not on same device as " #y)
35
+
36
+ torch::Tensor correlation_cuda_forward(
37
+ torch::Tensor input1,
38
+ torch::Tensor input2,
39
+ int kH, int kW,
40
+ int patchH, int patchW,
41
+ int padH, int padW,
42
+ int dilationH, int dilationW,
43
+ int dilation_patchH, int dilation_patchW,
44
+ int dH, int dW);
45
+
46
+ std::vector<torch::Tensor> correlation_cuda_backward(
47
+ torch::Tensor grad_output,
48
+ torch::Tensor input1,
49
+ torch::Tensor input2,
50
+ int kH, int kW,
51
+ int patchH, int patchW,
52
+ int padH, int padW,
53
+ int dilationH, int dilationW,
54
+ int dilation_patchH, int dilation_patchW,
55
+ int dH, int dW);
56
+
57
+ // C++ interface
58
+
59
+ torch::Tensor correlation_sample_forward(
60
+ torch::Tensor input1,
61
+ torch::Tensor input2,
62
+ int kH, int kW,
63
+ int patchH, int patchW,
64
+ int padH, int padW,
65
+ int dilationH, int dilationW,
66
+ int dilation_patchH, int dilation_patchW,
67
+ int dH, int dW) {
68
+ if (input1.device().is_cuda()){
69
+ CHECK_INPUT(input1);
70
+ CHECK_INPUT(input2);
71
+
72
+ // set device of input1 as default CUDA device
73
+ // https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_optional_c_u_d_a_guard.html
74
+ const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
75
+ CHECK_SAME_DEVICE(input1, input2);
76
+
77
+ return correlation_cuda_forward(input1, input2, kH, kW, patchH, patchW,
78
+ padH, padW, dilationH, dilationW,
79
+ dilation_patchH, dilation_patchW,
80
+ dH, dW);
81
+ }else{
82
+ return correlation_cpp_forward(input1, input2, kH, kW, patchH, patchW,
83
+ padH, padW, dilationH, dilationW,
84
+ dilation_patchH, dilation_patchW,
85
+ dH, dW);
86
+ }
87
+ }
88
+
89
+ std::vector<torch::Tensor> correlation_sample_backward(
90
+ torch::Tensor input1,
91
+ torch::Tensor input2,
92
+ torch::Tensor grad_output,
93
+ int kH, int kW,
94
+ int patchH, int patchW,
95
+ int padH, int padW,
96
+ int dilationH, int dilationW,
97
+ int dilation_patchH, int dilation_patchW,
98
+ int dH, int dW) {
99
+
100
+ if(grad_output.device().is_cuda()){
101
+ CHECK_INPUT(input1);
102
+ CHECK_INPUT(input2);
103
+
104
+ // set device of input1 as default CUDA device
105
+ const at::cuda::OptionalCUDAGuard guard_input1(device_of(input1));
106
+ CHECK_SAME_DEVICE(input1, input2);
107
+ CHECK_SAME_DEVICE(input1, grad_output);
108
+
109
+ return correlation_cuda_backward(input1, input2, grad_output,
110
+ kH, kW, patchH, patchW,
111
+ padH, padW,
112
+ dilationH, dilationW,
113
+ dilation_patchH, dilation_patchW,
114
+ dH, dW);
115
+ }else{
116
+ return correlation_cpp_backward(
117
+ input1, input2, grad_output,
118
+ kH, kW, patchH, patchW,
119
+ padH, padW,
120
+ dilationH, dilationW,
121
+ dilation_patchH, dilation_patchW,
122
+ dH, dW);
123
+ }
124
+ }
125
+
126
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
127
+ m.def("forward", &correlation_sample_forward, "Spatial Correlation Sampler Forward");
128
+ m.def("backward", &correlation_sample_backward, "Spatial Correlation Sampler backward");
129
+ }
130
+
131
+ #else
132
+
133
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
134
+ m.def("forward", &correlation_cpp_forward, "Spatial Correlation Sampler Forward");
135
+ m.def("backward", &correlation_cpp_backward, "Spatial Correlation Sampler backward");
136
+ }
137
+
138
+ #endif
aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .spatial_correlation_sampler import SpatialCorrelationSampler, spatial_correlation_sample
aot/Pytorch-Correlation-extension/Correlation_Module/spatial_correlation_sampler/spatial_correlation_sampler.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.autograd import Function
3
+ from torch.autograd.function import once_differentiable
4
+ from torch.nn.modules.utils import _pair
5
+
6
+ import spatial_correlation_sampler_backend as correlation
7
+
8
+
9
+ def spatial_correlation_sample(input1,
10
+ input2,
11
+ kernel_size=1,
12
+ patch_size=1,
13
+ stride=1,
14
+ padding=0,
15
+ dilation=1,
16
+ dilation_patch=1):
17
+ """Apply spatial correlation sampling on from input1 to input2,
18
+
19
+ Every parameter except input1 and input2 can be either single int
20
+ or a pair of int. For more information about Spatial Correlation
21
+ Sampling, see this page.
22
+ https://lmb.informatik.uni-freiburg.de/Publications/2015/DFIB15/
23
+
24
+ Args:
25
+ input1 : The first parameter.
26
+ input2 : The second parameter.
27
+ kernel_size : total size of your correlation kernel, in pixels
28
+ patch_size : total size of your patch, determining how many
29
+ different shifts will be applied
30
+ stride : stride of the spatial sampler, will modify output
31
+ height and width
32
+ padding : padding applied to input1 and input2 before applying
33
+ the correlation sampling, will modify output height and width
34
+ dilation_patch : step for every shift in patch
35
+
36
+ Returns:
37
+ Tensor: Result of correlation sampling
38
+
39
+ """
40
+ return SpatialCorrelationSamplerFunction.apply(input1, input2,
41
+ kernel_size, patch_size,
42
+ stride, padding, dilation, dilation_patch)
43
+
44
+
45
+ class SpatialCorrelationSamplerFunction(Function):
46
+
47
+ @staticmethod
48
+ def forward(ctx,
49
+ input1,
50
+ input2,
51
+ kernel_size=1,
52
+ patch_size=1,
53
+ stride=1,
54
+ padding=0,
55
+ dilation=1,
56
+ dilation_patch=1):
57
+
58
+ ctx.save_for_backward(input1, input2)
59
+ kH, kW = ctx.kernel_size = _pair(kernel_size)
60
+ patchH, patchW = ctx.patch_size = _pair(patch_size)
61
+ padH, padW = ctx.padding = _pair(padding)
62
+ dilationH, dilationW = ctx.dilation = _pair(dilation)
63
+ dilation_patchH, dilation_patchW = ctx.dilation_patch = _pair(dilation_patch)
64
+ dH, dW = ctx.stride = _pair(stride)
65
+
66
+ output = correlation.forward(input1, input2,
67
+ kH, kW, patchH, patchW,
68
+ padH, padW, dilationH, dilationW,
69
+ dilation_patchH, dilation_patchW,
70
+ dH, dW)
71
+
72
+ return output
73
+
74
+ @staticmethod
75
+ @once_differentiable
76
+ def backward(ctx, grad_output):
77
+ input1, input2 = ctx.saved_variables
78
+
79
+ kH, kW = ctx.kernel_size
80
+ patchH, patchW = ctx.patch_size
81
+ padH, padW = ctx.padding
82
+ dilationH, dilationW = ctx.dilation
83
+ dilation_patchH, dilation_patchW = ctx.dilation_patch
84
+ dH, dW = ctx.stride
85
+
86
+ grad_input1, grad_input2 = correlation.backward(input1, input2, grad_output,
87
+ kH, kW, patchH, patchW,
88
+ padH, padW, dilationH, dilationW,
89
+ dilation_patchH, dilation_patchW,
90
+ dH, dW)
91
+ return grad_input1, grad_input2, None, None, None, None, None, None
92
+
93
+
94
+ class SpatialCorrelationSampler(nn.Module):
95
+ def __init__(self, kernel_size=1, patch_size=1, stride=1, padding=0, dilation=1, dilation_patch=1):
96
+ super(SpatialCorrelationSampler, self).__init__()
97
+ self.kernel_size = kernel_size
98
+ self.patch_size = patch_size
99
+ self.stride = stride
100
+ self.padding = padding
101
+ self.dilation = dilation
102
+ self.dilation_patch = dilation_patch
103
+
104
+ def forward(self, input1, input2):
105
+ return SpatialCorrelationSamplerFunction.apply(input1, input2, self.kernel_size,
106
+ self.patch_size, self.stride,
107
+ self.padding, self.dilation, self.dilation_patch)
aot/Pytorch-Correlation-extension/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) [year] [fullname]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
aot/Pytorch-Correlation-extension/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [![PyPI](https://img.shields.io/pypi/v/spatial-correlation-sampler.svg)](https://pypi.org/project/spatial-correlation-sampler/)
3
+
4
+
5
+ # Pytorch Correlation module
6
+
7
+ this is a custom C++/Cuda implementation of Correlation module, used e.g. in [FlowNetC](https://arxiv.org/abs/1504.06852)
8
+
9
+ This [tutorial](http://pytorch.org/tutorials/advanced/cpp_extension.html) was used as a basis for implementation, as well as
10
+ [NVIDIA's cuda code](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package)
11
+
12
+ - Build and Install C++ and CUDA extensions by executing `python setup.py install`,
13
+ - Benchmark C++ vs. CUDA by running `python benchmark.py {cpu, cuda}`,
14
+ - Run gradient checks on the code by running `python grad_check.py --backend {cpu, cuda}`.
15
+
16
+ # Requirements
17
+
18
+ This module is expected to compile for Pytorch `2.1.0`.
19
+
20
+ Before installation please check compatibility of your GPU and CUDA (_Compute Capability_) [nvidia docs](https://developer.nvidia.com/cuda-gpus).
21
+ e.g RTX 6000 is using CC=8.9 so we are setting the environment variable to
22
+
23
+ `export TORCH_CUDA_ARCH_LIST="8.9+PTX"`
24
+
25
+ # Installation
26
+
27
+ be reminded this module requires `python3-dev` to compile C++ code, e.g. on Ubuntu run:
28
+
29
+ `apt install python3-dev`
30
+
31
+ this module is available on pip
32
+
33
+ `pip install spatial-correlation-sampler`
34
+
35
+ For a cpu-only version, you can install from source with
36
+
37
+ `python setup_cpu.py install`
38
+
39
+ # Known Problems
40
+
41
+ This module needs compatible gcc version and CUDA to be compiled.
42
+ Namely, CUDA 9.1 and below will need gcc5, while CUDA 9.2 and 10.0 will need gcc7
43
+ See [this issue](https://github.com/ClementPinard/Pytorch-Correlation-extension/issues/1) for more information
44
+
45
+ # Usage
46
+
47
+ API has a few difference with NVIDIA's module
48
+ * output is now a 5D tensor, which reflects the shifts horizontal and vertical.
49
+ ```
50
+ input (B x C x H x W) -> output (B x PatchH x PatchW x oH x oW)
51
+ ```
52
+ * Output sizes `oH` and `oW` are no longer dependant of patch size, but only of kernel size and padding
53
+ * Patch size `patch_size` is now the whole patch, and not only the radii.
54
+ * `stride1` is now `stride` and`stride2` is `dilation_patch`, which behave like dilated convolutions
55
+ * equivalent `max_displacement` is then `dilation_patch * (patch_size - 1) / 2`.
56
+ * `dilation` is a new parameter, it acts the same way as dilated convolution regarding the correlation kernel
57
+ * to get the right parameters for FlowNetC, you would have
58
+ ```
59
+ kernel_size=1
60
+ patch_size=21,
61
+ stride=1,
62
+ padding=0,
63
+ dilation=1
64
+ dilation_patch=2
65
+ ```
66
+
67
+
68
+ ## Example
69
+ ```python
70
+ import torch
71
+ from spatial_correlation_sampler import SpatialCorrelationSampler,
72
+
73
+ device = "cuda"
74
+ batch_size = 1
75
+ channel = 1
76
+ H = 10
77
+ W = 10
78
+ dtype = torch.float32
79
+
80
+ input1 = torch.randint(1, 4, (batch_size, channel, H, W), dtype=dtype, device=device, requires_grad=True)
81
+ input2 = torch.randint_like(input1, 1, 4).requires_grad_(True)
82
+
83
+ #You can either use the function or the module. Note that the module doesn't contain any parameter tensor.
84
+
85
+ #function
86
+
87
+ out = spatial_correlation_sample(input1,
88
+ input2,
89
+ kernel_size=3,
90
+ patch_size=1,
91
+ stride=2,
92
+ padding=0,
93
+ dilation=2,
94
+ dilation_patch=1)
95
+
96
+ #module
97
+
98
+ correlation_sampler = SpatialCorrelationSampler(
99
+ kernel_size=3,
100
+ patch_size=1,
101
+ stride=2,
102
+ padding=0,
103
+ dilation=2,
104
+ dilation_patch=1)
105
+ out = correlation_sampler(input1, input2)
106
+
107
+ ```
108
+
109
+ # Benchmark
110
+
111
+ * default parameters are from `benchmark.py`, FlowNetC parameters are same as use in `FlowNetC` with a batch size of 4, described in [this paper](https://arxiv.org/abs/1504.06852), implemented [here](https://github.com/lmb-freiburg/flownet2) and [here](https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/FlowNetC.py).
112
+ * Feel free to file an issue to add entries to this with your hardware !
113
+
114
+ ## CUDA Benchmark
115
+
116
+ * See [here](https://gist.github.com/ClementPinard/270e910147119831014932f67fb1b5ea) for a benchmark script working with [NVIDIA](https://github.com/NVIDIA/flownet2-pytorch/tree/master/networks/correlation_package)'s code, and Pytorch.
117
+ * Benchmark are launched with environment variable `CUDA_LAUNCH_BLOCKING` set to `1`.
118
+ * Only `float32` is benchmarked.
119
+ * FlowNetC correlation parameters where launched with the following command:
120
+
121
+ ```bash
122
+ CUDA_LAUNCH_BLOCKING=1 python benchmark.py --scale ms -k1 --patch 21 -s1 -p0 --patch_dilation 2 -b4 --height 48 --width 64 -c256 cuda -d float
123
+
124
+ CUDA_LAUNCH_BLOCKING=1 python NV_correlation_benchmark.py --scale ms -k1 --patch 21 -s1 -p0 --patch_dilation 2 -b4 --height 48 --width 64 -c256
125
+ ```
126
+
127
+ | implementation | Correlation parameters | device | pass | min time | avg time |
128
+ | -------------- | ---------------------- | ------- | -------- | ------------: | ------------: |
129
+ | ours | default | 980 GTX | forward | **5.745 ms** | **5.851 ms** |
130
+ | ours | default | 980 GTX | backward | 77.694 ms | 77.957 ms |
131
+ | NVIDIA | default | 980 GTX | forward | 13.779 ms | 13.853 ms |
132
+ | NVIDIA | default | 980 GTX | backward | **73.383 ms** | **73.708 ms** |
133
+ | | | | | | |
134
+ | ours | FlowNetC | 980 GTX | forward | **26.102 ms** | **26.179 ms** |
135
+ | ours | FlowNetC | 980 GTX | backward | **208.091 ms** | **208.510 ms** |
136
+ | NVIDIA | FlowNetC | 980 GTX | forward | 35.363 ms | 35.550 ms |
137
+ | NVIDIA | FlowNetC | 980 GTX | backward | 283.748 ms | 284.346 ms |
138
+
139
+ ### Notes
140
+ * The overhead of our implementation regarding `kernel_size` > 1 during backward needs some investigation, feel free to
141
+ dive in the code to improve it !
142
+ * The backward pass of NVIDIA is not entirely correct when stride1 > 1 and kernel_size > 1, because not everything
143
+ is computed, see [here](https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/correlation_package/src/correlation_cuda_kernel.cu#L120).
144
+
145
+ ## CPU Benchmark
146
+
147
+ * No other implementation is avalaible on CPU.
148
+ * It is obviously not recommended to run it on CPU if you have a GPU.
149
+
150
+ | Correlation parameters | device | pass | min time | avg time |
151
+ | ---------------------- | -------------------- | -------- | ----------: | ----------: |
152
+ | default | E5-2630 v3 @ 2.40GHz | forward | 159.616 ms | 188.727 ms |
153
+ | default | E5-2630 v3 @ 2.40GHz | backward | 282.641 ms | 294.194 ms |
154
+ | FlowNetC | E5-2630 v3 @ 2.40GHz | forward | 2.138 s | 2.144 s |
155
+ | FlowNetC | E5-2630 v3 @ 2.40GHz | backward | 7.006 s | 7.075 s |
aot/Pytorch-Correlation-extension/benchmark.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ from __future__ import print_function
3
+
4
+ import argparse
5
+ import time
6
+
7
+ import torch
8
+ from spatial_correlation_sampler import SpatialCorrelationSampler
9
+ from tqdm import trange
10
+
11
+ TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000}
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
15
+ parser.add_argument('-b', '--batch-size', type=int, default=16)
16
+ parser.add_argument('-k', '--kernel-size', type=int, default=3)
17
+ parser.add_argument('--patch', type=int, default=3)
18
+ parser.add_argument('--patch_dilation', type=int, default=2)
19
+ parser.add_argument('-c', '--channel', type=int, default=64)
20
+ parser.add_argument('--height', type=int, default=100)
21
+ parser.add_argument('-w', '--width', type=int, default=100)
22
+ parser.add_argument('-s', '--stride', type=int, default=2)
23
+ parser.add_argument('-p', '--pad', type=int, default=1)
24
+ parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
25
+ parser.add_argument('-r', '--runs', type=int, default=100)
26
+ parser.add_argument('--dilation', type=int, default=2)
27
+ parser.add_argument('-d', '--dtype', choices=['half', 'float', 'double'])
28
+
29
+ args = parser.parse_args()
30
+
31
+ device = torch.device(args.backend)
32
+
33
+ if args.dtype == 'half':
34
+ dtype = torch.float16
35
+ elif args.dtype == 'float':
36
+ dtype = torch.float32
37
+ else:
38
+ dtype = torch.float64
39
+
40
+
41
+ input1 = torch.randn(args.batch_size,
42
+ args.channel,
43
+ args.height,
44
+ args.width,
45
+ dtype=dtype,
46
+ device=device,
47
+ requires_grad=True)
48
+ input2 = torch.randn_like(input1)
49
+
50
+ correlation_sampler = SpatialCorrelationSampler(
51
+ args.kernel_size,
52
+ args.patch,
53
+ args.stride,
54
+ args.pad,
55
+ args.dilation,
56
+ args.patch_dilation)
57
+
58
+ # Force CUDA initialization
59
+ output = correlation_sampler(input1, input2)
60
+ print(output.size())
61
+ output.mean().backward()
62
+ forward_min = float('inf')
63
+ forward_time = 0
64
+ backward_min = float('inf')
65
+ backward_time = 0
66
+ for _ in trange(args.runs):
67
+ correlation_sampler.zero_grad()
68
+
69
+ start = time.time()
70
+ output = correlation_sampler(input1, input2)
71
+ elapsed = time.time() - start
72
+ forward_min = min(forward_min, elapsed)
73
+ forward_time += elapsed
74
+ output = output.mean()
75
+
76
+ start = time.time()
77
+ (output.mean()).backward()
78
+ elapsed = time.time() - start
79
+ backward_min = min(backward_min, elapsed)
80
+ backward_time += elapsed
81
+
82
+ scale = TIME_SCALES[args.scale]
83
+ forward_min *= scale
84
+ backward_min *= scale
85
+ forward_average = forward_time / args.runs * scale
86
+ backward_average = backward_time / args.runs * scale
87
+
88
+ print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format(
89
+ forward_min, forward_average, backward_min, backward_average,
90
+ args.scale))
aot/Pytorch-Correlation-extension/check.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ from __future__ import print_function
3
+
4
+ import argparse
5
+ import numpy as np
6
+ import torch
7
+
8
+ from spatial_correlation_sampler import SpatialCorrelationSampler
9
+
10
+
11
+ def check_equal(first, second, verbose):
12
+ if verbose:
13
+ print()
14
+ for i, (x, y) in enumerate(zip(first, second)):
15
+ x = x.cpu().detach().numpy()
16
+ y = y.cpu().detach().numpy()
17
+ if verbose:
18
+ print("x = {}".format(x.flatten()))
19
+ print("y = {}".format(y.flatten()))
20
+ print('-' * 80)
21
+ np.testing.assert_allclose(x, y, err_msg="Index: {}".format(i))
22
+
23
+
24
+ def zero_grad(variables):
25
+ for variable in variables:
26
+ if variable.grad is not None: variable.grad.zero_()
27
+
28
+
29
+ def get_grads(variables):
30
+ return [var.grad.clone() for var in variables]
31
+
32
+
33
+ def check_forward(input1, input2, correlation_sampler, verbose, gpu_index=0):
34
+ device = torch.device(f"cuda:{gpu_index}")
35
+
36
+ cpu_values = correlation_sampler(input1, input2)
37
+ cuda_values = correlation_sampler(input1.to(device), input2.to(device))
38
+
39
+ print(f"Forward: CPU vs. CUDA device:{gpu_index} ... ", end='')
40
+ check_equal(cpu_values, cuda_values, verbose)
41
+ print('Ok')
42
+
43
+
44
+ def check_backward(input1, input2, correlation_sampler, verbose, gpu_index=0):
45
+ device = torch.device(f"cuda:{gpu_index}")
46
+
47
+ zero_grad([input1, input2])
48
+
49
+ cpu_values = correlation_sampler(input1, input2)
50
+ cpu_values.sum().backward()
51
+ grad_cpu = get_grads([input1, input2])
52
+
53
+ zero_grad([input1, input2])
54
+
55
+ cuda_values = correlation_sampler(input1.to(device), input2.to(device))
56
+ cuda_values.sum().backward()
57
+ grad_cuda = get_grads([input1, input2])
58
+
59
+ print(f"Backward: CPU vs. CUDA device:{gpu_index} ... ", end='')
60
+ check_equal(grad_cpu, grad_cuda, verbose)
61
+ print('Ok')
62
+
63
+
64
+ def check_multi_gpu_forward(correlation_sampler, verbose):
65
+ print("Multi-GPU forward")
66
+ total_gpus = torch.cuda.device_count()
67
+ for gpu in range(total_gpus):
68
+ check_forward(input1, input2, correlation_sampler, verbose, gpu_index=gpu)
69
+
70
+ def check_multi_gpu_backward(correlation_sampler, verbose):
71
+ print("Multi-GPU backward")
72
+ total_gpus = torch.cuda.device_count()
73
+ for gpu in range(total_gpus):
74
+ check_backward(input1, input2, correlation_sampler, verbose, gpu_index=gpu)
75
+
76
+
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument('direction', choices=['forward', 'backward'], nargs='+')
79
+ parser.add_argument('-b', '--batch-size', type=int, default=1)
80
+ parser.add_argument('-k', '--kernel-size', type=int, default=3)
81
+ parser.add_argument('--patch', type=int, default=3)
82
+ parser.add_argument('--patch_dilation', type=int, default=2)
83
+ parser.add_argument('-c', '--channel', type=int, default=10)
84
+ parser.add_argument('--height', type=int, default=10)
85
+ parser.add_argument('-w', '--width', type=int, default=10)
86
+ parser.add_argument('-s', '--stride', type=int, default=2)
87
+ parser.add_argument('-p', '--pad', type=int, default=5)
88
+ parser.add_argument('-v', '--verbose', action='store_true', default=False)
89
+ parser.add_argument('-d', '--dilation', type=int, default=2)
90
+ args = parser.parse_args()
91
+ print(args)
92
+
93
+ assert(torch.cuda.is_available()), "no comparison to make"
94
+ input1 = torch.randn(args.batch_size,
95
+ args.channel,
96
+ args.height,
97
+ args.width).double()
98
+ input2 = torch.randn(args.batch_size,
99
+ args.channel,
100
+ args.height,
101
+ args.width).double()
102
+ input1.requires_grad = True
103
+ input2.requires_grad = True
104
+
105
+ correlation_sampler = SpatialCorrelationSampler(
106
+ args.kernel_size,
107
+ args.patch,
108
+ args.stride,
109
+ args.pad,
110
+ args.dilation,
111
+ args.patch_dilation)
112
+
113
+ if 'forward' in args.direction:
114
+ check_forward(input1, input2, correlation_sampler, args.verbose)
115
+ if torch.cuda.device_count() > 1: check_multi_gpu_forward(correlation_sampler, args.verbose)
116
+
117
+ if 'backward' in args.direction:
118
+ check_backward(input1, input2, correlation_sampler, args.verbose)
119
+ if torch.cuda.device_count() > 1: check_multi_gpu_backward(correlation_sampler, args.verbose)
aot/Pytorch-Correlation-extension/grad_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ # torch.set_printoptions(precision=1, threshold=10000)
4
+ from torch.autograd import gradcheck
5
+ from spatial_correlation_sampler import SpatialCorrelationSampler
6
+
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
9
+ parser.add_argument('-b', '--batch-size', type=int, default=2)
10
+ parser.add_argument('-k', '--kernel-size', type=int, default=3)
11
+ parser.add_argument('--patch', type=int, default=3)
12
+ parser.add_argument('--patch_dilation', type=int, default=2)
13
+ parser.add_argument('-c', '--channel', type=int, default=2)
14
+ parser.add_argument('--height', type=int, default=10)
15
+ parser.add_argument('-w', '--width', type=int, default=10)
16
+ parser.add_argument('-s', '--stride', type=int, default=2)
17
+ parser.add_argument('-p', '--pad', type=int, default=1)
18
+ parser.add_argument('-d', '--dilation', type=int, default=2)
19
+
20
+ args = parser.parse_args()
21
+
22
+ input1 = torch.randn(args.batch_size,
23
+ args.channel,
24
+ args.height,
25
+ args.width,
26
+ dtype=torch.float64,
27
+ device=torch.device(args.backend))
28
+ input2 = torch.randn(args.batch_size,
29
+ args.channel,
30
+ args.height,
31
+ args.width,
32
+ dtype=torch.float64,
33
+ device=torch.device(args.backend))
34
+
35
+ input1.requires_grad = True
36
+ input2.requires_grad = True
37
+
38
+ correlation_sampler = SpatialCorrelationSampler(args.kernel_size,
39
+ args.patch,
40
+ args.stride,
41
+ args.pad,
42
+ args.dilation,
43
+ args.patch_dilation)
44
+
45
+
46
+ if gradcheck(correlation_sampler, [input1, input2]):
47
+ print('Ok')
aot/Pytorch-Correlation-extension/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=1.0.1
2
+ numpy
aot/Pytorch-Correlation-extension/setup.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from setuptools import setup
3
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
4
+ from os.path import join
5
+
6
+ CPU_ONLY = False
7
+ project_root = 'Correlation_Module'
8
+
9
+ source_files = ['correlation.cpp', 'correlation_sampler.cpp']
10
+
11
+ cxx_args = ['-std=c++17', '-fopenmp']
12
+
13
+ def generate_nvcc_args(gpu_archs):
14
+ nvcc_args = []
15
+ for arch in gpu_archs:
16
+ nvcc_args.extend(['-gencode', f'arch=compute_{arch},code=sm_{arch}'])
17
+ return nvcc_args
18
+
19
+ gpu_arch = os.environ.get('GPU_ARCH', '').split()
20
+ nvcc_args = generate_nvcc_args(gpu_arch)
21
+
22
+ with open("README.md", "r") as fh:
23
+ long_description = fh.read()
24
+
25
+
26
+ def launch_setup():
27
+ if CPU_ONLY:
28
+ Extension = CppExtension
29
+ macro = []
30
+ else:
31
+ Extension = CUDAExtension
32
+ source_files.append('correlation_cuda_kernel.cu')
33
+ macro = [("USE_CUDA", None)]
34
+
35
+ sources = [join(project_root, file) for file in source_files]
36
+
37
+ setup(
38
+ name='spatial_correlation_sampler',
39
+ version="0.4.0",
40
+ author="Clément Pinard",
41
+ author_email="[email protected]",
42
+ description="Correlation module for pytorch",
43
+ long_description=long_description,
44
+ long_description_content_type="text/markdown",
45
+ url="https://github.com/ClementPinard/Pytorch-Correlation-extension",
46
+ install_requires=['torch>=1.1', 'numpy'],
47
+ ext_modules=[
48
+ Extension('spatial_correlation_sampler_backend',
49
+ sources,
50
+ define_macros=macro,
51
+ extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args},
52
+ extra_link_args=['-lgomp'])
53
+ ],
54
+ package_dir={'': project_root},
55
+ packages=['spatial_correlation_sampler'],
56
+ cmdclass={
57
+ 'build_ext': BuildExtension
58
+ },
59
+ classifiers=[
60
+ "Programming Language :: Python :: 3",
61
+ "License :: OSI Approved :: MIT License",
62
+ "Operating System :: POSIX :: Linux",
63
+ "Intended Audience :: Science/Research",
64
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
65
+ ])
66
+
67
+
68
+ if __name__ == '__main__':
69
+ launch_setup()
aot/Pytorch-Correlation-extension/setup_cpu.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import setup
2
+
3
+ setup.CPU_ONLY = True
4
+ setup.launch_setup()
aot/README.md ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AOT Series Frameworks in PyTorch
2
+
3
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/decoupling-features-in-hierarchical/semi-supervised-video-object-segmentation-on-15)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-15?p=decoupling-features-in-hierarchical)
4
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/video-object-segmentation-on-youtube-vos)](https://paperswithcode.com/sota/video-object-segmentation-on-youtube-vos?p=associating-objects-with-scalable)
5
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-18)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-18?p=associating-objects-with-scalable)
6
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/semi-supervised-video-object-segmentation-on-1)](https://paperswithcode.com/sota/semi-supervised-video-object-segmentation-on-1?p=associating-objects-with-scalable)
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2017)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2017?p=associating-objects-with-scalable)
8
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/associating-objects-with-scalable/visual-object-tracking-on-davis-2016)](https://paperswithcode.com/sota/visual-object-tracking-on-davis-2016?p=associating-objects-with-scalable)
9
+
10
+ A modular reference PyTorch implementation of AOT series frameworks:
11
+ - **DeAOT**: Decoupling Features in Hierachical Propagation for Video Object Segmentation (NeurIPS 2022, Spotlight) [[OpenReview](https://openreview.net/forum?id=DgM7-7eMkq0)][[PDF](https://arxiv.org/pdf/2210.09782.pdf)]
12
+ <img src="source/overview_deaot.png" width="90%"/>
13
+
14
+ - **AOT**: Associating Objects with Transformers for Video Object Segmentation (NeurIPS 2021, Score 8/8/7/8) [[OpenReview](https://openreview.net/forum?id=hl3v8io3ZYt)][[PDF](https://arxiv.org/abs/2106.02638)]
15
+ <img src="source/overview.png" width="90%"/>
16
+
17
+ An extension of AOT, [AOST](https://arxiv.org/abs/2203.11442) (under review), is available now. AOST is a more robust and flexible framework, supporting run-time speed-accuracy trade-offs.
18
+
19
+ ## Examples
20
+ Benchmark examples:
21
+
22
+ <img src="source/some_results.png" width="81%"/>
23
+
24
+ General examples (Messi and Kobe):
25
+
26
+ <img src="source/messi.gif" width="45%"/> <img src="source/kobe.gif" width="45%"/>
27
+
28
+ ## Highlights
29
+ - **High performance:** up to **85.5%** ([R50-AOTL](MODEL_ZOO.md#youtube-vos-2018-val)) on YouTube-VOS 2018 and **82.1%** ([SwinB-AOTL]((MODEL_ZOO.md#youtube-vos-2018-val))) on DAVIS-2017 Test-dev under standard settings (without any test-time augmentation and post processing).
30
+ - **High efficiency:** up to **51fps** ([AOTT](MODEL_ZOO.md#davis-2017-test)) on DAVIS-2017 (480p) even with **10** objects and **41fps** on YouTube-VOS (1.3x480p). AOT can process multiple objects (less than a pre-defined number, 10 is the default) as efficiently as processing a single object. This project also supports inferring any number of objects together within a video by automatic separation and aggregation.
31
+ - **Multi-GPU training and inference**
32
+ - **Mixed precision training and inference**
33
+ - **Test-time augmentation:** multi-scale and flipping augmentations are supported.
34
+
35
+ ## Requirements
36
+ * Python3
37
+ * pytorch >= 1.7.0 and torchvision
38
+ * opencv-python
39
+ * Pillow
40
+ * Pytorch Correlation (Recommend to install from [source](https://github.com/ClementPinard/Pytorch-Correlation-extension) instead of using `pip`. **The project can also work without this module but will lose some efficiency of the short-term attention**.)
41
+
42
+ Optional:
43
+ * scikit-image (if you want to run our **Demo**, please install)
44
+
45
+ ## Model Zoo and Results
46
+ Pre-trained models, benckmark scores, and pre-computed results reproduced by this project can be found in [MODEL_ZOO.md](MODEL_ZOO.md).
47
+
48
+ ## Demo - Panoptic Propagation
49
+ We provide a simple demo to demonstrate AOT's effectiveness. The demo will propagate more than **40** objects, including semantic regions (like sky) and instances (like person), together within a single complex scenario and predict its video panoptic segmentation.
50
+
51
+ To run the demo, download the [checkpoint](https://drive.google.com/file/d/1qJDYn3Ibpquu4ffYoQmVjg1YCbr2JQep/view?usp=sharing) of R50-AOTL into [pretrain_models](pretrain_models), and then run:
52
+ ```bash
53
+ python tools/demo.py
54
+ ```
55
+ which will predict the given scenarios in the resolution of 1.3x480p. You can also run this demo with other AOTs ([MODEL_ZOO.md](MODEL_ZOO.md)) by setting `--model` (model type) and `--ckpt_path` (checkpoint path).
56
+
57
+ Two scenarios from [VSPW](https://www.vspwdataset.com/home) are supplied in [datasets/Demo](datasets/Demo):
58
+
59
+ - 1001_3iEIq5HBY1s: 44 objects. 1080P.
60
+ - 1007_YCTBBdbKSSg: 43 objects. 1080P.
61
+
62
+ Results:
63
+
64
+ <img src="source/1001_3iEIq5HBY1s.gif" width="45%"/> <img src="source/1007_YCTBBdbKSSg.gif" width="45%"/>
65
+
66
+
67
+ ## Getting Started
68
+ 0. Prepare a valid environment follow the [requirements](#requirements).
69
+
70
+ 1. Prepare datasets:
71
+
72
+ Please follow the below instruction to prepare datasets in each corresponding folder.
73
+ * **Static**
74
+
75
+ [datasets/Static](datasets/Static): pre-training dataset with static images. Guidance can be found in [AFB-URR](https://github.com/xmlyqing00/AFB-URR), which we referred to in the implementation of the pre-training.
76
+ * **YouTube-VOS**
77
+
78
+ A commonly-used large-scale VOS dataset.
79
+
80
+ [datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation.
81
+
82
+ [datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation.
83
+
84
+ * **DAVIS**
85
+
86
+ A commonly-used small-scale VOS dataset.
87
+
88
+ [datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required.
89
+
90
+
91
+ 2. Prepare ImageNet pre-trained encoders
92
+
93
+ Select and download below checkpoints into [pretrain_models](pretrain_models):
94
+
95
+ - [MobileNet-V2](https://download.pytorch.org/models/mobilenet_v2-b0353104.pth) (default encoder)
96
+ - [MobileNet-V3](https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth)
97
+ - [ResNet-50](https://download.pytorch.org/models/resnet50-0676ba61.pth)
98
+ - [ResNet-101](https://download.pytorch.org/models/resnet101-63fe2227.pth)
99
+ - [ResNeSt-50](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest50-528c19ca.pth)
100
+ - [ResNeSt-101](https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth)
101
+ - [Swin-Base](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)
102
+
103
+ The current default training configs are not optimized for encoders larger than ResNet-50. If you want to use larger encoders, we recommend early stopping the main-training stage at 80,000 iterations (100,000 in default) to avoid over-fitting on the seen classes of YouTube-VOS.
104
+
105
+
106
+
107
+ 3. Training and Evaluation
108
+
109
+ The [example script](train_eval.sh) will train AOTT with 2 stages using 4 GPUs and auto-mixed precision (`--amp`). The first stage is a pre-training stage using `Static` dataset, and the second stage is a main-training stage, which uses both `YouTube-VOS 2019 train` and `DAVIS-2017 train` for training, resulting in a model that can generalize to different domains (YouTube-VOS and DAVIS) and different frame rates (6fps, 24fps, and 30fps).
110
+
111
+ Notably, you can use only the `YouTube-VOS 2019 train` split in the second stage by changing `pre_ytb_dav` to `pre_ytb`, which leads to better YouTube-VOS performance on unseen classes. Besides, if you don't want to do the first stage, you can start the training from stage `ytb`, but the performance will drop about 1~2% absolutely.
112
+
113
+ After the training is finished (about 0.6 days for each stage with 4 Tesla V100 GPUs), the [example script](train_eval.sh) will evaluate the model on YouTube-VOS and DAVIS, and the results will be packed into Zip files. For calculating scores, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev).
114
+
115
+ ## Adding your own dataset
116
+ Coming
117
+
118
+ ## Troubleshooting
119
+ Waiting
120
+
121
+ ## TODO
122
+ - [ ] Code documentation
123
+ - [ ] Adding your own dataset
124
+ - [ ] Results with test-time augmentations in Model Zoo
125
+ - [ ] Support gradient accumulation
126
+ - [x] Demo tool
127
+
128
+ ## Citations
129
+ Please consider citing the related paper(s) in your publications if it helps your research.
130
+ ```
131
+ @inproceedings{yang2022deaot,
132
+ title={Decoupling Features in Hierarchical Propagation for Video Object Segmentation},
133
+ author={Yang, Zongxin and Yang, Yi},
134
+ booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
135
+ year={2022}
136
+ }
137
+ @article{yang2021aost,
138
+ title={Scalable Multi-object Identification for Video Object Segmentation},
139
+ author={Yang, Zongxin and Miao, Jiaxu and Wang, Xiaohan and Wei, Yunchao and Yang, Yi},
140
+ journal={arXiv preprint arXiv:2203.11442},
141
+ year={2022}
142
+ }
143
+ @inproceedings{yang2021aot,
144
+ title={Associating Objects with Transformers for Video Object Segmentation},
145
+ author={Yang, Zongxin and Wei, Yunchao and Yang, Yi},
146
+ booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
147
+ year={2021}
148
+ }
149
+ ```
150
+
151
+ ## License
152
+ This project is released under the BSD-3-Clause license. See [LICENSE](LICENSE) for additional details.
aot/__init__.py ADDED
File without changes
aot/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (181 Bytes). View file
 
aot/configs/__pycache__/default.cpython-310.pyc ADDED
Binary file (4.29 kB). View file
 
aot/configs/__pycache__/pre_ytb_dav.cpython-310.pyc ADDED
Binary file (943 Bytes). View file
 
aot/configs/default.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib
3
+
4
+
5
+ class DefaultEngineConfig():
6
+ def __init__(self, exp_name='default', model='aott'):
7
+ model_cfg = importlib.import_module('configs.models.' +
8
+ model).ModelConfig()
9
+ self.__dict__.update(model_cfg.__dict__) # add model config
10
+
11
+ self.EXP_NAME = exp_name + '_' + self.MODEL_NAME
12
+
13
+ self.STAGE_NAME = 'YTB'
14
+
15
+ self.DATASETS = ['youtubevos']
16
+ self.DATA_WORKERS = 8
17
+ self.DATA_RANDOMCROP = (465,
18
+ 465) if self.MODEL_ALIGN_CORNERS else (464,
19
+ 464)
20
+ self.DATA_RANDOMFLIP = 0.5
21
+ self.DATA_MAX_CROP_STEPS = 10
22
+ self.DATA_SHORT_EDGE_LEN = 480
23
+ self.DATA_MIN_SCALE_FACTOR = 0.7
24
+ self.DATA_MAX_SCALE_FACTOR = 1.3
25
+ self.DATA_RANDOM_REVERSE_SEQ = True
26
+ self.DATA_SEQ_LEN = 5
27
+ self.DATA_DAVIS_REPEAT = 5
28
+ self.DATA_RANDOM_GAP_DAVIS = 12 # max frame interval between two sampled frames for DAVIS (24fps)
29
+ self.DATA_RANDOM_GAP_YTB = 3 # max frame interval between two sampled frames for YouTube-VOS (6fps)
30
+ self.DATA_DYNAMIC_MERGE_PROB = 0.3
31
+
32
+ self.PRETRAIN = True
33
+ self.PRETRAIN_FULL = False # if False, load encoder only
34
+ self.PRETRAIN_MODEL = './data_wd/pretrain_model/mobilenet_v2.pth'
35
+ # self.PRETRAIN_MODEL = './pretrain_models/mobilenet_v2-b0353104.pth'
36
+
37
+ self.TRAIN_TOTAL_STEPS = 100000
38
+ self.TRAIN_START_STEP = 0
39
+ self.TRAIN_WEIGHT_DECAY = 0.07
40
+ self.TRAIN_WEIGHT_DECAY_EXCLUSIVE = {
41
+ # 'encoder.': 0.01
42
+ }
43
+ self.TRAIN_WEIGHT_DECAY_EXEMPTION = [
44
+ 'absolute_pos_embed', 'relative_position_bias_table',
45
+ 'relative_emb_v', 'conv_out'
46
+ ]
47
+ self.TRAIN_LR = 2e-4
48
+ self.TRAIN_LR_MIN = 2e-5 if 'mobilenetv2' in self.MODEL_ENCODER else 1e-5
49
+ self.TRAIN_LR_POWER = 0.9
50
+ self.TRAIN_LR_ENCODER_RATIO = 0.1
51
+ self.TRAIN_LR_WARM_UP_RATIO = 0.05
52
+ self.TRAIN_LR_COSINE_DECAY = False
53
+ self.TRAIN_LR_RESTART = 1
54
+ self.TRAIN_LR_UPDATE_STEP = 1
55
+ self.TRAIN_AUX_LOSS_WEIGHT = 1.0
56
+ self.TRAIN_AUX_LOSS_RATIO = 1.0
57
+ self.TRAIN_OPT = 'adamw'
58
+ self.TRAIN_SGD_MOMENTUM = 0.9
59
+ self.TRAIN_GPUS = 4
60
+ self.TRAIN_BATCH_SIZE = 16
61
+ self.TRAIN_TBLOG = False
62
+ self.TRAIN_TBLOG_STEP = 50
63
+ self.TRAIN_LOG_STEP = 20
64
+ self.TRAIN_IMG_LOG = True
65
+ self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15
66
+ self.TRAIN_SEQ_TRAINING_FREEZE_PARAMS = ['patch_wise_id_bank']
67
+ self.TRAIN_SEQ_TRAINING_START_RATIO = 0.5
68
+ self.TRAIN_HARD_MINING_RATIO = 0.5
69
+ self.TRAIN_EMA_RATIO = 0.1
70
+ self.TRAIN_CLIP_GRAD_NORM = 5.
71
+ self.TRAIN_SAVE_STEP = 5000
72
+ self.TRAIN_MAX_KEEP_CKPT = 8
73
+ self.TRAIN_RESUME = False
74
+ self.TRAIN_RESUME_CKPT = None
75
+ self.TRAIN_RESUME_STEP = 0
76
+ self.TRAIN_AUTO_RESUME = True
77
+ self.TRAIN_DATASET_FULL_RESOLUTION = False
78
+ self.TRAIN_ENABLE_PREV_FRAME = False
79
+ self.TRAIN_ENCODER_FREEZE_AT = 2
80
+ self.TRAIN_LSTT_EMB_DROPOUT = 0.
81
+ self.TRAIN_LSTT_ID_DROPOUT = 0.
82
+ self.TRAIN_LSTT_DROPPATH = 0.1
83
+ self.TRAIN_LSTT_DROPPATH_SCALING = False
84
+ self.TRAIN_LSTT_DROPPATH_LST = False
85
+ self.TRAIN_LSTT_LT_DROPOUT = 0.
86
+ self.TRAIN_LSTT_ST_DROPOUT = 0.
87
+
88
+ self.TEST_GPU_ID = 0
89
+ self.TEST_GPU_NUM = 1
90
+ self.TEST_FRAME_LOG = False
91
+ self.TEST_DATASET = 'youtubevos'
92
+ self.TEST_DATASET_FULL_RESOLUTION = False
93
+ self.TEST_DATASET_SPLIT = 'val'
94
+ self.TEST_CKPT_PATH = None
95
+ # if "None", evaluate the latest checkpoint.
96
+ self.TEST_CKPT_STEP = None
97
+ self.TEST_FLIP = False
98
+ self.TEST_MULTISCALE = [1]
99
+ self.TEST_MAX_SHORT_EDGE = None
100
+ self.TEST_MAX_LONG_EDGE = 800 * 1.3
101
+ self.TEST_WORKERS = 4
102
+
103
+ # GPU distribution
104
+ self.DIST_ENABLE = True
105
+ self.DIST_BACKEND = "nccl" # "gloo"
106
+ self.DIST_URL = "tcp://127.0.0.1:13241"
107
+ self.DIST_START_GPU = 0
108
+
109
+ def init_dir(self):
110
+ self.DIR_DATA = '../VOS02/datasets'#'./datasets'
111
+ self.DIR_DAVIS = os.path.join(self.DIR_DATA, 'DAVIS')
112
+ self.DIR_YTB = os.path.join(self.DIR_DATA, 'YTB')
113
+ self.DIR_STATIC = os.path.join(self.DIR_DATA, 'Static')
114
+
115
+ self.DIR_ROOT = './'#'./data_wd/youtube_vos_jobs'
116
+
117
+ self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME,
118
+ self.STAGE_NAME)
119
+ self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt')
120
+ self.DIR_EMA_CKPT = os.path.join(self.DIR_RESULT, 'ema_ckpt')
121
+ self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log')
122
+ self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard')
123
+ # self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img')
124
+ # self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval')
125
+ self.DIR_IMG_LOG = './img_logs'
126
+ self.DIR_EVALUATION = './results'
127
+
128
+ for path in [
129
+ self.DIR_RESULT, self.DIR_CKPT, self.DIR_EMA_CKPT,
130
+ self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG,
131
+ self.DIR_TB_LOG
132
+ ]:
133
+ if not os.path.isdir(path):
134
+ try:
135
+ os.makedirs(path)
136
+ except Exception as inst:
137
+ print(inst)
138
+ print('Failed to make dir: {}.'.format(path))
aot/configs/models/__pycache__/default.cpython-310.pyc ADDED
Binary file (1.22 kB). View file
 
aot/configs/models/__pycache__/r50_aotl.cpython-310.pyc ADDED
Binary file (873 Bytes). View file
 
aot/configs/models/aotb.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .default import DefaultModelConfig
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'AOTB'
8
+
9
+ self.MODEL_LSTT_NUM = 3
aot/configs/models/aotl.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .default import DefaultModelConfig
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'AOTL'
8
+
9
+ self.MODEL_LSTT_NUM = 3
10
+
11
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
12
+
13
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/models/aots.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .default import DefaultModelConfig
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'AOTS'
8
+
9
+ self.MODEL_LSTT_NUM = 2
aot/configs/models/aott.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .default import DefaultModelConfig
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'AOTT'
aot/configs/models/deaotb.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .default_deaot import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'DeAOTB'
8
+
9
+ self.MODEL_LSTT_NUM = 3
aot/configs/models/deaotl.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default_deaot import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'DeAOTL'
8
+
9
+ self.MODEL_LSTT_NUM = 3
10
+
11
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
12
+
13
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/models/deaots.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .default_deaot import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'DeAOTS'
8
+
9
+ self.MODEL_LSTT_NUM = 2
aot/configs/models/deaott.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .default_deaot import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'DeAOTT'
aot/configs/models/default.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class DefaultModelConfig():
2
+ def __init__(self):
3
+ self.MODEL_NAME = 'AOTDefault'
4
+
5
+ self.MODEL_VOS = 'aot'
6
+ self.MODEL_ENGINE = 'aotengine'
7
+ self.MODEL_ALIGN_CORNERS = True
8
+ self.MODEL_ENCODER = 'mobilenetv2'
9
+ self.MODEL_ENCODER_PRETRAIN = './pretrain_models/mobilenet_v2-b0353104.pth'
10
+ self.MODEL_ENCODER_DIM = [24, 32, 96, 1280] # 4x, 8x, 16x, 16x
11
+ self.MODEL_ENCODER_EMBEDDING_DIM = 256
12
+ self.MODEL_DECODER_INTERMEDIATE_LSTT = True
13
+ self.MODEL_FREEZE_BN = True
14
+ self.MODEL_FREEZE_BACKBONE = False
15
+ self.MODEL_MAX_OBJ_NUM = 10
16
+ self.MODEL_SELF_HEADS = 8
17
+ self.MODEL_ATT_HEADS = 8
18
+ self.MODEL_LSTT_NUM = 1
19
+ self.MODEL_EPSILON = 1e-5
20
+ self.MODEL_USE_PREV_PROB = False
21
+
22
+ self.TRAIN_LONG_TERM_MEM_GAP = 9999
23
+ self.TRAIN_AUG_TYPE = 'v1'
24
+
25
+ self.TEST_LONG_TERM_MEM_GAP = 9999
26
+
27
+ self.TEST_SHORT_TERM_MEM_SKIP = 1
aot/configs/models/default_deaot.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default import DefaultModelConfig as BaseConfig
2
+
3
+
4
+ class DefaultModelConfig(BaseConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'DeAOTDefault'
8
+
9
+ self.MODEL_VOS = 'deaot'
10
+ self.MODEL_ENGINE = 'deaotengine'
11
+
12
+ self.MODEL_DECODER_INTERMEDIATE_LSTT = False
13
+
14
+ self.MODEL_SELF_HEADS = 1
15
+ self.MODEL_ATT_HEADS = 1
16
+
17
+ self.TRAIN_AUG_TYPE = 'v2'
aot/configs/models/r101_aotl.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'R101_AOTL'
8
+
9
+ self.MODEL_ENCODER = 'resnet101'
10
+ self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet101-63fe2227.pth' # https://download.pytorch.org/models/resnet101-63fe2227.pth
11
+ self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
12
+ self.MODEL_LSTT_NUM = 3
13
+
14
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
15
+
16
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/models/r50_aotl.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'R50_AOTL'
8
+
9
+ self.MODEL_ENCODER = 'resnet50'
10
+ self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnet50-0676ba61.pth' # https://download.pytorch.org/models/resnet50-0676ba61.pth
11
+ self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
12
+ self.MODEL_LSTT_NUM = 3
13
+
14
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
15
+
16
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/models/r50_deaotl.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default_deaot import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'R50_DeAOTL'
8
+
9
+ self.MODEL_ENCODER = 'resnet50'
10
+ self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
11
+
12
+ self.MODEL_LSTT_NUM = 3
13
+
14
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
15
+
16
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/models/rs101_aotl.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'R101_AOTL'
8
+
9
+ self.MODEL_ENCODER = 'resnest101'
10
+ self.MODEL_ENCODER_PRETRAIN = './pretrain_models/resnest101-22405ba7.pth' # https://github.com/zhanghang1989/ResNeSt/releases/download/weights_step1/resnest101-22405ba7.pth
11
+ self.MODEL_ENCODER_DIM = [256, 512, 1024, 1024] # 4x, 8x, 16x, 16x
12
+ self.MODEL_LSTT_NUM = 3
13
+
14
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
15
+
16
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/models/swinb_aotl.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'SwinB_AOTL'
8
+
9
+ self.MODEL_ENCODER = 'swin_base'
10
+ self.MODEL_ENCODER_PRETRAIN = './pretrain_models/swin_base_patch4_window7_224_22k.pth' # https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
11
+ self.MODEL_ALIGN_CORNERS = False
12
+ self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x
13
+ self.MODEL_LSTT_NUM = 3
14
+
15
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
16
+
17
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/models/swinb_deaotl.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default_deaot import DefaultModelConfig
2
+
3
+
4
+ class ModelConfig(DefaultModelConfig):
5
+ def __init__(self):
6
+ super().__init__()
7
+ self.MODEL_NAME = 'SwinB_DeAOTL'
8
+
9
+ self.MODEL_ENCODER = 'swin_base'
10
+ self.MODEL_ALIGN_CORNERS = False
11
+ self.MODEL_ENCODER_DIM = [128, 256, 512, 512] # 4x, 8x, 16x, 16x
12
+
13
+ self.MODEL_LSTT_NUM = 3
14
+
15
+ self.TRAIN_LONG_TERM_MEM_GAP = 2
16
+
17
+ self.TEST_LONG_TERM_MEM_GAP = 5
aot/configs/pre.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .default import DefaultEngineConfig
2
+
3
+
4
+ class EngineConfig(DefaultEngineConfig):
5
+ def __init__(self, exp_name='default', model='AOTT'):
6
+ super().__init__(exp_name, model)
7
+ self.STAGE_NAME = 'PRE'
8
+
9
+ self.init_dir()
10
+
11
+ self.DATASETS = ['static']
12
+
13
+ self.DATA_DYNAMIC_MERGE_PROB = 1.0
14
+
15
+ self.TRAIN_LR = 4e-4
16
+ self.TRAIN_LR_MIN = 2e-5
17
+ self.TRAIN_WEIGHT_DECAY = 0.03
18
+ self.TRAIN_SEQ_TRAINING_START_RATIO = 1.0
19
+ self.TRAIN_AUX_LOSS_RATIO = 0.1
aot/configs/pre_dav.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .default import DefaultEngineConfig
3
+
4
+
5
+ class EngineConfig(DefaultEngineConfig):
6
+ def __init__(self, exp_name='default', model='AOTT'):
7
+ super().__init__(exp_name, model)
8
+ self.STAGE_NAME = 'PRE_DAV'
9
+
10
+ self.init_dir()
11
+
12
+ self.DATASETS = ['davis2017']
13
+
14
+ self.TRAIN_TOTAL_STEPS = 50000
15
+
16
+ pretrain_stage = 'PRE'
17
+ pretrain_ckpt = 'save_step_100000.pth'
18
+ self.PRETRAIN_FULL = True # if False, load encoder only
19
+ self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result',
20
+ self.EXP_NAME, pretrain_stage,
21
+ 'ema_ckpt', pretrain_ckpt)
aot/configs/pre_ytb.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .default import DefaultEngineConfig
3
+
4
+
5
+ class EngineConfig(DefaultEngineConfig):
6
+ def __init__(self, exp_name='default', model='AOTT'):
7
+ super().__init__(exp_name, model)
8
+ self.STAGE_NAME = 'PRE_YTB'
9
+
10
+ self.init_dir()
11
+
12
+ pretrain_stage = 'PRE'
13
+ pretrain_ckpt = 'save_step_100000.pth'
14
+ self.PRETRAIN_FULL = True # if False, load encoder only
15
+ self.PRETRAIN_MODEL = os.path.join(self.DIR_ROOT, 'result',
16
+ self.EXP_NAME, pretrain_stage,
17
+ 'ema_ckpt', pretrain_ckpt)