Spaces:
Runtime error
Runtime error
meow
commited on
Commit
•
d6d3a5b
1
Parent(s):
f9fd2fa
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +21 -0
- README.md +462 -12
- app.py +52 -0
- cog.yaml +38 -0
- common/.gitignore +1 -0
- common/___init___.py +0 -0
- common/abstract_pl.py +180 -0
- common/args_utils.py +15 -0
- common/body_models.py +146 -0
- common/camera.py +474 -0
- common/comet_utils.py +158 -0
- common/data_utils.py +371 -0
- common/ld_utils.py +116 -0
- common/list_utils.py +52 -0
- common/mesh.py +94 -0
- common/metrics.py +51 -0
- common/np_utils.py +7 -0
- common/object_tensors.py +293 -0
- common/pl_utils.py +63 -0
- common/rend_utils.py +139 -0
- common/rot.py +782 -0
- common/sys_utils.py +44 -0
- common/thing.py +66 -0
- common/torch_utils.py +212 -0
- common/transforms.py +356 -0
- common/viewer.py +287 -0
- common/vis_utils.py +129 -0
- common/xdict.py +288 -0
- data_loaders/.DS_Store +0 -0
- data_loaders/__pycache__/get_data.cpython-38.pyc +0 -0
- data_loaders/__pycache__/tensors.cpython-38.pyc +0 -0
- data_loaders/get_data.py +178 -0
- data_loaders/humanml/.DS_Store +0 -0
- data_loaders/humanml/README.md +1 -0
- data_loaders/humanml/common/__pycache__/quaternion.cpython-38.pyc +0 -0
- data_loaders/humanml/common/__pycache__/skeleton.cpython-38.pyc +0 -0
- data_loaders/humanml/common/quaternion.py +423 -0
- data_loaders/humanml/common/skeleton.py +199 -0
- data_loaders/humanml/data/__init__.py +0 -0
- data_loaders/humanml/data/__pycache__/__init__.cpython-38.pyc +0 -0
- data_loaders/humanml/data/__pycache__/dataset.cpython-38.pyc +0 -0
- data_loaders/humanml/data/__pycache__/dataset_ours.cpython-38.pyc +0 -0
- data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-38.pyc +0 -0
- data_loaders/humanml/data/__pycache__/utils.cpython-38.pyc +0 -0
- data_loaders/humanml/data/dataset.py +795 -0
- data_loaders/humanml/data/dataset_ours.py +0 -0
- data_loaders/humanml/data/dataset_ours_single_seq.py +0 -0
- data_loaders/humanml/data/utils.py +507 -0
- data_loaders/humanml/motion_loaders/__init__.py +0 -0
- data_loaders/humanml/motion_loaders/__pycache__/__init__.cpython-38.pyc +0 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Guy Tevet
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,462 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MDM: Human Motion Diffusion Model
|
2 |
+
|
3 |
+
|
4 |
+
data in what format and data in this foramt
|
5 |
+
|
6 |
+
|
7 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/human-motion-diffusion-model/motion-synthesis-on-humanact12)](https://paperswithcode.com/sota/motion-synthesis-on-humanact12?p=human-motion-diffusion-model)
|
8 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/human-motion-diffusion-model/motion-synthesis-on-humanml3d)](https://paperswithcode.com/sota/motion-synthesis-on-humanml3d?p=human-motion-diffusion-model)
|
9 |
+
[![arXiv](https://img.shields.io/badge/arXiv-<2209.14916>-<COLOR>.svg)](https://arxiv.org/abs/2209.14916)
|
10 |
+
|
11 |
+
<a href="https://replicate.com/arielreplicate/motion_diffusion_model"><img src="https://replicate.com/arielreplicate/motion_diffusion_model/badge"></a>
|
12 |
+
|
13 |
+
The official PyTorch implementation of the paper [**"Human Motion Diffusion Model"**](https://arxiv.org/abs/2209.14916).
|
14 |
+
|
15 |
+
Please visit our [**webpage**](https://guytevet.github.io/mdm-page/) for more details.
|
16 |
+
|
17 |
+
![teaser](https://github.com/GuyTevet/mdm-page/raw/main/static/figures/github.gif)
|
18 |
+
|
19 |
+
#### Bibtex
|
20 |
+
If you find this code useful in your research, please cite:
|
21 |
+
|
22 |
+
```
|
23 |
+
@article{tevet2022human,
|
24 |
+
title={Human Motion Diffusion Model},
|
25 |
+
author={Tevet, Guy and Raab, Sigal and Gordon, Brian and Shafir, Yonatan and Bermano, Amit H and Cohen-Or, Daniel},
|
26 |
+
journal={arXiv preprint arXiv:2209.14916},
|
27 |
+
year={2022}
|
28 |
+
}
|
29 |
+
```
|
30 |
+
|
31 |
+
## News
|
32 |
+
|
33 |
+
📢 **23/Nov/22** - Fixed evaluation issue (#42) - Please pull and run `bash prepare/download_t2m_evaluators.sh` from the top of the repo to adapt.
|
34 |
+
|
35 |
+
📢 **4/Nov/22** - Added sampling, training and evaluation of unconstrained tasks.
|
36 |
+
Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_unconstrained_assets.sh; conda install -y -c anaconda scikit-learn
|
37 |
+
` to adapt.
|
38 |
+
|
39 |
+
📢 **3/Nov/22** - Added in-between and upper-body editing.
|
40 |
+
|
41 |
+
📢 **31/Oct/22** - Added sampling, training and evaluation of action-to-motion tasks.
|
42 |
+
|
43 |
+
📢 **9/Oct/22** - Added training and evaluation scripts.
|
44 |
+
Note slight env changes adapting to the new code. If you already have an installed environment, run `bash prepare/download_glove.sh; pip install clearml` to adapt.
|
45 |
+
|
46 |
+
📢 **6/Oct/22** - First release - sampling and rendering using pre-trained models.
|
47 |
+
|
48 |
+
|
49 |
+
## Getting started
|
50 |
+
|
51 |
+
This code was tested on `Ubuntu 18.04.5 LTS` and requires:
|
52 |
+
|
53 |
+
* Python 3.7
|
54 |
+
* conda3 or miniconda3
|
55 |
+
* CUDA capable GPU (one is enough)
|
56 |
+
|
57 |
+
### 1. Setup environment
|
58 |
+
|
59 |
+
Install ffmpeg (if not already installed):
|
60 |
+
|
61 |
+
```shell
|
62 |
+
sudo apt update
|
63 |
+
sudo apt install ffmpeg
|
64 |
+
```
|
65 |
+
For windows use [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) instead.
|
66 |
+
|
67 |
+
Setup conda env:
|
68 |
+
```shell
|
69 |
+
conda env create -f environment.yml
|
70 |
+
conda activate mdm
|
71 |
+
python -m spacy download en_core_web_sm
|
72 |
+
pip install git+https://github.com/openai/CLIP.git
|
73 |
+
```
|
74 |
+
|
75 |
+
Download dependencies:
|
76 |
+
|
77 |
+
<details>
|
78 |
+
<summary><b>Text to Motion</b></summary>
|
79 |
+
|
80 |
+
```bash
|
81 |
+
bash prepare/download_smpl_files.sh
|
82 |
+
bash prepare/download_glove.sh
|
83 |
+
bash prepare/download_t2m_evaluators.sh
|
84 |
+
```
|
85 |
+
</details>
|
86 |
+
|
87 |
+
<details>
|
88 |
+
<summary><b>Action to Motion</b></summary>
|
89 |
+
|
90 |
+
```bash
|
91 |
+
bash prepare/download_smpl_files.sh
|
92 |
+
bash prepare/download_recognition_models.sh
|
93 |
+
```
|
94 |
+
</details>
|
95 |
+
|
96 |
+
<details>
|
97 |
+
<summary><b>Unconstrained</b></summary>
|
98 |
+
|
99 |
+
```bash
|
100 |
+
bash prepare/download_smpl_files.sh
|
101 |
+
bash prepare/download_recognition_models.sh
|
102 |
+
bash prepare/download_recognition_unconstrained_models.sh
|
103 |
+
```
|
104 |
+
</details>
|
105 |
+
|
106 |
+
### 2. Get data
|
107 |
+
|
108 |
+
<details>
|
109 |
+
<summary><b>Text to Motion</b></summary>
|
110 |
+
|
111 |
+
There are two paths to get the data:
|
112 |
+
|
113 |
+
(a) **Go the easy way if** you just want to generate text-to-motion (excluding editing which does require motion capture data)
|
114 |
+
|
115 |
+
(b) **Get full data** to train and evaluate the model.
|
116 |
+
|
117 |
+
|
118 |
+
#### a. The easy way (text only)
|
119 |
+
|
120 |
+
**HumanML3D** - Clone HumanML3D, then copy the data dir to our repository:
|
121 |
+
|
122 |
+
```shell
|
123 |
+
cd ..
|
124 |
+
git clone https://github.com/EricGuo5513/HumanML3D.git
|
125 |
+
unzip ./HumanML3D/HumanML3D/texts.zip -d ./HumanML3D/HumanML3D/
|
126 |
+
cp -r HumanML3D/HumanML3D motion-diffusion-model/dataset/HumanML3D
|
127 |
+
cd motion-diffusion-model
|
128 |
+
```
|
129 |
+
|
130 |
+
|
131 |
+
#### b. Full data (text + motion capture)
|
132 |
+
|
133 |
+
**HumanML3D** - Follow the instructions in [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git),
|
134 |
+
then copy the result dataset to our repository:
|
135 |
+
|
136 |
+
```shell
|
137 |
+
cp -r ../HumanML3D/HumanML3D ./dataset/HumanML3D
|
138 |
+
```
|
139 |
+
|
140 |
+
**KIT** - Download from [HumanML3D](https://github.com/EricGuo5513/HumanML3D.git) (no processing needed this time) and the place result in `./dataset/KIT-ML`
|
141 |
+
</details>
|
142 |
+
|
143 |
+
<details>
|
144 |
+
<summary><b>Action to Motion</b></summary>
|
145 |
+
|
146 |
+
**UESTC, HumanAct12**
|
147 |
+
```bash
|
148 |
+
bash prepare/download_a2m_datasets.sh
|
149 |
+
```
|
150 |
+
</details>
|
151 |
+
|
152 |
+
<details>
|
153 |
+
<summary><b>Unconstrained</b></summary>
|
154 |
+
|
155 |
+
**HumanAct12**
|
156 |
+
```bash
|
157 |
+
bash prepare/download_unconstrained_datasets.sh
|
158 |
+
```
|
159 |
+
</details>
|
160 |
+
|
161 |
+
### 3. Download the pretrained models
|
162 |
+
|
163 |
+
Download the model(s) you wish to use, then unzip and place them in `./save/`.
|
164 |
+
|
165 |
+
<details>
|
166 |
+
<summary><b>Text to Motion</b></summary>
|
167 |
+
|
168 |
+
**You need only the first one.**
|
169 |
+
|
170 |
+
**HumanML3D**
|
171 |
+
|
172 |
+
[humanml-encoder-512](https://drive.google.com/file/d/1PE0PK8e5a5j-7-Xhs5YET5U5pGh0c821/view?usp=sharing) (best model)
|
173 |
+
|
174 |
+
[humanml-decoder-512](https://drive.google.com/file/d/1q3soLadvVh7kJuJPd2cegMNY2xVuVudj/view?usp=sharing)
|
175 |
+
|
176 |
+
[humanml-decoder-with-emb-512](https://drive.google.com/file/d/1GnsW0K3UjuOkNkAWmjrGIUmeDDZrmPE5/view?usp=sharing)
|
177 |
+
|
178 |
+
**KIT**
|
179 |
+
|
180 |
+
[kit-encoder-512](https://drive.google.com/file/d/1SHCRcE0es31vkJMLGf9dyLe7YsWj7pNL/view?usp=sharing)
|
181 |
+
|
182 |
+
</details>
|
183 |
+
|
184 |
+
<details>
|
185 |
+
<summary><b>Action to Motion</b></summary>
|
186 |
+
|
187 |
+
**UESTC**
|
188 |
+
|
189 |
+
[uestc](https://drive.google.com/file/d/1goB2DJK4B-fLu2QmqGWKAqWGMTAO6wQ6/view?usp=sharing)
|
190 |
+
|
191 |
+
[uestc_no_fc](https://drive.google.com/file/d/1fpv3mR-qP9CYCsi9CrQhFqlLavcSQky6/view?usp=sharing)
|
192 |
+
|
193 |
+
**HumanAct12**
|
194 |
+
|
195 |
+
[humanact12](https://drive.google.com/file/d/154X8_Lgpec6Xj0glEGql7FVKqPYCdBFO/view?usp=sharing)
|
196 |
+
|
197 |
+
[humanact12_no_fc](https://drive.google.com/file/d/1frKVMBYNiN5Mlq7zsnhDBzs9vGJvFeiQ/view?usp=sharing)
|
198 |
+
|
199 |
+
</details>
|
200 |
+
|
201 |
+
<details>
|
202 |
+
<summary><b>Unconstrained</b></summary>
|
203 |
+
|
204 |
+
**HumanAct12**
|
205 |
+
|
206 |
+
[humanact12_unconstrained](https://drive.google.com/file/d/1uG68m200pZK3pD-zTmPXu5XkgNpx_mEx/view?usp=share_link)
|
207 |
+
|
208 |
+
</details>
|
209 |
+
|
210 |
+
|
211 |
+
## Example Usage
|
212 |
+
|
213 |
+
|
214 |
+
example usage and results on TACO dataset
|
215 |
+
|
216 |
+
|
217 |
+
| Input | Result | Overlayed |
|
218 |
+
| :----------------------: | :---------------------: | :-----------------------: |
|
219 |
+
| ![](assets/taco-20231104_017-src-a.gif) | ![](assets/taco-20231104_017-res-a.gif) | ![](assets/taco-20231104_017-overlayed-a.gif) |
|
220 |
+
|
221 |
+
|
222 |
+
Follow steps below to reproduce the above result.
|
223 |
+
|
224 |
+
1. **Denoising**
|
225 |
+
```bash
|
226 |
+
bash scripts/val_examples/predict_taco_rndseed_spatial_20231104_017.sh
|
227 |
+
```
|
228 |
+
Ten random seeds will be utilizd for prediction. The predicted results will be saved in the folder `./data/taco/result`.
|
229 |
+
2. **Mesh reconstruction**
|
230 |
+
```bash
|
231 |
+
bash scripts/val_examples/reconstruct_taco_20231104_017.sh
|
232 |
+
```
|
233 |
+
Results will be saved under the same folder with the above step.
|
234 |
+
3. **Extracting results and visualization**
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
<details>
|
239 |
+
<summary><b>Text to Motion</b></summary>
|
240 |
+
|
241 |
+
### Generate from test set prompts
|
242 |
+
|
243 |
+
```shell
|
244 |
+
python -m sample.generate --model_path ./save/humanml_trans_enc_512/model000200000.pt --num_samples 10 --num_repetitions 3
|
245 |
+
```
|
246 |
+
|
247 |
+
### Generate from your text file
|
248 |
+
|
249 |
+
```shell
|
250 |
+
python -m sample.generate --model_path ./save/humanml_trans_enc_512/model000200000.pt --input_text ./assets/example_text_prompts.txt
|
251 |
+
```
|
252 |
+
|
253 |
+
### Generate a single prompt
|
254 |
+
|
255 |
+
```shell
|
256 |
+
python -m sample.generate --model_path ./save/humanml_trans_enc_512/model000200000.pt --text_prompt "the person walked forward and is picking up his toolbox."
|
257 |
+
```
|
258 |
+
</details>
|
259 |
+
|
260 |
+
<details>
|
261 |
+
<summary><b>Action to Motion</b></summary>
|
262 |
+
|
263 |
+
### Generate from test set actions
|
264 |
+
|
265 |
+
```shell
|
266 |
+
python -m sample.generate --model_path ./save/humanact12/model000350000.pt --num_samples 10 --num_repetitions 3
|
267 |
+
```
|
268 |
+
|
269 |
+
### Generate from your actions file
|
270 |
+
|
271 |
+
```shell
|
272 |
+
python -m sample.generate --model_path ./save/humanact12/model000350000.pt --action_file ./assets/example_action_names_humanact12.txt
|
273 |
+
```
|
274 |
+
|
275 |
+
### Generate a single action
|
276 |
+
|
277 |
+
```shell
|
278 |
+
python -m sample.generate --model_path ./save/humanact12/model000350000.pt --text_prompt "drink"
|
279 |
+
```
|
280 |
+
</details>
|
281 |
+
|
282 |
+
<details>
|
283 |
+
<summary><b>Unconstrained</b></summary>
|
284 |
+
|
285 |
+
```shell
|
286 |
+
python -m sample.generate --model_path ./save/unconstrained/model000450000.pt --num_samples 10 --num_repetitions 3
|
287 |
+
```
|
288 |
+
|
289 |
+
By abuse of notation, (num_samples * num_repetitions) samples are created, and are visually organized in a display of num_samples rows and num_repetitions columns.
|
290 |
+
|
291 |
+
</details>
|
292 |
+
|
293 |
+
**You may also define:**
|
294 |
+
* `--device` id.
|
295 |
+
* `--seed` to sample different prompts.
|
296 |
+
* `--motion_length` (text-to-motion only) in seconds (maximum is 9.8[sec]).
|
297 |
+
|
298 |
+
**Running those will get you:**
|
299 |
+
|
300 |
+
* `results.npy` file with text prompts and xyz positions of the generated animation
|
301 |
+
* `sample##_rep##.mp4` - a stick figure animation for each generated motion.
|
302 |
+
|
303 |
+
It will look something like this:
|
304 |
+
|
305 |
+
![example](assets/example_stick_fig.gif)
|
306 |
+
|
307 |
+
You can stop here, or render the SMPL mesh using the following script.
|
308 |
+
|
309 |
+
### Render SMPL mesh
|
310 |
+
|
311 |
+
To create SMPL mesh per frame run:
|
312 |
+
|
313 |
+
```shell
|
314 |
+
python -m visualize.render_mesh --input_path /path/to/mp4/stick/figure/file
|
315 |
+
```
|
316 |
+
|
317 |
+
**This script outputs:**
|
318 |
+
* `sample##_rep##_smpl_params.npy` - SMPL parameters (thetas, root translations, vertices and faces)
|
319 |
+
* `sample##_rep##_obj` - Mesh per frame in `.obj` format.
|
320 |
+
|
321 |
+
**Notes:**
|
322 |
+
* The `.obj` can be integrated into Blender/Maya/3DS-MAX and rendered using them.
|
323 |
+
* This script is running [SMPLify](https://smplify.is.tue.mpg.de/) and needs GPU as well (can be specified with the `--device` flag).
|
324 |
+
* **Important** - Do not change the original `.mp4` path before running the script.
|
325 |
+
|
326 |
+
**Notes for 3d makers:**
|
327 |
+
* You have two ways to animate the sequence:
|
328 |
+
1. Use the [SMPL add-on](https://smpl.is.tue.mpg.de/index.html) and the theta parameters saved to `sample##_rep##_smpl_params.npy` (we always use beta=0 and the gender-neutral model).
|
329 |
+
1. A more straightforward way is using the mesh data itself. All meshes have the same topology (SMPL), so you just need to keyframe vertex locations.
|
330 |
+
Since the OBJs are not preserving vertices order, we also save this data to the `sample##_rep##_smpl_params.npy` file for your convenience.
|
331 |
+
|
332 |
+
## Motion Editing
|
333 |
+
|
334 |
+
* This feature is available for text-to-motion datasets (HumanML3D and KIT).
|
335 |
+
* In order to use it, you need to acquire the full data (not just the texts).
|
336 |
+
* We support the two modes presented in the paper: `in_between` and `upper_body`.
|
337 |
+
|
338 |
+
### Unconditioned editing
|
339 |
+
|
340 |
+
```shell
|
341 |
+
python -m sample.edit --model_path ./save/humanml_trans_enc_512/model000200000.pt --edit_mode in_between
|
342 |
+
```
|
343 |
+
|
344 |
+
**You may also define:**
|
345 |
+
* `--num_samples` (default is 10) / `--num_repetitions` (default is 3).
|
346 |
+
* `--device` id.
|
347 |
+
* `--seed` to sample different prompts.
|
348 |
+
* `--edit_mode upper_body` For upper body editing (lower body is fixed).
|
349 |
+
|
350 |
+
|
351 |
+
The output will look like this (blue frames are from the input motion; orange were generated by the model):
|
352 |
+
|
353 |
+
![example](assets/in_between_edit.gif)
|
354 |
+
|
355 |
+
* As in *Motion Synthesis*, you may follow the **Render SMPL mesh** section to obtain meshes for your edited motions.
|
356 |
+
|
357 |
+
### Text conditioned editing
|
358 |
+
|
359 |
+
Just add the text conditioning using `--text_condition`. For example:
|
360 |
+
|
361 |
+
```shell
|
362 |
+
python -m sample.edit --model_path ./save/humanml_trans_enc_512/model000200000.pt --edit_mode upper_body --text_condition "A person throws a ball"
|
363 |
+
```
|
364 |
+
|
365 |
+
The output will look like this (blue joints are from the input motion; orange were generated by the model):
|
366 |
+
|
367 |
+
![example](assets/upper_body_edit.gif)
|
368 |
+
|
369 |
+
## Train your own MDM
|
370 |
+
|
371 |
+
<details>
|
372 |
+
<summary><b>Text to Motion</b></summary>
|
373 |
+
|
374 |
+
**HumanML3D**
|
375 |
+
```shell
|
376 |
+
python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset humanml
|
377 |
+
```
|
378 |
+
|
379 |
+
**KIT**
|
380 |
+
```shell
|
381 |
+
python -m train.train_mdm --save_dir save/my_kit_trans_enc_512 --dataset kit
|
382 |
+
```
|
383 |
+
</details>
|
384 |
+
<details>
|
385 |
+
<summary><b>Action to Motion</b></summary>
|
386 |
+
|
387 |
+
```shell
|
388 |
+
python -m train.train_mdm --save_dir save/my_name --dataset {humanact12,uestc} --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1
|
389 |
+
```
|
390 |
+
</details>
|
391 |
+
|
392 |
+
<details>
|
393 |
+
<summary><b>Unconstrained</b></summary>
|
394 |
+
|
395 |
+
```shell
|
396 |
+
python -m train.train_mdm --save_dir save/my_name --dataset humanact12 --cond_mask_prob 0 --lambda_rcxyz 1 --lambda_vel 1 --lambda_fc 1 --unconstrained
|
397 |
+
```
|
398 |
+
</details>
|
399 |
+
|
400 |
+
* Use `--device` to define GPU id.
|
401 |
+
* Use `--arch` to choose one of the architectures reported in the paper `{trans_enc, trans_dec, gru}` (`trans_enc` is default).
|
402 |
+
* Add `--train_platform_type {ClearmlPlatform, TensorboardPlatform}` to track results with either [ClearML](https://clear.ml/) or [Tensorboard](https://www.tensorflow.org/tensorboard).
|
403 |
+
* Add `--eval_during_training` to run a short (90 minutes) evaluation for each saved checkpoint.
|
404 |
+
This will slow down training but will give you better monitoring.
|
405 |
+
|
406 |
+
## Evaluate
|
407 |
+
|
408 |
+
<details>
|
409 |
+
<summary><b>Text to Motion</b></summary>
|
410 |
+
|
411 |
+
* Takes about 20 hours (on a single GPU)
|
412 |
+
* The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.
|
413 |
+
|
414 |
+
**HumanML3D**
|
415 |
+
```shell
|
416 |
+
python -m eval.eval_humanml --model_path ./save/humanml_trans_enc_512/model000475000.pt
|
417 |
+
```
|
418 |
+
|
419 |
+
**KIT**
|
420 |
+
```shell
|
421 |
+
python -m eval.eval_humanml --model_path ./save/kit_trans_enc_512/model000400000.pt
|
422 |
+
```
|
423 |
+
</details>
|
424 |
+
|
425 |
+
<details>
|
426 |
+
<summary><b>Action to Motion</b></summary>
|
427 |
+
|
428 |
+
* Takes about 7 hours for UESTC and 2 hours for HumanAct12 (on a single GPU)
|
429 |
+
* The output of this script for the pre-trained models (as was reported in the paper) is provided in the checkpoints zip file.
|
430 |
+
|
431 |
+
```shell
|
432 |
+
python -m eval.eval_humanact12_uestc --model <path-to-model-ckpt> --eval_mode full
|
433 |
+
```
|
434 |
+
where `path-to-model-ckpt` can be a path to any of the pretrained action-to-motion models listed above, or to a checkpoint trained by the user.
|
435 |
+
|
436 |
+
</details>
|
437 |
+
|
438 |
+
|
439 |
+
<details>
|
440 |
+
<summary><b>Unconstrained</b></summary>
|
441 |
+
|
442 |
+
* Takes about 3 hours (on a single GPU)
|
443 |
+
|
444 |
+
```shell
|
445 |
+
python -m eval.eval_humanact12_uestc --model ./save/unconstrained/model000450000.pt --eval_mode full
|
446 |
+
```
|
447 |
+
|
448 |
+
Precision and recall are not computed to save computing time. If you wish to compute them, edit the file eval/a2m/gru_eval.py and change the string `fast=True` to `fast=False`.
|
449 |
+
|
450 |
+
</details>
|
451 |
+
|
452 |
+
## Acknowledgments
|
453 |
+
|
454 |
+
This code is standing on the shoulders of giants. We want to thank the following contributors
|
455 |
+
that our code is based on:
|
456 |
+
|
457 |
+
[guided-diffusion](https://github.com/openai/guided-diffusion), [MotionCLIP](https://github.com/GuyTevet/MotionCLIP), [text-to-motion](https://github.com/EricGuo5513/text-to-motion), [actor](https://github.com/Mathux/ACTOR), [joints2smpl](https://github.com/wangsen1312/joints2smpl), [MoDi](https://github.com/sigal-raab/MoDi).
|
458 |
+
|
459 |
+
## License
|
460 |
+
This code is distributed under an [MIT LICENSE](LICENSE).
|
461 |
+
|
462 |
+
Note that our code depends on other libraries, including CLIP, SMPL, SMPL-X, PyTorch3D, and uses datasets that each have their own respective licenses that must also be followed.
|
app.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
import tempfile
|
9 |
+
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
# from gradio_inter.predict_from_file import predict_from_file
|
13 |
+
from gradio_inter.create_bash_file import create_bash_file
|
14 |
+
|
15 |
+
def create_temp_file(path: str) -> str:
|
16 |
+
temp_dir = tempfile.gettempdir()
|
17 |
+
temp_folder = os.path.join(temp_dir, "denoising")
|
18 |
+
os.makedirs(temp_folder, exist_ok=True)
|
19 |
+
# Clean up directory
|
20 |
+
# for i in os.listdir(temp_folder):
|
21 |
+
# print("Removing", i)
|
22 |
+
# os.remove(os.path.join(temp_folder, i))
|
23 |
+
|
24 |
+
temp_path = os.path.join(temp_folder, path.split("/")[-1])
|
25 |
+
shutil.copy2(path, temp_path)
|
26 |
+
return temp_path
|
27 |
+
|
28 |
+
# from gradio_inter.predict import predict_from_data
|
29 |
+
# from gradio_inter.predi
|
30 |
+
|
31 |
+
def transpose(matrix):
|
32 |
+
return matrix.T
|
33 |
+
|
34 |
+
|
35 |
+
def predict(file_path: str):
|
36 |
+
temp_file_path = create_temp_file(file_path)
|
37 |
+
# predict_from_file
|
38 |
+
temp_bash_file = create_bash_file(temp_file_path)
|
39 |
+
|
40 |
+
os.system(f"bash {temp_bash_file}")
|
41 |
+
|
42 |
+
|
43 |
+
demo = gr.Interface(
|
44 |
+
predict,
|
45 |
+
# gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3),
|
46 |
+
gr.File(type="filepath"),
|
47 |
+
"dict",
|
48 |
+
cache_examples=False
|
49 |
+
)
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
demo.launch()
|
cog.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
cuda: "11.3"
|
4 |
+
python_version: 3.8
|
5 |
+
system_packages:
|
6 |
+
- libgl1-mesa-glx
|
7 |
+
- libglib2.0-0
|
8 |
+
|
9 |
+
python_packages:
|
10 |
+
- imageio==2.22.2
|
11 |
+
- matplotlib==3.1.3
|
12 |
+
- spacy==3.3.1
|
13 |
+
- smplx==0.1.28
|
14 |
+
- chumpy==0.70
|
15 |
+
- blis==0.7.8
|
16 |
+
- click==8.1.3
|
17 |
+
- confection==0.0.2
|
18 |
+
- ftfy==6.1.1
|
19 |
+
- importlib-metadata==5.0.0
|
20 |
+
- lxml==4.9.1
|
21 |
+
- murmurhash==1.0.8
|
22 |
+
- preshed==3.0.7
|
23 |
+
- pycryptodomex==3.15.0
|
24 |
+
- regex==2022.9.13
|
25 |
+
- srsly==2.4.4
|
26 |
+
- thinc==8.0.17
|
27 |
+
- typing-extensions==4.1.1
|
28 |
+
- urllib3==1.26.12
|
29 |
+
- wasabi==0.10.1
|
30 |
+
- wcwidth==0.2.5
|
31 |
+
|
32 |
+
run:
|
33 |
+
- apt update -y && apt-get install ffmpeg -y
|
34 |
+
# - python -m spacy download en_core_web_sm
|
35 |
+
- git clone https://github.com/openai/CLIP.git sub_modules/CLIP
|
36 |
+
- pip install -e sub_modules/CLIP
|
37 |
+
|
38 |
+
predict: "sample/predict.py:Predictor"
|
common/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
common/___init___.py
ADDED
File without changes
|
common/abstract_pl.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
import torch.optim as optim
|
7 |
+
|
8 |
+
import common.pl_utils as pl_utils
|
9 |
+
from common.comet_utils import log_dict
|
10 |
+
from common.pl_utils import avg_losses_cpu, push_checkpoint_metric
|
11 |
+
from common.xdict import xdict
|
12 |
+
|
13 |
+
|
14 |
+
class AbstractPL(pl.LightningModule):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
args,
|
18 |
+
push_images_fn,
|
19 |
+
tracked_metric,
|
20 |
+
metric_init_val,
|
21 |
+
high_loss_val,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.experiment = args.experiment
|
25 |
+
self.args = args
|
26 |
+
self.tracked_metric = tracked_metric
|
27 |
+
self.metric_init_val = metric_init_val
|
28 |
+
|
29 |
+
self.started_training = False
|
30 |
+
self.loss_dict_vec = []
|
31 |
+
self.push_images = push_images_fn
|
32 |
+
self.vis_train_batches = []
|
33 |
+
self.vis_val_batches = []
|
34 |
+
self.high_loss_val = high_loss_val
|
35 |
+
self.max_vis_examples = 20
|
36 |
+
self.val_step_outputs = []
|
37 |
+
self.test_step_outputs = []
|
38 |
+
|
39 |
+
def set_training_flags(self):
|
40 |
+
self.started_training = True
|
41 |
+
|
42 |
+
def load_from_ckpt(self, ckpt_path):
|
43 |
+
sd = torch.load(ckpt_path)["state_dict"]
|
44 |
+
print(self.load_state_dict(sd))
|
45 |
+
|
46 |
+
def training_step(self, batch, batch_idx):
|
47 |
+
self.set_training_flags()
|
48 |
+
if len(self.vis_train_batches) < self.num_vis_train:
|
49 |
+
self.vis_train_batches.append(batch)
|
50 |
+
inputs, targets, meta_info = batch
|
51 |
+
|
52 |
+
out = self.forward(inputs, targets, meta_info, "train")
|
53 |
+
loss = out["loss"]
|
54 |
+
|
55 |
+
loss = {k: loss[k].mean().view(-1) for k in loss}
|
56 |
+
total_loss = sum(loss[k] for k in loss)
|
57 |
+
|
58 |
+
loss_dict = {"total_loss": total_loss, "loss": total_loss}
|
59 |
+
loss_dict.update(loss)
|
60 |
+
|
61 |
+
for k, v in loss_dict.items():
|
62 |
+
if k != "loss":
|
63 |
+
loss_dict[k] = v.detach()
|
64 |
+
|
65 |
+
log_every = self.args.log_every
|
66 |
+
self.loss_dict_vec.append(loss_dict)
|
67 |
+
self.loss_dict_vec = self.loss_dict_vec[len(self.loss_dict_vec) - log_every :]
|
68 |
+
if batch_idx % log_every == 0 and batch_idx != 0:
|
69 |
+
running_loss_dict = avg_losses_cpu(self.loss_dict_vec)
|
70 |
+
running_loss_dict = xdict(running_loss_dict).postfix("__train")
|
71 |
+
log_dict(self.experiment, running_loss_dict, step=self.global_step)
|
72 |
+
return loss_dict
|
73 |
+
|
74 |
+
def on_train_epoch_end(self):
|
75 |
+
self.experiment.log_epoch_end(self.current_epoch)
|
76 |
+
|
77 |
+
def validation_step(self, batch, batch_idx):
|
78 |
+
if len(self.vis_val_batches) < self.num_vis_val:
|
79 |
+
self.vis_val_batches.append(batch)
|
80 |
+
out = self.inference_step(batch, batch_idx)
|
81 |
+
self.val_step_outputs.append(out)
|
82 |
+
return out
|
83 |
+
|
84 |
+
def on_validation_epoch_end(self):
|
85 |
+
outputs = self.val_step_outputs
|
86 |
+
outputs = self.inference_epoch_end(outputs, postfix="__val")
|
87 |
+
self.log("loss__val", outputs["loss__val"])
|
88 |
+
self.val_step_outputs.clear() # free memory
|
89 |
+
return outputs
|
90 |
+
|
91 |
+
def inference_step(self, batch, batch_idx):
|
92 |
+
if self.training:
|
93 |
+
self.eval()
|
94 |
+
with torch.no_grad():
|
95 |
+
inputs, targets, meta_info = batch
|
96 |
+
out, loss = self.forward(inputs, targets, meta_info, "test")
|
97 |
+
return {"out_dict": out, "loss": loss}
|
98 |
+
|
99 |
+
def inference_epoch_end(self, out_list, postfix):
|
100 |
+
if not self.started_training:
|
101 |
+
self.started_training = True
|
102 |
+
result = push_checkpoint_metric(self.tracked_metric, self.metric_init_val)
|
103 |
+
return result
|
104 |
+
|
105 |
+
# unpack
|
106 |
+
outputs, loss_dict = pl_utils.reform_outputs(out_list)
|
107 |
+
|
108 |
+
if "test" in postfix:
|
109 |
+
per_img_metric_dict = {}
|
110 |
+
for k, v in outputs.items():
|
111 |
+
if "metric." in k:
|
112 |
+
per_img_metric_dict[k] = np.array(v)
|
113 |
+
|
114 |
+
metric_dict = {}
|
115 |
+
for k, v in outputs.items():
|
116 |
+
if "metric." in k:
|
117 |
+
metric_dict[k] = np.nanmean(np.array(v))
|
118 |
+
|
119 |
+
loss_metric_dict = {}
|
120 |
+
loss_metric_dict.update(metric_dict)
|
121 |
+
loss_metric_dict.update(loss_dict)
|
122 |
+
loss_metric_dict = xdict(loss_metric_dict).postfix(postfix)
|
123 |
+
|
124 |
+
log_dict(
|
125 |
+
self.experiment,
|
126 |
+
loss_metric_dict,
|
127 |
+
step=self.global_step,
|
128 |
+
)
|
129 |
+
|
130 |
+
if self.args.interface_p is None and "test" not in postfix:
|
131 |
+
result = push_checkpoint_metric(
|
132 |
+
self.tracked_metric, loss_metric_dict[self.tracked_metric]
|
133 |
+
)
|
134 |
+
self.log(self.tracked_metric, result[self.tracked_metric])
|
135 |
+
|
136 |
+
if not self.args.no_vis:
|
137 |
+
print("Rendering train images")
|
138 |
+
self.visualize_batches(self.vis_train_batches, "_train", False)
|
139 |
+
print("Rendering val images")
|
140 |
+
self.visualize_batches(self.vis_val_batches, "_val", False)
|
141 |
+
|
142 |
+
if "test" in postfix:
|
143 |
+
return (
|
144 |
+
outputs,
|
145 |
+
{"per_img_metric_dict": per_img_metric_dict},
|
146 |
+
metric_dict,
|
147 |
+
)
|
148 |
+
return loss_metric_dict
|
149 |
+
|
150 |
+
def configure_optimizers(self):
|
151 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
|
152 |
+
scheduler = optim.lr_scheduler.MultiStepLR(
|
153 |
+
optimizer, self.args.lr_dec_epoch, gamma=self.args.lr_decay, verbose=True
|
154 |
+
)
|
155 |
+
return [optimizer], [scheduler]
|
156 |
+
|
157 |
+
def visualize_batches(self, batches, postfix, no_tqdm=True):
|
158 |
+
im_list = []
|
159 |
+
if self.training:
|
160 |
+
self.eval()
|
161 |
+
|
162 |
+
tic = time.time()
|
163 |
+
for batch_idx, batch in enumerate(batches):
|
164 |
+
with torch.no_grad():
|
165 |
+
inputs, targets, meta_info = batch
|
166 |
+
vis_dict = self.forward(inputs, targets, meta_info, "vis")
|
167 |
+
for vis_fn in self.vis_fns:
|
168 |
+
curr_im_list = vis_fn(
|
169 |
+
vis_dict,
|
170 |
+
self.max_vis_examples,
|
171 |
+
self.renderer,
|
172 |
+
postfix=postfix,
|
173 |
+
no_tqdm=no_tqdm,
|
174 |
+
)
|
175 |
+
im_list += curr_im_list
|
176 |
+
print("Rendering: %d/%d" % (batch_idx + 1, len(batches)))
|
177 |
+
|
178 |
+
self.push_images(self.experiment, im_list, self.global_step)
|
179 |
+
print("Done rendering (%.1fs)" % (time.time() - tic))
|
180 |
+
return im_list
|
common/args_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from loguru import logger
|
2 |
+
|
3 |
+
|
4 |
+
def set_default_params(args, default_args):
|
5 |
+
# if a val is not set on argparse, use default val
|
6 |
+
# else, use the one in the argparse
|
7 |
+
custom_dict = {}
|
8 |
+
for key, val in args.items():
|
9 |
+
if val is None:
|
10 |
+
args[key] = default_args[key]
|
11 |
+
else:
|
12 |
+
custom_dict[key] = val
|
13 |
+
|
14 |
+
logger.info(f"Using custom values: {custom_dict}")
|
15 |
+
return args
|
common/body_models.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from smplx import MANO
|
6 |
+
|
7 |
+
from common.mesh import Mesh
|
8 |
+
|
9 |
+
|
10 |
+
class MANODecimator:
|
11 |
+
def __init__(self):
|
12 |
+
data = np.load(
|
13 |
+
"./data/arctic_data/data/meta/mano_decimator_195.npy", allow_pickle=True
|
14 |
+
).item()
|
15 |
+
mydata = {}
|
16 |
+
for key, val in data.items():
|
17 |
+
# only consider decimation matrix so far
|
18 |
+
if "D" in key:
|
19 |
+
mydata[key] = torch.FloatTensor(val)
|
20 |
+
self.data = mydata
|
21 |
+
|
22 |
+
def downsample(self, verts, is_right):
|
23 |
+
dev = verts.device
|
24 |
+
flag = "right" if is_right else "left"
|
25 |
+
if self.data[f"D_{flag}"].device != dev:
|
26 |
+
self.data[f"D_{flag}"] = self.data[f"D_{flag}"].to(dev)
|
27 |
+
D = self.data[f"D_{flag}"]
|
28 |
+
batch_size = verts.shape[0]
|
29 |
+
D_batch = D[None, :, :].repeat(batch_size, 1, 1)
|
30 |
+
verts_sub = torch.bmm(D_batch, verts)
|
31 |
+
return verts_sub
|
32 |
+
|
33 |
+
|
34 |
+
MODEL_DIR = "./data/body_models/mano"
|
35 |
+
|
36 |
+
SEAL_FACES_R = [
|
37 |
+
[120, 108, 778],
|
38 |
+
[108, 79, 778],
|
39 |
+
[79, 78, 778],
|
40 |
+
[78, 121, 778],
|
41 |
+
[121, 214, 778],
|
42 |
+
[214, 215, 778],
|
43 |
+
[215, 279, 778],
|
44 |
+
[279, 239, 778],
|
45 |
+
[239, 234, 778],
|
46 |
+
[234, 92, 778],
|
47 |
+
[92, 38, 778],
|
48 |
+
[38, 122, 778],
|
49 |
+
[122, 118, 778],
|
50 |
+
[118, 117, 778],
|
51 |
+
[117, 119, 778],
|
52 |
+
[119, 120, 778],
|
53 |
+
]
|
54 |
+
|
55 |
+
# vertex ids around the ring of the wrist
|
56 |
+
CIRCLE_V_ID = np.array(
|
57 |
+
[108, 79, 78, 121, 214, 215, 279, 239, 234, 92, 38, 122, 118, 117, 119, 120],
|
58 |
+
dtype=np.int64,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def seal_mano_mesh(v3d, faces, is_rhand):
|
63 |
+
# v3d: B, 778, 3
|
64 |
+
# faces: 1538, 3
|
65 |
+
# output: v3d(B, 779, 3); faces (1554, 3)
|
66 |
+
|
67 |
+
seal_faces = torch.LongTensor(np.array(SEAL_FACES_R)).to(faces.device)
|
68 |
+
if not is_rhand:
|
69 |
+
# left hand
|
70 |
+
seal_faces = seal_faces[:, np.array([1, 0, 2])] # invert face normal
|
71 |
+
centers = v3d[:, CIRCLE_V_ID].mean(dim=1)[:, None, :]
|
72 |
+
sealed_vertices = torch.cat((v3d, centers), dim=1)
|
73 |
+
faces = torch.cat((faces, seal_faces), dim=0)
|
74 |
+
return sealed_vertices, faces
|
75 |
+
|
76 |
+
|
77 |
+
def build_layers(device=None):
|
78 |
+
from common.object_tensors import ObjectTensors
|
79 |
+
|
80 |
+
layers = {
|
81 |
+
"right": build_mano_aa(True),
|
82 |
+
"left": build_mano_aa(False),
|
83 |
+
"object_tensors": ObjectTensors(),
|
84 |
+
}
|
85 |
+
|
86 |
+
if device is not None:
|
87 |
+
layers["right"] = layers["right"].to(device)
|
88 |
+
layers["left"] = layers["left"].to(device)
|
89 |
+
layers["object_tensors"].to(device)
|
90 |
+
return layers
|
91 |
+
|
92 |
+
|
93 |
+
MANO_MODEL_DIR = "./data/body_models/mano"
|
94 |
+
SMPLX_MODEL_P = {
|
95 |
+
"male": "./data/body_models/smplx/SMPLX_MALE.npz",
|
96 |
+
"female": "./data/body_models/smplx/SMPLX_FEMALE.npz",
|
97 |
+
"neutral": "./data/body_models/smplx/SMPLX_NEUTRAL.npz",
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
def build_smplx(batch_size, gender, vtemplate):
|
102 |
+
import smplx
|
103 |
+
|
104 |
+
subj_m = smplx.create(
|
105 |
+
model_path=SMPLX_MODEL_P[gender],
|
106 |
+
model_type="smplx",
|
107 |
+
gender=gender,
|
108 |
+
num_pca_comps=45,
|
109 |
+
v_template=vtemplate,
|
110 |
+
flat_hand_mean=True,
|
111 |
+
use_pca=False,
|
112 |
+
batch_size=batch_size,
|
113 |
+
# batch_size=320,
|
114 |
+
)
|
115 |
+
return subj_m
|
116 |
+
|
117 |
+
|
118 |
+
def build_subject_smplx(batch_size, subject_id):
|
119 |
+
with open("./data/arctic_data/data/meta/misc.json", "r") as f:
|
120 |
+
misc = json.load(f)
|
121 |
+
vtemplate_p = f"./data/arctic_data/data/meta/subject_vtemplates/{subject_id}.obj"
|
122 |
+
mesh = Mesh(filename=vtemplate_p)
|
123 |
+
vtemplate = mesh.v
|
124 |
+
gender = misc[subject_id]["gender"]
|
125 |
+
return build_smplx(batch_size, gender, vtemplate)
|
126 |
+
|
127 |
+
|
128 |
+
def build_mano_aa(is_rhand, create_transl=False, flat_hand=False):
|
129 |
+
return MANO(
|
130 |
+
MODEL_DIR,
|
131 |
+
create_transl=create_transl,
|
132 |
+
use_pca=False,
|
133 |
+
flat_hand_mean=flat_hand,
|
134 |
+
is_rhand=is_rhand,
|
135 |
+
)
|
136 |
+
|
137 |
+
##
|
138 |
+
def construct_layers(dev):
|
139 |
+
mano_layers = {
|
140 |
+
"right": build_mano_aa(True, create_transl=True, flat_hand=False),
|
141 |
+
"left": build_mano_aa(False, create_transl=True, flat_hand=False),
|
142 |
+
"smplx": build_smplx(1, "neutral", None),
|
143 |
+
}
|
144 |
+
for layer in mano_layers.values():
|
145 |
+
layer.to(dev)
|
146 |
+
return mano_layers
|
common/camera.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
"""
|
5 |
+
Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
|
6 |
+
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
def perspective_to_weak_perspective_torch(
|
11 |
+
perspective_camera,
|
12 |
+
focal_length,
|
13 |
+
img_res,
|
14 |
+
):
|
15 |
+
# Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
|
16 |
+
# in 3D given the bounding box size
|
17 |
+
# This camera translation can be used in a full perspective projection
|
18 |
+
# if isinstance(focal_length, torch.Tensor):
|
19 |
+
# focal_length = focal_length[:, 0]
|
20 |
+
|
21 |
+
tx = perspective_camera[:, 0]
|
22 |
+
ty = perspective_camera[:, 1]
|
23 |
+
tz = perspective_camera[:, 2]
|
24 |
+
|
25 |
+
weak_perspective_camera = torch.stack(
|
26 |
+
[2 * focal_length / (img_res * tz + 1e-9), tx, ty],
|
27 |
+
dim=-1,
|
28 |
+
)
|
29 |
+
return weak_perspective_camera
|
30 |
+
|
31 |
+
|
32 |
+
def convert_perspective_to_weak_perspective(
|
33 |
+
perspective_camera,
|
34 |
+
focal_length,
|
35 |
+
img_res,
|
36 |
+
):
|
37 |
+
# Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
|
38 |
+
# in 3D given the bounding box size
|
39 |
+
# This camera translation can be used in a full perspective projection
|
40 |
+
# if isinstance(focal_length, torch.Tensor):
|
41 |
+
# focal_length = focal_length[:, 0]
|
42 |
+
|
43 |
+
weak_perspective_camera = torch.stack(
|
44 |
+
[
|
45 |
+
2 * focal_length / (img_res * perspective_camera[:, 2] + 1e-9),
|
46 |
+
perspective_camera[:, 0],
|
47 |
+
perspective_camera[:, 1],
|
48 |
+
],
|
49 |
+
dim=-1,
|
50 |
+
)
|
51 |
+
return weak_perspective_camera
|
52 |
+
|
53 |
+
|
54 |
+
def convert_weak_perspective_to_perspective(
|
55 |
+
weak_perspective_camera, focal_length, img_res
|
56 |
+
):
|
57 |
+
# Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
|
58 |
+
# in 3D given the bounding box size
|
59 |
+
# This camera translation can be used in a full perspective projection
|
60 |
+
# if isinstance(focal_length, torch.Tensor):
|
61 |
+
# focal_length = focal_length[:, 0]
|
62 |
+
|
63 |
+
perspective_camera = torch.stack(
|
64 |
+
[
|
65 |
+
weak_perspective_camera[:, 1],
|
66 |
+
weak_perspective_camera[:, 2],
|
67 |
+
2 * focal_length / (img_res * weak_perspective_camera[:, 0] + 1e-9),
|
68 |
+
],
|
69 |
+
dim=-1,
|
70 |
+
)
|
71 |
+
return perspective_camera
|
72 |
+
|
73 |
+
|
74 |
+
def get_default_cam_t(f, img_res):
|
75 |
+
cam = torch.tensor([[5.0, 0.0, 0.0]])
|
76 |
+
return convert_weak_perspective_to_perspective(cam, f, img_res)
|
77 |
+
|
78 |
+
|
79 |
+
def estimate_translation_np(S, joints_2d, joints_conf, focal_length, img_size):
|
80 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
81 |
+
Input:
|
82 |
+
S: (25, 3) 3D joint locations
|
83 |
+
joints: (25, 3) 2D joint locations and confidence
|
84 |
+
Returns:
|
85 |
+
(3,) camera translation vector
|
86 |
+
"""
|
87 |
+
num_joints = S.shape[0]
|
88 |
+
# focal length
|
89 |
+
|
90 |
+
f = np.array([focal_length[0], focal_length[1]])
|
91 |
+
# optical center
|
92 |
+
center = np.array([img_size[1] / 2.0, img_size[0] / 2.0])
|
93 |
+
|
94 |
+
# transformations
|
95 |
+
Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
|
96 |
+
XY = np.reshape(S[:, 0:2], -1)
|
97 |
+
O = np.tile(center, num_joints)
|
98 |
+
F = np.tile(f, num_joints)
|
99 |
+
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
|
100 |
+
|
101 |
+
# least squares
|
102 |
+
Q = np.array(
|
103 |
+
[
|
104 |
+
F * np.tile(np.array([1, 0]), num_joints),
|
105 |
+
F * np.tile(np.array([0, 1]), num_joints),
|
106 |
+
O - np.reshape(joints_2d, -1),
|
107 |
+
]
|
108 |
+
).T
|
109 |
+
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
|
110 |
+
|
111 |
+
# weighted least squares
|
112 |
+
W = np.diagflat(weight2)
|
113 |
+
Q = np.dot(W, Q)
|
114 |
+
c = np.dot(W, c)
|
115 |
+
|
116 |
+
# square matrix
|
117 |
+
A = np.dot(Q.T, Q)
|
118 |
+
b = np.dot(Q.T, c)
|
119 |
+
|
120 |
+
# solution
|
121 |
+
trans = np.linalg.solve(A, b)
|
122 |
+
|
123 |
+
return trans
|
124 |
+
|
125 |
+
|
126 |
+
def estimate_translation(
|
127 |
+
S,
|
128 |
+
joints_2d,
|
129 |
+
focal_length,
|
130 |
+
img_size,
|
131 |
+
use_all_joints=False,
|
132 |
+
rotation=None,
|
133 |
+
pad_2d=False,
|
134 |
+
):
|
135 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
136 |
+
Input:
|
137 |
+
S: (B, 49, 3) 3D joint locations
|
138 |
+
joints: (B, 49, 3) 2D joint locations and confidence
|
139 |
+
Returns:
|
140 |
+
(B, 3) camera translation vectors
|
141 |
+
"""
|
142 |
+
if pad_2d:
|
143 |
+
batch, num_pts = joints_2d.shape[:2]
|
144 |
+
joints_2d_pad = torch.ones((batch, num_pts, 3))
|
145 |
+
joints_2d_pad[:, :, :2] = joints_2d
|
146 |
+
joints_2d_pad = joints_2d_pad.to(joints_2d.device)
|
147 |
+
joints_2d = joints_2d_pad
|
148 |
+
|
149 |
+
device = S.device
|
150 |
+
|
151 |
+
if rotation is not None:
|
152 |
+
S = torch.einsum("bij,bkj->bki", rotation, S)
|
153 |
+
|
154 |
+
# Use only joints 25:49 (GT joints)
|
155 |
+
if use_all_joints:
|
156 |
+
S = S.cpu().numpy()
|
157 |
+
joints_2d = joints_2d.cpu().numpy()
|
158 |
+
else:
|
159 |
+
S = S[:, 25:, :].cpu().numpy()
|
160 |
+
joints_2d = joints_2d[:, 25:, :].cpu().numpy()
|
161 |
+
|
162 |
+
joints_conf = joints_2d[:, :, -1]
|
163 |
+
joints_2d = joints_2d[:, :, :-1]
|
164 |
+
trans = np.zeros((S.shape[0], 3), dtype=np.float32)
|
165 |
+
# Find the translation for each example in the batch
|
166 |
+
for i in range(S.shape[0]):
|
167 |
+
S_i = S[i]
|
168 |
+
joints_i = joints_2d[i]
|
169 |
+
conf_i = joints_conf[i]
|
170 |
+
trans[i] = estimate_translation_np(
|
171 |
+
S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size
|
172 |
+
)
|
173 |
+
return torch.from_numpy(trans).to(device)
|
174 |
+
|
175 |
+
|
176 |
+
def estimate_translation_cam(
|
177 |
+
S, joints_2d, focal_length, img_size, use_all_joints=False, rotation=None
|
178 |
+
):
|
179 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
180 |
+
Input:
|
181 |
+
S: (B, 49, 3) 3D joint locations
|
182 |
+
joints: (B, 49, 3) 2D joint locations and confidence
|
183 |
+
Returns:
|
184 |
+
(B, 3) camera translation vectors
|
185 |
+
"""
|
186 |
+
|
187 |
+
def estimate_translation_np(S, joints_2d, joints_conf, focal_length, img_size):
|
188 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
189 |
+
Input:
|
190 |
+
S: (25, 3) 3D joint locations
|
191 |
+
joints: (25, 3) 2D joint locations and confidence
|
192 |
+
Returns:
|
193 |
+
(3,) camera translation vector
|
194 |
+
"""
|
195 |
+
|
196 |
+
num_joints = S.shape[0]
|
197 |
+
# focal length
|
198 |
+
f = np.array([focal_length[0], focal_length[1]])
|
199 |
+
# optical center
|
200 |
+
center = np.array([img_size[0] / 2.0, img_size[1] / 2.0])
|
201 |
+
|
202 |
+
# transformations
|
203 |
+
Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
|
204 |
+
XY = np.reshape(S[:, 0:2], -1)
|
205 |
+
O = np.tile(center, num_joints)
|
206 |
+
F = np.tile(f, num_joints)
|
207 |
+
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
|
208 |
+
|
209 |
+
# least squares
|
210 |
+
Q = np.array(
|
211 |
+
[
|
212 |
+
F * np.tile(np.array([1, 0]), num_joints),
|
213 |
+
F * np.tile(np.array([0, 1]), num_joints),
|
214 |
+
O - np.reshape(joints_2d, -1),
|
215 |
+
]
|
216 |
+
).T
|
217 |
+
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
|
218 |
+
|
219 |
+
# weighted least squares
|
220 |
+
W = np.diagflat(weight2)
|
221 |
+
Q = np.dot(W, Q)
|
222 |
+
c = np.dot(W, c)
|
223 |
+
|
224 |
+
# square matrix
|
225 |
+
A = np.dot(Q.T, Q)
|
226 |
+
b = np.dot(Q.T, c)
|
227 |
+
|
228 |
+
# solution
|
229 |
+
trans = np.linalg.solve(A, b)
|
230 |
+
|
231 |
+
return trans
|
232 |
+
|
233 |
+
device = S.device
|
234 |
+
|
235 |
+
if rotation is not None:
|
236 |
+
S = torch.einsum("bij,bkj->bki", rotation, S)
|
237 |
+
|
238 |
+
# Use only joints 25:49 (GT joints)
|
239 |
+
if use_all_joints:
|
240 |
+
S = S.cpu().numpy()
|
241 |
+
joints_2d = joints_2d.cpu().numpy()
|
242 |
+
else:
|
243 |
+
S = S[:, 25:, :].cpu().numpy()
|
244 |
+
joints_2d = joints_2d[:, 25:, :].cpu().numpy()
|
245 |
+
|
246 |
+
joints_conf = joints_2d[:, :, -1]
|
247 |
+
joints_2d = joints_2d[:, :, :-1]
|
248 |
+
trans = np.zeros((S.shape[0], 3), dtype=np.float32)
|
249 |
+
# Find the translation for each example in the batch
|
250 |
+
for i in range(S.shape[0]):
|
251 |
+
S_i = S[i]
|
252 |
+
joints_i = joints_2d[i]
|
253 |
+
conf_i = joints_conf[i]
|
254 |
+
trans[i] = estimate_translation_np(
|
255 |
+
S_i, joints_i, conf_i, focal_length=focal_length, img_size=img_size
|
256 |
+
)
|
257 |
+
return torch.from_numpy(trans).to(device)
|
258 |
+
|
259 |
+
|
260 |
+
def get_coord_maps(size=56):
|
261 |
+
xx_ones = torch.ones([1, size], dtype=torch.int32)
|
262 |
+
xx_ones = xx_ones.unsqueeze(-1)
|
263 |
+
|
264 |
+
xx_range = torch.arange(size, dtype=torch.int32).unsqueeze(0)
|
265 |
+
xx_range = xx_range.unsqueeze(1)
|
266 |
+
|
267 |
+
xx_channel = torch.matmul(xx_ones, xx_range)
|
268 |
+
xx_channel = xx_channel.unsqueeze(-1)
|
269 |
+
|
270 |
+
yy_ones = torch.ones([1, size], dtype=torch.int32)
|
271 |
+
yy_ones = yy_ones.unsqueeze(1)
|
272 |
+
|
273 |
+
yy_range = torch.arange(size, dtype=torch.int32).unsqueeze(0)
|
274 |
+
yy_range = yy_range.unsqueeze(-1)
|
275 |
+
|
276 |
+
yy_channel = torch.matmul(yy_range, yy_ones)
|
277 |
+
yy_channel = yy_channel.unsqueeze(-1)
|
278 |
+
|
279 |
+
xx_channel = xx_channel.permute(0, 3, 1, 2)
|
280 |
+
yy_channel = yy_channel.permute(0, 3, 1, 2)
|
281 |
+
|
282 |
+
xx_channel = xx_channel.float() / (size - 1)
|
283 |
+
yy_channel = yy_channel.float() / (size - 1)
|
284 |
+
|
285 |
+
xx_channel = xx_channel * 2 - 1
|
286 |
+
yy_channel = yy_channel * 2 - 1
|
287 |
+
|
288 |
+
out = torch.cat([xx_channel, yy_channel], dim=1)
|
289 |
+
return out
|
290 |
+
|
291 |
+
|
292 |
+
def look_at(eye, at=np.array([0, 0, 0]), up=np.array([0, 0, 1]), eps=1e-5):
|
293 |
+
at = at.astype(float).reshape(1, 3)
|
294 |
+
up = up.astype(float).reshape(1, 3)
|
295 |
+
|
296 |
+
eye = eye.reshape(-1, 3)
|
297 |
+
up = up.repeat(eye.shape[0] // up.shape[0], axis=0)
|
298 |
+
eps = np.array([eps]).reshape(1, 1).repeat(up.shape[0], axis=0)
|
299 |
+
|
300 |
+
z_axis = eye - at
|
301 |
+
z_axis /= np.max(np.stack([np.linalg.norm(z_axis, axis=1, keepdims=True), eps]))
|
302 |
+
|
303 |
+
x_axis = np.cross(up, z_axis)
|
304 |
+
x_axis /= np.max(np.stack([np.linalg.norm(x_axis, axis=1, keepdims=True), eps]))
|
305 |
+
|
306 |
+
y_axis = np.cross(z_axis, x_axis)
|
307 |
+
y_axis /= np.max(np.stack([np.linalg.norm(y_axis, axis=1, keepdims=True), eps]))
|
308 |
+
|
309 |
+
r_mat = np.concatenate(
|
310 |
+
(x_axis.reshape(-1, 3, 1), y_axis.reshape(-1, 3, 1), z_axis.reshape(-1, 3, 1)),
|
311 |
+
axis=2,
|
312 |
+
)
|
313 |
+
|
314 |
+
return r_mat
|
315 |
+
|
316 |
+
|
317 |
+
def to_sphere(u, v):
|
318 |
+
theta = 2 * np.pi * u
|
319 |
+
phi = np.arccos(1 - 2 * v)
|
320 |
+
cx = np.sin(phi) * np.cos(theta)
|
321 |
+
cy = np.sin(phi) * np.sin(theta)
|
322 |
+
cz = np.cos(phi)
|
323 |
+
s = np.stack([cx, cy, cz])
|
324 |
+
return s
|
325 |
+
|
326 |
+
|
327 |
+
def sample_on_sphere(range_u=(0, 1), range_v=(0, 1)):
|
328 |
+
u = np.random.uniform(*range_u)
|
329 |
+
v = np.random.uniform(*range_v)
|
330 |
+
return to_sphere(u, v)
|
331 |
+
|
332 |
+
|
333 |
+
def sample_pose_on_sphere(range_v=(0, 1), range_u=(0, 1), radius=1, up=[0, 1, 0]):
|
334 |
+
# sample location on unit sphere
|
335 |
+
loc = sample_on_sphere(range_u, range_v)
|
336 |
+
|
337 |
+
# sample radius if necessary
|
338 |
+
if isinstance(radius, tuple):
|
339 |
+
radius = np.random.uniform(*radius)
|
340 |
+
|
341 |
+
loc = loc * radius
|
342 |
+
R = look_at(loc, up=np.array(up))[0]
|
343 |
+
|
344 |
+
RT = np.concatenate([R, loc.reshape(3, 1)], axis=1)
|
345 |
+
RT = torch.Tensor(RT.astype(np.float32))
|
346 |
+
return RT
|
347 |
+
|
348 |
+
|
349 |
+
def rectify_pose(camera_r, body_aa, rotate_x=False):
|
350 |
+
body_r = batch_rodrigues(body_aa).reshape(-1, 3, 3)
|
351 |
+
|
352 |
+
if rotate_x:
|
353 |
+
rotate_x = torch.tensor([[[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]])
|
354 |
+
body_r = body_r @ rotate_x
|
355 |
+
|
356 |
+
final_r = camera_r @ body_r
|
357 |
+
body_aa = batch_rot2aa(final_r)
|
358 |
+
return body_aa
|
359 |
+
|
360 |
+
|
361 |
+
def estimate_translation_k_np(S, joints_2d, joints_conf, K):
|
362 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
363 |
+
Input:
|
364 |
+
S: (25, 3) 3D joint locations
|
365 |
+
joints: (25, 3) 2D joint locations and confidence
|
366 |
+
Returns:
|
367 |
+
(3,) camera translation vector
|
368 |
+
"""
|
369 |
+
num_joints = S.shape[0]
|
370 |
+
# focal length
|
371 |
+
|
372 |
+
focal = np.array([K[0, 0], K[1, 1]])
|
373 |
+
# optical center
|
374 |
+
center = np.array([K[0, 2], K[1, 2]])
|
375 |
+
|
376 |
+
# transformations
|
377 |
+
Z = np.reshape(np.tile(S[:, 2], (2, 1)).T, -1)
|
378 |
+
XY = np.reshape(S[:, 0:2], -1)
|
379 |
+
O = np.tile(center, num_joints)
|
380 |
+
F = np.tile(focal, num_joints)
|
381 |
+
weight2 = np.reshape(np.tile(np.sqrt(joints_conf), (2, 1)).T, -1)
|
382 |
+
|
383 |
+
# least squares
|
384 |
+
Q = np.array(
|
385 |
+
[
|
386 |
+
F * np.tile(np.array([1, 0]), num_joints),
|
387 |
+
F * np.tile(np.array([0, 1]), num_joints),
|
388 |
+
O - np.reshape(joints_2d, -1),
|
389 |
+
]
|
390 |
+
).T
|
391 |
+
c = (np.reshape(joints_2d, -1) - O) * Z - F * XY
|
392 |
+
|
393 |
+
# weighted least squares
|
394 |
+
W = np.diagflat(weight2)
|
395 |
+
Q = np.dot(W, Q)
|
396 |
+
c = np.dot(W, c)
|
397 |
+
|
398 |
+
# square matrix
|
399 |
+
A = np.dot(Q.T, Q)
|
400 |
+
b = np.dot(Q.T, c)
|
401 |
+
|
402 |
+
# solution
|
403 |
+
trans = np.linalg.solve(A, b)
|
404 |
+
|
405 |
+
return trans
|
406 |
+
|
407 |
+
|
408 |
+
def estimate_translation_k(
|
409 |
+
S,
|
410 |
+
joints_2d,
|
411 |
+
K,
|
412 |
+
use_all_joints=False,
|
413 |
+
rotation=None,
|
414 |
+
pad_2d=False,
|
415 |
+
):
|
416 |
+
"""Find camera translation that brings 3D joints S closest to 2D the corresponding joints_2d.
|
417 |
+
Input:
|
418 |
+
S: (B, 49, 3) 3D joint locations
|
419 |
+
joints: (B, 49, 3) 2D joint locations and confidence
|
420 |
+
Returns:
|
421 |
+
(B, 3) camera translation vectors
|
422 |
+
"""
|
423 |
+
if pad_2d:
|
424 |
+
batch, num_pts = joints_2d.shape[:2]
|
425 |
+
joints_2d_pad = torch.ones((batch, num_pts, 3))
|
426 |
+
joints_2d_pad[:, :, :2] = joints_2d
|
427 |
+
joints_2d_pad = joints_2d_pad.to(joints_2d.device)
|
428 |
+
joints_2d = joints_2d_pad
|
429 |
+
|
430 |
+
device = S.device
|
431 |
+
|
432 |
+
if rotation is not None:
|
433 |
+
S = torch.einsum("bij,bkj->bki", rotation, S)
|
434 |
+
|
435 |
+
# Use only joints 25:49 (GT joints)
|
436 |
+
if use_all_joints:
|
437 |
+
S = S.cpu().numpy()
|
438 |
+
joints_2d = joints_2d.cpu().numpy()
|
439 |
+
else:
|
440 |
+
S = S[:, 25:, :].cpu().numpy()
|
441 |
+
joints_2d = joints_2d[:, 25:, :].cpu().numpy()
|
442 |
+
|
443 |
+
joints_conf = joints_2d[:, :, -1]
|
444 |
+
joints_2d = joints_2d[:, :, :-1]
|
445 |
+
trans = np.zeros((S.shape[0], 3), dtype=np.float32)
|
446 |
+
# Find the translation for each example in the batch
|
447 |
+
for i in range(S.shape[0]):
|
448 |
+
S_i = S[i]
|
449 |
+
joints_i = joints_2d[i]
|
450 |
+
conf_i = joints_conf[i]
|
451 |
+
K_i = K[i]
|
452 |
+
trans[i] = estimate_translation_k_np(S_i, joints_i, conf_i, K_i)
|
453 |
+
return torch.from_numpy(trans).to(device)
|
454 |
+
|
455 |
+
|
456 |
+
def weak_perspective_to_perspective_torch(
|
457 |
+
weak_perspective_camera, focal_length, img_res, min_s
|
458 |
+
):
|
459 |
+
# Convert Weak Perspective Camera [s, tx, ty] to camera translation [tx, ty, tz]
|
460 |
+
# in 3D given the bounding box size
|
461 |
+
# This camera translation can be used in a full perspective projection
|
462 |
+
s = weak_perspective_camera[:, 0]
|
463 |
+
s = torch.clamp(s, min_s)
|
464 |
+
tx = weak_perspective_camera[:, 1]
|
465 |
+
ty = weak_perspective_camera[:, 2]
|
466 |
+
perspective_camera = torch.stack(
|
467 |
+
[
|
468 |
+
tx,
|
469 |
+
ty,
|
470 |
+
2 * focal_length / (img_res * s + 1e-9),
|
471 |
+
],
|
472 |
+
dim=-1,
|
473 |
+
)
|
474 |
+
return perspective_camera
|
common/comet_utils.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import os.path as op
|
4 |
+
import time
|
5 |
+
|
6 |
+
import comet_ml
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from loguru import logger
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from src.datasets.dataset_utils import copy_repo_arctic
|
13 |
+
|
14 |
+
# folder used for debugging
|
15 |
+
DUMMY_EXP = "xxxxxxxxx"
|
16 |
+
|
17 |
+
|
18 |
+
def add_paths(args):
|
19 |
+
exp_key = args.exp_key
|
20 |
+
args_p = f"./logs/{exp_key}/args.json"
|
21 |
+
ckpt_p = f"./logs/{exp_key}/checkpoints/last.ckpt"
|
22 |
+
if not op.exists(ckpt_p) or DUMMY_EXP in ckpt_p:
|
23 |
+
ckpt_p = ""
|
24 |
+
if args.resume_ckpt != "":
|
25 |
+
ckpt_p = args.resume_ckpt
|
26 |
+
args.ckpt_p = ckpt_p
|
27 |
+
args.log_dir = f"./logs/{exp_key}"
|
28 |
+
|
29 |
+
if args.infer_ckpt != "":
|
30 |
+
basedir = "/".join(args.infer_ckpt.split("/")[:2])
|
31 |
+
basename = op.basename(args.infer_ckpt).replace(".ckpt", ".params.pt")
|
32 |
+
args.interface_p = op.join(basedir, basename)
|
33 |
+
args.args_p = args_p
|
34 |
+
if args.cluster:
|
35 |
+
args.run_p = op.join(args.log_dir, "condor", "run.sh")
|
36 |
+
args.submit_p = op.join(args.log_dir, "condor", "submit.sub")
|
37 |
+
args.repo_p = op.join(args.log_dir, "repo")
|
38 |
+
|
39 |
+
return args
|
40 |
+
|
41 |
+
|
42 |
+
def save_args(args, save_keys):
|
43 |
+
args_save = {}
|
44 |
+
for key, val in args.items():
|
45 |
+
if key in save_keys:
|
46 |
+
args_save[key] = val
|
47 |
+
with open(args.args_p, "w") as f:
|
48 |
+
json.dump(args_save, f, indent=4)
|
49 |
+
logger.info(f"Saved args at {args.args_p}")
|
50 |
+
|
51 |
+
|
52 |
+
def create_files(args):
|
53 |
+
os.makedirs(args.log_dir, exist_ok=True)
|
54 |
+
if args.cluster:
|
55 |
+
os.makedirs(op.dirname(args.run_p), exist_ok=True)
|
56 |
+
copy_repo_arctic(args.exp_key)
|
57 |
+
|
58 |
+
|
59 |
+
def log_exp_meta(args):
|
60 |
+
tags = [args.method]
|
61 |
+
logger.info(f"Experiment tags: {tags}")
|
62 |
+
args.experiment.set_name(args.exp_key)
|
63 |
+
args.experiment.add_tags(tags)
|
64 |
+
args.experiment.log_parameters(args)
|
65 |
+
|
66 |
+
|
67 |
+
def init_experiment(args):
|
68 |
+
if args.resume_ckpt != "":
|
69 |
+
args.exp_key = args.resume_ckpt.split("/")[1]
|
70 |
+
if args.fast_dev_run:
|
71 |
+
args.exp_key = DUMMY_EXP
|
72 |
+
if args.exp_key == "":
|
73 |
+
args.exp_key = generate_exp_key()
|
74 |
+
args = add_paths(args)
|
75 |
+
if op.exists(args.args_p) and args.exp_key not in [DUMMY_EXP]:
|
76 |
+
with open(args.args_p, "r") as f:
|
77 |
+
args_disk = json.load(f)
|
78 |
+
if "comet_key" in args_disk.keys():
|
79 |
+
args.comet_key = args_disk["comet_key"]
|
80 |
+
|
81 |
+
create_files(args)
|
82 |
+
|
83 |
+
project_name = args.project
|
84 |
+
disabled = args.mute
|
85 |
+
comet_url = args["comet_key"] if "comet_key" in args.keys() else None
|
86 |
+
|
87 |
+
api_key = os.environ["COMET_API_KEY"]
|
88 |
+
workspace = os.environ["COMET_WORKSPACE"]
|
89 |
+
if not args.cluster:
|
90 |
+
if comet_url is None:
|
91 |
+
experiment = comet_ml.Experiment(
|
92 |
+
api_key=api_key,
|
93 |
+
workspace=workspace,
|
94 |
+
project_name=project_name,
|
95 |
+
disabled=disabled,
|
96 |
+
display_summary_level=0,
|
97 |
+
)
|
98 |
+
args.comet_key = experiment.get_key()
|
99 |
+
else:
|
100 |
+
experiment = comet_ml.ExistingExperiment(
|
101 |
+
previous_experiment=comet_url,
|
102 |
+
api_key=api_key,
|
103 |
+
project_name=project_name,
|
104 |
+
workspace=workspace,
|
105 |
+
disabled=disabled,
|
106 |
+
display_summary_level=0,
|
107 |
+
)
|
108 |
+
|
109 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
+
logger.add(
|
111 |
+
os.path.join(args.log_dir, "train.log"),
|
112 |
+
level="INFO",
|
113 |
+
colorize=True,
|
114 |
+
)
|
115 |
+
logger.info(torch.cuda.get_device_properties(device))
|
116 |
+
args.gpu = torch.cuda.get_device_properties(device).name
|
117 |
+
else:
|
118 |
+
experiment = None
|
119 |
+
args.experiment = experiment
|
120 |
+
return experiment, args
|
121 |
+
|
122 |
+
|
123 |
+
def log_dict(experiment, metric_dict, step, postfix=None):
|
124 |
+
if experiment is None:
|
125 |
+
return
|
126 |
+
for key, value in metric_dict.items():
|
127 |
+
if postfix is not None:
|
128 |
+
key = key + postfix
|
129 |
+
if isinstance(value, torch.Tensor) and len(value.view(-1)) == 1:
|
130 |
+
value = value.item()
|
131 |
+
|
132 |
+
if isinstance(value, (int, float, np.float32)):
|
133 |
+
experiment.log_metric(key, value, step=step)
|
134 |
+
|
135 |
+
|
136 |
+
def generate_exp_key():
|
137 |
+
import random
|
138 |
+
|
139 |
+
hash = random.getrandbits(128)
|
140 |
+
key = "%032x" % (hash)
|
141 |
+
key = key[:9]
|
142 |
+
return key
|
143 |
+
|
144 |
+
|
145 |
+
def push_images(experiment, all_im_list, global_step=None, no_tqdm=False, verbose=True):
|
146 |
+
if verbose:
|
147 |
+
print("Pushing PIL images")
|
148 |
+
tic = time.time()
|
149 |
+
iterator = all_im_list if no_tqdm else tqdm(all_im_list)
|
150 |
+
for im in iterator:
|
151 |
+
im_np = np.array(im["im"])
|
152 |
+
if "fig_name" in im.keys():
|
153 |
+
experiment.log_image(im_np, im["fig_name"], step=global_step)
|
154 |
+
else:
|
155 |
+
experiment.log_image(im_np, "unnamed", step=global_step)
|
156 |
+
if verbose:
|
157 |
+
toc = time.time()
|
158 |
+
print("Done pushing PIL images (%.1fs)" % (toc - tic))
|
common/data_utils.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains functions that are used to perform data augmentation.
|
3 |
+
"""
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
|
10 |
+
def get_transform(center, scale, res, rot=0):
|
11 |
+
"""Generate transformation matrix."""
|
12 |
+
h = 200 * scale
|
13 |
+
t = np.zeros((3, 3))
|
14 |
+
t[0, 0] = float(res[1]) / h
|
15 |
+
t[1, 1] = float(res[0]) / h
|
16 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + 0.5)
|
17 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + 0.5)
|
18 |
+
t[2, 2] = 1
|
19 |
+
if not rot == 0:
|
20 |
+
rot = -rot # To match direction of rotation from cropping
|
21 |
+
rot_mat = np.zeros((3, 3))
|
22 |
+
rot_rad = rot * np.pi / 180
|
23 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
24 |
+
rot_mat[0, :2] = [cs, -sn]
|
25 |
+
rot_mat[1, :2] = [sn, cs]
|
26 |
+
rot_mat[2, 2] = 1
|
27 |
+
# Need to rotate around center
|
28 |
+
t_mat = np.eye(3)
|
29 |
+
t_mat[0, 2] = -res[1] / 2
|
30 |
+
t_mat[1, 2] = -res[0] / 2
|
31 |
+
t_inv = t_mat.copy()
|
32 |
+
t_inv[:2, 2] *= -1
|
33 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
34 |
+
return t
|
35 |
+
|
36 |
+
|
37 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
38 |
+
"""Transform pixel location to different reference."""
|
39 |
+
t = get_transform(center, scale, res, rot=rot)
|
40 |
+
if invert:
|
41 |
+
t = np.linalg.inv(t)
|
42 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.0]).T
|
43 |
+
new_pt = np.dot(t, new_pt)
|
44 |
+
return new_pt[:2].astype(int) + 1
|
45 |
+
|
46 |
+
|
47 |
+
def rotate_2d(pt_2d, rot_rad):
|
48 |
+
x = pt_2d[0]
|
49 |
+
y = pt_2d[1]
|
50 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
51 |
+
xx = x * cs - y * sn
|
52 |
+
yy = x * sn + y * cs
|
53 |
+
return np.array([xx, yy], dtype=np.float32)
|
54 |
+
|
55 |
+
|
56 |
+
def gen_trans_from_patch_cv(
|
57 |
+
c_x, c_y, src_width, src_height, dst_width, dst_height, scale, rot, inv=False
|
58 |
+
):
|
59 |
+
# augment size with scale
|
60 |
+
src_w = src_width * scale
|
61 |
+
src_h = src_height * scale
|
62 |
+
src_center = np.array([c_x, c_y], dtype=np.float32)
|
63 |
+
|
64 |
+
# augment rotation
|
65 |
+
rot_rad = np.pi * rot / 180
|
66 |
+
src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
|
67 |
+
src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
|
68 |
+
|
69 |
+
dst_w = dst_width
|
70 |
+
dst_h = dst_height
|
71 |
+
dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
|
72 |
+
dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
|
73 |
+
dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
|
74 |
+
|
75 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
76 |
+
src[0, :] = src_center
|
77 |
+
src[1, :] = src_center + src_downdir
|
78 |
+
src[2, :] = src_center + src_rightdir
|
79 |
+
|
80 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
81 |
+
dst[0, :] = dst_center
|
82 |
+
dst[1, :] = dst_center + dst_downdir
|
83 |
+
dst[2, :] = dst_center + dst_rightdir
|
84 |
+
|
85 |
+
if inv:
|
86 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
87 |
+
else:
|
88 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
89 |
+
|
90 |
+
trans = trans.astype(np.float32)
|
91 |
+
return trans
|
92 |
+
|
93 |
+
|
94 |
+
def generate_patch_image(
|
95 |
+
cvimg,
|
96 |
+
bbox,
|
97 |
+
scale,
|
98 |
+
rot,
|
99 |
+
out_shape,
|
100 |
+
interpl_strategy,
|
101 |
+
gauss_kernel=5,
|
102 |
+
gauss_sigma=8.0,
|
103 |
+
):
|
104 |
+
img = cvimg.copy()
|
105 |
+
|
106 |
+
bb_c_x = float(bbox[0])
|
107 |
+
bb_c_y = float(bbox[1])
|
108 |
+
bb_width = float(bbox[2])
|
109 |
+
bb_height = float(bbox[3])
|
110 |
+
|
111 |
+
trans = gen_trans_from_patch_cv(
|
112 |
+
bb_c_x, bb_c_y, bb_width, bb_height, out_shape[1], out_shape[0], scale, rot
|
113 |
+
)
|
114 |
+
|
115 |
+
# anti-aliasing
|
116 |
+
blur = cv2.GaussianBlur(img, (gauss_kernel, gauss_kernel), gauss_sigma)
|
117 |
+
img_patch = cv2.warpAffine(
|
118 |
+
blur, trans, (int(out_shape[1]), int(out_shape[0])), flags=interpl_strategy
|
119 |
+
)
|
120 |
+
img_patch = img_patch.astype(np.float32)
|
121 |
+
inv_trans = gen_trans_from_patch_cv(
|
122 |
+
bb_c_x,
|
123 |
+
bb_c_y,
|
124 |
+
bb_width,
|
125 |
+
bb_height,
|
126 |
+
out_shape[1],
|
127 |
+
out_shape[0],
|
128 |
+
scale,
|
129 |
+
rot,
|
130 |
+
inv=True,
|
131 |
+
)
|
132 |
+
|
133 |
+
return img_patch, trans, inv_trans
|
134 |
+
|
135 |
+
|
136 |
+
def augm_params(is_train, flip_prob, noise_factor, rot_factor, scale_factor):
|
137 |
+
"""Get augmentation parameters."""
|
138 |
+
flip = 0 # flipping
|
139 |
+
pn = np.ones(3) # per channel pixel-noise
|
140 |
+
rot = 0 # rotation
|
141 |
+
sc = 1 # scaling
|
142 |
+
if is_train:
|
143 |
+
# We flip with probability 1/2
|
144 |
+
if np.random.uniform() <= flip_prob:
|
145 |
+
flip = 1
|
146 |
+
assert False, "Flipping not supported"
|
147 |
+
|
148 |
+
# Each channel is multiplied with a number
|
149 |
+
# in the area [1-opt.noiseFactor,1+opt.noiseFactor]
|
150 |
+
pn = np.random.uniform(1 - noise_factor, 1 + noise_factor, 3)
|
151 |
+
|
152 |
+
# The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
|
153 |
+
rot = min(
|
154 |
+
2 * rot_factor,
|
155 |
+
max(
|
156 |
+
-2 * rot_factor,
|
157 |
+
np.random.randn() * rot_factor,
|
158 |
+
),
|
159 |
+
)
|
160 |
+
|
161 |
+
# The scale is multiplied with a number
|
162 |
+
# in the area [1-scaleFactor,1+scaleFactor]
|
163 |
+
sc = min(
|
164 |
+
1 + scale_factor,
|
165 |
+
max(
|
166 |
+
1 - scale_factor,
|
167 |
+
np.random.randn() * scale_factor + 1,
|
168 |
+
),
|
169 |
+
)
|
170 |
+
# but it is zero with probability 3/5
|
171 |
+
if np.random.uniform() <= 0.6:
|
172 |
+
rot = 0
|
173 |
+
|
174 |
+
augm_dict = {}
|
175 |
+
augm_dict["flip"] = flip
|
176 |
+
augm_dict["pn"] = pn
|
177 |
+
augm_dict["rot"] = rot
|
178 |
+
augm_dict["sc"] = sc
|
179 |
+
return augm_dict
|
180 |
+
|
181 |
+
|
182 |
+
def rgb_processing(is_train, rgb_img, center, bbox_dim, augm_dict, img_res):
|
183 |
+
rot = augm_dict["rot"]
|
184 |
+
sc = augm_dict["sc"]
|
185 |
+
pn = augm_dict["pn"]
|
186 |
+
scale = sc * bbox_dim
|
187 |
+
|
188 |
+
crop_dim = int(scale * 200)
|
189 |
+
# faster cropping!!
|
190 |
+
rgb_img = generate_patch_image(
|
191 |
+
rgb_img,
|
192 |
+
[center[0], center[1], crop_dim, crop_dim],
|
193 |
+
1.0,
|
194 |
+
rot,
|
195 |
+
[img_res, img_res],
|
196 |
+
cv2.INTER_CUBIC,
|
197 |
+
)[0]
|
198 |
+
|
199 |
+
# in the rgb image we add pixel noise in a channel-wise manner
|
200 |
+
rgb_img[:, :, 0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 0] * pn[0]))
|
201 |
+
rgb_img[:, :, 1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 1] * pn[1]))
|
202 |
+
rgb_img[:, :, 2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:, :, 2] * pn[2]))
|
203 |
+
rgb_img = np.transpose(rgb_img.astype("float32"), (2, 0, 1)) / 255.0
|
204 |
+
return rgb_img
|
205 |
+
|
206 |
+
|
207 |
+
def transform_kp2d(kp2d, bbox):
|
208 |
+
# bbox: (cx, cy, scale) in the original image space
|
209 |
+
# scale is normalized
|
210 |
+
assert isinstance(kp2d, np.ndarray)
|
211 |
+
assert len(kp2d.shape) == 2
|
212 |
+
cx, cy, scale = bbox
|
213 |
+
s = 200 * scale # to px
|
214 |
+
cap_dim = 1000 # px
|
215 |
+
factor = cap_dim / (1.5 * s)
|
216 |
+
kp2d_cropped = np.copy(kp2d)
|
217 |
+
kp2d_cropped[:, 0] -= cx - 1.5 / 2 * s
|
218 |
+
kp2d_cropped[:, 1] -= cy - 1.5 / 2 * s
|
219 |
+
kp2d_cropped[:, 0] *= factor
|
220 |
+
kp2d_cropped[:, 1] *= factor
|
221 |
+
return kp2d_cropped
|
222 |
+
|
223 |
+
|
224 |
+
def j2d_processing(kp, center, bbox_dim, augm_dict, img_res):
|
225 |
+
"""Process gt 2D keypoints and apply all augmentation transforms."""
|
226 |
+
scale = augm_dict["sc"] * bbox_dim
|
227 |
+
rot = augm_dict["rot"]
|
228 |
+
|
229 |
+
nparts = kp.shape[0]
|
230 |
+
for i in range(nparts):
|
231 |
+
kp[i, 0:2] = transform(
|
232 |
+
kp[i, 0:2] + 1,
|
233 |
+
center,
|
234 |
+
scale,
|
235 |
+
[img_res, img_res],
|
236 |
+
rot=rot,
|
237 |
+
)
|
238 |
+
# convert to normalized coordinates
|
239 |
+
kp = normalize_kp2d_np(kp, img_res)
|
240 |
+
kp = kp.astype("float32")
|
241 |
+
return kp
|
242 |
+
|
243 |
+
|
244 |
+
def pose_processing(pose, augm_dict):
|
245 |
+
"""Process SMPL theta parameters and apply all augmentation transforms."""
|
246 |
+
rot = augm_dict["rot"]
|
247 |
+
# rotation or the pose parameters
|
248 |
+
pose[:3] = rot_aa(pose[:3], rot)
|
249 |
+
# flip the pose parameters
|
250 |
+
# (72),float
|
251 |
+
pose = pose.astype("float32")
|
252 |
+
return pose
|
253 |
+
|
254 |
+
|
255 |
+
def rot_aa(aa, rot):
|
256 |
+
"""Rotate axis angle parameters."""
|
257 |
+
# pose parameters
|
258 |
+
R = np.array(
|
259 |
+
[
|
260 |
+
[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
261 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
262 |
+
[0, 0, 1],
|
263 |
+
]
|
264 |
+
)
|
265 |
+
# find the rotation of the body in camera frame
|
266 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
267 |
+
# apply the global rotation to the global orientation
|
268 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
269 |
+
aa = (resrot.T)[0]
|
270 |
+
return aa
|
271 |
+
|
272 |
+
|
273 |
+
def denormalize_images(images):
|
274 |
+
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(
|
275 |
+
1, 3, 1, 1
|
276 |
+
)
|
277 |
+
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(
|
278 |
+
1, 3, 1, 1
|
279 |
+
)
|
280 |
+
return images
|
281 |
+
|
282 |
+
|
283 |
+
def read_img(img_fn, dummy_shape):
|
284 |
+
try:
|
285 |
+
cv_img = _read_img(img_fn)
|
286 |
+
except:
|
287 |
+
logger.warning(f"Unable to load {img_fn}")
|
288 |
+
cv_img = np.zeros(dummy_shape, dtype=np.float32)
|
289 |
+
return cv_img, False
|
290 |
+
return cv_img, True
|
291 |
+
|
292 |
+
|
293 |
+
def _read_img(img_fn):
|
294 |
+
img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
|
295 |
+
return img.astype(np.float32)
|
296 |
+
|
297 |
+
|
298 |
+
def normalize_kp2d_np(kp2d: np.ndarray, img_res):
|
299 |
+
assert kp2d.shape[1] == 3
|
300 |
+
kp2d_normalized = kp2d.copy()
|
301 |
+
kp2d_normalized[:, :2] = 2.0 * kp2d[:, :2] / img_res - 1.0
|
302 |
+
return kp2d_normalized
|
303 |
+
|
304 |
+
|
305 |
+
def unnormalize_2d_kp(kp_2d_np: np.ndarray, res):
|
306 |
+
assert kp_2d_np.shape[1] == 3
|
307 |
+
kp_2d = np.copy(kp_2d_np)
|
308 |
+
kp_2d[:, :2] = 0.5 * res * (kp_2d[:, :2] + 1)
|
309 |
+
return kp_2d
|
310 |
+
|
311 |
+
|
312 |
+
def normalize_kp2d(kp2d: torch.Tensor, img_res):
|
313 |
+
assert len(kp2d.shape) == 3
|
314 |
+
kp2d_normalized = kp2d.clone()
|
315 |
+
kp2d_normalized[:, :, :2] = 2.0 * kp2d[:, :, :2] / img_res - 1.0
|
316 |
+
return kp2d_normalized
|
317 |
+
|
318 |
+
|
319 |
+
def unormalize_kp2d(kp2d_normalized: torch.Tensor, img_res):
|
320 |
+
assert len(kp2d_normalized.shape) == 3
|
321 |
+
assert kp2d_normalized.shape[2] == 2
|
322 |
+
kp2d = kp2d_normalized.clone()
|
323 |
+
kp2d = 0.5 * img_res * (kp2d + 1)
|
324 |
+
return kp2d
|
325 |
+
|
326 |
+
|
327 |
+
def get_wp_intrix(fixed_focal: float, img_res):
|
328 |
+
# consruct weak perspective on patch
|
329 |
+
camera_center = np.array([img_res // 2, img_res // 2])
|
330 |
+
intrx = torch.zeros([3, 3])
|
331 |
+
intrx[0, 0] = fixed_focal
|
332 |
+
intrx[1, 1] = fixed_focal
|
333 |
+
intrx[2, 2] = 1.0
|
334 |
+
intrx[0, -1] = camera_center[0]
|
335 |
+
intrx[1, -1] = camera_center[1]
|
336 |
+
return intrx
|
337 |
+
|
338 |
+
|
339 |
+
def get_aug_intrix(
|
340 |
+
intrx, fixed_focal: float, img_res, use_gt_k, bbox_cx, bbox_cy, scale
|
341 |
+
):
|
342 |
+
"""
|
343 |
+
This function returns camera intrinsics under scaling.
|
344 |
+
If use_gt_k, the GT K is used, but scaled based on the amount of scaling in the patch.
|
345 |
+
Else, we construct an intrinsic camera with a fixed focal length and fixed camera center.
|
346 |
+
"""
|
347 |
+
|
348 |
+
if not use_gt_k:
|
349 |
+
# consruct weak perspective on patch
|
350 |
+
intrx = get_wp_intrix(fixed_focal, img_res)
|
351 |
+
else:
|
352 |
+
# update the GT intrinsics (full image space)
|
353 |
+
# such that it matches the scale of the patch
|
354 |
+
|
355 |
+
dim = scale * 200.0 # bbox size
|
356 |
+
k_scale = float(img_res) / dim # resized_dim / bbox_size in full image space
|
357 |
+
"""
|
358 |
+
# x1 and y1: top-left corner of bbox
|
359 |
+
intrinsics after data augmentation
|
360 |
+
fx' = k*fx
|
361 |
+
fy' = k*fy
|
362 |
+
cx' = k*(cx - x1)
|
363 |
+
cy' = k*(cy - y1)
|
364 |
+
"""
|
365 |
+
intrx[0, 0] *= k_scale # k*fx
|
366 |
+
intrx[1, 1] *= k_scale # k*fy
|
367 |
+
intrx[0, 2] -= bbox_cx - dim / 2.0
|
368 |
+
intrx[1, 2] -= bbox_cy - dim / 2.0
|
369 |
+
intrx[0, 2] *= k_scale
|
370 |
+
intrx[1, 2] *= k_scale
|
371 |
+
return intrx
|
common/ld_utils.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def sort_dict(disordered):
|
8 |
+
sorted_dict = {k: disordered[k] for k in sorted(disordered)}
|
9 |
+
return sorted_dict
|
10 |
+
|
11 |
+
|
12 |
+
def prefix_dict(mydict, prefix):
|
13 |
+
out = {prefix + k: v for k, v in mydict.items()}
|
14 |
+
return out
|
15 |
+
|
16 |
+
|
17 |
+
def postfix_dict(mydict, postfix):
|
18 |
+
out = {k + postfix: v for k, v in mydict.items()}
|
19 |
+
return out
|
20 |
+
|
21 |
+
|
22 |
+
def unsort(L, sort_idx):
|
23 |
+
assert isinstance(sort_idx, list)
|
24 |
+
assert isinstance(L, list)
|
25 |
+
LL = zip(sort_idx, L)
|
26 |
+
LL = sorted(LL, key=lambda x: x[0])
|
27 |
+
_, L = zip(*LL)
|
28 |
+
return list(L)
|
29 |
+
|
30 |
+
|
31 |
+
def cat_dl(out_list, dim, verbose=True, squeeze=True):
|
32 |
+
out = {}
|
33 |
+
for key, val in out_list.items():
|
34 |
+
if isinstance(val[0], torch.Tensor):
|
35 |
+
out[key] = torch.cat(val, dim=dim)
|
36 |
+
if squeeze:
|
37 |
+
out[key] = out[key].squeeze()
|
38 |
+
elif isinstance(val[0], np.ndarray):
|
39 |
+
out[key] = np.concatenate(val, axis=dim)
|
40 |
+
if squeeze:
|
41 |
+
out[key] = np.squeeze(out[key])
|
42 |
+
elif isinstance(val[0], list):
|
43 |
+
out[key] = sum(val, [])
|
44 |
+
else:
|
45 |
+
if verbose:
|
46 |
+
print(f"Ignoring {key} undefined type {type(val[0])}")
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
def stack_dl(out_list, dim, verbose=True, squeeze=True):
|
51 |
+
out = {}
|
52 |
+
for key, val in out_list.items():
|
53 |
+
if isinstance(val[0], torch.Tensor):
|
54 |
+
out[key] = torch.stack(val, dim=dim)
|
55 |
+
if squeeze:
|
56 |
+
out[key] = out[key].squeeze()
|
57 |
+
elif isinstance(val[0], np.ndarray):
|
58 |
+
out[key] = np.stack(val, axis=dim)
|
59 |
+
if squeeze:
|
60 |
+
out[key] = np.squeeze(out[key])
|
61 |
+
elif isinstance(val[0], list):
|
62 |
+
out[key] = sum(val, [])
|
63 |
+
else:
|
64 |
+
out[key] = val
|
65 |
+
if verbose:
|
66 |
+
print(f"Processing {key} undefined type {type(val[0])}")
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
def add_prefix_postfix(mydict, prefix="", postfix=""):
|
71 |
+
assert isinstance(mydict, dict)
|
72 |
+
return dict((prefix + key + postfix, value) for (key, value) in mydict.items())
|
73 |
+
|
74 |
+
|
75 |
+
def ld2dl(LD):
|
76 |
+
assert isinstance(LD, list)
|
77 |
+
assert isinstance(LD[0], dict)
|
78 |
+
"""
|
79 |
+
A list of dict (same keys) to a dict of lists
|
80 |
+
"""
|
81 |
+
dict_list = {k: [dic[k] for dic in LD] for k in LD[0]}
|
82 |
+
return dict_list
|
83 |
+
|
84 |
+
|
85 |
+
class NameSpace(object):
|
86 |
+
def __init__(self, adict):
|
87 |
+
self.__dict__.update(adict)
|
88 |
+
|
89 |
+
|
90 |
+
def dict2ns(mydict):
|
91 |
+
"""
|
92 |
+
Convert dict objec to namespace
|
93 |
+
"""
|
94 |
+
return NameSpace(mydict)
|
95 |
+
|
96 |
+
|
97 |
+
def ld2dev(ld, dev):
|
98 |
+
"""
|
99 |
+
Convert tensors in a list or dict to a device recursively
|
100 |
+
"""
|
101 |
+
if isinstance(ld, torch.Tensor):
|
102 |
+
return ld.to(dev)
|
103 |
+
if isinstance(ld, dict):
|
104 |
+
for k, v in ld.items():
|
105 |
+
ld[k] = ld2dev(v, dev)
|
106 |
+
return ld
|
107 |
+
if isinstance(ld, list):
|
108 |
+
return [ld2dev(x, dev) for x in ld]
|
109 |
+
return ld
|
110 |
+
|
111 |
+
|
112 |
+
def all_comb_dict(hyper_dict):
|
113 |
+
assert isinstance(hyper_dict, dict)
|
114 |
+
keys, values = zip(*hyper_dict.items())
|
115 |
+
permute_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
116 |
+
return permute_dicts
|
common/list_utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
|
4 |
+
def chunks_by_len(L, n):
|
5 |
+
"""
|
6 |
+
Split a list into n chunks
|
7 |
+
"""
|
8 |
+
num_chunks = int(math.ceil(float(len(L)) / n))
|
9 |
+
splits = [L[x : x + num_chunks] for x in range(0, len(L), num_chunks)]
|
10 |
+
return splits
|
11 |
+
|
12 |
+
|
13 |
+
def chunks_by_size(L, n):
|
14 |
+
"""Yield successive n-sized chunks from lst."""
|
15 |
+
seqs = []
|
16 |
+
for i in range(0, len(L), n):
|
17 |
+
seqs.append(L[i : i + n])
|
18 |
+
return seqs
|
19 |
+
|
20 |
+
|
21 |
+
def unsort(L, sort_idx):
|
22 |
+
assert isinstance(sort_idx, list)
|
23 |
+
assert isinstance(L, list)
|
24 |
+
LL = zip(sort_idx, L)
|
25 |
+
LL = sorted(LL, key=lambda x: x[0])
|
26 |
+
_, L = zip(*LL)
|
27 |
+
return list(L)
|
28 |
+
|
29 |
+
|
30 |
+
def add_prefix_postfix(mydict, prefix="", postfix=""):
|
31 |
+
assert isinstance(mydict, dict)
|
32 |
+
return dict((prefix + key + postfix, value) for (key, value) in mydict.items())
|
33 |
+
|
34 |
+
|
35 |
+
def ld2dl(LD):
|
36 |
+
assert isinstance(LD, list)
|
37 |
+
assert isinstance(LD[0], dict)
|
38 |
+
"""
|
39 |
+
A list of dict (same keys) to a dict of lists
|
40 |
+
"""
|
41 |
+
dict_list = {k: [dic[k] for dic in LD] for k in LD[0]}
|
42 |
+
return dict_list
|
43 |
+
|
44 |
+
|
45 |
+
def chunks(lst, n):
|
46 |
+
"""Yield successive n-sized chunks from lst."""
|
47 |
+
seqs = []
|
48 |
+
for i in range(0, len(lst), n):
|
49 |
+
seqs.append(lst[i : i + n])
|
50 |
+
seqs_chunked = sum(seqs, [])
|
51 |
+
assert set(seqs_chunked) == set(lst)
|
52 |
+
return seqs
|
common/mesh.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import trimesh
|
3 |
+
|
4 |
+
colors = {
|
5 |
+
"pink": [1.00, 0.75, 0.80],
|
6 |
+
"purple": [0.63, 0.13, 0.94],
|
7 |
+
"red": [1.0, 0.0, 0.0],
|
8 |
+
"green": [0.0, 1.0, 0.0],
|
9 |
+
"yellow": [1.0, 1.0, 0],
|
10 |
+
"brown": [1.00, 0.25, 0.25],
|
11 |
+
"blue": [0.0, 0.0, 1.0],
|
12 |
+
"white": [1.0, 1.0, 1.0],
|
13 |
+
"orange": [1.00, 0.65, 0.00],
|
14 |
+
"grey": [0.75, 0.75, 0.75],
|
15 |
+
"black": [0.0, 0.0, 0.0],
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class Mesh(trimesh.Trimesh):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
filename=None,
|
23 |
+
v=None,
|
24 |
+
f=None,
|
25 |
+
vc=None,
|
26 |
+
fc=None,
|
27 |
+
process=False,
|
28 |
+
visual=None,
|
29 |
+
**kwargs
|
30 |
+
):
|
31 |
+
if filename is not None:
|
32 |
+
mesh = trimesh.load(filename, process=process)
|
33 |
+
v = mesh.vertices
|
34 |
+
f = mesh.faces
|
35 |
+
visual = mesh.visual
|
36 |
+
|
37 |
+
super(Mesh, self).__init__(
|
38 |
+
vertices=v, faces=f, visual=visual, process=process, **kwargs
|
39 |
+
)
|
40 |
+
|
41 |
+
self.v = self.vertices
|
42 |
+
self.f = self.faces
|
43 |
+
assert self.v is self.vertices
|
44 |
+
assert self.f is self.faces
|
45 |
+
|
46 |
+
if vc is not None:
|
47 |
+
self.set_vc(vc)
|
48 |
+
self.vc = self.visual.vertex_colors
|
49 |
+
assert self.vc is self.visual.vertex_colors
|
50 |
+
if fc is not None:
|
51 |
+
self.set_fc(fc)
|
52 |
+
self.fc = self.visual.face_colors
|
53 |
+
assert self.fc is self.visual.face_colors
|
54 |
+
|
55 |
+
def rot_verts(self, vertices, rxyz):
|
56 |
+
return np.array(vertices * rxyz.T)
|
57 |
+
|
58 |
+
def colors_like(self, color, array, ids):
|
59 |
+
color = np.array(color)
|
60 |
+
|
61 |
+
if color.max() <= 1.0:
|
62 |
+
color = color * 255
|
63 |
+
color = color.astype(np.int8)
|
64 |
+
|
65 |
+
n_color = color.shape[0]
|
66 |
+
n_ids = ids.shape[0]
|
67 |
+
|
68 |
+
new_color = np.array(array)
|
69 |
+
if n_color <= 4:
|
70 |
+
new_color[ids, :n_color] = np.repeat(color[np.newaxis], n_ids, axis=0)
|
71 |
+
else:
|
72 |
+
new_color[ids, :] = color
|
73 |
+
|
74 |
+
return new_color
|
75 |
+
|
76 |
+
def set_vc(self, vc, vertex_ids=None):
|
77 |
+
all_ids = np.arange(self.vertices.shape[0])
|
78 |
+
if vertex_ids is None:
|
79 |
+
vertex_ids = all_ids
|
80 |
+
|
81 |
+
vertex_ids = all_ids[vertex_ids]
|
82 |
+
new_vc = self.colors_like(vc, self.visual.vertex_colors, vertex_ids)
|
83 |
+
self.visual.vertex_colors[:] = new_vc
|
84 |
+
|
85 |
+
def set_fc(self, fc, face_ids=None):
|
86 |
+
if face_ids is None:
|
87 |
+
face_ids = np.arange(self.faces.shape[0])
|
88 |
+
|
89 |
+
new_fc = self.colors_like(fc, self.visual.face_colors, face_ids)
|
90 |
+
self.visual.face_colors[:] = new_fc
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def cat(meshes):
|
94 |
+
return trimesh.util.concatenate(meshes)
|
common/metrics.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def compute_v2v_dist_no_reduce(v3d_cam_gt, v3d_cam_pred, is_valid):
|
8 |
+
assert isinstance(v3d_cam_gt, list)
|
9 |
+
assert isinstance(v3d_cam_pred, list)
|
10 |
+
assert len(v3d_cam_gt) == len(v3d_cam_pred)
|
11 |
+
assert len(v3d_cam_gt) == len(is_valid)
|
12 |
+
v2v = []
|
13 |
+
for v_gt, v_pred, valid in zip(v3d_cam_gt, v3d_cam_pred, is_valid):
|
14 |
+
if valid:
|
15 |
+
dist = ((v_gt - v_pred) ** 2).sum(dim=1).sqrt().cpu().numpy() # meter
|
16 |
+
else:
|
17 |
+
dist = None
|
18 |
+
v2v.append(dist)
|
19 |
+
return v2v
|
20 |
+
|
21 |
+
|
22 |
+
def compute_joint3d_error(joints3d_cam_gt, joints3d_cam_pred, valid_jts):
|
23 |
+
valid_jts = valid_jts.view(-1)
|
24 |
+
assert joints3d_cam_gt.shape == joints3d_cam_pred.shape
|
25 |
+
assert joints3d_cam_gt.shape[0] == valid_jts.shape[0]
|
26 |
+
dist = ((joints3d_cam_gt - joints3d_cam_pred) ** 2).sum(dim=2).sqrt()
|
27 |
+
invalid_idx = torch.nonzero((1 - valid_jts).long()).view(-1)
|
28 |
+
dist[invalid_idx, :] = float("nan")
|
29 |
+
dist = dist.cpu().numpy()
|
30 |
+
return dist
|
31 |
+
|
32 |
+
|
33 |
+
def compute_mrrpe(root_r_gt, root_l_gt, root_r_pred, root_l_pred, is_valid):
|
34 |
+
rel_vec_gt = root_l_gt - root_r_gt
|
35 |
+
rel_vec_pred = root_l_pred - root_r_pred
|
36 |
+
|
37 |
+
invalid_idx = torch.nonzero((1 - is_valid).long()).view(-1)
|
38 |
+
mrrpe = ((rel_vec_pred - rel_vec_gt) ** 2).sum(dim=1).sqrt()
|
39 |
+
mrrpe[invalid_idx] = float("nan")
|
40 |
+
mrrpe = mrrpe.cpu().numpy()
|
41 |
+
return mrrpe
|
42 |
+
|
43 |
+
|
44 |
+
def compute_arti_deg_error(pred_radian, gt_radian):
|
45 |
+
assert pred_radian.shape == gt_radian.shape
|
46 |
+
|
47 |
+
# articulation error in degree
|
48 |
+
pred_degree = pred_radian / math.pi * 180 # degree
|
49 |
+
gt_degree = gt_radian / math.pi * 180 # degree
|
50 |
+
err_deg = torch.abs(pred_degree - gt_degree).tolist()
|
51 |
+
return np.array(err_deg, dtype=np.float32)
|
common/np_utils.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def permute_np(x, idx):
|
5 |
+
original_perm = tuple(range(len(x.shape)))
|
6 |
+
x = np.moveaxis(x, original_perm, idx)
|
7 |
+
return x
|
common/object_tensors.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os.path as op
|
3 |
+
import sys
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import trimesh
|
9 |
+
from easydict import EasyDict
|
10 |
+
from scipy.spatial.distance import cdist
|
11 |
+
|
12 |
+
sys.path = [".."] + sys.path
|
13 |
+
import common.thing as thing
|
14 |
+
from common.rot import axis_angle_to_quaternion, quaternion_apply
|
15 |
+
from common.torch_utils import pad_tensor_list
|
16 |
+
from common.xdict import xdict
|
17 |
+
|
18 |
+
# objects to consider for training so far
|
19 |
+
OBJECTS = [
|
20 |
+
"capsulemachine",
|
21 |
+
"box",
|
22 |
+
"ketchup",
|
23 |
+
"laptop",
|
24 |
+
"microwave",
|
25 |
+
"mixer",
|
26 |
+
"notebook",
|
27 |
+
"espressomachine",
|
28 |
+
"waffleiron",
|
29 |
+
"scissors",
|
30 |
+
"phone",
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
class ObjectTensors(nn.Module):
|
35 |
+
def __init__(self):
|
36 |
+
super(ObjectTensors, self).__init__()
|
37 |
+
self.obj_tensors = thing.thing2dev(construct_obj_tensors(OBJECTS), "cpu")
|
38 |
+
self.dev = None
|
39 |
+
|
40 |
+
def forward_7d_batch(
|
41 |
+
self,
|
42 |
+
angles: (None, torch.Tensor),
|
43 |
+
global_orient: (None, torch.Tensor),
|
44 |
+
transl: (None, torch.Tensor),
|
45 |
+
query_names: list,
|
46 |
+
fwd_template: bool,
|
47 |
+
):
|
48 |
+
self._sanity_check(angles, global_orient, transl, query_names, fwd_template)
|
49 |
+
|
50 |
+
# store output
|
51 |
+
out = xdict()
|
52 |
+
|
53 |
+
# meta info
|
54 |
+
obj_idx = np.array(
|
55 |
+
[self.obj_tensors["names"].index(name) for name in query_names]
|
56 |
+
)
|
57 |
+
out["diameter"] = self.obj_tensors["diameter"][obj_idx]
|
58 |
+
out["f"] = self.obj_tensors["f"][obj_idx]
|
59 |
+
out["f_len"] = self.obj_tensors["f_len"][obj_idx]
|
60 |
+
out["v_len"] = self.obj_tensors["v_len"][obj_idx]
|
61 |
+
|
62 |
+
max_len = out["v_len"].max()
|
63 |
+
out["v"] = self.obj_tensors["v"][obj_idx][:, :max_len]
|
64 |
+
out["mask"] = self.obj_tensors["mask"][obj_idx][:, :max_len]
|
65 |
+
out["v_sub"] = self.obj_tensors["v_sub"][obj_idx]
|
66 |
+
out["parts_ids"] = self.obj_tensors["parts_ids"][obj_idx][:, :max_len]
|
67 |
+
out["parts_sub_ids"] = self.obj_tensors["parts_sub_ids"][obj_idx]
|
68 |
+
|
69 |
+
if fwd_template:
|
70 |
+
return out
|
71 |
+
|
72 |
+
# articulation + global rotation
|
73 |
+
quat_arti = axis_angle_to_quaternion(self.obj_tensors["z_axis"] * angles)
|
74 |
+
quat_global = axis_angle_to_quaternion(global_orient.view(-1, 3))
|
75 |
+
|
76 |
+
# mm
|
77 |
+
# collect entities to be transformed
|
78 |
+
tf_dict = xdict()
|
79 |
+
tf_dict["v_top"] = out["v"].clone()
|
80 |
+
tf_dict["v_sub_top"] = out["v_sub"].clone()
|
81 |
+
tf_dict["v_bottom"] = out["v"].clone()
|
82 |
+
tf_dict["v_sub_bottom"] = out["v_sub"].clone()
|
83 |
+
tf_dict["bbox_top"] = self.obj_tensors["bbox_top"][obj_idx]
|
84 |
+
tf_dict["bbox_bottom"] = self.obj_tensors["bbox_bottom"][obj_idx]
|
85 |
+
tf_dict["kp_top"] = self.obj_tensors["kp_top"][obj_idx]
|
86 |
+
tf_dict["kp_bottom"] = self.obj_tensors["kp_bottom"][obj_idx]
|
87 |
+
|
88 |
+
# articulate top parts
|
89 |
+
for key, val in tf_dict.items():
|
90 |
+
if "top" in key:
|
91 |
+
val_rot = quaternion_apply(quat_arti[:, None, :], val)
|
92 |
+
tf_dict.overwrite(key, val_rot)
|
93 |
+
|
94 |
+
# global rotation for all
|
95 |
+
for key, val in tf_dict.items():
|
96 |
+
val_rot = quaternion_apply(quat_global[:, None, :], val)
|
97 |
+
if transl is not None:
|
98 |
+
val_rot = val_rot + transl[:, None, :]
|
99 |
+
tf_dict.overwrite(key, val_rot)
|
100 |
+
|
101 |
+
# prep output
|
102 |
+
top_idx = out["parts_ids"] == 1
|
103 |
+
v_tensor = tf_dict["v_bottom"].clone()
|
104 |
+
v_tensor[top_idx, :] = tf_dict["v_top"][top_idx, :]
|
105 |
+
|
106 |
+
top_idx = out["parts_sub_ids"] == 1
|
107 |
+
v_sub_tensor = tf_dict["v_sub_bottom"].clone()
|
108 |
+
v_sub_tensor[top_idx, :] = tf_dict["v_sub_top"][top_idx, :]
|
109 |
+
|
110 |
+
bbox = torch.cat((tf_dict["bbox_top"], tf_dict["bbox_bottom"]), dim=1)
|
111 |
+
kp3d = torch.cat((tf_dict["kp_top"], tf_dict["kp_bottom"]), dim=1)
|
112 |
+
|
113 |
+
out.overwrite("v", v_tensor)
|
114 |
+
out.overwrite("v_sub", v_sub_tensor)
|
115 |
+
out.overwrite("bbox3d", bbox)
|
116 |
+
out.overwrite("kp3d", kp3d)
|
117 |
+
return out
|
118 |
+
|
119 |
+
def forward(self, angles, global_orient, transl, query_names):
|
120 |
+
out = self.forward_7d_batch(
|
121 |
+
angles, global_orient, transl, query_names, fwd_template=False
|
122 |
+
)
|
123 |
+
return out
|
124 |
+
|
125 |
+
def forward_template(self, query_names):
|
126 |
+
out = self.forward_7d_batch(
|
127 |
+
angles=None,
|
128 |
+
global_orient=None,
|
129 |
+
transl=None,
|
130 |
+
query_names=query_names,
|
131 |
+
fwd_template=True,
|
132 |
+
)
|
133 |
+
return out
|
134 |
+
|
135 |
+
def to(self, dev):
|
136 |
+
self.obj_tensors = thing.thing2dev(self.obj_tensors, dev)
|
137 |
+
self.dev = dev
|
138 |
+
|
139 |
+
def _sanity_check(self, angles, global_orient, transl, query_names, fwd_template):
|
140 |
+
# sanity check
|
141 |
+
if not fwd_template:
|
142 |
+
# assume transl is in meter
|
143 |
+
if transl is not None:
|
144 |
+
transl = transl * 1000 # mm
|
145 |
+
|
146 |
+
batch_size = angles.shape[0]
|
147 |
+
assert angles.shape == (batch_size, 1)
|
148 |
+
assert global_orient.shape == (batch_size, 3)
|
149 |
+
if transl is not None:
|
150 |
+
assert isinstance(transl, torch.Tensor)
|
151 |
+
assert transl.shape == (batch_size, 3)
|
152 |
+
assert len(query_names) == batch_size
|
153 |
+
|
154 |
+
|
155 |
+
def construct_obj(object_model_p):
|
156 |
+
# load vtemplate
|
157 |
+
mesh_p = op.join(object_model_p, "mesh.obj")
|
158 |
+
parts_p = op.join(object_model_p, f"parts.json")
|
159 |
+
json_p = op.join(object_model_p, "object_params.json")
|
160 |
+
obj_name = op.basename(object_model_p)
|
161 |
+
|
162 |
+
top_sub_p = f"./data/arctic_data/data/meta/object_vtemplates/{obj_name}/top_keypoints_300.json"
|
163 |
+
bottom_sub_p = top_sub_p.replace("top_", "bottom_")
|
164 |
+
with open(top_sub_p, "r") as f:
|
165 |
+
sub_top = np.array(json.load(f)["keypoints"])
|
166 |
+
|
167 |
+
with open(bottom_sub_p, "r") as f:
|
168 |
+
sub_bottom = np.array(json.load(f)["keypoints"])
|
169 |
+
sub_v = np.concatenate((sub_top, sub_bottom), axis=0)
|
170 |
+
|
171 |
+
with open(parts_p, "r") as f:
|
172 |
+
parts = np.array(json.load(f), dtype=np.bool)
|
173 |
+
|
174 |
+
assert op.exists(mesh_p), f"Not found: {mesh_p}"
|
175 |
+
|
176 |
+
mesh = trimesh.exchange.load.load_mesh(mesh_p, process=False)
|
177 |
+
mesh_v = mesh.vertices
|
178 |
+
|
179 |
+
mesh_f = torch.LongTensor(mesh.faces)
|
180 |
+
vidx = np.argmin(cdist(sub_v, mesh_v, metric="euclidean"), axis=1)
|
181 |
+
parts_sub = parts[vidx]
|
182 |
+
|
183 |
+
vsk = object_model_p.split("/")[-1]
|
184 |
+
|
185 |
+
with open(json_p, "r") as f:
|
186 |
+
params = json.load(f)
|
187 |
+
rest = EasyDict()
|
188 |
+
rest.top = np.array(params["mocap_top"])
|
189 |
+
rest.bottom = np.array(params["mocap_bottom"])
|
190 |
+
bbox_top = np.array(params["bbox_top"])
|
191 |
+
bbox_bottom = np.array(params["bbox_bottom"])
|
192 |
+
kp_top = np.array(params["keypoints_top"])
|
193 |
+
kp_bottom = np.array(params["keypoints_bottom"])
|
194 |
+
|
195 |
+
np.random.seed(1)
|
196 |
+
|
197 |
+
obj = EasyDict()
|
198 |
+
obj.name = vsk
|
199 |
+
obj.obj_name = "".join([i for i in vsk if not i.isdigit()])
|
200 |
+
obj.v = torch.FloatTensor(mesh_v)
|
201 |
+
obj.v_sub = torch.FloatTensor(sub_v)
|
202 |
+
obj.f = torch.LongTensor(mesh_f)
|
203 |
+
obj.parts = torch.LongTensor(parts)
|
204 |
+
obj.parts_sub = torch.LongTensor(parts_sub)
|
205 |
+
|
206 |
+
with open("./data/arctic_data/data/meta/object_meta.json", "r") as f:
|
207 |
+
object_meta = json.load(f)
|
208 |
+
obj.diameter = torch.FloatTensor(np.array(object_meta[obj.obj_name]["diameter"]))
|
209 |
+
obj.bbox_top = torch.FloatTensor(bbox_top)
|
210 |
+
obj.bbox_bottom = torch.FloatTensor(bbox_bottom)
|
211 |
+
obj.kp_top = torch.FloatTensor(kp_top)
|
212 |
+
obj.kp_bottom = torch.FloatTensor(kp_bottom)
|
213 |
+
obj.mocap_top = torch.FloatTensor(np.array(params["mocap_top"]))
|
214 |
+
obj.mocap_bottom = torch.FloatTensor(np.array(params["mocap_bottom"]))
|
215 |
+
return obj
|
216 |
+
|
217 |
+
|
218 |
+
def construct_obj_tensors(object_names):
|
219 |
+
obj_list = []
|
220 |
+
for k in object_names:
|
221 |
+
object_model_p = f"./data/arctic_data/data/meta/object_vtemplates/%s" % (k)
|
222 |
+
obj = construct_obj(object_model_p)
|
223 |
+
obj_list.append(obj)
|
224 |
+
|
225 |
+
bbox_top_list = []
|
226 |
+
bbox_bottom_list = []
|
227 |
+
mocap_top_list = []
|
228 |
+
mocap_bottom_list = []
|
229 |
+
kp_top_list = []
|
230 |
+
kp_bottom_list = []
|
231 |
+
v_list = []
|
232 |
+
v_sub_list = []
|
233 |
+
f_list = []
|
234 |
+
parts_list = []
|
235 |
+
parts_sub_list = []
|
236 |
+
diameter_list = []
|
237 |
+
for obj in obj_list:
|
238 |
+
v_list.append(obj.v)
|
239 |
+
v_sub_list.append(obj.v_sub)
|
240 |
+
f_list.append(obj.f)
|
241 |
+
|
242 |
+
# root_list.append(obj.root)
|
243 |
+
bbox_top_list.append(obj.bbox_top)
|
244 |
+
bbox_bottom_list.append(obj.bbox_bottom)
|
245 |
+
kp_top_list.append(obj.kp_top)
|
246 |
+
kp_bottom_list.append(obj.kp_bottom)
|
247 |
+
mocap_top_list.append(obj.mocap_top / 1000)
|
248 |
+
mocap_bottom_list.append(obj.mocap_bottom / 1000)
|
249 |
+
parts_list.append(obj.parts + 1)
|
250 |
+
parts_sub_list.append(obj.parts_sub + 1)
|
251 |
+
diameter_list.append(obj.diameter)
|
252 |
+
|
253 |
+
v_list, v_len_list = pad_tensor_list(v_list)
|
254 |
+
p_list, p_len_list = pad_tensor_list(parts_list)
|
255 |
+
ps_list = torch.stack(parts_sub_list, dim=0)
|
256 |
+
assert (p_len_list - v_len_list).sum() == 0
|
257 |
+
|
258 |
+
max_len = v_len_list.max()
|
259 |
+
mask = torch.zeros(len(obj_list), max_len)
|
260 |
+
for idx, vlen in enumerate(v_len_list):
|
261 |
+
mask[idx, :vlen] = 1.0
|
262 |
+
|
263 |
+
v_sub_list = torch.stack(v_sub_list, dim=0)
|
264 |
+
diameter_list = torch.stack(diameter_list, dim=0)
|
265 |
+
|
266 |
+
f_list, f_len_list = pad_tensor_list(f_list)
|
267 |
+
|
268 |
+
bbox_top_list = torch.stack(bbox_top_list, dim=0)
|
269 |
+
bbox_bottom_list = torch.stack(bbox_bottom_list, dim=0)
|
270 |
+
kp_top_list = torch.stack(kp_top_list, dim=0)
|
271 |
+
kp_bottom_list = torch.stack(kp_bottom_list, dim=0)
|
272 |
+
|
273 |
+
obj_tensors = {}
|
274 |
+
obj_tensors["names"] = object_names
|
275 |
+
obj_tensors["parts_ids"] = p_list
|
276 |
+
obj_tensors["parts_sub_ids"] = ps_list
|
277 |
+
|
278 |
+
obj_tensors["v"] = v_list.float() / 1000
|
279 |
+
obj_tensors["v_sub"] = v_sub_list.float() / 1000
|
280 |
+
obj_tensors["v_len"] = v_len_list
|
281 |
+
obj_tensors["f"] = f_list
|
282 |
+
obj_tensors["f_len"] = f_len_list
|
283 |
+
obj_tensors["diameter"] = diameter_list.float()
|
284 |
+
|
285 |
+
obj_tensors["mask"] = mask
|
286 |
+
obj_tensors["bbox_top"] = bbox_top_list.float() / 1000
|
287 |
+
obj_tensors["bbox_bottom"] = bbox_bottom_list.float() / 1000
|
288 |
+
obj_tensors["kp_top"] = kp_top_list.float() / 1000
|
289 |
+
obj_tensors["kp_bottom"] = kp_bottom_list.float() / 1000
|
290 |
+
obj_tensors["mocap_top"] = mocap_top_list
|
291 |
+
obj_tensors["mocap_bottom"] = mocap_bottom_list
|
292 |
+
obj_tensors["z_axis"] = torch.FloatTensor(np.array([0, 0, -1])).view(1, 3)
|
293 |
+
return obj_tensors
|
common/pl_utils.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import common.thing as thing
|
7 |
+
from common.ld_utils import ld2dl
|
8 |
+
|
9 |
+
|
10 |
+
def reweight_loss_by_keys(loss_dict, keys, alpha):
|
11 |
+
for key in keys:
|
12 |
+
val, weight = loss_dict[key]
|
13 |
+
weight_new = weight * alpha
|
14 |
+
loss_dict[key] = (val, weight_new)
|
15 |
+
return loss_dict
|
16 |
+
|
17 |
+
|
18 |
+
def select_loss_group(groups, agent_id, alphas):
|
19 |
+
random.seed(1)
|
20 |
+
random.shuffle(groups)
|
21 |
+
|
22 |
+
keys = groups[agent_id % len(groups)]
|
23 |
+
|
24 |
+
random.seed(time.time())
|
25 |
+
alpha = random.choice(alphas)
|
26 |
+
random.seed(1)
|
27 |
+
return keys, alpha
|
28 |
+
|
29 |
+
|
30 |
+
def push_checkpoint_metric(key, val):
|
31 |
+
val = float(val)
|
32 |
+
checkpt_metric = torch.FloatTensor([val])
|
33 |
+
result = {key: checkpt_metric}
|
34 |
+
return result
|
35 |
+
|
36 |
+
|
37 |
+
def avg_losses_cpu(outputs):
|
38 |
+
outputs = ld2dl(outputs)
|
39 |
+
for key, val in outputs.items():
|
40 |
+
val = [v.cpu() for v in val]
|
41 |
+
val = torch.cat(val, dim=0).view(-1)
|
42 |
+
outputs[key] = val.mean()
|
43 |
+
return outputs
|
44 |
+
|
45 |
+
|
46 |
+
def reform_outputs(out_list):
|
47 |
+
out_list_dict = ld2dl(out_list)
|
48 |
+
outputs = ld2dl(out_list_dict["out_dict"])
|
49 |
+
losses = ld2dl(out_list_dict["loss"])
|
50 |
+
|
51 |
+
for k, tensor in outputs.items():
|
52 |
+
if isinstance(tensor[0], list):
|
53 |
+
outputs[k] = sum(tensor, [])
|
54 |
+
else:
|
55 |
+
outputs[k] = torch.cat(tensor)
|
56 |
+
|
57 |
+
for k, tensor in losses.items():
|
58 |
+
tensor = [ten.view(-1) for ten in tensor]
|
59 |
+
losses[k] = torch.cat(tensor)
|
60 |
+
|
61 |
+
outputs = {k: thing.thing2np(v) for k, v in outputs.items()}
|
62 |
+
loss_dict = {k: v.mean().item() for k, v in losses.items()}
|
63 |
+
return outputs, loss_dict
|
common/rend_utils.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pyrender
|
6 |
+
import trimesh
|
7 |
+
|
8 |
+
# offline rendering
|
9 |
+
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
10 |
+
|
11 |
+
|
12 |
+
def flip_meshes(meshes):
|
13 |
+
rot = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
|
14 |
+
for mesh in meshes:
|
15 |
+
mesh.apply_transform(rot)
|
16 |
+
return meshes
|
17 |
+
|
18 |
+
|
19 |
+
def color2material(mesh_color: list):
|
20 |
+
material = pyrender.MetallicRoughnessMaterial(
|
21 |
+
metallicFactor=0.1,
|
22 |
+
alphaMode="OPAQUE",
|
23 |
+
baseColorFactor=(
|
24 |
+
mesh_color[0] / 255.0,
|
25 |
+
mesh_color[1] / 255.0,
|
26 |
+
mesh_color[2] / 255.0,
|
27 |
+
0.5,
|
28 |
+
),
|
29 |
+
)
|
30 |
+
return material
|
31 |
+
|
32 |
+
|
33 |
+
class Renderer:
|
34 |
+
def __init__(self, img_res: int) -> None:
|
35 |
+
self.renderer = pyrender.OffscreenRenderer(
|
36 |
+
viewport_width=img_res, viewport_height=img_res, point_size=1.0
|
37 |
+
)
|
38 |
+
|
39 |
+
self.img_res = img_res
|
40 |
+
|
41 |
+
def render_meshes_pose(
|
42 |
+
self,
|
43 |
+
meshes,
|
44 |
+
image=None,
|
45 |
+
cam_transl=None,
|
46 |
+
cam_center=None,
|
47 |
+
K=None,
|
48 |
+
materials=None,
|
49 |
+
sideview_angle=None,
|
50 |
+
):
|
51 |
+
# unpack
|
52 |
+
if cam_transl is not None:
|
53 |
+
cam_trans = np.copy(cam_transl)
|
54 |
+
cam_trans[0] *= -1.0
|
55 |
+
else:
|
56 |
+
cam_trans = None
|
57 |
+
meshes = copy.deepcopy(meshes)
|
58 |
+
meshes = flip_meshes(meshes)
|
59 |
+
|
60 |
+
if sideview_angle is not None:
|
61 |
+
# center around the final mesh
|
62 |
+
anchor_mesh = meshes[-1]
|
63 |
+
center = anchor_mesh.vertices.mean(axis=0)
|
64 |
+
|
65 |
+
rot = trimesh.transformations.rotation_matrix(
|
66 |
+
np.radians(sideview_angle), [0, 1, 0]
|
67 |
+
)
|
68 |
+
out_meshes = []
|
69 |
+
for mesh in copy.deepcopy(meshes):
|
70 |
+
mesh.vertices -= center
|
71 |
+
mesh.apply_transform(rot)
|
72 |
+
mesh.vertices += center
|
73 |
+
# further away to see more
|
74 |
+
mesh.vertices += np.array([0, 0, -0.10])
|
75 |
+
out_meshes.append(mesh)
|
76 |
+
meshes = out_meshes
|
77 |
+
|
78 |
+
# setting up
|
79 |
+
self.create_scene()
|
80 |
+
self.setup_light()
|
81 |
+
self.position_camera(cam_trans, K)
|
82 |
+
if materials is not None:
|
83 |
+
meshes = [
|
84 |
+
pyrender.Mesh.from_trimesh(mesh, material=material)
|
85 |
+
for mesh, material in zip(meshes, materials)
|
86 |
+
]
|
87 |
+
else:
|
88 |
+
meshes = [pyrender.Mesh.from_trimesh(mesh) for mesh in meshes]
|
89 |
+
|
90 |
+
for mesh in meshes:
|
91 |
+
self.scene.add(mesh)
|
92 |
+
|
93 |
+
color, valid_mask = self.render_rgb()
|
94 |
+
if image is None:
|
95 |
+
output_img = color[:, :, :3]
|
96 |
+
else:
|
97 |
+
output_img = self.overlay_image(color, valid_mask, image)
|
98 |
+
rend_img = (output_img * 255).astype(np.uint8)
|
99 |
+
return rend_img
|
100 |
+
|
101 |
+
def render_rgb(self):
|
102 |
+
color, rend_depth = self.renderer.render(
|
103 |
+
self.scene, flags=pyrender.RenderFlags.RGBA
|
104 |
+
)
|
105 |
+
color = color.astype(np.float32) / 255.0
|
106 |
+
valid_mask = (rend_depth > 0)[:, :, None]
|
107 |
+
return color, valid_mask
|
108 |
+
|
109 |
+
def overlay_image(self, color, valid_mask, image):
|
110 |
+
output_img = color[:, :, :3] * valid_mask + (1 - valid_mask) * image
|
111 |
+
return output_img
|
112 |
+
|
113 |
+
def position_camera(self, cam_transl, K):
|
114 |
+
camera_pose = np.eye(4)
|
115 |
+
if cam_transl is not None:
|
116 |
+
camera_pose[:3, 3] = cam_transl
|
117 |
+
|
118 |
+
fx = K[0, 0]
|
119 |
+
fy = K[1, 1]
|
120 |
+
cx = K[0, 2]
|
121 |
+
cy = K[1, 2]
|
122 |
+
camera = pyrender.IntrinsicsCamera(fx=fx, fy=fy, cx=cx, cy=cy)
|
123 |
+
self.scene.add(camera, pose=camera_pose)
|
124 |
+
|
125 |
+
def setup_light(self):
|
126 |
+
light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=1)
|
127 |
+
light_pose = np.eye(4)
|
128 |
+
|
129 |
+
light_pose[:3, 3] = np.array([0, -1, 1])
|
130 |
+
self.scene.add(light, pose=light_pose)
|
131 |
+
|
132 |
+
light_pose[:3, 3] = np.array([0, 1, 1])
|
133 |
+
self.scene.add(light, pose=light_pose)
|
134 |
+
|
135 |
+
light_pose[:3, 3] = np.array([1, 1, 2])
|
136 |
+
self.scene.add(light, pose=light_pose)
|
137 |
+
|
138 |
+
def create_scene(self):
|
139 |
+
self.scene = pyrender.Scene(ambient_light=(0.5, 0.5, 0.5))
|
common/rot.py
ADDED
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
"""
|
7 |
+
Taken from https://pytorch3d.readthedocs.io/en/latest/_modules/pytorch3d/transforms/rotation_conversions.html
|
8 |
+
Just to avoid installing pytorch3d at times
|
9 |
+
"""
|
10 |
+
|
11 |
+
|
12 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
Convert a unit quaternion to a standard form: one in which the real
|
15 |
+
part is non negative.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
quaternions: Quaternions with real part first,
|
19 |
+
as tensor of shape (..., 4).
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Standardized quaternions as tensor of shape (..., 4).
|
23 |
+
"""
|
24 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
25 |
+
|
26 |
+
|
27 |
+
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
28 |
+
"""
|
29 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
30 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
31 |
+
Usual torch rules for broadcasting apply.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
35 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
39 |
+
"""
|
40 |
+
ab = quaternion_raw_multiply(a, b)
|
41 |
+
return standardize_quaternion(ab)
|
42 |
+
|
43 |
+
|
44 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
45 |
+
"""
|
46 |
+
Returns torch.sqrt(torch.max(0, x))
|
47 |
+
but with a zero subgradient where x is 0.
|
48 |
+
"""
|
49 |
+
ret = torch.zeros_like(x)
|
50 |
+
positive_mask = x > 0
|
51 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
52 |
+
return ret
|
53 |
+
|
54 |
+
|
55 |
+
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
56 |
+
"""
|
57 |
+
Convert rotations given as quaternions to axis/angle.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
quaternions: quaternions with real part first,
|
61 |
+
as tensor of shape (..., 4).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Rotations given as a vector in axis angle form, as a tensor
|
65 |
+
of shape (..., 3), where the magnitude is the angle
|
66 |
+
turned anticlockwise in radians around the vector's
|
67 |
+
direction.
|
68 |
+
"""
|
69 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
70 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
71 |
+
angles = 2 * half_angles
|
72 |
+
eps = 1e-6
|
73 |
+
small_angles = angles.abs() < eps
|
74 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
75 |
+
sin_half_angles_over_angles[~small_angles] = (
|
76 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
77 |
+
)
|
78 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
79 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
80 |
+
sin_half_angles_over_angles[small_angles] = (
|
81 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
82 |
+
)
|
83 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
84 |
+
|
85 |
+
|
86 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
87 |
+
"""
|
88 |
+
Convert rotations given as quaternions to rotation matrices.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
quaternions: quaternions with real part first,
|
92 |
+
as tensor of shape (..., 4).
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
96 |
+
"""
|
97 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
98 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
99 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
100 |
+
|
101 |
+
o = torch.stack(
|
102 |
+
(
|
103 |
+
1 - two_s * (j * j + k * k),
|
104 |
+
two_s * (i * j - k * r),
|
105 |
+
two_s * (i * k + j * r),
|
106 |
+
two_s * (i * j + k * r),
|
107 |
+
1 - two_s * (i * i + k * k),
|
108 |
+
two_s * (j * k - i * r),
|
109 |
+
two_s * (i * k - j * r),
|
110 |
+
two_s * (j * k + i * r),
|
111 |
+
1 - two_s * (i * i + j * j),
|
112 |
+
),
|
113 |
+
-1,
|
114 |
+
)
|
115 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
116 |
+
|
117 |
+
|
118 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
119 |
+
"""
|
120 |
+
Convert rotations given as rotation matrices to quaternions.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
127 |
+
"""
|
128 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
129 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
130 |
+
|
131 |
+
batch_dim = matrix.shape[:-2]
|
132 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
133 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
134 |
+
)
|
135 |
+
|
136 |
+
q_abs = _sqrt_positive_part(
|
137 |
+
torch.stack(
|
138 |
+
[
|
139 |
+
1.0 + m00 + m11 + m22,
|
140 |
+
1.0 + m00 - m11 - m22,
|
141 |
+
1.0 - m00 + m11 - m22,
|
142 |
+
1.0 - m00 - m11 + m22,
|
143 |
+
],
|
144 |
+
dim=-1,
|
145 |
+
)
|
146 |
+
)
|
147 |
+
|
148 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
149 |
+
quat_by_rijk = torch.stack(
|
150 |
+
[
|
151 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
152 |
+
# `int`.
|
153 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
154 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
155 |
+
# `int`.
|
156 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
157 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
158 |
+
# `int`.
|
159 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
160 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
161 |
+
# `int`.
|
162 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
163 |
+
],
|
164 |
+
dim=-2,
|
165 |
+
)
|
166 |
+
|
167 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
168 |
+
# the candidate won't be picked.
|
169 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
170 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
171 |
+
|
172 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
173 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
174 |
+
|
175 |
+
return quat_candidates[
|
176 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
177 |
+
].reshape(batch_dim + (4,))
|
178 |
+
|
179 |
+
|
180 |
+
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
181 |
+
"""
|
182 |
+
Convert rotations given as rotation matrices to axis/angle.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Rotations given as a vector in axis angle form, as a tensor
|
189 |
+
of shape (..., 3), where the magnitude is the angle
|
190 |
+
turned anticlockwise in radians around the vector's
|
191 |
+
direction.
|
192 |
+
"""
|
193 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
194 |
+
|
195 |
+
|
196 |
+
def rot_aa(aa, rot):
|
197 |
+
"""Rotate axis angle parameters."""
|
198 |
+
# pose parameters
|
199 |
+
R = np.array(
|
200 |
+
[
|
201 |
+
[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
202 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
203 |
+
[0, 0, 1],
|
204 |
+
]
|
205 |
+
)
|
206 |
+
# find the rotation of the body in camera frame
|
207 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
208 |
+
# apply the global rotation to the global orientation
|
209 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
210 |
+
aa = (resrot.T)[0]
|
211 |
+
return aa
|
212 |
+
|
213 |
+
|
214 |
+
def quat2mat(quat):
|
215 |
+
"""
|
216 |
+
This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L50
|
217 |
+
Convert quaternion coefficients to rotation matrix.
|
218 |
+
Args:
|
219 |
+
quat: size = [batch_size, 4] 4 <===>(w, x, y, z)
|
220 |
+
Returns:
|
221 |
+
Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3]
|
222 |
+
"""
|
223 |
+
norm_quat = quat
|
224 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
225 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
226 |
+
|
227 |
+
batch_size = quat.size(0)
|
228 |
+
|
229 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
230 |
+
wx, wy, wz = w * x, w * y, w * z
|
231 |
+
xy, xz, yz = x * y, x * z, y * z
|
232 |
+
|
233 |
+
rotMat = torch.stack(
|
234 |
+
[
|
235 |
+
w2 + x2 - y2 - z2,
|
236 |
+
2 * xy - 2 * wz,
|
237 |
+
2 * wy + 2 * xz,
|
238 |
+
2 * wz + 2 * xy,
|
239 |
+
w2 - x2 + y2 - z2,
|
240 |
+
2 * yz - 2 * wx,
|
241 |
+
2 * xz - 2 * wy,
|
242 |
+
2 * wx + 2 * yz,
|
243 |
+
w2 - x2 - y2 + z2,
|
244 |
+
],
|
245 |
+
dim=1,
|
246 |
+
).view(batch_size, 3, 3)
|
247 |
+
return rotMat
|
248 |
+
|
249 |
+
|
250 |
+
def batch_aa2rot(axisang):
|
251 |
+
# This function is borrowed from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py#L37
|
252 |
+
assert len(axisang.shape) == 2
|
253 |
+
assert axisang.shape[1] == 3
|
254 |
+
# axisang N x 3
|
255 |
+
axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1)
|
256 |
+
angle = torch.unsqueeze(axisang_norm, -1)
|
257 |
+
axisang_normalized = torch.div(axisang, angle)
|
258 |
+
angle = angle * 0.5
|
259 |
+
v_cos = torch.cos(angle)
|
260 |
+
v_sin = torch.sin(angle)
|
261 |
+
quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1)
|
262 |
+
rot_mat = quat2mat(quat)
|
263 |
+
rot_mat = rot_mat.view(rot_mat.shape[0], 9)
|
264 |
+
return rot_mat
|
265 |
+
|
266 |
+
|
267 |
+
def batch_rot2aa(Rs):
|
268 |
+
assert len(Rs.shape) == 3
|
269 |
+
assert Rs.shape[1] == Rs.shape[2]
|
270 |
+
assert Rs.shape[1] == 3
|
271 |
+
|
272 |
+
"""
|
273 |
+
Rs is B x 3 x 3
|
274 |
+
void cMathUtil::RotMatToAxisAngle(const tMatrix& mat, tVector& out_axis,
|
275 |
+
double& out_theta)
|
276 |
+
{
|
277 |
+
double c = 0.5 * (mat(0, 0) + mat(1, 1) + mat(2, 2) - 1);
|
278 |
+
c = cMathUtil::Clamp(c, -1.0, 1.0);
|
279 |
+
|
280 |
+
out_theta = std::acos(c);
|
281 |
+
|
282 |
+
if (std::abs(out_theta) < 0.00001)
|
283 |
+
{
|
284 |
+
out_axis = tVector(0, 0, 1, 0);
|
285 |
+
}
|
286 |
+
else
|
287 |
+
{
|
288 |
+
double m21 = mat(2, 1) - mat(1, 2);
|
289 |
+
double m02 = mat(0, 2) - mat(2, 0);
|
290 |
+
double m10 = mat(1, 0) - mat(0, 1);
|
291 |
+
double denom = std::sqrt(m21 * m21 + m02 * m02 + m10 * m10);
|
292 |
+
out_axis[0] = m21 / denom;
|
293 |
+
out_axis[1] = m02 / denom;
|
294 |
+
out_axis[2] = m10 / denom;
|
295 |
+
out_axis[3] = 0;
|
296 |
+
}
|
297 |
+
}
|
298 |
+
"""
|
299 |
+
cos = 0.5 * (torch.stack([torch.trace(x) for x in Rs]) - 1)
|
300 |
+
cos = torch.clamp(cos, -1, 1)
|
301 |
+
|
302 |
+
theta = torch.acos(cos)
|
303 |
+
|
304 |
+
m21 = Rs[:, 2, 1] - Rs[:, 1, 2]
|
305 |
+
m02 = Rs[:, 0, 2] - Rs[:, 2, 0]
|
306 |
+
m10 = Rs[:, 1, 0] - Rs[:, 0, 1]
|
307 |
+
denom = torch.sqrt(m21 * m21 + m02 * m02 + m10 * m10)
|
308 |
+
|
309 |
+
axis0 = torch.where(torch.abs(theta) < 0.00001, m21, m21 / denom)
|
310 |
+
axis1 = torch.where(torch.abs(theta) < 0.00001, m02, m02 / denom)
|
311 |
+
axis2 = torch.where(torch.abs(theta) < 0.00001, m10, m10 / denom)
|
312 |
+
|
313 |
+
return theta.unsqueeze(1) * torch.stack([axis0, axis1, axis2], 1)
|
314 |
+
|
315 |
+
|
316 |
+
def batch_rodrigues(theta):
|
317 |
+
"""Convert axis-angle representation to rotation matrix.
|
318 |
+
Args:
|
319 |
+
theta: size = [B, 3]
|
320 |
+
Returns:
|
321 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
322 |
+
"""
|
323 |
+
l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
|
324 |
+
angle = torch.unsqueeze(l1norm, -1)
|
325 |
+
normalized = torch.div(theta, angle)
|
326 |
+
angle = angle * 0.5
|
327 |
+
v_cos = torch.cos(angle)
|
328 |
+
v_sin = torch.sin(angle)
|
329 |
+
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
|
330 |
+
return quat_to_rotmat(quat)
|
331 |
+
|
332 |
+
|
333 |
+
def quat_to_rotmat(quat):
|
334 |
+
"""Convert quaternion coefficients to rotation matrix.
|
335 |
+
Args:
|
336 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
337 |
+
Returns:
|
338 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
339 |
+
"""
|
340 |
+
norm_quat = quat
|
341 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
342 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
343 |
+
|
344 |
+
B = quat.size(0)
|
345 |
+
|
346 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
347 |
+
wx, wy, wz = w * x, w * y, w * z
|
348 |
+
xy, xz, yz = x * y, x * z, y * z
|
349 |
+
|
350 |
+
rotMat = torch.stack(
|
351 |
+
[
|
352 |
+
w2 + x2 - y2 - z2,
|
353 |
+
2 * xy - 2 * wz,
|
354 |
+
2 * wy + 2 * xz,
|
355 |
+
2 * wz + 2 * xy,
|
356 |
+
w2 - x2 + y2 - z2,
|
357 |
+
2 * yz - 2 * wx,
|
358 |
+
2 * xz - 2 * wy,
|
359 |
+
2 * wx + 2 * yz,
|
360 |
+
w2 - x2 - y2 + z2,
|
361 |
+
],
|
362 |
+
dim=1,
|
363 |
+
).view(B, 3, 3)
|
364 |
+
return rotMat
|
365 |
+
|
366 |
+
|
367 |
+
def rot6d_to_rotmat(x):
|
368 |
+
"""Convert 6D rotation representation to 3x3 rotation matrix.
|
369 |
+
Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
|
370 |
+
Input:
|
371 |
+
(B,6) Batch of 6-D rotation representations
|
372 |
+
Output:
|
373 |
+
(B,3,3) Batch of corresponding rotation matrices
|
374 |
+
"""
|
375 |
+
x = x.reshape(-1, 3, 2)
|
376 |
+
a1 = x[:, :, 0]
|
377 |
+
a2 = x[:, :, 1]
|
378 |
+
b1 = F.normalize(a1)
|
379 |
+
b2 = F.normalize(a2 - torch.einsum("bi,bi->b", b1, a2).unsqueeze(-1) * b1)
|
380 |
+
b3 = torch.cross(b1, b2)
|
381 |
+
return torch.stack((b1, b2, b3), dim=-1)
|
382 |
+
|
383 |
+
|
384 |
+
def rotmat_to_rot6d(x):
|
385 |
+
rotmat = x.reshape(-1, 3, 3)
|
386 |
+
rot6d = rotmat[:, :, :2].reshape(x.shape[0], -1)
|
387 |
+
return rot6d
|
388 |
+
|
389 |
+
|
390 |
+
def rotation_matrix_to_angle_axis(rotation_matrix):
|
391 |
+
"""
|
392 |
+
This function is borrowed from https://github.com/kornia/kornia
|
393 |
+
|
394 |
+
Convert 3x4 rotation matrix to Rodrigues vector
|
395 |
+
|
396 |
+
Args:
|
397 |
+
rotation_matrix (Tensor): rotation matrix.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
Tensor: Rodrigues vector transformation.
|
401 |
+
|
402 |
+
Shape:
|
403 |
+
- Input: :math:`(N, 3, 4)`
|
404 |
+
- Output: :math:`(N, 3)`
|
405 |
+
|
406 |
+
Example:
|
407 |
+
>>> input = torch.rand(2, 3, 4) # Nx4x4
|
408 |
+
>>> output = tgm.rotation_matrix_to_angle_axis(input) # Nx3
|
409 |
+
"""
|
410 |
+
if rotation_matrix.shape[1:] == (3, 3):
|
411 |
+
rot_mat = rotation_matrix.reshape(-1, 3, 3)
|
412 |
+
hom = (
|
413 |
+
torch.tensor([0, 0, 1], dtype=torch.float32, device=rotation_matrix.device)
|
414 |
+
.reshape(1, 3, 1)
|
415 |
+
.expand(rot_mat.shape[0], -1, -1)
|
416 |
+
)
|
417 |
+
rotation_matrix = torch.cat([rot_mat, hom], dim=-1)
|
418 |
+
|
419 |
+
quaternion = rotation_matrix_to_quaternion(rotation_matrix)
|
420 |
+
aa = quaternion_to_angle_axis(quaternion)
|
421 |
+
aa[torch.isnan(aa)] = 0.0
|
422 |
+
return aa
|
423 |
+
|
424 |
+
|
425 |
+
def quaternion_to_angle_axis(quaternion: torch.Tensor) -> torch.Tensor:
|
426 |
+
"""
|
427 |
+
This function is borrowed from https://github.com/kornia/kornia
|
428 |
+
|
429 |
+
Convert quaternion vector to angle axis of rotation.
|
430 |
+
|
431 |
+
Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h
|
432 |
+
|
433 |
+
Args:
|
434 |
+
quaternion (torch.Tensor): tensor with quaternions.
|
435 |
+
|
436 |
+
Return:
|
437 |
+
torch.Tensor: tensor with angle axis of rotation.
|
438 |
+
|
439 |
+
Shape:
|
440 |
+
- Input: :math:`(*, 4)` where `*` means, any number of dimensions
|
441 |
+
- Output: :math:`(*, 3)`
|
442 |
+
|
443 |
+
Example:
|
444 |
+
>>> quaternion = torch.rand(2, 4) # Nx4
|
445 |
+
>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion) # Nx3
|
446 |
+
"""
|
447 |
+
if not torch.is_tensor(quaternion):
|
448 |
+
raise TypeError(
|
449 |
+
"Input type is not a torch.Tensor. Got {}".format(type(quaternion))
|
450 |
+
)
|
451 |
+
|
452 |
+
if not quaternion.shape[-1] == 4:
|
453 |
+
raise ValueError(
|
454 |
+
"Input must be a tensor of shape Nx4 or 4. Got {}".format(quaternion.shape)
|
455 |
+
)
|
456 |
+
# unpack input and compute conversion
|
457 |
+
q1: torch.Tensor = quaternion[..., 1]
|
458 |
+
q2: torch.Tensor = quaternion[..., 2]
|
459 |
+
q3: torch.Tensor = quaternion[..., 3]
|
460 |
+
sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3
|
461 |
+
|
462 |
+
sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
|
463 |
+
cos_theta: torch.Tensor = quaternion[..., 0]
|
464 |
+
two_theta: torch.Tensor = 2.0 * torch.where(
|
465 |
+
cos_theta < 0.0,
|
466 |
+
torch.atan2(-sin_theta, -cos_theta),
|
467 |
+
torch.atan2(sin_theta, cos_theta),
|
468 |
+
)
|
469 |
+
|
470 |
+
k_pos: torch.Tensor = two_theta / sin_theta
|
471 |
+
k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta)
|
472 |
+
k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)
|
473 |
+
|
474 |
+
angle_axis: torch.Tensor = torch.zeros_like(quaternion)[..., :3]
|
475 |
+
angle_axis[..., 0] += q1 * k
|
476 |
+
angle_axis[..., 1] += q2 * k
|
477 |
+
angle_axis[..., 2] += q3 * k
|
478 |
+
return angle_axis
|
479 |
+
|
480 |
+
|
481 |
+
def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
|
482 |
+
"""
|
483 |
+
This function is borrowed from https://github.com/kornia/kornia
|
484 |
+
|
485 |
+
Convert 3x4 rotation matrix to 4d quaternion vector
|
486 |
+
|
487 |
+
This algorithm is based on algorithm described in
|
488 |
+
https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201
|
489 |
+
|
490 |
+
Args:
|
491 |
+
rotation_matrix (Tensor): the rotation matrix to convert.
|
492 |
+
|
493 |
+
Return:
|
494 |
+
Tensor: the rotation in quaternion
|
495 |
+
|
496 |
+
Shape:
|
497 |
+
- Input: :math:`(N, 3, 4)`
|
498 |
+
- Output: :math:`(N, 4)`
|
499 |
+
|
500 |
+
Example:
|
501 |
+
>>> input = torch.rand(4, 3, 4) # Nx3x4
|
502 |
+
>>> output = tgm.rotation_matrix_to_quaternion(input) # Nx4
|
503 |
+
"""
|
504 |
+
if not torch.is_tensor(rotation_matrix):
|
505 |
+
raise TypeError(
|
506 |
+
"Input type is not a torch.Tensor. Got {}".format(type(rotation_matrix))
|
507 |
+
)
|
508 |
+
|
509 |
+
if len(rotation_matrix.shape) > 3:
|
510 |
+
raise ValueError(
|
511 |
+
"Input size must be a three dimensional tensor. Got {}".format(
|
512 |
+
rotation_matrix.shape
|
513 |
+
)
|
514 |
+
)
|
515 |
+
if not rotation_matrix.shape[-2:] == (3, 4):
|
516 |
+
raise ValueError(
|
517 |
+
"Input size must be a N x 3 x 4 tensor. Got {}".format(
|
518 |
+
rotation_matrix.shape
|
519 |
+
)
|
520 |
+
)
|
521 |
+
|
522 |
+
rmat_t = torch.transpose(rotation_matrix, 1, 2)
|
523 |
+
|
524 |
+
mask_d2 = rmat_t[:, 2, 2] < eps
|
525 |
+
|
526 |
+
mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
|
527 |
+
mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]
|
528 |
+
|
529 |
+
t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
530 |
+
q0 = torch.stack(
|
531 |
+
[
|
532 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
533 |
+
t0,
|
534 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
535 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
536 |
+
],
|
537 |
+
-1,
|
538 |
+
)
|
539 |
+
t0_rep = t0.repeat(4, 1).t()
|
540 |
+
|
541 |
+
t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
|
542 |
+
q1 = torch.stack(
|
543 |
+
[
|
544 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
545 |
+
rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
|
546 |
+
t1,
|
547 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
548 |
+
],
|
549 |
+
-1,
|
550 |
+
)
|
551 |
+
t1_rep = t1.repeat(4, 1).t()
|
552 |
+
|
553 |
+
t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
554 |
+
q2 = torch.stack(
|
555 |
+
[
|
556 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
557 |
+
rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
|
558 |
+
rmat_t[:, 1, 2] + rmat_t[:, 2, 1],
|
559 |
+
t2,
|
560 |
+
],
|
561 |
+
-1,
|
562 |
+
)
|
563 |
+
t2_rep = t2.repeat(4, 1).t()
|
564 |
+
|
565 |
+
t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
|
566 |
+
q3 = torch.stack(
|
567 |
+
[
|
568 |
+
t3,
|
569 |
+
rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
|
570 |
+
rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
|
571 |
+
rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
|
572 |
+
],
|
573 |
+
-1,
|
574 |
+
)
|
575 |
+
t3_rep = t3.repeat(4, 1).t()
|
576 |
+
|
577 |
+
mask_c0 = mask_d2 * mask_d0_d1
|
578 |
+
mask_c1 = mask_d2 * ~mask_d0_d1
|
579 |
+
mask_c2 = ~mask_d2 * mask_d0_nd1
|
580 |
+
mask_c3 = ~mask_d2 * ~mask_d0_nd1
|
581 |
+
mask_c0 = mask_c0.view(-1, 1).type_as(q0)
|
582 |
+
mask_c1 = mask_c1.view(-1, 1).type_as(q1)
|
583 |
+
mask_c2 = mask_c2.view(-1, 1).type_as(q2)
|
584 |
+
mask_c3 = mask_c3.view(-1, 1).type_as(q3)
|
585 |
+
|
586 |
+
q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
|
587 |
+
q /= torch.sqrt(
|
588 |
+
t0_rep * mask_c0
|
589 |
+
+ t1_rep * mask_c1
|
590 |
+
+ t2_rep * mask_c2 # noqa
|
591 |
+
+ t3_rep * mask_c3
|
592 |
+
) # noqa
|
593 |
+
q *= 0.5
|
594 |
+
return q
|
595 |
+
|
596 |
+
|
597 |
+
def batch_euler2matrix(r):
|
598 |
+
return quaternion_to_rotation_matrix(euler_to_quaternion(r))
|
599 |
+
|
600 |
+
|
601 |
+
def euler_to_quaternion(r):
|
602 |
+
x = r[..., 0]
|
603 |
+
y = r[..., 1]
|
604 |
+
z = r[..., 2]
|
605 |
+
|
606 |
+
z = z / 2.0
|
607 |
+
y = y / 2.0
|
608 |
+
x = x / 2.0
|
609 |
+
cz = torch.cos(z)
|
610 |
+
sz = torch.sin(z)
|
611 |
+
cy = torch.cos(y)
|
612 |
+
sy = torch.sin(y)
|
613 |
+
cx = torch.cos(x)
|
614 |
+
sx = torch.sin(x)
|
615 |
+
quaternion = torch.zeros_like(r.repeat(1, 2))[..., :4].to(r.device)
|
616 |
+
quaternion[..., 0] += cx * cy * cz - sx * sy * sz
|
617 |
+
quaternion[..., 1] += cx * sy * sz + cy * cz * sx
|
618 |
+
quaternion[..., 2] += cx * cz * sy - sx * cy * sz
|
619 |
+
quaternion[..., 3] += cx * cy * sz + sx * cz * sy
|
620 |
+
return quaternion
|
621 |
+
|
622 |
+
|
623 |
+
def quaternion_to_rotation_matrix(quat):
|
624 |
+
"""Convert quaternion coefficients to rotation matrix.
|
625 |
+
Args:
|
626 |
+
quat: size = [B, 4] 4 <===>(w, x, y, z)
|
627 |
+
Returns:
|
628 |
+
Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
|
629 |
+
"""
|
630 |
+
norm_quat = quat
|
631 |
+
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
|
632 |
+
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
|
633 |
+
|
634 |
+
B = quat.size(0)
|
635 |
+
|
636 |
+
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
|
637 |
+
wx, wy, wz = w * x, w * y, w * z
|
638 |
+
xy, xz, yz = x * y, x * z, y * z
|
639 |
+
|
640 |
+
rotMat = torch.stack(
|
641 |
+
[
|
642 |
+
w2 + x2 - y2 - z2,
|
643 |
+
2 * xy - 2 * wz,
|
644 |
+
2 * wy + 2 * xz,
|
645 |
+
2 * wz + 2 * xy,
|
646 |
+
w2 - x2 + y2 - z2,
|
647 |
+
2 * yz - 2 * wx,
|
648 |
+
2 * xz - 2 * wy,
|
649 |
+
2 * wx + 2 * yz,
|
650 |
+
w2 - x2 - y2 + z2,
|
651 |
+
],
|
652 |
+
dim=1,
|
653 |
+
).view(B, 3, 3)
|
654 |
+
return rotMat
|
655 |
+
|
656 |
+
|
657 |
+
def euler_angles_from_rotmat(R):
|
658 |
+
"""
|
659 |
+
computer euler angles for rotation around x, y, z axis
|
660 |
+
from rotation amtrix
|
661 |
+
R: 4x4 rotation matrix
|
662 |
+
https://www.gregslabaugh.net/publications/euler.pdf
|
663 |
+
"""
|
664 |
+
r21 = np.round(R[:, 2, 0].item(), 4)
|
665 |
+
if abs(r21) != 1:
|
666 |
+
y_angle1 = -1 * torch.asin(R[:, 2, 0])
|
667 |
+
y_angle2 = math.pi + torch.asin(R[:, 2, 0])
|
668 |
+
cy1, cy2 = torch.cos(y_angle1), torch.cos(y_angle2)
|
669 |
+
|
670 |
+
x_angle1 = torch.atan2(R[:, 2, 1] / cy1, R[:, 2, 2] / cy1)
|
671 |
+
x_angle2 = torch.atan2(R[:, 2, 1] / cy2, R[:, 2, 2] / cy2)
|
672 |
+
z_angle1 = torch.atan2(R[:, 1, 0] / cy1, R[:, 0, 0] / cy1)
|
673 |
+
z_angle2 = torch.atan2(R[:, 1, 0] / cy2, R[:, 0, 0] / cy2)
|
674 |
+
|
675 |
+
s1 = (x_angle1, y_angle1, z_angle1)
|
676 |
+
s2 = (x_angle2, y_angle2, z_angle2)
|
677 |
+
s = (s1, s2)
|
678 |
+
|
679 |
+
else:
|
680 |
+
z_angle = torch.tensor([0], device=R.device).float()
|
681 |
+
if r21 == -1:
|
682 |
+
y_angle = torch.tensor([math.pi / 2], device=R.device).float()
|
683 |
+
x_angle = z_angle + torch.atan2(R[:, 0, 1], R[:, 0, 2])
|
684 |
+
else:
|
685 |
+
y_angle = -torch.tensor([math.pi / 2], device=R.device).float()
|
686 |
+
x_angle = -z_angle + torch.atan2(-R[:, 0, 1], R[:, 0, 2])
|
687 |
+
s = ((x_angle, y_angle, z_angle),)
|
688 |
+
return s
|
689 |
+
|
690 |
+
|
691 |
+
def quaternion_raw_multiply(a, b):
|
692 |
+
"""
|
693 |
+
Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
|
694 |
+
Multiply two quaternions.
|
695 |
+
Usual torch rules for broadcasting apply.
|
696 |
+
|
697 |
+
Args:
|
698 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
699 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
700 |
+
|
701 |
+
Returns:
|
702 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
703 |
+
"""
|
704 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
705 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
706 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
707 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
708 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
709 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
710 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
711 |
+
|
712 |
+
|
713 |
+
def quaternion_invert(quaternion):
|
714 |
+
"""
|
715 |
+
Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
|
716 |
+
Given a quaternion representing rotation, get the quaternion representing
|
717 |
+
its inverse.
|
718 |
+
|
719 |
+
Args:
|
720 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
721 |
+
first, which must be versors (unit quaternions).
|
722 |
+
|
723 |
+
Returns:
|
724 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
725 |
+
"""
|
726 |
+
|
727 |
+
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
728 |
+
|
729 |
+
|
730 |
+
def quaternion_apply(quaternion, point):
|
731 |
+
"""
|
732 |
+
Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
|
733 |
+
Apply the rotation given by a quaternion to a 3D point.
|
734 |
+
Usual torch rules for broadcasting apply.
|
735 |
+
|
736 |
+
Args:
|
737 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
738 |
+
point: Tensor of 3D points of shape (..., 3).
|
739 |
+
|
740 |
+
Returns:
|
741 |
+
Tensor of rotated points of shape (..., 3).
|
742 |
+
"""
|
743 |
+
if point.size(-1) != 3:
|
744 |
+
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
745 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
746 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
747 |
+
out = quaternion_raw_multiply(
|
748 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
749 |
+
quaternion_invert(quaternion),
|
750 |
+
)
|
751 |
+
return out[..., 1:]
|
752 |
+
|
753 |
+
|
754 |
+
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
755 |
+
"""
|
756 |
+
Source: https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
|
757 |
+
Convert rotations given as axis/angle to quaternions.
|
758 |
+
Args:
|
759 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
760 |
+
as a tensor of shape (..., 3), where the magnitude is
|
761 |
+
the angle turned anticlockwise in radians around the
|
762 |
+
vector's direction.
|
763 |
+
Returns:
|
764 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
765 |
+
"""
|
766 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
767 |
+
half_angles = angles * 0.5
|
768 |
+
eps = 1e-6
|
769 |
+
small_angles = angles.abs() < eps
|
770 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
771 |
+
sin_half_angles_over_angles[~small_angles] = (
|
772 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
773 |
+
)
|
774 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
775 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
776 |
+
sin_half_angles_over_angles[small_angles] = (
|
777 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
778 |
+
)
|
779 |
+
quaternions = torch.cat(
|
780 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
781 |
+
)
|
782 |
+
return quaternions
|
common/sys_utils.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as op
|
3 |
+
import shutil
|
4 |
+
from glob import glob
|
5 |
+
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
|
9 |
+
def copy(src, dst):
|
10 |
+
if os.path.islink(src):
|
11 |
+
linkto = os.readlink(src)
|
12 |
+
os.symlink(linkto, dst)
|
13 |
+
else:
|
14 |
+
if os.path.isdir(src):
|
15 |
+
shutil.copytree(src, dst)
|
16 |
+
else:
|
17 |
+
shutil.copy(src, dst)
|
18 |
+
|
19 |
+
|
20 |
+
def copy_repo(src_files, dst_folder, filter_keywords):
|
21 |
+
src_files = [
|
22 |
+
f for f in src_files if not any(keyword in f for keyword in filter_keywords)
|
23 |
+
]
|
24 |
+
dst_files = [op.join(dst_folder, op.basename(f)) for f in src_files]
|
25 |
+
for src_f, dst_f in zip(src_files, dst_files):
|
26 |
+
logger.info(f"FROM: {src_f}\nTO:{dst_f}")
|
27 |
+
copy(src_f, dst_f)
|
28 |
+
|
29 |
+
|
30 |
+
def mkdir(directory):
|
31 |
+
if not os.path.exists(directory):
|
32 |
+
os.makedirs(directory)
|
33 |
+
|
34 |
+
|
35 |
+
def mkdir_p(exp_path):
|
36 |
+
os.makedirs(exp_path, exist_ok=True)
|
37 |
+
|
38 |
+
|
39 |
+
def count_files(path):
|
40 |
+
"""
|
41 |
+
Non-recursively count number of files in a folder.
|
42 |
+
"""
|
43 |
+
files = glob(path)
|
44 |
+
return len(files)
|
common/thing.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
"""
|
5 |
+
This file stores functions for conversion between numpy and torch, torch, list, etc.
|
6 |
+
Also deal with general operations such as to(dev), detach, etc.
|
7 |
+
"""
|
8 |
+
|
9 |
+
|
10 |
+
def thing2list(thing):
|
11 |
+
if isinstance(thing, torch.Tensor):
|
12 |
+
return thing.tolist()
|
13 |
+
if isinstance(thing, np.ndarray):
|
14 |
+
return thing.tolist()
|
15 |
+
if isinstance(thing, dict):
|
16 |
+
return {k: thing2list(v) for k, v in md.items()}
|
17 |
+
if isinstance(thing, list):
|
18 |
+
return [thing2list(ten) for ten in thing]
|
19 |
+
return thing
|
20 |
+
|
21 |
+
|
22 |
+
def thing2dev(thing, dev):
|
23 |
+
if hasattr(thing, "to"):
|
24 |
+
thing = thing.to(dev)
|
25 |
+
return thing
|
26 |
+
if isinstance(thing, list):
|
27 |
+
return [thing2dev(ten, dev) for ten in thing]
|
28 |
+
if isinstance(thing, tuple):
|
29 |
+
return tuple(thing2dev(list(thing), dev))
|
30 |
+
if isinstance(thing, dict):
|
31 |
+
return {k: thing2dev(v, dev) for k, v in thing.items()}
|
32 |
+
if isinstance(thing, torch.Tensor):
|
33 |
+
return thing.to(dev)
|
34 |
+
return thing
|
35 |
+
|
36 |
+
|
37 |
+
def thing2np(thing):
|
38 |
+
if isinstance(thing, list):
|
39 |
+
return np.array(thing)
|
40 |
+
if isinstance(thing, torch.Tensor):
|
41 |
+
return thing.cpu().detach().numpy()
|
42 |
+
if isinstance(thing, dict):
|
43 |
+
return {k: thing2np(v) for k, v in thing.items()}
|
44 |
+
return thing
|
45 |
+
|
46 |
+
|
47 |
+
def thing2torch(thing):
|
48 |
+
if isinstance(thing, list):
|
49 |
+
return torch.tensor(np.array(thing))
|
50 |
+
if isinstance(thing, np.ndarray):
|
51 |
+
return torch.from_numpy(thing)
|
52 |
+
if isinstance(thing, dict):
|
53 |
+
return {k: thing2torch(v) for k, v in thing.items()}
|
54 |
+
return thing
|
55 |
+
|
56 |
+
|
57 |
+
def detach_thing(thing):
|
58 |
+
if isinstance(thing, torch.Tensor):
|
59 |
+
return thing.cpu().detach()
|
60 |
+
if isinstance(thing, list):
|
61 |
+
return [detach_thing(ten) for ten in thing]
|
62 |
+
if isinstance(thing, tuple):
|
63 |
+
return tuple(detach_thing(list(thing)))
|
64 |
+
if isinstance(thing, dict):
|
65 |
+
return {k: detach_thing(v) for k, v in thing.items()}
|
66 |
+
return thing
|
common/torch_utils.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.optim as optim
|
7 |
+
|
8 |
+
from common.ld_utils import unsort as unsort_list
|
9 |
+
|
10 |
+
|
11 |
+
# pytorch implementation for np.nanmean
|
12 |
+
# https://github.com/pytorch/pytorch/issues/21987#issuecomment-539402619
|
13 |
+
def nanmean(v, *args, inplace=False, **kwargs):
|
14 |
+
if not inplace:
|
15 |
+
v = v.clone()
|
16 |
+
is_nan = torch.isnan(v)
|
17 |
+
v[is_nan] = 0
|
18 |
+
return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs)
|
19 |
+
|
20 |
+
|
21 |
+
def grad_norm(model):
|
22 |
+
# compute norm of gradient for a model
|
23 |
+
total_norm = None
|
24 |
+
for p in model.parameters():
|
25 |
+
if p.grad is not None:
|
26 |
+
if total_norm is None:
|
27 |
+
total_norm = 0
|
28 |
+
param_norm = p.grad.detach().data.norm(2)
|
29 |
+
total_norm += param_norm.item() ** 2
|
30 |
+
|
31 |
+
if total_norm is not None:
|
32 |
+
total_norm = total_norm ** (1.0 / 2)
|
33 |
+
else:
|
34 |
+
total_norm = 0.0
|
35 |
+
return total_norm
|
36 |
+
|
37 |
+
|
38 |
+
def pad_tensor_list(v_list: list):
|
39 |
+
dev = v_list[0].device
|
40 |
+
num_meshes = len(v_list)
|
41 |
+
num_dim = 1 if len(v_list[0].shape) == 1 else v_list[0].shape[1]
|
42 |
+
v_len_list = []
|
43 |
+
for verts in v_list:
|
44 |
+
v_len_list.append(verts.shape[0])
|
45 |
+
|
46 |
+
pad_len = max(v_len_list)
|
47 |
+
dtype = v_list[0].dtype
|
48 |
+
if num_dim == 1:
|
49 |
+
padded_tensor = torch.zeros(num_meshes, pad_len, dtype=dtype)
|
50 |
+
else:
|
51 |
+
padded_tensor = torch.zeros(num_meshes, pad_len, num_dim, dtype=dtype)
|
52 |
+
for idx, (verts, v_len) in enumerate(zip(v_list, v_len_list)):
|
53 |
+
padded_tensor[idx, :v_len] = verts
|
54 |
+
padded_tensor = padded_tensor.to(dev)
|
55 |
+
v_len_list = torch.LongTensor(v_len_list).to(dev)
|
56 |
+
return padded_tensor, v_len_list
|
57 |
+
|
58 |
+
|
59 |
+
def unpad_vtensor(
|
60 |
+
vtensor: (torch.Tensor), lens: (torch.LongTensor, torch.cuda.LongTensor)
|
61 |
+
):
|
62 |
+
tensors_list = []
|
63 |
+
for verts, vlen in zip(vtensor, lens):
|
64 |
+
tensors_list.append(verts[:vlen])
|
65 |
+
return tensors_list
|
66 |
+
|
67 |
+
|
68 |
+
def one_hot_embedding(labels, num_classes):
|
69 |
+
"""Embedding labels to one-hot form.
|
70 |
+
Args:
|
71 |
+
labels: (LongTensor) class labels, sized [N, D1, D2, ..].
|
72 |
+
num_classes: (int) number of classes.
|
73 |
+
Returns:
|
74 |
+
(tensor) encoded labels, sized [N, D1, D2, .., Dk, #classes].
|
75 |
+
"""
|
76 |
+
y = torch.eye(num_classes).float()
|
77 |
+
return y[labels]
|
78 |
+
|
79 |
+
|
80 |
+
def unsort(ten, sort_idx):
|
81 |
+
"""
|
82 |
+
Unsort a tensor of shape (N, *) using the sort_idx list(N).
|
83 |
+
Return a tensor of the pre-sorting order in shape (N, *)
|
84 |
+
"""
|
85 |
+
assert isinstance(ten, torch.Tensor)
|
86 |
+
assert isinstance(sort_idx, list)
|
87 |
+
assert ten.shape[0] == len(sort_idx)
|
88 |
+
|
89 |
+
out_list = list(torch.chunk(ten, ten.size(0), dim=0))
|
90 |
+
out_list = unsort_list(out_list, sort_idx)
|
91 |
+
out_list = torch.cat(out_list, dim=0)
|
92 |
+
return out_list
|
93 |
+
|
94 |
+
|
95 |
+
def all_comb(X, Y):
|
96 |
+
"""
|
97 |
+
Returns all possible combinations of elements in X and Y.
|
98 |
+
X: (n_x, d_x)
|
99 |
+
Y: (n_y, d_y)
|
100 |
+
Output: Z: (n_x*x_y, d_x+d_y)
|
101 |
+
Example:
|
102 |
+
X = tensor([[8, 8, 8],
|
103 |
+
[7, 5, 9]])
|
104 |
+
Y = tensor([[3, 8, 7, 7],
|
105 |
+
[3, 7, 9, 9],
|
106 |
+
[6, 4, 3, 7]])
|
107 |
+
Z = tensor([[8, 8, 8, 3, 8, 7, 7],
|
108 |
+
[8, 8, 8, 3, 7, 9, 9],
|
109 |
+
[8, 8, 8, 6, 4, 3, 7],
|
110 |
+
[7, 5, 9, 3, 8, 7, 7],
|
111 |
+
[7, 5, 9, 3, 7, 9, 9],
|
112 |
+
[7, 5, 9, 6, 4, 3, 7]])
|
113 |
+
"""
|
114 |
+
assert len(X.size()) == 2
|
115 |
+
assert len(Y.size()) == 2
|
116 |
+
X1 = X.unsqueeze(1)
|
117 |
+
Y1 = Y.unsqueeze(0)
|
118 |
+
X2 = X1.repeat(1, Y.shape[0], 1)
|
119 |
+
Y2 = Y1.repeat(X.shape[0], 1, 1)
|
120 |
+
Z = torch.cat([X2, Y2], -1)
|
121 |
+
Z = Z.view(-1, Z.shape[-1])
|
122 |
+
return Z
|
123 |
+
|
124 |
+
|
125 |
+
def toggle_parameters(model, requires_grad):
|
126 |
+
"""
|
127 |
+
Set all weights to requires_grad or not.
|
128 |
+
"""
|
129 |
+
for param in model.parameters():
|
130 |
+
param.requires_grad = requires_grad
|
131 |
+
|
132 |
+
|
133 |
+
def detach_tensor(ten):
|
134 |
+
"""This function move tensor to cpu and convert to numpy"""
|
135 |
+
if isinstance(ten, torch.Tensor):
|
136 |
+
return ten.cpu().detach().numpy()
|
137 |
+
return ten
|
138 |
+
|
139 |
+
|
140 |
+
def count_model_parameters(model):
|
141 |
+
"""
|
142 |
+
Return the amount of parameters that requries gradients.
|
143 |
+
"""
|
144 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
145 |
+
|
146 |
+
|
147 |
+
def reset_all_seeds(seed):
|
148 |
+
"""Reset all seeds for reproduciability."""
|
149 |
+
random.seed(seed)
|
150 |
+
torch.manual_seed(seed)
|
151 |
+
np.random.seed(seed)
|
152 |
+
|
153 |
+
|
154 |
+
def get_activation(name):
|
155 |
+
"""This function return an activation constructor by name."""
|
156 |
+
if name == "tanh":
|
157 |
+
return nn.Tanh()
|
158 |
+
elif name == "sigmoid":
|
159 |
+
return nn.Sigmoid()
|
160 |
+
elif name == "relu":
|
161 |
+
return nn.ReLU()
|
162 |
+
elif name == "selu":
|
163 |
+
return nn.SELU()
|
164 |
+
elif name == "relu6":
|
165 |
+
return nn.ReLU6()
|
166 |
+
elif name == "softplus":
|
167 |
+
return nn.Softplus()
|
168 |
+
elif name == "softshrink":
|
169 |
+
return nn.Softshrink()
|
170 |
+
else:
|
171 |
+
print("Undefined activation: %s" % (name))
|
172 |
+
assert False
|
173 |
+
|
174 |
+
|
175 |
+
def stack_ll_tensors(tensor_list_list):
|
176 |
+
"""
|
177 |
+
Recursively stack a list of lists of lists .. whose elements are tensors with the same shape
|
178 |
+
"""
|
179 |
+
if isinstance(tensor_list_list, torch.Tensor):
|
180 |
+
return tensor_list_list
|
181 |
+
assert isinstance(tensor_list_list, list)
|
182 |
+
if isinstance(tensor_list_list[0], torch.Tensor):
|
183 |
+
return torch.stack(tensor_list_list)
|
184 |
+
|
185 |
+
stacked_tensor = []
|
186 |
+
for tensor_list in tensor_list_list:
|
187 |
+
stacked_tensor.append(stack_ll_tensors(tensor_list))
|
188 |
+
stacked_tensor = torch.stack(stacked_tensor)
|
189 |
+
return stacked_tensor
|
190 |
+
|
191 |
+
|
192 |
+
def get_optim(name):
|
193 |
+
"""This function return an optimizer constructor by name."""
|
194 |
+
if name == "adam":
|
195 |
+
return optim.Adam
|
196 |
+
elif name == "rmsprop":
|
197 |
+
return optim.RMSprop
|
198 |
+
elif name == "sgd":
|
199 |
+
return optim.SGD
|
200 |
+
else:
|
201 |
+
print("Undefined optim: %s" % (name))
|
202 |
+
assert False
|
203 |
+
|
204 |
+
|
205 |
+
def decay_lr(optimizer, gamma):
|
206 |
+
"""
|
207 |
+
Decay the learning rate by gamma
|
208 |
+
"""
|
209 |
+
assert isinstance(gamma, float)
|
210 |
+
assert 0 <= gamma and gamma <= 1.0
|
211 |
+
for param_group in optimizer.param_groups:
|
212 |
+
param_group["lr"] *= gamma
|
common/transforms.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import common.data_utils as data_utils
|
5 |
+
from common.np_utils import permute_np
|
6 |
+
|
7 |
+
"""
|
8 |
+
Useful geometric operations, e.g. Perspective projection and a differentiable Rodrigues formula
|
9 |
+
Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
def to_xy(x_homo):
|
14 |
+
assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
|
15 |
+
assert x_homo.shape[1] == 3
|
16 |
+
assert len(x_homo.shape) == 2
|
17 |
+
batch_size = x_homo.shape[0]
|
18 |
+
x = torch.ones(batch_size, 2, device=x_homo.device)
|
19 |
+
x = x_homo[:, :2] / x_homo[:, 2:3]
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def to_xyz(x_homo):
|
24 |
+
assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
|
25 |
+
assert x_homo.shape[1] == 4
|
26 |
+
assert len(x_homo.shape) == 2
|
27 |
+
batch_size = x_homo.shape[0]
|
28 |
+
x = torch.ones(batch_size, 3, device=x_homo.device)
|
29 |
+
x = x_homo[:, :3] / x_homo[:, 3:4]
|
30 |
+
return x
|
31 |
+
|
32 |
+
|
33 |
+
def to_homo(x):
|
34 |
+
assert isinstance(x, (torch.FloatTensor, torch.cuda.FloatTensor))
|
35 |
+
assert x.shape[1] == 3
|
36 |
+
assert len(x.shape) == 2
|
37 |
+
batch_size = x.shape[0]
|
38 |
+
x_homo = torch.ones(batch_size, 4, device=x.device)
|
39 |
+
x_homo[:, :3] = x.clone()
|
40 |
+
return x_homo
|
41 |
+
|
42 |
+
|
43 |
+
def to_homo_batch(x):
|
44 |
+
assert isinstance(x, (torch.FloatTensor, torch.cuda.FloatTensor))
|
45 |
+
assert x.shape[2] == 3
|
46 |
+
assert len(x.shape) == 3
|
47 |
+
batch_size = x.shape[0]
|
48 |
+
num_pts = x.shape[1]
|
49 |
+
x_homo = torch.ones(batch_size, num_pts, 4, device=x.device)
|
50 |
+
x_homo[:, :, :3] = x.clone()
|
51 |
+
return x_homo
|
52 |
+
|
53 |
+
|
54 |
+
def to_xyz_batch(x_homo):
|
55 |
+
"""
|
56 |
+
Input: (B, N, 4)
|
57 |
+
Ouput: (B, N, 3)
|
58 |
+
"""
|
59 |
+
assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
|
60 |
+
assert x_homo.shape[2] == 4
|
61 |
+
assert len(x_homo.shape) == 3
|
62 |
+
batch_size = x_homo.shape[0]
|
63 |
+
num_pts = x_homo.shape[1]
|
64 |
+
x = torch.ones(batch_size, num_pts, 3, device=x_homo.device)
|
65 |
+
x = x_homo[:, :, :3] / x_homo[:, :, 3:4]
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
def to_xy_batch(x_homo):
|
70 |
+
assert isinstance(x_homo, (torch.FloatTensor, torch.cuda.FloatTensor))
|
71 |
+
assert x_homo.shape[2] == 3
|
72 |
+
assert len(x_homo.shape) == 3
|
73 |
+
batch_size = x_homo.shape[0]
|
74 |
+
num_pts = x_homo.shape[1]
|
75 |
+
x = torch.ones(batch_size, num_pts, 2, device=x_homo.device)
|
76 |
+
x = x_homo[:, :, :2] / x_homo[:, :, 2:3]
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
# VR Distortion Correction Using Vertex Displacement
|
81 |
+
# https://stackoverflow.com/questions/44489686/camera-lens-distortion-in-opengl
|
82 |
+
def distort_pts3d_all(_pts_cam, dist_coeffs):
|
83 |
+
# egocentric cameras commonly has heavy distortion
|
84 |
+
# this function transform points in the undistorted camera coord
|
85 |
+
# to distorted camera coord such that the 2d projection can match the pixels.
|
86 |
+
pts_cam = _pts_cam.clone().double()
|
87 |
+
z = pts_cam[:, :, 2]
|
88 |
+
|
89 |
+
z_inv = 1 / z
|
90 |
+
|
91 |
+
x1 = pts_cam[:, :, 0] * z_inv
|
92 |
+
y1 = pts_cam[:, :, 1] * z_inv
|
93 |
+
|
94 |
+
# precalculations
|
95 |
+
x1_2 = x1 * x1
|
96 |
+
y1_2 = y1 * y1
|
97 |
+
x1_y1 = x1 * y1
|
98 |
+
r2 = x1_2 + y1_2
|
99 |
+
r4 = r2 * r2
|
100 |
+
r6 = r4 * r2
|
101 |
+
|
102 |
+
r_dist = (1 + dist_coeffs[0] * r2 + dist_coeffs[1] * r4 + dist_coeffs[4] * r6) / (
|
103 |
+
1 + dist_coeffs[5] * r2 + dist_coeffs[6] * r4 + dist_coeffs[7] * r6
|
104 |
+
)
|
105 |
+
|
106 |
+
# full (rational + tangential) distortion
|
107 |
+
x2 = x1 * r_dist + 2 * dist_coeffs[2] * x1_y1 + dist_coeffs[3] * (r2 + 2 * x1_2)
|
108 |
+
y2 = y1 * r_dist + 2 * dist_coeffs[3] * x1_y1 + dist_coeffs[2] * (r2 + 2 * y1_2)
|
109 |
+
# denormalize for projection (which is a linear operation)
|
110 |
+
cam_pts_dist = torch.stack([x2 * z, y2 * z, z], dim=2).float()
|
111 |
+
return cam_pts_dist
|
112 |
+
|
113 |
+
|
114 |
+
def rigid_tf_torch_batch(points, R, T):
|
115 |
+
"""
|
116 |
+
Performs rigid transformation to incoming points but batched
|
117 |
+
Q = (points*R.T) + T
|
118 |
+
points: (batch, num, 3)
|
119 |
+
R: (batch, 3, 3)
|
120 |
+
T: (batch, 3, 1)
|
121 |
+
out: (batch, num, 3)
|
122 |
+
"""
|
123 |
+
points_out = torch.bmm(R, points.permute(0, 2, 1)) + T
|
124 |
+
points_out = points_out.permute(0, 2, 1)
|
125 |
+
return points_out
|
126 |
+
|
127 |
+
|
128 |
+
def solve_rigid_tf_np(A: np.ndarray, B: np.ndarray):
|
129 |
+
"""
|
130 |
+
“Least-Squares Fitting of Two 3-D Point Sets”, Arun, K. S. , May 1987
|
131 |
+
Input: expects Nx3 matrix of points
|
132 |
+
Returns R,t
|
133 |
+
R = 3x3 rotation matrix
|
134 |
+
t = 3x1 column vector
|
135 |
+
|
136 |
+
This function should be a fix for compute_rigid_tf when the det == -1
|
137 |
+
"""
|
138 |
+
|
139 |
+
assert A.shape == B.shape
|
140 |
+
A = A.T
|
141 |
+
B = B.T
|
142 |
+
|
143 |
+
num_rows, num_cols = A.shape
|
144 |
+
if num_rows != 3:
|
145 |
+
raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
|
146 |
+
|
147 |
+
num_rows, num_cols = B.shape
|
148 |
+
if num_rows != 3:
|
149 |
+
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
|
150 |
+
|
151 |
+
# find mean column wise
|
152 |
+
centroid_A = np.mean(A, axis=1)
|
153 |
+
centroid_B = np.mean(B, axis=1)
|
154 |
+
|
155 |
+
# ensure centroids are 3x1
|
156 |
+
centroid_A = centroid_A.reshape(-1, 1)
|
157 |
+
centroid_B = centroid_B.reshape(-1, 1)
|
158 |
+
|
159 |
+
# subtract mean
|
160 |
+
Am = A - centroid_A
|
161 |
+
Bm = B - centroid_B
|
162 |
+
|
163 |
+
H = Am @ np.transpose(Bm)
|
164 |
+
|
165 |
+
# find rotation
|
166 |
+
U, S, Vt = np.linalg.svd(H)
|
167 |
+
R = Vt.T @ U.T
|
168 |
+
|
169 |
+
# special reflection case
|
170 |
+
if np.linalg.det(R) < 0:
|
171 |
+
Vt[2, :] *= -1
|
172 |
+
R = Vt.T @ U.T
|
173 |
+
|
174 |
+
t = -R @ centroid_A + centroid_B
|
175 |
+
|
176 |
+
return R, t
|
177 |
+
|
178 |
+
|
179 |
+
def batch_solve_rigid_tf(A, B):
|
180 |
+
"""
|
181 |
+
“Least-Squares Fitting of Two 3-D Point Sets”, Arun, K. S. , May 1987
|
182 |
+
Input: expects BxNx3 matrix of points
|
183 |
+
Returns R,t
|
184 |
+
R = Bx3x3 rotation matrix
|
185 |
+
t = Bx3x1 column vector
|
186 |
+
"""
|
187 |
+
|
188 |
+
assert A.shape == B.shape
|
189 |
+
dev = A.device
|
190 |
+
A = A.cpu().numpy()
|
191 |
+
B = B.cpu().numpy()
|
192 |
+
A = permute_np(A, (0, 2, 1))
|
193 |
+
B = permute_np(B, (0, 2, 1))
|
194 |
+
|
195 |
+
batch, num_rows, num_cols = A.shape
|
196 |
+
if num_rows != 3:
|
197 |
+
raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
|
198 |
+
|
199 |
+
_, num_rows, num_cols = B.shape
|
200 |
+
if num_rows != 3:
|
201 |
+
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
|
202 |
+
|
203 |
+
# find mean column wise
|
204 |
+
centroid_A = np.mean(A, axis=2)
|
205 |
+
centroid_B = np.mean(B, axis=2)
|
206 |
+
|
207 |
+
# ensure centroids are 3x1
|
208 |
+
centroid_A = centroid_A.reshape(batch, -1, 1)
|
209 |
+
centroid_B = centroid_B.reshape(batch, -1, 1)
|
210 |
+
|
211 |
+
# subtract mean
|
212 |
+
Am = A - centroid_A
|
213 |
+
Bm = B - centroid_B
|
214 |
+
|
215 |
+
H = np.matmul(Am, permute_np(Bm, (0, 2, 1)))
|
216 |
+
|
217 |
+
# find rotation
|
218 |
+
U, S, Vt = np.linalg.svd(H)
|
219 |
+
R = np.matmul(permute_np(Vt, (0, 2, 1)), permute_np(U, (0, 2, 1)))
|
220 |
+
|
221 |
+
# special reflection case
|
222 |
+
neg_idx = np.linalg.det(R) < 0
|
223 |
+
if neg_idx.sum() > 0:
|
224 |
+
raise Exception(
|
225 |
+
f"some rotation matrices are not orthogonal; make sure implementation is correct for such case: {neg_idx}"
|
226 |
+
)
|
227 |
+
Vt[neg_idx, 2, :] *= -1
|
228 |
+
R[neg_idx, :, :] = np.matmul(
|
229 |
+
permute_np(Vt[neg_idx], (0, 2, 1)), permute_np(U[neg_idx], (0, 2, 1))
|
230 |
+
)
|
231 |
+
|
232 |
+
t = np.matmul(-R, centroid_A) + centroid_B
|
233 |
+
|
234 |
+
R = torch.FloatTensor(R).to(dev)
|
235 |
+
t = torch.FloatTensor(t).to(dev)
|
236 |
+
return R, t
|
237 |
+
|
238 |
+
|
239 |
+
def rigid_tf_np(points, R, T):
|
240 |
+
"""
|
241 |
+
Performs rigid transformation to incoming points
|
242 |
+
Q = (points*R.T) + T
|
243 |
+
points: (num, 3)
|
244 |
+
R: (3, 3)
|
245 |
+
T: (1, 3)
|
246 |
+
|
247 |
+
out: (num, 3)
|
248 |
+
"""
|
249 |
+
|
250 |
+
assert isinstance(points, np.ndarray)
|
251 |
+
assert isinstance(R, np.ndarray)
|
252 |
+
assert isinstance(T, np.ndarray)
|
253 |
+
assert len(points.shape) == 2
|
254 |
+
assert points.shape[1] == 3
|
255 |
+
assert R.shape == (3, 3)
|
256 |
+
assert T.shape == (1, 3)
|
257 |
+
points_new = np.matmul(R, points.T).T + T
|
258 |
+
return points_new
|
259 |
+
|
260 |
+
|
261 |
+
def transform_points(world2cam_mat, pts):
|
262 |
+
"""
|
263 |
+
Map points from one coord to another based on the 4x4 matrix.
|
264 |
+
e.g., map points from world to camera coord.
|
265 |
+
pts: (N, 3), in METERS!!
|
266 |
+
world2cam_mat: (4, 4)
|
267 |
+
Output: points in cam coord (N, 3)
|
268 |
+
We follow this convention:
|
269 |
+
| R T | |pt|
|
270 |
+
| 0 1 | * | 1|
|
271 |
+
i.e. we rotate first then translate as T is the camera translation not position.
|
272 |
+
"""
|
273 |
+
assert isinstance(pts, (torch.FloatTensor, torch.cuda.FloatTensor))
|
274 |
+
assert isinstance(world2cam_mat, (torch.FloatTensor, torch.cuda.FloatTensor))
|
275 |
+
assert world2cam_mat.shape == (4, 4)
|
276 |
+
assert len(pts.shape) == 2
|
277 |
+
assert pts.shape[1] == 3
|
278 |
+
pts_homo = to_homo(pts)
|
279 |
+
|
280 |
+
# mocap to cam
|
281 |
+
pts_cam_homo = torch.matmul(world2cam_mat, pts_homo.T).T
|
282 |
+
pts_cam = to_xyz(pts_cam_homo)
|
283 |
+
|
284 |
+
assert pts_cam.shape[1] == 3
|
285 |
+
return pts_cam
|
286 |
+
|
287 |
+
|
288 |
+
def transform_points_batch(world2cam_mat, pts):
|
289 |
+
"""
|
290 |
+
Map points from one coord to another based on the 4x4 matrix.
|
291 |
+
e.g., map points from world to camera coord.
|
292 |
+
pts: (B, N, 3), in METERS!!
|
293 |
+
world2cam_mat: (B, 4, 4)
|
294 |
+
Output: points in cam coord (B, N, 3)
|
295 |
+
We follow this convention:
|
296 |
+
| R T | |pt|
|
297 |
+
| 0 1 | * | 1|
|
298 |
+
i.e. we rotate first then translate as T is the camera translation not position.
|
299 |
+
"""
|
300 |
+
assert isinstance(pts, (torch.FloatTensor, torch.cuda.FloatTensor))
|
301 |
+
assert isinstance(world2cam_mat, (torch.FloatTensor, torch.cuda.FloatTensor))
|
302 |
+
assert world2cam_mat.shape[1:] == (4, 4)
|
303 |
+
assert len(pts.shape) == 3
|
304 |
+
assert pts.shape[2] == 3
|
305 |
+
batch_size = pts.shape[0]
|
306 |
+
pts_homo = to_homo_batch(pts)
|
307 |
+
|
308 |
+
# mocap to cam
|
309 |
+
pts_cam_homo = torch.bmm(world2cam_mat, pts_homo.permute(0, 2, 1)).permute(0, 2, 1)
|
310 |
+
pts_cam = to_xyz_batch(pts_cam_homo)
|
311 |
+
|
312 |
+
assert pts_cam.shape[2] == 3
|
313 |
+
return pts_cam
|
314 |
+
|
315 |
+
|
316 |
+
def project2d_batch(K, pts_cam):
|
317 |
+
"""
|
318 |
+
K: (B, 3, 3)
|
319 |
+
pts_cam: (B, N, 3)
|
320 |
+
"""
|
321 |
+
|
322 |
+
assert isinstance(K, (torch.FloatTensor, torch.cuda.FloatTensor))
|
323 |
+
assert isinstance(pts_cam, (torch.FloatTensor, torch.cuda.FloatTensor))
|
324 |
+
assert K.shape[1:] == (3, 3)
|
325 |
+
assert pts_cam.shape[2] == 3
|
326 |
+
assert len(pts_cam.shape) == 3
|
327 |
+
pts2d_homo = torch.bmm(K, pts_cam.permute(0, 2, 1)).permute(0, 2, 1)
|
328 |
+
pts2d = to_xy_batch(pts2d_homo)
|
329 |
+
return pts2d
|
330 |
+
|
331 |
+
|
332 |
+
def project2d_norm_batch(K, pts_cam, patch_width):
|
333 |
+
"""
|
334 |
+
K: (B, 3, 3)
|
335 |
+
pts_cam: (B, N, 3)
|
336 |
+
"""
|
337 |
+
|
338 |
+
assert isinstance(K, (torch.FloatTensor, torch.cuda.FloatTensor))
|
339 |
+
assert isinstance(pts_cam, (torch.FloatTensor, torch.cuda.FloatTensor))
|
340 |
+
assert K.shape[1:] == (3, 3)
|
341 |
+
assert pts_cam.shape[2] == 3
|
342 |
+
assert len(pts_cam.shape) == 3
|
343 |
+
v2d = project2d_batch(K, pts_cam)
|
344 |
+
v2d_norm = data_utils.normalize_kp2d(v2d, patch_width)
|
345 |
+
return v2d_norm
|
346 |
+
|
347 |
+
|
348 |
+
def project2d(K, pts_cam):
|
349 |
+
assert isinstance(K, (torch.FloatTensor, torch.cuda.FloatTensor))
|
350 |
+
assert isinstance(pts_cam, (torch.FloatTensor, torch.cuda.FloatTensor))
|
351 |
+
assert K.shape == (3, 3)
|
352 |
+
assert pts_cam.shape[1] == 3
|
353 |
+
assert len(pts_cam.shape) == 2
|
354 |
+
pts2d_homo = torch.matmul(K, pts_cam.T).T
|
355 |
+
pts2d = to_xy(pts2d_homo)
|
356 |
+
return pts2d
|
common/viewer.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as op
|
3 |
+
import re
|
4 |
+
from abc import abstractmethod
|
5 |
+
|
6 |
+
import matplotlib.cm as cm
|
7 |
+
import numpy as np
|
8 |
+
from aitviewer.headless import HeadlessRenderer
|
9 |
+
from aitviewer.renderables.billboard import Billboard
|
10 |
+
from aitviewer.renderables.meshes import Meshes
|
11 |
+
from aitviewer.scene.camera import OpenCVCamera
|
12 |
+
from aitviewer.scene.material import Material
|
13 |
+
from aitviewer.utils.so3 import aa2rot_numpy
|
14 |
+
from aitviewer.viewer import Viewer
|
15 |
+
from easydict import EasyDict as edict
|
16 |
+
from loguru import logger
|
17 |
+
from PIL import Image
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
OBJ_ID = 100
|
21 |
+
SMPLX_ID = 150
|
22 |
+
LEFT_ID = 200
|
23 |
+
RIGHT_ID = 250
|
24 |
+
SEGM_IDS = {"object": OBJ_ID, "smplx": SMPLX_ID, "left": LEFT_ID, "right": RIGHT_ID}
|
25 |
+
|
26 |
+
cmap = cm.get_cmap("plasma")
|
27 |
+
materials = {
|
28 |
+
"none": None,
|
29 |
+
"white": Material(color=(1.0, 1.0, 1.0, 1.0), ambient=0.2),
|
30 |
+
"red": Material(color=(0.969, 0.106, 0.059, 1.0), ambient=0.2),
|
31 |
+
"blue": Material(color=(0.0, 0.0, 1.0, 1.0), ambient=0.2),
|
32 |
+
"green": Material(color=(1.0, 0.0, 0.0, 1.0), ambient=0.2),
|
33 |
+
"cyan": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
|
34 |
+
"light-blue": Material(color=(0.588, 0.5647, 0.9725, 1.0), ambient=0.2),
|
35 |
+
"cyan-light": Material(color=(0.051, 0.659, 0.051, 1.0), ambient=0.2),
|
36 |
+
"dark-light": Material(color=(0.404, 0.278, 0.278, 1.0), ambient=0.2),
|
37 |
+
"rice": Material(color=(0.922, 0.922, 0.102, 1.0), ambient=0.2),
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
class ViewerData(edict):
|
42 |
+
"""
|
43 |
+
Interface to standardize viewer data.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, Rt, K, cols, rows, imgnames=None):
|
47 |
+
self.imgnames = imgnames
|
48 |
+
self.Rt = Rt
|
49 |
+
self.K = K
|
50 |
+
self.num_frames = Rt.shape[0]
|
51 |
+
self.cols = cols
|
52 |
+
self.rows = rows
|
53 |
+
self.validate_format()
|
54 |
+
|
55 |
+
def validate_format(self):
|
56 |
+
assert len(self.Rt.shape) == 3
|
57 |
+
assert self.Rt.shape[0] == self.num_frames
|
58 |
+
assert self.Rt.shape[1] == 3
|
59 |
+
assert self.Rt.shape[2] == 4
|
60 |
+
|
61 |
+
assert len(self.K.shape) == 2
|
62 |
+
assert self.K.shape[0] == 3
|
63 |
+
assert self.K.shape[1] == 3
|
64 |
+
if self.imgnames is not None:
|
65 |
+
assert self.num_frames == len(self.imgnames)
|
66 |
+
assert self.num_frames > 0
|
67 |
+
im_p = self.imgnames[0]
|
68 |
+
assert op.exists(im_p), f"Image path {im_p} does not exist"
|
69 |
+
|
70 |
+
|
71 |
+
class ARCTICViewer:
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
render_types=["rgb", "depth", "mask"],
|
75 |
+
interactive=True,
|
76 |
+
size=(2024, 2024),
|
77 |
+
):
|
78 |
+
if not interactive:
|
79 |
+
v = HeadlessRenderer()
|
80 |
+
else:
|
81 |
+
v = Viewer(size=size)
|
82 |
+
|
83 |
+
self.v = v
|
84 |
+
self.interactive = interactive
|
85 |
+
# self.layers = layers
|
86 |
+
self.render_types = render_types
|
87 |
+
|
88 |
+
def view_interactive(self):
|
89 |
+
self.v.run()
|
90 |
+
|
91 |
+
def view_fn_headless(self, num_iter, out_folder):
|
92 |
+
v = self.v
|
93 |
+
|
94 |
+
v._init_scene()
|
95 |
+
|
96 |
+
logger.info("Rendering to video")
|
97 |
+
if "video" in self.render_types:
|
98 |
+
vid_p = op.join(out_folder, "video.mp4")
|
99 |
+
v.save_video(video_dir=vid_p)
|
100 |
+
|
101 |
+
pbar = tqdm(range(num_iter))
|
102 |
+
for fidx in pbar:
|
103 |
+
out_rgb = op.join(out_folder, "images", f"rgb/{fidx:04d}.png")
|
104 |
+
out_mask = op.join(out_folder, "images", f"mask/{fidx:04d}.png")
|
105 |
+
out_depth = op.join(out_folder, "images", f"depth/{fidx:04d}.npy")
|
106 |
+
|
107 |
+
# render RGB, depth, segmentation masks
|
108 |
+
if "rgb" in self.render_types:
|
109 |
+
v.export_frame(out_rgb)
|
110 |
+
if "depth" in self.render_types:
|
111 |
+
os.makedirs(op.dirname(out_depth), exist_ok=True)
|
112 |
+
render_depth(v, out_depth)
|
113 |
+
if "mask" in self.render_types:
|
114 |
+
os.makedirs(op.dirname(out_mask), exist_ok=True)
|
115 |
+
render_mask(v, out_mask)
|
116 |
+
v.scene.next_frame()
|
117 |
+
logger.info(f"Exported to {out_folder}")
|
118 |
+
|
119 |
+
@abstractmethod
|
120 |
+
def load_data(self):
|
121 |
+
pass
|
122 |
+
|
123 |
+
def check_format(self, batch):
|
124 |
+
meshes_all, data = batch
|
125 |
+
assert isinstance(meshes_all, dict)
|
126 |
+
assert len(meshes_all) > 0
|
127 |
+
for mesh in meshes_all.values():
|
128 |
+
assert isinstance(mesh, Meshes)
|
129 |
+
assert isinstance(data, ViewerData)
|
130 |
+
|
131 |
+
def render_seq(self, batch, out_folder="./render_out"):
|
132 |
+
meshes_all, data = batch
|
133 |
+
self.setup_viewer(data)
|
134 |
+
for mesh in meshes_all.values():
|
135 |
+
self.v.scene.add(mesh)
|
136 |
+
if self.interactive:
|
137 |
+
self.view_interactive()
|
138 |
+
else:
|
139 |
+
num_iter = data["num_frames"]
|
140 |
+
self.view_fn_headless(num_iter, out_folder)
|
141 |
+
|
142 |
+
def setup_viewer(self, data):
|
143 |
+
v = self.v
|
144 |
+
fps = 30
|
145 |
+
if "imgnames" in data:
|
146 |
+
setup_billboard(data, v)
|
147 |
+
|
148 |
+
# camera.show_path()
|
149 |
+
v.run_animations = True # autoplay
|
150 |
+
v.run_animations = False # autoplay
|
151 |
+
v.playback_fps = fps
|
152 |
+
v.scene.fps = fps
|
153 |
+
v.scene.origin.enabled = False
|
154 |
+
v.scene.floor.enabled = False
|
155 |
+
v.auto_set_floor = False
|
156 |
+
v.scene.floor.position[1] = -3
|
157 |
+
# v.scene.camera.position = np.array((0.0, 0.0, 0))
|
158 |
+
self.v = v
|
159 |
+
|
160 |
+
|
161 |
+
def dist2vc(dist_ro, dist_lo, dist_o, _cmap, tf_fn=None):
|
162 |
+
if tf_fn is not None:
|
163 |
+
exp_map = tf_fn
|
164 |
+
else:
|
165 |
+
exp_map = small_exp_map
|
166 |
+
dist_ro = exp_map(dist_ro)
|
167 |
+
dist_lo = exp_map(dist_lo)
|
168 |
+
dist_o = exp_map(dist_o)
|
169 |
+
|
170 |
+
vc_ro = _cmap(dist_ro)
|
171 |
+
vc_lo = _cmap(dist_lo)
|
172 |
+
vc_o = _cmap(dist_o)
|
173 |
+
return vc_ro, vc_lo, vc_o
|
174 |
+
|
175 |
+
|
176 |
+
def small_exp_map(_dist):
|
177 |
+
dist = np.copy(_dist)
|
178 |
+
# dist = 1.0 - np.clip(dist, 0, 0.1) / 0.1
|
179 |
+
dist = np.exp(-20.0 * dist)
|
180 |
+
return dist
|
181 |
+
|
182 |
+
|
183 |
+
def construct_viewer_meshes(data, draw_edges=False, flat_shading=True):
|
184 |
+
rotation_flip = aa2rot_numpy(np.array([1, 0, 0]) * np.pi)
|
185 |
+
meshes = {}
|
186 |
+
for key, val in data.items():
|
187 |
+
if "object" in key:
|
188 |
+
flat_shading = False
|
189 |
+
else:
|
190 |
+
flat_shading = flat_shading
|
191 |
+
v3d = val["v3d"]
|
192 |
+
meshes[key] = Meshes(
|
193 |
+
v3d,
|
194 |
+
val["f3d"],
|
195 |
+
vertex_colors=val["vc"],
|
196 |
+
name=val["name"],
|
197 |
+
flat_shading=flat_shading,
|
198 |
+
draw_edges=draw_edges,
|
199 |
+
material=materials[val["color"]],
|
200 |
+
rotation=rotation_flip,
|
201 |
+
)
|
202 |
+
return meshes
|
203 |
+
|
204 |
+
|
205 |
+
def setup_viewer(
|
206 |
+
v, shared_folder_p, video, images_path, data, flag, seq_name, side_angle
|
207 |
+
):
|
208 |
+
fps = 10
|
209 |
+
cols, rows = 224, 224
|
210 |
+
focal = 1000.0
|
211 |
+
|
212 |
+
# setup image paths
|
213 |
+
regex = re.compile(r"(\d*)$")
|
214 |
+
|
215 |
+
def sort_key(x):
|
216 |
+
name = os.path.splitext(x)[0]
|
217 |
+
return int(regex.search(name).group(0))
|
218 |
+
|
219 |
+
# setup billboard
|
220 |
+
images_path = op.join(shared_folder_p, "images")
|
221 |
+
images_paths = [
|
222 |
+
os.path.join(images_path, f)
|
223 |
+
for f in sorted(os.listdir(images_path), key=sort_key)
|
224 |
+
]
|
225 |
+
assert len(images_paths) > 0
|
226 |
+
|
227 |
+
cam_t = data[f"{flag}.object.cam_t"]
|
228 |
+
num_frames = min(cam_t.shape[0], len(images_paths))
|
229 |
+
cam_t = cam_t[:num_frames]
|
230 |
+
# setup camera
|
231 |
+
K = np.array([[focal, 0, rows / 2.0], [0, focal, cols / 2.0], [0, 0, 1]])
|
232 |
+
Rt = np.zeros((num_frames, 3, 4))
|
233 |
+
Rt[:, :, 3] = cam_t
|
234 |
+
Rt[:, :3, :3] = np.eye(3)
|
235 |
+
Rt[:, 1:3, :3] *= -1.0
|
236 |
+
|
237 |
+
camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
|
238 |
+
if side_angle is None:
|
239 |
+
billboard = Billboard.from_camera_and_distance(
|
240 |
+
camera, 10.0, cols, rows, images_paths
|
241 |
+
)
|
242 |
+
v.scene.add(billboard)
|
243 |
+
v.scene.add(camera)
|
244 |
+
v.run_animations = True # autoplay
|
245 |
+
v.playback_fps = fps
|
246 |
+
v.scene.fps = fps
|
247 |
+
v.scene.origin.enabled = False
|
248 |
+
v.scene.floor.enabled = False
|
249 |
+
v.auto_set_floor = False
|
250 |
+
v.scene.floor.position[1] = -3
|
251 |
+
v.set_temp_camera(camera)
|
252 |
+
# v.scene.camera.position = np.array((0.0, 0.0, 0))
|
253 |
+
return v
|
254 |
+
|
255 |
+
|
256 |
+
def render_depth(v, depth_p):
|
257 |
+
depth = np.array(v.get_depth()).astype(np.float16)
|
258 |
+
np.save(depth_p, depth)
|
259 |
+
|
260 |
+
|
261 |
+
def render_mask(v, mask_p):
|
262 |
+
nodes_uid = {node.name: node.uid for node in v.scene.collect_nodes()}
|
263 |
+
my_cmap = {
|
264 |
+
uid: [SEGM_IDS[name], SEGM_IDS[name], SEGM_IDS[name]]
|
265 |
+
for name, uid in nodes_uid.items()
|
266 |
+
if name in SEGM_IDS.keys()
|
267 |
+
}
|
268 |
+
mask = np.array(v.get_mask(color_map=my_cmap)).astype(np.uint8)
|
269 |
+
mask = Image.fromarray(mask)
|
270 |
+
mask.save(mask_p)
|
271 |
+
|
272 |
+
|
273 |
+
def setup_billboard(data, v):
|
274 |
+
images_paths = data.imgnames
|
275 |
+
K = data.K
|
276 |
+
Rt = data.Rt
|
277 |
+
rows = data.rows
|
278 |
+
cols = data.cols
|
279 |
+
camera = OpenCVCamera(K, Rt, cols, rows, viewer=v)
|
280 |
+
if images_paths is not None:
|
281 |
+
billboard = Billboard.from_camera_and_distance(
|
282 |
+
camera, 10.0, cols, rows, images_paths
|
283 |
+
)
|
284 |
+
v.scene.add(billboard)
|
285 |
+
v.scene.add(camera)
|
286 |
+
v.scene.camera.load_cam()
|
287 |
+
v.set_temp_camera(camera)
|
common/vis_utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.cm as cm
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
# connection between the 8 points of 3d bbox
|
7 |
+
BONES_3D_BBOX = [
|
8 |
+
(0, 1),
|
9 |
+
(1, 2),
|
10 |
+
(2, 3),
|
11 |
+
(3, 0),
|
12 |
+
(0, 4),
|
13 |
+
(1, 5),
|
14 |
+
(2, 6),
|
15 |
+
(3, 7),
|
16 |
+
(4, 5),
|
17 |
+
(5, 6),
|
18 |
+
(6, 7),
|
19 |
+
(7, 4),
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
def plot_2d_bbox(bbox_2d, bones, color, ax):
|
24 |
+
if ax is None:
|
25 |
+
axx = plt
|
26 |
+
else:
|
27 |
+
axx = ax
|
28 |
+
colors = cm.rainbow(np.linspace(0, 1, len(bbox_2d)))
|
29 |
+
for pt, c in zip(bbox_2d, colors):
|
30 |
+
axx.scatter(pt[0], pt[1], color=c, s=50)
|
31 |
+
|
32 |
+
if bones is None:
|
33 |
+
bones = BONES_3D_BBOX
|
34 |
+
for bone in bones:
|
35 |
+
sidx, eidx = bone
|
36 |
+
# bottom of bbox is white
|
37 |
+
if min(sidx, eidx) >= 4:
|
38 |
+
color = "w"
|
39 |
+
axx.plot(
|
40 |
+
[bbox_2d[sidx][0], bbox_2d[eidx][0]],
|
41 |
+
[bbox_2d[sidx][1], bbox_2d[eidx][1]],
|
42 |
+
color,
|
43 |
+
)
|
44 |
+
return axx
|
45 |
+
|
46 |
+
|
47 |
+
# http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure
|
48 |
+
def fig2data(fig):
|
49 |
+
"""
|
50 |
+
@brief Convert a Matplotlib figure to a 4D
|
51 |
+
numpy array with RGBA channels and return it
|
52 |
+
@param fig a matplotlib figure
|
53 |
+
@return a numpy 3D array of RGBA values
|
54 |
+
"""
|
55 |
+
# draw the renderer
|
56 |
+
fig.canvas.draw()
|
57 |
+
|
58 |
+
# Get the RGBA buffer from the figure
|
59 |
+
w, h = fig.canvas.get_width_height()
|
60 |
+
buf = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)
|
61 |
+
buf.shape = (w, h, 4)
|
62 |
+
|
63 |
+
# canvas.tostring_argb give pixmap in ARGB mode.
|
64 |
+
# Roll the ALPHA channel to have it in RGBA mode
|
65 |
+
buf = np.roll(buf, 3, axis=2)
|
66 |
+
return buf
|
67 |
+
|
68 |
+
|
69 |
+
# http://www.icare.univ-lille1.fr/tutorials/convert_a_matplotlib_figure
|
70 |
+
def fig2img(fig):
|
71 |
+
"""
|
72 |
+
@brief Convert a Matplotlib figure to a PIL Image
|
73 |
+
in RGBA format and return it
|
74 |
+
@param fig a matplotlib figure
|
75 |
+
@return a Python Imaging Library ( PIL ) image
|
76 |
+
"""
|
77 |
+
# put the figure pixmap into a numpy array
|
78 |
+
buf = fig2data(fig)
|
79 |
+
w, h, _ = buf.shape
|
80 |
+
return Image.frombytes("RGBA", (w, h), buf.tobytes())
|
81 |
+
|
82 |
+
|
83 |
+
def concat_pil_images(images):
|
84 |
+
"""
|
85 |
+
Put a list of PIL images next to each other
|
86 |
+
"""
|
87 |
+
assert isinstance(images, list)
|
88 |
+
widths, heights = zip(*(i.size for i in images))
|
89 |
+
|
90 |
+
total_width = sum(widths)
|
91 |
+
max_height = max(heights)
|
92 |
+
|
93 |
+
new_im = Image.new("RGB", (total_width, max_height))
|
94 |
+
|
95 |
+
x_offset = 0
|
96 |
+
for im in images:
|
97 |
+
new_im.paste(im, (x_offset, 0))
|
98 |
+
x_offset += im.size[0]
|
99 |
+
return new_im
|
100 |
+
|
101 |
+
|
102 |
+
def stack_pil_images(images):
|
103 |
+
"""
|
104 |
+
Stack a list of PIL images next to each other
|
105 |
+
"""
|
106 |
+
assert isinstance(images, list)
|
107 |
+
widths, heights = zip(*(i.size for i in images))
|
108 |
+
|
109 |
+
total_height = sum(heights)
|
110 |
+
max_width = max(widths)
|
111 |
+
|
112 |
+
new_im = Image.new("RGB", (max_width, total_height))
|
113 |
+
|
114 |
+
y_offset = 0
|
115 |
+
for im in images:
|
116 |
+
new_im.paste(im, (0, y_offset))
|
117 |
+
y_offset += im.size[1]
|
118 |
+
return new_im
|
119 |
+
|
120 |
+
|
121 |
+
def im_list_to_plt(image_list, figsize, title_list=None):
|
122 |
+
fig, axes = plt.subplots(nrows=1, ncols=len(image_list), figsize=figsize)
|
123 |
+
for idx, (ax, im) in enumerate(zip(axes, image_list)):
|
124 |
+
ax.imshow(im)
|
125 |
+
ax.set_title(title_list[idx])
|
126 |
+
fig.tight_layout()
|
127 |
+
im = fig2img(fig)
|
128 |
+
plt.close()
|
129 |
+
return im
|
common/xdict.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import common.thing as thing
|
5 |
+
|
6 |
+
|
7 |
+
def _print_stat(key, thing):
|
8 |
+
"""
|
9 |
+
Helper function for printing statistics about a key-value pair in an xdict.
|
10 |
+
"""
|
11 |
+
mytype = type(thing)
|
12 |
+
if isinstance(thing, (list, tuple)):
|
13 |
+
print("{:<20}: {:<30}\t{:}".format(key, len(thing), mytype))
|
14 |
+
elif isinstance(thing, (torch.Tensor)):
|
15 |
+
dev = thing.device
|
16 |
+
shape = str(thing.shape).replace(" ", "")
|
17 |
+
print("{:<20}: {:<30}\t{:}\t{}".format(key, shape, mytype, dev))
|
18 |
+
elif isinstance(thing, (np.ndarray)):
|
19 |
+
dev = ""
|
20 |
+
shape = str(thing.shape).replace(" ", "")
|
21 |
+
print("{:<20}: {:<30}\t{:}".format(key, shape, mytype))
|
22 |
+
else:
|
23 |
+
print("{:<20}: {:}".format(key, mytype))
|
24 |
+
|
25 |
+
|
26 |
+
class xdict(dict):
|
27 |
+
"""
|
28 |
+
A subclass of Python's built-in dict class, which provides additional methods for manipulating and operating on dictionaries.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, mydict=None):
|
32 |
+
"""
|
33 |
+
Constructor for the xdict class. Creates a new xdict object and optionally initializes it with key-value pairs from the provided dictionary mydict. If mydict is not provided, an empty xdict is created.
|
34 |
+
"""
|
35 |
+
if mydict is None:
|
36 |
+
return
|
37 |
+
|
38 |
+
for k, v in mydict.items():
|
39 |
+
super().__setitem__(k, v)
|
40 |
+
|
41 |
+
def subset(self, keys):
|
42 |
+
"""
|
43 |
+
Returns a new xdict object containing only the key-value pairs with keys in the provided list 'keys'.
|
44 |
+
"""
|
45 |
+
out_dict = {}
|
46 |
+
for k in keys:
|
47 |
+
out_dict[k] = self[k]
|
48 |
+
return xdict(out_dict)
|
49 |
+
|
50 |
+
def __setitem__(self, key, val):
|
51 |
+
"""
|
52 |
+
Overrides the dict.__setitem__ method to raise an assertion error if a key already exists.
|
53 |
+
"""
|
54 |
+
assert key not in self.keys(), f"Key already exists {key}"
|
55 |
+
super().__setitem__(key, val)
|
56 |
+
|
57 |
+
def search(self, keyword, replace_to=None):
|
58 |
+
"""
|
59 |
+
Returns a new xdict object containing only the key-value pairs with keys that contain the provided keyword.
|
60 |
+
"""
|
61 |
+
out_dict = {}
|
62 |
+
for k in self.keys():
|
63 |
+
if keyword in k:
|
64 |
+
if replace_to is None:
|
65 |
+
out_dict[k] = self[k]
|
66 |
+
else:
|
67 |
+
out_dict[k.replace(keyword, replace_to)] = self[k]
|
68 |
+
return xdict(out_dict)
|
69 |
+
|
70 |
+
def rm(self, keyword, keep_list=[], verbose=False):
|
71 |
+
"""
|
72 |
+
Returns a new xdict object with keys that contain keyword removed. Keys in keep_list are excluded from the removal.
|
73 |
+
"""
|
74 |
+
out_dict = {}
|
75 |
+
for k in self.keys():
|
76 |
+
if keyword not in k or k in keep_list:
|
77 |
+
out_dict[k] = self[k]
|
78 |
+
else:
|
79 |
+
if verbose:
|
80 |
+
print(f"Removing: {k}")
|
81 |
+
return xdict(out_dict)
|
82 |
+
|
83 |
+
def overwrite(self, k, v):
|
84 |
+
"""
|
85 |
+
The original assignment operation of Python dict
|
86 |
+
"""
|
87 |
+
super().__setitem__(k, v)
|
88 |
+
|
89 |
+
def merge(self, dict2):
|
90 |
+
"""
|
91 |
+
Same as dict.update(), but raises an assertion error if there are duplicate keys between the two dictionaries.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
dict2 (dict or xdict): The dictionary or xdict instance to merge with.
|
95 |
+
|
96 |
+
Raises:
|
97 |
+
AssertionError: If dict2 is not a dictionary or xdict instance.
|
98 |
+
AssertionError: If there are duplicate keys between the two instances.
|
99 |
+
"""
|
100 |
+
assert isinstance(dict2, (dict, xdict))
|
101 |
+
mykeys = set(self.keys())
|
102 |
+
intersect = mykeys.intersection(set(dict2.keys()))
|
103 |
+
assert len(intersect) == 0, f"Merge failed: duplicate keys ({intersect})"
|
104 |
+
self.update(dict2)
|
105 |
+
|
106 |
+
def mul(self, scalar):
|
107 |
+
"""
|
108 |
+
Multiplies each value (could be tensor, np.array, list) in the xdict instance by the provided scalar.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
scalar (float): The scalar to multiply the values by.
|
112 |
+
|
113 |
+
Raises:
|
114 |
+
AssertionError: If scalar is not a float.
|
115 |
+
"""
|
116 |
+
if isinstance(scalar, int):
|
117 |
+
scalar = 1.0 * scalar
|
118 |
+
assert isinstance(scalar, float)
|
119 |
+
out_dict = {}
|
120 |
+
for k in self.keys():
|
121 |
+
if isinstance(self[k], list):
|
122 |
+
out_dict[k] = [v * scalar for v in self[k]]
|
123 |
+
else:
|
124 |
+
out_dict[k] = self[k] * scalar
|
125 |
+
return xdict(out_dict)
|
126 |
+
|
127 |
+
def prefix(self, text):
|
128 |
+
"""
|
129 |
+
Adds a prefix to each key in the xdict instance.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
text (str): The prefix to add.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
xdict: The xdict instance with the added prefix.
|
136 |
+
"""
|
137 |
+
out_dict = {}
|
138 |
+
for k in self.keys():
|
139 |
+
out_dict[text + k] = self[k]
|
140 |
+
return xdict(out_dict)
|
141 |
+
|
142 |
+
def replace_keys(self, str_src, str_tar):
|
143 |
+
"""
|
144 |
+
Replaces a substring in all keys of the xdict instance.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
str_src (str): The substring to replace.
|
148 |
+
str_tar (str): The replacement string.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
xdict: The xdict instance with the replaced keys.
|
152 |
+
"""
|
153 |
+
out_dict = {}
|
154 |
+
for k in self.keys():
|
155 |
+
old_key = k
|
156 |
+
new_key = old_key.replace(str_src, str_tar)
|
157 |
+
out_dict[new_key] = self[k]
|
158 |
+
return xdict(out_dict)
|
159 |
+
|
160 |
+
def postfix(self, text):
|
161 |
+
"""
|
162 |
+
Adds a postfix to each key in the xdict instance.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
text (str): The postfix to add.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
xdict: The xdict instance with the added postfix.
|
169 |
+
"""
|
170 |
+
out_dict = {}
|
171 |
+
for k in self.keys():
|
172 |
+
out_dict[k + text] = self[k]
|
173 |
+
return xdict(out_dict)
|
174 |
+
|
175 |
+
def sorted_keys(self):
|
176 |
+
"""
|
177 |
+
Returns a sorted list of the keys in the xdict instance.
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
list: A sorted list of keys in the xdict instance.
|
181 |
+
"""
|
182 |
+
return sorted(list(self.keys()))
|
183 |
+
|
184 |
+
def to(self, dev):
|
185 |
+
"""
|
186 |
+
Moves the xdict instance to a specific device.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
dev (torch.device): The device to move the instance to.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
xdict: The xdict instance moved to the specified device.
|
193 |
+
"""
|
194 |
+
if dev is None:
|
195 |
+
return self
|
196 |
+
raw_dict = dict(self)
|
197 |
+
return xdict(thing.thing2dev(raw_dict, dev))
|
198 |
+
|
199 |
+
def to_torch(self):
|
200 |
+
"""
|
201 |
+
Converts elements in the xdict to Torch tensors and returns a new xdict.
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
xdict: A new xdict with Torch tensors as values.
|
205 |
+
"""
|
206 |
+
return xdict(thing.thing2torch(self))
|
207 |
+
|
208 |
+
def to_np(self):
|
209 |
+
"""
|
210 |
+
Converts elements in the xdict to numpy arrays and returns a new xdict.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
xdict: A new xdict with numpy arrays as values.
|
214 |
+
"""
|
215 |
+
return xdict(thing.thing2np(self))
|
216 |
+
|
217 |
+
def tolist(self):
|
218 |
+
"""
|
219 |
+
Converts elements in the xdict to Python lists and returns a new xdict.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
xdict: A new xdict with Python lists as values.
|
223 |
+
"""
|
224 |
+
return xdict(thing.thing2list(self))
|
225 |
+
|
226 |
+
def print_stat(self):
|
227 |
+
"""
|
228 |
+
Prints statistics for each item in the xdict.
|
229 |
+
"""
|
230 |
+
for k, v in self.items():
|
231 |
+
_print_stat(k, v)
|
232 |
+
|
233 |
+
def detach(self):
|
234 |
+
"""
|
235 |
+
Detaches all Torch tensors in the xdict from the computational graph and moves them to the CPU.
|
236 |
+
Non-tensor objects are ignored.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
xdict: A new xdict with detached Torch tensors as values.
|
240 |
+
"""
|
241 |
+
return xdict(thing.detach_thing(self))
|
242 |
+
|
243 |
+
def has_invalid(self):
|
244 |
+
"""
|
245 |
+
Checks if any of the Torch tensors in the xdict contain NaN or Inf values.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
bool: True if at least one tensor contains NaN or Inf values, False otherwise.
|
249 |
+
"""
|
250 |
+
for k, v in self.items():
|
251 |
+
if isinstance(v, torch.Tensor):
|
252 |
+
if torch.isnan(v).any():
|
253 |
+
print(f"{k} contains nan values")
|
254 |
+
return True
|
255 |
+
if torch.isinf(v).any():
|
256 |
+
print(f"{k} contains inf values")
|
257 |
+
return True
|
258 |
+
return False
|
259 |
+
|
260 |
+
def apply(self, operation, criterion=None):
|
261 |
+
"""
|
262 |
+
Applies an operation to the values in the xdict, based on an optional criterion.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
operation (callable): A callable object that takes a single argument and returns a value.
|
266 |
+
criterion (callable, optional): A callable object that takes two arguments (key and value) and returns a boolean.
|
267 |
+
|
268 |
+
Returns:
|
269 |
+
xdict: A new xdict with the same keys as the original, but with the values modified by the operation.
|
270 |
+
"""
|
271 |
+
out = {}
|
272 |
+
for k, v in self.items():
|
273 |
+
if criterion is None or criterion(k, v):
|
274 |
+
out[k] = operation(v)
|
275 |
+
return xdict(out)
|
276 |
+
|
277 |
+
def save(self, path, dev=None, verbose=True):
|
278 |
+
"""
|
279 |
+
Saves the xdict to disk as a Torch tensor.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
path (str): The path to save the xdict.
|
283 |
+
dev (torch.device, optional): The device to use for saving the tensor (default is CPU).
|
284 |
+
verbose (bool, optional): Whether to print a message indicating that the xdict has been saved (default is True).
|
285 |
+
"""
|
286 |
+
if verbose:
|
287 |
+
print(f"Saving to {path}")
|
288 |
+
torch.save(self.to(dev), path)
|
data_loaders/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_loaders/__pycache__/get_data.cpython-38.pyc
ADDED
Binary file (4.42 kB). View file
|
|
data_loaders/__pycache__/tensors.cpython-38.pyc
ADDED
Binary file (6.98 kB). View file
|
|
data_loaders/get_data.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from data_loaders.tensors import collate as all_collate
|
3 |
+
from data_loaders.tensors import t2m_collate, motion_ours_collate, motion_ours_singe_seq_collate, motion_ours_obj_base_rel_dist_collate
|
4 |
+
# from data_loaders.humanml.data.dataset import HumanML3D
|
5 |
+
import torch
|
6 |
+
|
7 |
+
def get_dataset_class(name, args=None):
|
8 |
+
if name == "amass":
|
9 |
+
from .amass import AMASS
|
10 |
+
return AMASS
|
11 |
+
elif name == "uestc":
|
12 |
+
from .a2m.uestc import UESTC
|
13 |
+
return UESTC
|
14 |
+
elif name == "humanact12":
|
15 |
+
from .a2m.humanact12poses import HumanAct12Poses
|
16 |
+
return HumanAct12Poses ## to pose ##
|
17 |
+
elif name == "humanml":
|
18 |
+
from data_loaders.humanml.data.dataset import HumanML3D
|
19 |
+
return HumanML3D
|
20 |
+
elif name == "kit":
|
21 |
+
from data_loaders.humanml.data.dataset import KIT
|
22 |
+
return KIT
|
23 |
+
elif name == "motion_ours": # motion ours
|
24 |
+
if len(args.single_seq_path) > 0 and not args.use_predicted_infos and not args.use_interpolated_infos:
|
25 |
+
print(f"Using single frame dataset for evaluation purpose...")
|
26 |
+
# from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V16
|
27 |
+
if args.rep_type == "obj_base_rel_dist":
|
28 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V17 as my_data
|
29 |
+
elif args.rep_type == "ambient_obj_base_rel_dist":
|
30 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V18 as my_data
|
31 |
+
elif args.rep_type in[ "obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
|
32 |
+
if args.use_arctic and args.use_pose_pred:
|
33 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Arctic_from_Pred as my_data
|
34 |
+
elif args.use_hho:
|
35 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_HHO as my_data
|
36 |
+
elif args.use_arctic:
|
37 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Arctic as my_data
|
38 |
+
elif len(args.cad_model_fn) > 0:
|
39 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_Ours as my_data
|
40 |
+
elif len(args.predicted_info_fn) > 0:
|
41 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19_From_Evaluated_Info as my_data
|
42 |
+
else:
|
43 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V19 as my_data
|
44 |
+
else:
|
45 |
+
from data_loaders.humanml.data.dataset_ours_single_seq import GRAB_Dataset_V16 as my_data
|
46 |
+
return my_data
|
47 |
+
else:
|
48 |
+
if args.rep_type == "obj_base_rel_dist":
|
49 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V17 as my_data
|
50 |
+
elif args.rep_type == "ambient_obj_base_rel_dist":
|
51 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V18 as my_data
|
52 |
+
elif args.rep_type in ["obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
|
53 |
+
if args.use_arctic:
|
54 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V19_ARCTIC as my_data
|
55 |
+
elif args.use_vox_data: # use vox data here #
|
56 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V20 as my_data
|
57 |
+
elif args.use_predicted_infos: # train with predicted infos for test tim adaptation #
|
58 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V21 as my_data
|
59 |
+
elif args.use_interpolated_infos:
|
60 |
+
# GRAB_Dataset_V22
|
61 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V22 as my_data
|
62 |
+
else:
|
63 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V19 as my_data
|
64 |
+
else:
|
65 |
+
from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V16 as my_data
|
66 |
+
return my_data
|
67 |
+
# from data_loaders.humanml.data.dataset_ours import GRAB_Dataset_V16
|
68 |
+
# return GRAB_Dataset_V16
|
69 |
+
else:
|
70 |
+
raise ValueError(f'Unsupported dataset name [{name}]')
|
71 |
+
|
72 |
+
def get_collate_fn(name, hml_mode='train', args=None):
|
73 |
+
print(f"name: {name}, hml_mode: {hml_mode}")
|
74 |
+
if hml_mode == 'gt':
|
75 |
+
from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
|
76 |
+
return t2m_eval_collate
|
77 |
+
if name in ["humanml", "kit"]:
|
78 |
+
return t2m_collate
|
79 |
+
elif name in ["motion_ours"]:
|
80 |
+
## === single seq path === ##
|
81 |
+
print(f"single_seq_path: {args.single_seq_path}, rep_type: {args.rep_type}")
|
82 |
+
# motion_ours_obj_base_rel_dist_collate
|
83 |
+
## rep_type of the obj_base_pts rel_dist; ambient obj base rel dist ##
|
84 |
+
if args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist", "obj_base_rel_dist_we", "obj_base_rel_dist_we_wj", "obj_base_rel_dist_we_wj_latents"]:
|
85 |
+
return motion_ours_obj_base_rel_dist_collate
|
86 |
+
else: # single_seq_path #
|
87 |
+
if len(args.single_seq_path) > 0:
|
88 |
+
return motion_ours_singe_seq_collate
|
89 |
+
else:
|
90 |
+
return motion_ours_collate
|
91 |
+
# if len(args.single_seq_path) > 0:
|
92 |
+
# return motion_ours_singe_seq_collate
|
93 |
+
# else:
|
94 |
+
# if args.rep_type == "obj_base_rel_dist":
|
95 |
+
# return motion_ours_obj_base_rel_dist_collate
|
96 |
+
# else:
|
97 |
+
# return motion_ours_collate
|
98 |
+
else:
|
99 |
+
return all_collate
|
100 |
+
|
101 |
+
## get dataset and datasset ###
|
102 |
+
def get_dataset(name, num_frames, split='train', hml_mode='train', args=None):
|
103 |
+
DATA = get_dataset_class(name, args=args)
|
104 |
+
if name in ["humanml", "kit"]:
|
105 |
+
dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode)
|
106 |
+
elif name in ["motion_ours"]:
|
107 |
+
# humanml_datawarper = HumanML3D(split=split, num_frames=num_frames, mode=hml_mode, load_vectorizer=True)
|
108 |
+
# w_vectorizer = humanml_datawarper.w_vectorizer
|
109 |
+
|
110 |
+
w_vectorizer = None
|
111 |
+
# split = "val" ## add split, split here --> split --> split and split ##
|
112 |
+
data_path = "/data1/sim/GRAB_processed"
|
113 |
+
# split, w_vectorizer, window_size=30, step_size=15, num_points=8000, args=None
|
114 |
+
window_size = args.window_size
|
115 |
+
# split= "val"
|
116 |
+
dataset = DATA(data_path, split=split, w_vectorizer=w_vectorizer, window_size=window_size, step_size=15, num_points=8000, args=args)
|
117 |
+
else:
|
118 |
+
dataset = DATA(split=split, num_frames=num_frames)
|
119 |
+
return dataset
|
120 |
+
|
121 |
+
|
122 |
+
def get_dataset_only(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
|
123 |
+
dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
|
124 |
+
return dataset
|
125 |
+
|
126 |
+
# python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset motion_ours
|
127 |
+
def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
|
128 |
+
dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
|
129 |
+
collate = get_collate_fn(name, hml_mode, args=args)
|
130 |
+
|
131 |
+
if args is not None and name in ["motion_ours"] and len(args.single_seq_path) > 0:
|
132 |
+
shuffle_loader = False
|
133 |
+
drop_last = False
|
134 |
+
else:
|
135 |
+
shuffle_loader = True
|
136 |
+
drop_last = True
|
137 |
+
|
138 |
+
num_workers = 8 ## get data; get data loader ##
|
139 |
+
num_workers = 16 # num_workers # ## num_workders #
|
140 |
+
### ==== create dataloader here ==== ###
|
141 |
+
### ==== create dataloader here ==== ###
|
142 |
+
loader = DataLoader( # tag for each sequence
|
143 |
+
dataset, batch_size=batch_size, shuffle=shuffle_loader,
|
144 |
+
num_workers=num_workers, drop_last=drop_last, collate_fn=collate
|
145 |
+
)
|
146 |
+
|
147 |
+
return loader
|
148 |
+
|
149 |
+
|
150 |
+
# python -m train.train_mdm --save_dir save/my_humanml_trans_enc_512 --dataset motion_ours
|
151 |
+
def get_dataset_loader_dist(name, batch_size, num_frames, split='train', hml_mode='train', args=None):
|
152 |
+
dataset = get_dataset(name, num_frames, split, hml_mode, args=args)
|
153 |
+
collate = get_collate_fn(name, hml_mode, args=args)
|
154 |
+
|
155 |
+
if args is not None and name in ["motion_ours"] and len(args.single_seq_path) > 0:
|
156 |
+
# shuffle_loader = False
|
157 |
+
drop_last = False
|
158 |
+
else:
|
159 |
+
# shuffle_loader = True
|
160 |
+
drop_last = True
|
161 |
+
|
162 |
+
num_workers = 8 ## get data; get data loader ##
|
163 |
+
num_workers = 16 # num_workers # ## num_workders #
|
164 |
+
### ==== create dataloader here ==== ###
|
165 |
+
### ==== create dataloader here ==== ###
|
166 |
+
|
167 |
+
''' dist sampler and loader '''
|
168 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
169 |
+
loader = DataLoader(dataset, batch_size=batch_size,
|
170 |
+
sampler=sampler, num_workers=num_workers, drop_last=drop_last, collate_fn=collate)
|
171 |
+
|
172 |
+
|
173 |
+
# loader = DataLoader( # tag for each sequence
|
174 |
+
# dataset, batch_size=batch_size, shuffle=shuffle_loader,
|
175 |
+
# num_workers=num_workers, drop_last=drop_last, collate_fn=collate
|
176 |
+
# )
|
177 |
+
|
178 |
+
return loader
|
data_loaders/humanml/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data_loaders/humanml/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This code is based on https://github.com/EricGuo5513/text-to-motion.git
|
data_loaders/humanml/common/__pycache__/quaternion.cpython-38.pyc
ADDED
Binary file (11.6 kB). View file
|
|
data_loaders/humanml/common/__pycache__/skeleton.cpython-38.pyc
ADDED
Binary file (6.15 kB). View file
|
|
data_loaders/humanml/common/quaternion.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
12 |
+
|
13 |
+
_FLOAT_EPS = np.finfo(np.float).eps
|
14 |
+
|
15 |
+
# PyTorch-backed implementations
|
16 |
+
def qinv(q):
|
17 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
18 |
+
mask = torch.ones_like(q)
|
19 |
+
mask[..., 1:] = -mask[..., 1:]
|
20 |
+
return q * mask
|
21 |
+
|
22 |
+
|
23 |
+
def qinv_np(q):
|
24 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
25 |
+
return qinv(torch.from_numpy(q).float()).numpy()
|
26 |
+
|
27 |
+
|
28 |
+
def qnormalize(q):
|
29 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
30 |
+
return q / torch.norm(q, dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
|
33 |
+
def qmul(q, r):
|
34 |
+
"""
|
35 |
+
Multiply quaternion(s) q with quaternion(s) r.
|
36 |
+
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
37 |
+
Returns q*r as a tensor of shape (*, 4).
|
38 |
+
"""
|
39 |
+
assert q.shape[-1] == 4
|
40 |
+
assert r.shape[-1] == 4
|
41 |
+
|
42 |
+
original_shape = q.shape
|
43 |
+
|
44 |
+
# Compute outer product
|
45 |
+
terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
|
46 |
+
|
47 |
+
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
48 |
+
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
49 |
+
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
50 |
+
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
51 |
+
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
52 |
+
|
53 |
+
|
54 |
+
def qrot(q, v):
|
55 |
+
"""
|
56 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
57 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
58 |
+
where * denotes any number of dimensions.
|
59 |
+
Returns a tensor of shape (*, 3).
|
60 |
+
"""
|
61 |
+
assert q.shape[-1] == 4
|
62 |
+
assert v.shape[-1] == 3
|
63 |
+
assert q.shape[:-1] == v.shape[:-1]
|
64 |
+
|
65 |
+
original_shape = list(v.shape)
|
66 |
+
# print(q.shape)
|
67 |
+
q = q.contiguous().view(-1, 4)
|
68 |
+
v = v.contiguous().view(-1, 3)
|
69 |
+
|
70 |
+
qvec = q[:, 1:]
|
71 |
+
uv = torch.cross(qvec, v, dim=1)
|
72 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
73 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
74 |
+
|
75 |
+
|
76 |
+
def qeuler(q, order, epsilon=0, deg=True):
|
77 |
+
"""
|
78 |
+
Convert quaternion(s) q to Euler angles.
|
79 |
+
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
|
80 |
+
Returns a tensor of shape (*, 3).
|
81 |
+
"""
|
82 |
+
assert q.shape[-1] == 4
|
83 |
+
|
84 |
+
original_shape = list(q.shape)
|
85 |
+
original_shape[-1] = 3
|
86 |
+
q = q.view(-1, 4)
|
87 |
+
|
88 |
+
q0 = q[:, 0]
|
89 |
+
q1 = q[:, 1]
|
90 |
+
q2 = q[:, 2]
|
91 |
+
q3 = q[:, 3]
|
92 |
+
|
93 |
+
if order == 'xyz':
|
94 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
95 |
+
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
|
96 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
97 |
+
elif order == 'yzx':
|
98 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
99 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
100 |
+
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
|
101 |
+
elif order == 'zxy':
|
102 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
|
103 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
104 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
|
105 |
+
elif order == 'xzy':
|
106 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
107 |
+
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
108 |
+
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
|
109 |
+
elif order == 'yxz':
|
110 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
|
111 |
+
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
|
112 |
+
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
113 |
+
elif order == 'zyx':
|
114 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
115 |
+
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
|
116 |
+
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
117 |
+
else:
|
118 |
+
raise
|
119 |
+
|
120 |
+
if deg:
|
121 |
+
return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
|
122 |
+
else:
|
123 |
+
return torch.stack((x, y, z), dim=1).view(original_shape)
|
124 |
+
|
125 |
+
|
126 |
+
# Numpy-backed implementations
|
127 |
+
|
128 |
+
def qmul_np(q, r):
|
129 |
+
q = torch.from_numpy(q).contiguous().float()
|
130 |
+
r = torch.from_numpy(r).contiguous().float()
|
131 |
+
return qmul(q, r).numpy()
|
132 |
+
|
133 |
+
|
134 |
+
def qrot_np(q, v):
|
135 |
+
q = torch.from_numpy(q).contiguous().float()
|
136 |
+
v = torch.from_numpy(v).contiguous().float()
|
137 |
+
return qrot(q, v).numpy()
|
138 |
+
|
139 |
+
|
140 |
+
def qeuler_np(q, order, epsilon=0, use_gpu=False):
|
141 |
+
if use_gpu:
|
142 |
+
q = torch.from_numpy(q).cuda().float()
|
143 |
+
return qeuler(q, order, epsilon).cpu().numpy()
|
144 |
+
else:
|
145 |
+
q = torch.from_numpy(q).contiguous().float()
|
146 |
+
return qeuler(q, order, epsilon).numpy()
|
147 |
+
|
148 |
+
|
149 |
+
def qfix(q):
|
150 |
+
"""
|
151 |
+
Enforce quaternion continuity across the time dimension by selecting
|
152 |
+
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
153 |
+
between two consecutive frames.
|
154 |
+
|
155 |
+
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
|
156 |
+
Returns a tensor of the same shape.
|
157 |
+
"""
|
158 |
+
assert len(q.shape) == 3
|
159 |
+
assert q.shape[-1] == 4
|
160 |
+
|
161 |
+
result = q.copy()
|
162 |
+
dot_products = np.sum(q[1:] * q[:-1], axis=2)
|
163 |
+
mask = dot_products < 0
|
164 |
+
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
|
165 |
+
result[1:][mask] *= -1
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
def euler2quat(e, order, deg=True):
|
170 |
+
"""
|
171 |
+
Convert Euler angles to quaternions.
|
172 |
+
"""
|
173 |
+
assert e.shape[-1] == 3
|
174 |
+
|
175 |
+
original_shape = list(e.shape)
|
176 |
+
original_shape[-1] = 4
|
177 |
+
|
178 |
+
e = e.view(-1, 3)
|
179 |
+
|
180 |
+
## if euler angles in degrees
|
181 |
+
if deg:
|
182 |
+
e = e * np.pi / 180.
|
183 |
+
|
184 |
+
x = e[:, 0]
|
185 |
+
y = e[:, 1]
|
186 |
+
z = e[:, 2]
|
187 |
+
|
188 |
+
rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
|
189 |
+
ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
|
190 |
+
rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
|
191 |
+
|
192 |
+
result = None
|
193 |
+
for coord in order:
|
194 |
+
if coord == 'x':
|
195 |
+
r = rx
|
196 |
+
elif coord == 'y':
|
197 |
+
r = ry
|
198 |
+
elif coord == 'z':
|
199 |
+
r = rz
|
200 |
+
else:
|
201 |
+
raise
|
202 |
+
if result is None:
|
203 |
+
result = r
|
204 |
+
else:
|
205 |
+
result = qmul(result, r)
|
206 |
+
|
207 |
+
# Reverse antipodal representation to have a non-negative "w"
|
208 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
209 |
+
result *= -1
|
210 |
+
|
211 |
+
return result.view(original_shape)
|
212 |
+
|
213 |
+
|
214 |
+
def expmap_to_quaternion(e):
|
215 |
+
"""
|
216 |
+
Convert axis-angle rotations (aka exponential maps) to quaternions.
|
217 |
+
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
|
218 |
+
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
|
219 |
+
Returns a tensor of shape (*, 4).
|
220 |
+
"""
|
221 |
+
assert e.shape[-1] == 3
|
222 |
+
|
223 |
+
original_shape = list(e.shape)
|
224 |
+
original_shape[-1] = 4
|
225 |
+
e = e.reshape(-1, 3)
|
226 |
+
|
227 |
+
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
|
228 |
+
w = np.cos(0.5 * theta).reshape(-1, 1)
|
229 |
+
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
|
230 |
+
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
|
231 |
+
|
232 |
+
|
233 |
+
def euler_to_quaternion(e, order):
|
234 |
+
"""
|
235 |
+
Convert Euler angles to quaternions.
|
236 |
+
"""
|
237 |
+
assert e.shape[-1] == 3
|
238 |
+
|
239 |
+
original_shape = list(e.shape)
|
240 |
+
original_shape[-1] = 4
|
241 |
+
|
242 |
+
e = e.reshape(-1, 3)
|
243 |
+
|
244 |
+
x = e[:, 0]
|
245 |
+
y = e[:, 1]
|
246 |
+
z = e[:, 2]
|
247 |
+
|
248 |
+
rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
|
249 |
+
ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
|
250 |
+
rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
|
251 |
+
|
252 |
+
result = None
|
253 |
+
for coord in order:
|
254 |
+
if coord == 'x':
|
255 |
+
r = rx
|
256 |
+
elif coord == 'y':
|
257 |
+
r = ry
|
258 |
+
elif coord == 'z':
|
259 |
+
r = rz
|
260 |
+
else:
|
261 |
+
raise
|
262 |
+
if result is None:
|
263 |
+
result = r
|
264 |
+
else:
|
265 |
+
result = qmul_np(result, r)
|
266 |
+
|
267 |
+
# Reverse antipodal representation to have a non-negative "w"
|
268 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
269 |
+
result *= -1
|
270 |
+
|
271 |
+
return result.reshape(original_shape)
|
272 |
+
|
273 |
+
|
274 |
+
def quaternion_to_matrix(quaternions):
|
275 |
+
"""
|
276 |
+
Convert rotations given as quaternions to rotation matrices.
|
277 |
+
Args:
|
278 |
+
quaternions: quaternions with real part first,
|
279 |
+
as tensor of shape (..., 4).
|
280 |
+
Returns:
|
281 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
282 |
+
"""
|
283 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
284 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
285 |
+
|
286 |
+
o = torch.stack(
|
287 |
+
(
|
288 |
+
1 - two_s * (j * j + k * k),
|
289 |
+
two_s * (i * j - k * r),
|
290 |
+
two_s * (i * k + j * r),
|
291 |
+
two_s * (i * j + k * r),
|
292 |
+
1 - two_s * (i * i + k * k),
|
293 |
+
two_s * (j * k - i * r),
|
294 |
+
two_s * (i * k - j * r),
|
295 |
+
two_s * (j * k + i * r),
|
296 |
+
1 - two_s * (i * i + j * j),
|
297 |
+
),
|
298 |
+
-1,
|
299 |
+
)
|
300 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
301 |
+
|
302 |
+
|
303 |
+
def quaternion_to_matrix_np(quaternions):
|
304 |
+
q = torch.from_numpy(quaternions).contiguous().float()
|
305 |
+
return quaternion_to_matrix(q).numpy()
|
306 |
+
|
307 |
+
|
308 |
+
def quaternion_to_cont6d_np(quaternions):
|
309 |
+
rotation_mat = quaternion_to_matrix_np(quaternions)
|
310 |
+
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
|
311 |
+
return cont_6d
|
312 |
+
|
313 |
+
|
314 |
+
def quaternion_to_cont6d(quaternions):
|
315 |
+
rotation_mat = quaternion_to_matrix(quaternions)
|
316 |
+
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
|
317 |
+
return cont_6d
|
318 |
+
|
319 |
+
|
320 |
+
def cont6d_to_matrix(cont6d):
|
321 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
322 |
+
x_raw = cont6d[..., 0:3]
|
323 |
+
y_raw = cont6d[..., 3:6]
|
324 |
+
|
325 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
326 |
+
z = torch.cross(x, y_raw, dim=-1)
|
327 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
328 |
+
|
329 |
+
y = torch.cross(z, x, dim=-1)
|
330 |
+
|
331 |
+
x = x[..., None]
|
332 |
+
y = y[..., None]
|
333 |
+
z = z[..., None]
|
334 |
+
|
335 |
+
mat = torch.cat([x, y, z], dim=-1)
|
336 |
+
return mat
|
337 |
+
|
338 |
+
|
339 |
+
def cont6d_to_matrix_np(cont6d):
|
340 |
+
q = torch.from_numpy(cont6d).contiguous().float()
|
341 |
+
return cont6d_to_matrix(q).numpy()
|
342 |
+
|
343 |
+
|
344 |
+
def qpow(q0, t, dtype=torch.float):
|
345 |
+
''' q0 : tensor of quaternions
|
346 |
+
t: tensor of powers
|
347 |
+
'''
|
348 |
+
q0 = qnormalize(q0)
|
349 |
+
theta0 = torch.acos(q0[..., 0])
|
350 |
+
|
351 |
+
## if theta0 is close to zero, add epsilon to avoid NaNs
|
352 |
+
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
|
353 |
+
theta0 = (1 - mask) * theta0 + mask * 10e-10
|
354 |
+
v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
|
355 |
+
|
356 |
+
if isinstance(t, torch.Tensor):
|
357 |
+
q = torch.zeros(t.shape + q0.shape)
|
358 |
+
theta = t.view(-1, 1) * theta0.view(1, -1)
|
359 |
+
else: ## if t is a number
|
360 |
+
q = torch.zeros(q0.shape)
|
361 |
+
theta = t * theta0
|
362 |
+
|
363 |
+
q[..., 0] = torch.cos(theta)
|
364 |
+
q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
|
365 |
+
|
366 |
+
return q.to(dtype)
|
367 |
+
|
368 |
+
|
369 |
+
def qslerp(q0, q1, t):
|
370 |
+
'''
|
371 |
+
q0: starting quaternion
|
372 |
+
q1: ending quaternion
|
373 |
+
t: array of points along the way
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Tensor of Slerps: t.shape + q0.shape
|
377 |
+
'''
|
378 |
+
|
379 |
+
q0 = qnormalize(q0)
|
380 |
+
q1 = qnormalize(q1)
|
381 |
+
q_ = qpow(qmul(q1, qinv(q0)), t)
|
382 |
+
|
383 |
+
return qmul(q_,
|
384 |
+
q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
|
385 |
+
|
386 |
+
|
387 |
+
def qbetween(v0, v1):
|
388 |
+
'''
|
389 |
+
find the quaternion used to rotate v0 to v1
|
390 |
+
'''
|
391 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
392 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
393 |
+
|
394 |
+
v = torch.cross(v0, v1)
|
395 |
+
w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
|
396 |
+
keepdim=True)
|
397 |
+
return qnormalize(torch.cat([w, v], dim=-1))
|
398 |
+
|
399 |
+
|
400 |
+
def qbetween_np(v0, v1):
|
401 |
+
'''
|
402 |
+
find the quaternion used to rotate v0 to v1
|
403 |
+
'''
|
404 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
405 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
406 |
+
|
407 |
+
v0 = torch.from_numpy(v0).float()
|
408 |
+
v1 = torch.from_numpy(v1).float()
|
409 |
+
return qbetween(v0, v1).numpy()
|
410 |
+
|
411 |
+
|
412 |
+
def lerp(p0, p1, t):
|
413 |
+
if not isinstance(t, torch.Tensor):
|
414 |
+
t = torch.Tensor([t])
|
415 |
+
|
416 |
+
new_shape = t.shape + p0.shape
|
417 |
+
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
|
418 |
+
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
|
419 |
+
p0 = p0.view(new_view_p).expand(new_shape)
|
420 |
+
p1 = p1.view(new_view_p).expand(new_shape)
|
421 |
+
t = t.view(new_view_t).expand(new_shape)
|
422 |
+
|
423 |
+
return p0 + t * (p1 - p0)
|
data_loaders/humanml/common/skeleton.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data_loaders.humanml.common.quaternion import *
|
2 |
+
import scipy.ndimage.filters as filters
|
3 |
+
|
4 |
+
class Skeleton(object):
|
5 |
+
def __init__(self, offset, kinematic_tree, device):
|
6 |
+
self.device = device
|
7 |
+
self._raw_offset_np = offset.numpy()
|
8 |
+
self._raw_offset = offset.clone().detach().to(device).float()
|
9 |
+
self._kinematic_tree = kinematic_tree
|
10 |
+
self._offset = None
|
11 |
+
self._parents = [0] * len(self._raw_offset)
|
12 |
+
self._parents[0] = -1
|
13 |
+
for chain in self._kinematic_tree:
|
14 |
+
for j in range(1, len(chain)):
|
15 |
+
self._parents[chain[j]] = chain[j-1]
|
16 |
+
|
17 |
+
def njoints(self):
|
18 |
+
return len(self._raw_offset)
|
19 |
+
|
20 |
+
def offset(self):
|
21 |
+
return self._offset
|
22 |
+
|
23 |
+
def set_offset(self, offsets):
|
24 |
+
self._offset = offsets.clone().detach().to(self.device).float()
|
25 |
+
|
26 |
+
def kinematic_tree(self):
|
27 |
+
return self._kinematic_tree
|
28 |
+
|
29 |
+
def parents(self):
|
30 |
+
return self._parents
|
31 |
+
|
32 |
+
# joints (batch_size, joints_num, 3)
|
33 |
+
def get_offsets_joints_batch(self, joints):
|
34 |
+
assert len(joints.shape) == 3
|
35 |
+
_offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
|
36 |
+
for i in range(1, self._raw_offset.shape[0]):
|
37 |
+
_offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
|
38 |
+
|
39 |
+
self._offset = _offsets.detach()
|
40 |
+
return _offsets
|
41 |
+
|
42 |
+
# joints (joints_num, 3)
|
43 |
+
def get_offsets_joints(self, joints):
|
44 |
+
assert len(joints.shape) == 2
|
45 |
+
_offsets = self._raw_offset.clone()
|
46 |
+
for i in range(1, self._raw_offset.shape[0]):
|
47 |
+
# print(joints.shape)
|
48 |
+
_offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
|
49 |
+
|
50 |
+
self._offset = _offsets.detach()
|
51 |
+
return _offsets
|
52 |
+
|
53 |
+
# face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
|
54 |
+
# joints (batch_size, joints_num, 3)
|
55 |
+
def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
|
56 |
+
assert len(face_joint_idx) == 4
|
57 |
+
'''Get Forward Direction'''
|
58 |
+
l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
|
59 |
+
across1 = joints[:, r_hip] - joints[:, l_hip]
|
60 |
+
across2 = joints[:, sdr_r] - joints[:, sdr_l]
|
61 |
+
across = across1 + across2
|
62 |
+
across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
|
63 |
+
# print(across1.shape, across2.shape)
|
64 |
+
|
65 |
+
# forward (batch_size, 3)
|
66 |
+
forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
|
67 |
+
if smooth_forward:
|
68 |
+
forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
|
69 |
+
# forward (batch_size, 3)
|
70 |
+
forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
|
71 |
+
|
72 |
+
'''Get Root Rotation'''
|
73 |
+
target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
|
74 |
+
root_quat = qbetween_np(forward, target)
|
75 |
+
|
76 |
+
'''Inverse Kinematics'''
|
77 |
+
# quat_params (batch_size, joints_num, 4)
|
78 |
+
# print(joints.shape[:-1])
|
79 |
+
quat_params = np.zeros(joints.shape[:-1] + (4,))
|
80 |
+
# print(quat_params.shape)
|
81 |
+
root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
82 |
+
quat_params[:, 0] = root_quat
|
83 |
+
# quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
84 |
+
for chain in self._kinematic_tree:
|
85 |
+
R = root_quat
|
86 |
+
for j in range(len(chain) - 1):
|
87 |
+
# (batch, 3)
|
88 |
+
u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
|
89 |
+
# print(u.shape)
|
90 |
+
# (batch, 3)
|
91 |
+
v = joints[:, chain[j+1]] - joints[:, chain[j]]
|
92 |
+
v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
|
93 |
+
# print(u.shape, v.shape)
|
94 |
+
rot_u_v = qbetween_np(u, v)
|
95 |
+
|
96 |
+
R_loc = qmul_np(qinv_np(R), rot_u_v)
|
97 |
+
|
98 |
+
quat_params[:,chain[j + 1], :] = R_loc
|
99 |
+
R = qmul_np(R, R_loc)
|
100 |
+
|
101 |
+
return quat_params
|
102 |
+
|
103 |
+
# Be sure root joint is at the beginning of kinematic chains
|
104 |
+
def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
105 |
+
# quat_params (batch_size, joints_num, 4)
|
106 |
+
# joints (batch_size, joints_num, 3)
|
107 |
+
# root_pos (batch_size, 3)
|
108 |
+
if skel_joints is not None:
|
109 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
110 |
+
if len(self._offset.shape) == 2:
|
111 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
112 |
+
joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
|
113 |
+
joints[:, 0] = root_pos
|
114 |
+
for chain in self._kinematic_tree:
|
115 |
+
if do_root_R:
|
116 |
+
R = quat_params[:, 0]
|
117 |
+
else:
|
118 |
+
R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
|
119 |
+
for i in range(1, len(chain)):
|
120 |
+
R = qmul(R, quat_params[:, chain[i]])
|
121 |
+
offset_vec = offsets[:, chain[i]]
|
122 |
+
joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
|
123 |
+
return joints
|
124 |
+
|
125 |
+
# Be sure root joint is at the beginning of kinematic chains
|
126 |
+
def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
127 |
+
# quat_params (batch_size, joints_num, 4)
|
128 |
+
# joints (batch_size, joints_num, 3)
|
129 |
+
# root_pos (batch_size, 3)
|
130 |
+
if skel_joints is not None:
|
131 |
+
skel_joints = torch.from_numpy(skel_joints)
|
132 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
133 |
+
if len(self._offset.shape) == 2:
|
134 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
135 |
+
offsets = offsets.numpy()
|
136 |
+
joints = np.zeros(quat_params.shape[:-1] + (3,))
|
137 |
+
joints[:, 0] = root_pos
|
138 |
+
for chain in self._kinematic_tree:
|
139 |
+
if do_root_R:
|
140 |
+
R = quat_params[:, 0]
|
141 |
+
else:
|
142 |
+
R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
|
143 |
+
for i in range(1, len(chain)):
|
144 |
+
R = qmul_np(R, quat_params[:, chain[i]])
|
145 |
+
offset_vec = offsets[:, chain[i]]
|
146 |
+
joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
|
147 |
+
return joints
|
148 |
+
|
149 |
+
def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
150 |
+
# cont6d_params (batch_size, joints_num, 6)
|
151 |
+
# joints (batch_size, joints_num, 3)
|
152 |
+
# root_pos (batch_size, 3)
|
153 |
+
if skel_joints is not None:
|
154 |
+
skel_joints = torch.from_numpy(skel_joints)
|
155 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
156 |
+
if len(self._offset.shape) == 2:
|
157 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
158 |
+
offsets = offsets.numpy()
|
159 |
+
joints = np.zeros(cont6d_params.shape[:-1] + (3,))
|
160 |
+
joints[:, 0] = root_pos
|
161 |
+
for chain in self._kinematic_tree:
|
162 |
+
if do_root_R:
|
163 |
+
matR = cont6d_to_matrix_np(cont6d_params[:, 0])
|
164 |
+
else:
|
165 |
+
matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
|
166 |
+
for i in range(1, len(chain)):
|
167 |
+
matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
|
168 |
+
offset_vec = offsets[:, chain[i]][..., np.newaxis]
|
169 |
+
# print(matR.shape, offset_vec.shape)
|
170 |
+
joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
171 |
+
return joints
|
172 |
+
|
173 |
+
def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
174 |
+
# cont6d_params (batch_size, joints_num, 6)
|
175 |
+
# joints (batch_size, joints_num, 3)
|
176 |
+
# root_pos (batch_size, 3)
|
177 |
+
if skel_joints is not None:
|
178 |
+
# skel_joints = torch.from_numpy(skel_joints)
|
179 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
180 |
+
if len(self._offset.shape) == 2:
|
181 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
182 |
+
joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
|
183 |
+
joints[..., 0, :] = root_pos
|
184 |
+
for chain in self._kinematic_tree:
|
185 |
+
if do_root_R:
|
186 |
+
matR = cont6d_to_matrix(cont6d_params[:, 0])
|
187 |
+
else:
|
188 |
+
matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
|
189 |
+
for i in range(1, len(chain)):
|
190 |
+
matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
|
191 |
+
offset_vec = offsets[:, chain[i]].unsqueeze(-1)
|
192 |
+
# print(matR.shape, offset_vec.shape)
|
193 |
+
joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
194 |
+
return joints
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
data_loaders/humanml/data/__init__.py
ADDED
File without changes
|
data_loaders/humanml/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (172 Bytes). View file
|
|
data_loaders/humanml/data/__pycache__/dataset.cpython-38.pyc
ADDED
Binary file (19.1 kB). View file
|
|
data_loaders/humanml/data/__pycache__/dataset_ours.cpython-38.pyc
ADDED
Binary file (73.1 kB). View file
|
|
data_loaders/humanml/data/__pycache__/dataset_ours_single_seq.cpython-38.pyc
ADDED
Binary file (87.9 kB). View file
|
|
data_loaders/humanml/data/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (15.3 kB). View file
|
|
data_loaders/humanml/data/dataset.py
ADDED
@@ -0,0 +1,795 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils import data
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from os.path import join as pjoin
|
6 |
+
import random
|
7 |
+
import codecs as cs
|
8 |
+
from tqdm import tqdm
|
9 |
+
import spacy
|
10 |
+
|
11 |
+
from torch.utils.data._utils.collate import default_collate
|
12 |
+
from data_loaders.humanml.utils.word_vectorizer import WordVectorizer
|
13 |
+
from data_loaders.humanml.utils.get_opt import get_opt
|
14 |
+
|
15 |
+
# import spacy
|
16 |
+
|
17 |
+
def collate_fn(batch):
|
18 |
+
batch.sort(key=lambda x: x[3], reverse=True)
|
19 |
+
return default_collate(batch)
|
20 |
+
|
21 |
+
|
22 |
+
'''For use of training text-2-motion generative model'''
|
23 |
+
class Text2MotionDataset(data.Dataset):
|
24 |
+
def __init__(self, opt, mean, std, split_file, w_vectorizer):
|
25 |
+
self.opt = opt
|
26 |
+
self.w_vectorizer = w_vectorizer
|
27 |
+
self.max_length = 20
|
28 |
+
self.pointer = 0
|
29 |
+
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
|
30 |
+
|
31 |
+
joints_num = opt.joints_num
|
32 |
+
|
33 |
+
data_dict = {}
|
34 |
+
id_list = []
|
35 |
+
with cs.open(split_file, 'r') as f:
|
36 |
+
for line in f.readlines():
|
37 |
+
id_list.append(line.strip())
|
38 |
+
|
39 |
+
new_name_list = []
|
40 |
+
length_list = []
|
41 |
+
for name in tqdm(id_list):
|
42 |
+
try:
|
43 |
+
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
|
44 |
+
if (len(motion)) < min_motion_len or (len(motion) >= 200):
|
45 |
+
continue
|
46 |
+
text_data = []
|
47 |
+
flag = False
|
48 |
+
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
|
49 |
+
for line in f.readlines():
|
50 |
+
text_dict = {}
|
51 |
+
line_split = line.strip().split('#')
|
52 |
+
caption = line_split[0]
|
53 |
+
tokens = line_split[1].split(' ')
|
54 |
+
f_tag = float(line_split[2])
|
55 |
+
to_tag = float(line_split[3])
|
56 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
57 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
58 |
+
|
59 |
+
text_dict['caption'] = caption
|
60 |
+
text_dict['tokens'] = tokens
|
61 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
62 |
+
flag = True
|
63 |
+
text_data.append(text_dict)
|
64 |
+
else:
|
65 |
+
try:
|
66 |
+
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
|
67 |
+
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
|
68 |
+
continue
|
69 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
70 |
+
while new_name in data_dict:
|
71 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
72 |
+
data_dict[new_name] = {'motion': n_motion,
|
73 |
+
'length': len(n_motion),
|
74 |
+
'text':[text_dict]}
|
75 |
+
new_name_list.append(new_name)
|
76 |
+
length_list.append(len(n_motion))
|
77 |
+
except:
|
78 |
+
print(line_split)
|
79 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
80 |
+
# break
|
81 |
+
|
82 |
+
if flag:
|
83 |
+
data_dict[name] = {'motion': motion,
|
84 |
+
'length': len(motion),
|
85 |
+
'text':text_data}
|
86 |
+
new_name_list.append(name)
|
87 |
+
length_list.append(len(motion))
|
88 |
+
except:
|
89 |
+
# Some motion may not exist in KIT dataset
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
94 |
+
|
95 |
+
if opt.is_train:
|
96 |
+
# root_rot_velocity (B, seq_len, 1)
|
97 |
+
std[0:1] = std[0:1] / opt.feat_bias
|
98 |
+
# root_linear_velocity (B, seq_len, 2)
|
99 |
+
std[1:3] = std[1:3] / opt.feat_bias
|
100 |
+
# root_y (B, seq_len, 1)
|
101 |
+
std[3:4] = std[3:4] / opt.feat_bias
|
102 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
103 |
+
std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
|
104 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
105 |
+
std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
|
106 |
+
joints_num - 1) * 9] / 1.0
|
107 |
+
# local_velocity (B, seq_len, joint_num*3)
|
108 |
+
std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
|
109 |
+
4 + (joints_num - 1) * 9: 4 + (
|
110 |
+
joints_num - 1) * 9 + joints_num * 3] / 1.0
|
111 |
+
# foot contact (B, seq_len, 4)
|
112 |
+
std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
|
113 |
+
4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
|
114 |
+
|
115 |
+
assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
|
116 |
+
np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
|
117 |
+
np.save(pjoin(opt.meta_dir, 'std.npy'), std)
|
118 |
+
|
119 |
+
self.mean = mean
|
120 |
+
self.std = std
|
121 |
+
self.length_arr = np.array(length_list)
|
122 |
+
self.data_dict = data_dict
|
123 |
+
self.name_list = name_list
|
124 |
+
self.reset_max_len(self.max_length)
|
125 |
+
|
126 |
+
def reset_max_len(self, length):
|
127 |
+
assert length <= self.opt.max_motion_length
|
128 |
+
self.pointer = np.searchsorted(self.length_arr, length)
|
129 |
+
print("Pointer Pointing at %d"%self.pointer)
|
130 |
+
self.max_length = length
|
131 |
+
|
132 |
+
def inv_transform(self, data):
|
133 |
+
return data * self.std + self.mean
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return len(self.data_dict) - self.pointer
|
137 |
+
|
138 |
+
def __getitem__(self, item):
|
139 |
+
idx = self.pointer + item
|
140 |
+
data = self.data_dict[self.name_list[idx]]
|
141 |
+
motion, m_length, text_list = data['motion'], data['length'], data['text']
|
142 |
+
# Randomly select a caption
|
143 |
+
text_data = random.choice(text_list)
|
144 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
145 |
+
|
146 |
+
if len(tokens) < self.opt.max_text_len:
|
147 |
+
# pad with "unk"
|
148 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
149 |
+
sent_len = len(tokens)
|
150 |
+
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
|
151 |
+
else:
|
152 |
+
# crop
|
153 |
+
tokens = tokens[:self.opt.max_text_len]
|
154 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
155 |
+
sent_len = len(tokens)
|
156 |
+
pos_one_hots = []
|
157 |
+
word_embeddings = []
|
158 |
+
for token in tokens:
|
159 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
160 |
+
pos_one_hots.append(pos_oh[None, :])
|
161 |
+
word_embeddings.append(word_emb[None, :])
|
162 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
163 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
164 |
+
|
165 |
+
len_gap = (m_length - self.max_length) // self.opt.unit_length
|
166 |
+
|
167 |
+
if self.opt.is_train:
|
168 |
+
if m_length != self.max_length:
|
169 |
+
# print("Motion original length:%d_%d"%(m_length, len(motion)))
|
170 |
+
if self.opt.unit_length < 10:
|
171 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
172 |
+
else:
|
173 |
+
coin2 = 'single'
|
174 |
+
if len_gap == 0 or (len_gap == 1 and coin2 == 'double'):
|
175 |
+
m_length = self.max_length
|
176 |
+
idx = random.randint(0, m_length - self.max_length)
|
177 |
+
motion = motion[idx:idx+self.max_length]
|
178 |
+
else:
|
179 |
+
if coin2 == 'single':
|
180 |
+
n_m_length = self.max_length + self.opt.unit_length * len_gap
|
181 |
+
else:
|
182 |
+
n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1)
|
183 |
+
idx = random.randint(0, m_length - n_m_length)
|
184 |
+
motion = motion[idx:idx + self.max_length]
|
185 |
+
m_length = n_m_length
|
186 |
+
# print(len_gap, idx, coin2)
|
187 |
+
else:
|
188 |
+
if self.opt.unit_length < 10:
|
189 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
190 |
+
else:
|
191 |
+
coin2 = 'single'
|
192 |
+
|
193 |
+
if coin2 == 'double':
|
194 |
+
m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
|
195 |
+
elif coin2 == 'single':
|
196 |
+
m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
|
197 |
+
idx = random.randint(0, len(motion) - m_length)
|
198 |
+
motion = motion[idx:idx+m_length]
|
199 |
+
|
200 |
+
"Z Normalization"
|
201 |
+
motion = (motion - self.mean) / self.std
|
202 |
+
|
203 |
+
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length
|
204 |
+
|
205 |
+
|
206 |
+
'''For use of training text motion matching model, and evaluations'''
|
207 |
+
## text2motions dataset v2 ##
|
208 |
+
class Text2MotionDatasetV2(data.Dataset): # text2motion dataset
|
209 |
+
def __init__(self, opt, mean, std, split_file, w_vectorizer):
|
210 |
+
self.opt = opt
|
211 |
+
self.w_vectorizer = w_vectorizer
|
212 |
+
self.max_length = 20
|
213 |
+
self.pointer = 0
|
214 |
+
self.max_motion_length = opt.max_motion_length
|
215 |
+
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
|
216 |
+
|
217 |
+
data_dict = {}
|
218 |
+
id_list = []
|
219 |
+
with cs.open(split_file, 'r') as f:
|
220 |
+
for line in f.readlines():
|
221 |
+
id_list.append(line.strip()) ## id list ##
|
222 |
+
# id_list = id_list[:200]
|
223 |
+
|
224 |
+
new_name_list = []
|
225 |
+
length_list = []
|
226 |
+
for name in tqdm(id_list):
|
227 |
+
try:
|
228 |
+
## motion_dir ##
|
229 |
+
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
|
230 |
+
if (len(motion)) < min_motion_len or (len(motion) >= 200):
|
231 |
+
continue
|
232 |
+
text_data = []
|
233 |
+
flag = False
|
234 |
+
## motionnn
|
235 |
+
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
|
236 |
+
for line in f.readlines():
|
237 |
+
text_dict = {}
|
238 |
+
line_split = line.strip().split('#')
|
239 |
+
caption = line_split[0]
|
240 |
+
tokens = line_split[1].split(' ')
|
241 |
+
f_tag = float(line_split[2])
|
242 |
+
to_tag = float(line_split[3])
|
243 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
244 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
245 |
+
|
246 |
+
text_dict['caption'] = caption ## caption, motion, ##
|
247 |
+
text_dict['tokens'] = tokens
|
248 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
249 |
+
flag = True
|
250 |
+
text_data.append(text_dict)
|
251 |
+
else:
|
252 |
+
try:
|
253 |
+
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
|
254 |
+
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
|
255 |
+
continue
|
256 |
+
# new name for indexing #
|
257 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
258 |
+
while new_name in data_dict:
|
259 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
260 |
+
data_dict[new_name] = {'motion': n_motion,
|
261 |
+
'length': len(n_motion), ## length of motion ##
|
262 |
+
'text':[text_dict]}
|
263 |
+
new_name_list.append(new_name)
|
264 |
+
length_list.append(len(n_motion))
|
265 |
+
except:
|
266 |
+
print(line_split)
|
267 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
268 |
+
# break
|
269 |
+
|
270 |
+
if flag:
|
271 |
+
## motion, lenght, text ##
|
272 |
+
data_dict[name] = {'motion': motion, ## motion, length of the motion, text data
|
273 |
+
'length': len(motion), ## motion, lenght, text ##
|
274 |
+
'text': text_data}
|
275 |
+
new_name_list.append(name)
|
276 |
+
length_list.append(len(motion))
|
277 |
+
except:
|
278 |
+
pass
|
279 |
+
|
280 |
+
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
281 |
+
|
282 |
+
self.mean = mean
|
283 |
+
self.std = std
|
284 |
+
self.length_arr = np.array(length_list)
|
285 |
+
self.data_dict = data_dict
|
286 |
+
self.name_list = name_list
|
287 |
+
self.reset_max_len(self.max_length)
|
288 |
+
|
289 |
+
def reset_max_len(self, length):
|
290 |
+
assert length <= self.max_motion_length
|
291 |
+
self.pointer = np.searchsorted(self.length_arr, length)
|
292 |
+
print("Pointer Pointing at %d"%self.pointer)
|
293 |
+
self.max_length = length
|
294 |
+
|
295 |
+
def inv_transform(self, data):
|
296 |
+
return data * self.std + self.mean
|
297 |
+
|
298 |
+
def __len__(self):
|
299 |
+
return len(self.data_dict) - self.pointer
|
300 |
+
|
301 |
+
def __getitem__(self, item):
|
302 |
+
idx = self.pointer + item
|
303 |
+
data = self.data_dict[self.name_list[idx]] # data
|
304 |
+
motion, m_length, text_list = data['motion'], data['length'], data['text']
|
305 |
+
# Randomly select a caption
|
306 |
+
text_data = random.choice(text_list)
|
307 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
308 |
+
|
309 |
+
if len(tokens) < self.opt.max_text_len:
|
310 |
+
# pad with "unk"
|
311 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
312 |
+
sent_len = len(tokens)
|
313 |
+
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
|
314 |
+
else:
|
315 |
+
# crop
|
316 |
+
tokens = tokens[:self.opt.max_text_len]
|
317 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
318 |
+
sent_len = len(tokens)
|
319 |
+
pos_one_hots = [] ## pose one hots ##
|
320 |
+
word_embeddings = []
|
321 |
+
for token in tokens:
|
322 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
323 |
+
pos_one_hots.append(pos_oh[None, :])
|
324 |
+
word_embeddings.append(word_emb[None, :])
|
325 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
326 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
327 |
+
|
328 |
+
# Crop the motions in to times of 4, and introduce small variations
|
329 |
+
if self.opt.unit_length < 10:
|
330 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
331 |
+
else:
|
332 |
+
coin2 = 'single'
|
333 |
+
|
334 |
+
if coin2 == 'double':
|
335 |
+
m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
|
336 |
+
elif coin2 == 'single':
|
337 |
+
m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
|
338 |
+
idx = random.randint(0, len(motion) - m_length)
|
339 |
+
motion = motion[idx:idx+m_length]
|
340 |
+
|
341 |
+
"Z Normalization"
|
342 |
+
motion = (motion - self.mean) / self.std
|
343 |
+
|
344 |
+
if m_length < self.max_motion_length:
|
345 |
+
motion = np.concatenate([motion, # positions # right? #
|
346 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
|
347 |
+
], axis=0)
|
348 |
+
# print(word_embeddings.shape, motion.shape)
|
349 |
+
# print(tokens)
|
350 |
+
return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
|
351 |
+
|
352 |
+
|
353 |
+
## and
|
354 |
+
'''For use of training baseline'''
|
355 |
+
class Text2MotionDatasetBaseline(data.Dataset):
|
356 |
+
def __init__(self, opt, mean, std, split_file, w_vectorizer):
|
357 |
+
self.opt = opt
|
358 |
+
self.w_vectorizer = w_vectorizer
|
359 |
+
self.max_length = 20
|
360 |
+
self.pointer = 0
|
361 |
+
self.max_motion_length = opt.max_motion_length
|
362 |
+
min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
|
363 |
+
|
364 |
+
data_dict = {}
|
365 |
+
id_list = []
|
366 |
+
with cs.open(split_file, 'r') as f:
|
367 |
+
for line in f.readlines():
|
368 |
+
id_list.append(line.strip())
|
369 |
+
# id_list = id_list[:200]
|
370 |
+
|
371 |
+
new_name_list = []
|
372 |
+
length_list = []
|
373 |
+
for name in tqdm(id_list):
|
374 |
+
try:
|
375 |
+
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
|
376 |
+
if (len(motion)) < min_motion_len or (len(motion) >= 200):
|
377 |
+
continue
|
378 |
+
text_data = []
|
379 |
+
flag = False
|
380 |
+
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
|
381 |
+
for line in f.readlines():
|
382 |
+
text_dict = {}
|
383 |
+
line_split = line.strip().split('#')
|
384 |
+
caption = line_split[0]
|
385 |
+
tokens = line_split[1].split(' ')
|
386 |
+
f_tag = float(line_split[2])
|
387 |
+
to_tag = float(line_split[3])
|
388 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
389 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
390 |
+
|
391 |
+
text_dict['caption'] = caption
|
392 |
+
text_dict['tokens'] = tokens
|
393 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
394 |
+
flag = True
|
395 |
+
text_data.append(text_dict)
|
396 |
+
else:
|
397 |
+
try:
|
398 |
+
n_motion = motion[int(f_tag*20) : int(to_tag*20)]
|
399 |
+
if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
|
400 |
+
continue
|
401 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
402 |
+
while new_name in data_dict:
|
403 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
404 |
+
data_dict[new_name] = {'motion': n_motion,
|
405 |
+
'length': len(n_motion),
|
406 |
+
'text':[text_dict]}
|
407 |
+
new_name_list.append(new_name)
|
408 |
+
length_list.append(len(n_motion))
|
409 |
+
except:
|
410 |
+
print(line_split)
|
411 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
412 |
+
# break
|
413 |
+
|
414 |
+
if flag:
|
415 |
+
data_dict[name] = {'motion': motion,
|
416 |
+
'length': len(motion),
|
417 |
+
'text': text_data}
|
418 |
+
new_name_list.append(name)
|
419 |
+
length_list.append(len(motion))
|
420 |
+
except:
|
421 |
+
pass
|
422 |
+
|
423 |
+
name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
|
424 |
+
|
425 |
+
self.mean = mean
|
426 |
+
self.std = std
|
427 |
+
self.length_arr = np.array(length_list)
|
428 |
+
self.data_dict = data_dict
|
429 |
+
self.name_list = name_list
|
430 |
+
self.reset_max_len(self.max_length)
|
431 |
+
|
432 |
+
def reset_max_len(self, length):
|
433 |
+
assert length <= self.max_motion_length
|
434 |
+
self.pointer = np.searchsorted(self.length_arr, length)
|
435 |
+
print("Pointer Pointing at %d"%self.pointer)
|
436 |
+
self.max_length = length
|
437 |
+
|
438 |
+
def inv_transform(self, data):
|
439 |
+
return data * self.std + self.mean
|
440 |
+
|
441 |
+
def __len__(self):
|
442 |
+
return len(self.data_dict) - self.pointer
|
443 |
+
|
444 |
+
def __getitem__(self, item):
|
445 |
+
idx = self.pointer + item
|
446 |
+
data = self.data_dict[self.name_list[idx]]
|
447 |
+
motion, m_length, text_list = data['motion'], data['length'], data['text']
|
448 |
+
# Randomly select a caption
|
449 |
+
text_data = random.choice(text_list)
|
450 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
451 |
+
|
452 |
+
if len(tokens) < self.opt.max_text_len:
|
453 |
+
# pad with "unk"
|
454 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
455 |
+
sent_len = len(tokens)
|
456 |
+
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
|
457 |
+
else:
|
458 |
+
# crop
|
459 |
+
tokens = tokens[:self.opt.max_text_len]
|
460 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
461 |
+
sent_len = len(tokens)
|
462 |
+
pos_one_hots = []
|
463 |
+
word_embeddings = []
|
464 |
+
for token in tokens:
|
465 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
466 |
+
pos_one_hots.append(pos_oh[None, :])
|
467 |
+
word_embeddings.append(word_emb[None, :])
|
468 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
469 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
470 |
+
|
471 |
+
len_gap = (m_length - self.max_length) // self.opt.unit_length
|
472 |
+
|
473 |
+
if m_length != self.max_length:
|
474 |
+
# print("Motion original length:%d_%d"%(m_length, len(motion)))
|
475 |
+
if self.opt.unit_length < 10:
|
476 |
+
coin2 = np.random.choice(['single', 'single', 'double'])
|
477 |
+
else:
|
478 |
+
coin2 = 'single'
|
479 |
+
if len_gap == 0 or (len_gap == 1 and coin2 == 'double'):
|
480 |
+
m_length = self.max_length
|
481 |
+
s_idx = random.randint(0, m_length - self.max_length)
|
482 |
+
else:
|
483 |
+
if coin2 == 'single':
|
484 |
+
n_m_length = self.max_length + self.opt.unit_length * len_gap
|
485 |
+
else:
|
486 |
+
n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1)
|
487 |
+
s_idx = random.randint(0, m_length - n_m_length)
|
488 |
+
m_length = n_m_length
|
489 |
+
else:
|
490 |
+
s_idx = 0
|
491 |
+
|
492 |
+
src_motion = motion[s_idx: s_idx + m_length]
|
493 |
+
tgt_motion = motion[s_idx: s_idx + self.max_length]
|
494 |
+
|
495 |
+
"Z Normalization"
|
496 |
+
src_motion = (src_motion - self.mean) / self.std
|
497 |
+
tgt_motion = (tgt_motion - self.mean) / self.std
|
498 |
+
|
499 |
+
if m_length < self.max_motion_length:
|
500 |
+
src_motion = np.concatenate([src_motion,
|
501 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1]))
|
502 |
+
], axis=0)
|
503 |
+
# print(m_length, src_motion.shape, tgt_motion.shape)
|
504 |
+
# print(word_embeddings.shape, motion.shape)
|
505 |
+
# print(tokens)
|
506 |
+
return word_embeddings, caption, sent_len, src_motion, tgt_motion, m_length
|
507 |
+
|
508 |
+
|
509 |
+
class MotionDatasetV2(data.Dataset):
|
510 |
+
def __init__(self, opt, mean, std, split_file):
|
511 |
+
self.opt = opt
|
512 |
+
joints_num = opt.joints_num
|
513 |
+
|
514 |
+
self.data = []
|
515 |
+
self.lengths = []
|
516 |
+
id_list = []
|
517 |
+
with cs.open(split_file, 'r') as f:
|
518 |
+
for line in f.readlines():
|
519 |
+
id_list.append(line.strip())
|
520 |
+
|
521 |
+
for name in tqdm(id_list):
|
522 |
+
try:
|
523 |
+
motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
|
524 |
+
if motion.shape[0] < opt.window_size:
|
525 |
+
continue
|
526 |
+
self.lengths.append(motion.shape[0] - opt.window_size)
|
527 |
+
self.data.append(motion)
|
528 |
+
except:
|
529 |
+
# Some motion may not exist in KIT dataset
|
530 |
+
pass
|
531 |
+
|
532 |
+
self.cumsum = np.cumsum([0] + self.lengths)
|
533 |
+
|
534 |
+
if opt.is_train:
|
535 |
+
# root_rot_velocity (B, seq_len, 1)
|
536 |
+
std[0:1] = std[0:1] / opt.feat_bias
|
537 |
+
# root_linear_velocity (B, seq_len, 2)
|
538 |
+
std[1:3] = std[1:3] / opt.feat_bias
|
539 |
+
# root_y (B, seq_len, 1)
|
540 |
+
std[3:4] = std[3:4] / opt.feat_bias
|
541 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
542 |
+
std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
|
543 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
544 |
+
std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
|
545 |
+
joints_num - 1) * 9] / 1.0
|
546 |
+
# local_velocity (B, seq_len, joint_num*3)
|
547 |
+
std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
|
548 |
+
4 + (joints_num - 1) * 9: 4 + (
|
549 |
+
joints_num - 1) * 9 + joints_num * 3] / 1.0
|
550 |
+
# foot contact (B, seq_len, 4)
|
551 |
+
std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
|
552 |
+
4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
|
553 |
+
|
554 |
+
assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
|
555 |
+
np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
|
556 |
+
np.save(pjoin(opt.meta_dir, 'std.npy'), std)
|
557 |
+
|
558 |
+
self.mean = mean
|
559 |
+
self.std = std
|
560 |
+
print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
|
561 |
+
|
562 |
+
def inv_transform(self, data):
|
563 |
+
return data * self.std + self.mean
|
564 |
+
|
565 |
+
def __len__(self):
|
566 |
+
return self.cumsum[-1]
|
567 |
+
|
568 |
+
def __getitem__(self, item):
|
569 |
+
if item != 0:
|
570 |
+
motion_id = np.searchsorted(self.cumsum, item) - 1
|
571 |
+
idx = item - self.cumsum[motion_id] - 1
|
572 |
+
else:
|
573 |
+
motion_id = 0
|
574 |
+
idx = 0
|
575 |
+
# idx + j
|
576 |
+
motion = self.data[motion_id][idx:idx+self.opt.window_size]
|
577 |
+
"Z Normalization"
|
578 |
+
motion = (motion - self.mean) / self.std
|
579 |
+
|
580 |
+
return motion
|
581 |
+
|
582 |
+
|
583 |
+
class RawTextDataset(data.Dataset):
|
584 |
+
def __init__(self, opt, mean, std, text_file, w_vectorizer):
|
585 |
+
self.mean = mean
|
586 |
+
self.std = std
|
587 |
+
self.opt = opt
|
588 |
+
self.data_dict = []
|
589 |
+
self.nlp = spacy.load('en_core_web_sm')
|
590 |
+
|
591 |
+
with cs.open(text_file) as f:
|
592 |
+
for line in f.readlines():
|
593 |
+
word_list, pos_list = self.process_text(line.strip())
|
594 |
+
tokens = ['%s/%s'%(word_list[i], pos_list[i]) for i in range(len(word_list))]
|
595 |
+
self.data_dict.append({'caption':line.strip(), "tokens":tokens})
|
596 |
+
|
597 |
+
self.w_vectorizer = w_vectorizer
|
598 |
+
print("Total number of descriptions {}".format(len(self.data_dict)))
|
599 |
+
|
600 |
+
|
601 |
+
def process_text(self, sentence):
|
602 |
+
sentence = sentence.replace('-', '')
|
603 |
+
doc = self.nlp(sentence)
|
604 |
+
word_list = []
|
605 |
+
pos_list = []
|
606 |
+
for token in doc:
|
607 |
+
word = token.text
|
608 |
+
if not word.isalpha():
|
609 |
+
continue
|
610 |
+
if (token.pos_ == 'NOUN' or token.pos_ == 'VERB') and (word != 'left'):
|
611 |
+
word_list.append(token.lemma_)
|
612 |
+
else:
|
613 |
+
word_list.append(word)
|
614 |
+
pos_list.append(token.pos_)
|
615 |
+
return word_list, pos_list
|
616 |
+
|
617 |
+
def inv_transform(self, data):
|
618 |
+
return data * self.std + self.mean
|
619 |
+
|
620 |
+
def __len__(self):
|
621 |
+
return len(self.data_dict)
|
622 |
+
|
623 |
+
def __getitem__(self, item):
|
624 |
+
data = self.data_dict[item]
|
625 |
+
caption, tokens = data['caption'], data['tokens']
|
626 |
+
|
627 |
+
if len(tokens) < self.opt.max_text_len:
|
628 |
+
# pad with "unk"
|
629 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
630 |
+
sent_len = len(tokens)
|
631 |
+
tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
|
632 |
+
else:
|
633 |
+
# crop
|
634 |
+
tokens = tokens[:self.opt.max_text_len]
|
635 |
+
tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
|
636 |
+
sent_len = len(tokens)
|
637 |
+
pos_one_hots = []
|
638 |
+
word_embeddings = []
|
639 |
+
for token in tokens:
|
640 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
641 |
+
pos_one_hots.append(pos_oh[None, :])
|
642 |
+
word_embeddings.append(word_emb[None, :])
|
643 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
644 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
645 |
+
|
646 |
+
return word_embeddings, pos_one_hots, caption, sent_len
|
647 |
+
|
648 |
+
class TextOnlyDataset(data.Dataset):
|
649 |
+
def __init__(self, opt, mean, std, split_file):
|
650 |
+
self.mean = mean
|
651 |
+
self.std = std
|
652 |
+
self.opt = opt
|
653 |
+
self.data_dict = []
|
654 |
+
self.max_length = 20
|
655 |
+
self.pointer = 0
|
656 |
+
self.fixed_length = 120
|
657 |
+
|
658 |
+
|
659 |
+
data_dict = {}
|
660 |
+
id_list = []
|
661 |
+
with cs.open(split_file, 'r') as f:
|
662 |
+
for line in f.readlines():
|
663 |
+
id_list.append(line.strip())
|
664 |
+
# id_list = id_list[:200]
|
665 |
+
|
666 |
+
new_name_list = []
|
667 |
+
length_list = []
|
668 |
+
for name in tqdm(id_list):
|
669 |
+
try:
|
670 |
+
text_data = []
|
671 |
+
flag = False
|
672 |
+
with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
|
673 |
+
for line in f.readlines():
|
674 |
+
text_dict = {}
|
675 |
+
line_split = line.strip().split('#')
|
676 |
+
caption = line_split[0]
|
677 |
+
tokens = line_split[1].split(' ')
|
678 |
+
f_tag = float(line_split[2])
|
679 |
+
to_tag = float(line_split[3])
|
680 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
681 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
682 |
+
|
683 |
+
text_dict['caption'] = caption
|
684 |
+
text_dict['tokens'] = tokens
|
685 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
686 |
+
flag = True
|
687 |
+
text_data.append(text_dict)
|
688 |
+
else:
|
689 |
+
try:
|
690 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
691 |
+
while new_name in data_dict:
|
692 |
+
new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
|
693 |
+
data_dict[new_name] = {'text':[text_dict]}
|
694 |
+
new_name_list.append(new_name)
|
695 |
+
except:
|
696 |
+
print(line_split)
|
697 |
+
print(line_split[2], line_split[3], f_tag, to_tag, name)
|
698 |
+
# break
|
699 |
+
|
700 |
+
if flag:
|
701 |
+
data_dict[name] = {'text': text_data}
|
702 |
+
new_name_list.append(name)
|
703 |
+
except:
|
704 |
+
pass
|
705 |
+
|
706 |
+
self.length_arr = np.array(length_list)
|
707 |
+
self.data_dict = data_dict
|
708 |
+
self.name_list = new_name_list
|
709 |
+
|
710 |
+
def inv_transform(self, data):
|
711 |
+
return data * self.std + self.mean
|
712 |
+
|
713 |
+
def __len__(self):
|
714 |
+
return len(self.data_dict)
|
715 |
+
|
716 |
+
def __getitem__(self, item):
|
717 |
+
idx = self.pointer + item
|
718 |
+
data = self.data_dict[self.name_list[idx]]
|
719 |
+
text_list = data['text']
|
720 |
+
|
721 |
+
# Randomly select a caption
|
722 |
+
text_data = random.choice(text_list)
|
723 |
+
caption, tokens = text_data['caption'], text_data['tokens']
|
724 |
+
return None, None, caption, None, np.array([0]), self.fixed_length, None
|
725 |
+
# fixed_length can be set from outside before sampling
|
726 |
+
|
727 |
+
## t2m original dataset
|
728 |
+
# A wrapper class for t2m original dataset for MDM purposes
|
729 |
+
# humanml 3D
|
730 |
+
class HumanML3D(data.Dataset): ## humanml dataset ## ## human ml dataset text2motion ##
|
731 |
+
def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", load_vectorizer=False, **kwargs):
|
732 |
+
self.mode = mode
|
733 |
+
|
734 |
+
self.dataset_name = 't2m'
|
735 |
+
self.dataname = 't2m'
|
736 |
+
|
737 |
+
### humanml3d --> humanml3d,
|
738 |
+
# Configurations of T2M dataset and KIT dataset is almost the same
|
739 |
+
abs_base_path = f'.'
|
740 |
+
dataset_opt_path = pjoin(abs_base_path, datapath) ## pjoin, pjoin, getopt, # abs
|
741 |
+
device = None # torch.device('cuda:4') # This param is not in use in this context
|
742 |
+
opt = get_opt(dataset_opt_path, device)
|
743 |
+
opt.meta_dir = pjoin(abs_base_path, opt.meta_dir)
|
744 |
+
opt.motion_dir = pjoin(abs_base_path, opt.motion_dir)
|
745 |
+
opt.text_dir = pjoin(abs_base_path, opt.text_dir)
|
746 |
+
opt.model_dir = pjoin(abs_base_path, opt.model_dir)
|
747 |
+
opt.checkpoints_dir = pjoin(abs_base_path, opt.checkpoints_dir)
|
748 |
+
opt.data_root = pjoin(abs_base_path, opt.data_root) ## data_root --> data root;
|
749 |
+
opt.save_root = pjoin(abs_base_path, opt.save_root)
|
750 |
+
opt.meta_dir = './dataset'
|
751 |
+
self.opt = opt
|
752 |
+
print('Loading dataset %s ...' % opt.dataset_name)
|
753 |
+
|
754 |
+
if mode == 'gt':
|
755 |
+
# used by T2M models (including evaluators)
|
756 |
+
self.mean = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy'))
|
757 |
+
self.std = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy'))
|
758 |
+
elif mode in ['train', 'eval', 'text_only']:
|
759 |
+
# used by our models
|
760 |
+
self.mean = np.load(pjoin(opt.data_root, 'Mean.npy'))
|
761 |
+
self.std = np.load(pjoin(opt.data_root, 'Std.npy'))
|
762 |
+
|
763 |
+
if mode == 'eval':
|
764 |
+
# used by T2M models (including evaluators)
|
765 |
+
# this is to translate their norms to ours
|
766 |
+
self.mean_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy'))
|
767 |
+
self.std_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy'))
|
768 |
+
print(f"dataset_name: {opt.dataset_name}")
|
769 |
+
if load_vectorizer:
|
770 |
+
self.split_file = pjoin(opt.data_root, f'train.txt')
|
771 |
+
else:
|
772 |
+
self.split_file = pjoin(opt.data_root, f'{split}.txt')
|
773 |
+
if mode == 'text_only' and (not load_vectorizer):
|
774 |
+
self.t2m_dataset = TextOnlyDataset(self.opt, self.mean, self.std, self.split_file)
|
775 |
+
else:
|
776 |
+
self.w_vectorizer = WordVectorizer(pjoin(abs_base_path, 'glove'), 'our_vab')
|
777 |
+
### text to
|
778 |
+
self.t2m_dataset = Text2MotionDatasetV2(self.opt, self.mean, self.std, self.split_file, self.w_vectorizer)
|
779 |
+
self.num_actions = 1 # dummy placeholder
|
780 |
+
|
781 |
+
# assert len(self.t2m_dataset) > 1, 'You loaded an empty dataset, ' \
|
782 |
+
# 'it is probably because your data dir has only texts and no motions.\n' \
|
783 |
+
# 'To train and evaluate MDM you should get the FULL data as described ' \
|
784 |
+
# 'in the README file.'
|
785 |
+
|
786 |
+
def __getitem__(self, item):
|
787 |
+
return self.t2m_dataset.__getitem__(item)
|
788 |
+
|
789 |
+
def __len__(self):
|
790 |
+
return self.t2m_dataset.__len__()
|
791 |
+
|
792 |
+
# A wrapper class for t2m original dataset for MDM purposes
|
793 |
+
class KIT(HumanML3D):
|
794 |
+
def __init__(self, mode, datapath='./dataset/kit_opt.txt', split="train", **kwargs):
|
795 |
+
super(KIT, self).__init__(mode, datapath, split, **kwargs)
|
data_loaders/humanml/data/dataset_ours.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_loaders/humanml/data/dataset_ours_single_seq.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data_loaders/humanml/data/utils.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
from scipy.spatial.transform import Rotation as R
|
5 |
+
|
6 |
+
try:
|
7 |
+
from torch_cluster import fps
|
8 |
+
except:
|
9 |
+
pass
|
10 |
+
from collections import OrderedDict
|
11 |
+
import os, argparse, copy, json
|
12 |
+
import math
|
13 |
+
|
14 |
+
def sample_pcd_from_mesh(vertices, triangles, npoints=512):
|
15 |
+
arears = []
|
16 |
+
for i in range(triangles.shape[0]):
|
17 |
+
v_a, v_b, v_c = int(triangles[i, 0].item()), int(triangles[i, 1].item()), int(triangles[i, 2].item())
|
18 |
+
v_a, v_b, v_c = vertices[v_a], vertices[v_b], vertices[v_c]
|
19 |
+
ab, ac = v_b - v_a, v_c - v_a
|
20 |
+
cos_ab_ac = (np.sum(ab * ac) / np.clip(np.sqrt(np.sum(ab ** 2)) * np.sqrt(np.sum(ac ** 2)), a_min=1e-9, a_max=9999999.0)).item()
|
21 |
+
sin_ab_ac = math.sqrt(1. - cos_ab_ac ** 2)
|
22 |
+
cur_area = 0.5 * sin_ab_ac * np.sqrt(np.sum(ab ** 2)).item() * np.sqrt(np.sum(ac ** 2)).item()
|
23 |
+
arears.append(cur_area)
|
24 |
+
tot_area = sum(arears)
|
25 |
+
|
26 |
+
sampled_pcts = []
|
27 |
+
tot_indices = []
|
28 |
+
tot_factors = []
|
29 |
+
for i in range(triangles.shape[0]):
|
30 |
+
|
31 |
+
v_a, v_b, v_c = int(triangles[i, 0].item()), int(triangles[i, 1].item()), int(
|
32 |
+
triangles[i, 2].item())
|
33 |
+
v_a, v_b, v_c = vertices[v_a], vertices[v_b], vertices[v_c]
|
34 |
+
# ab, ac = v_b - v_a, v_c - v_a
|
35 |
+
# cur_sampled_pts = int(npoints * (arears[i] / tot_area))
|
36 |
+
cur_sampled_pts = math.ceil(npoints * (arears[i] / tot_area))
|
37 |
+
# if cur_sampled_pts == 0:
|
38 |
+
|
39 |
+
cur_sampled_pts = int(arears[i] * npoints)
|
40 |
+
cur_sampled_pts = 1 if cur_sampled_pts == 0 else cur_sampled_pts
|
41 |
+
|
42 |
+
tmp_x, tmp_y = np.random.uniform(0, 1., (cur_sampled_pts,)).tolist(), np.random.uniform(0., 1., (cur_sampled_pts,)).tolist()
|
43 |
+
|
44 |
+
for xx, yy in zip(tmp_x, tmp_y):
|
45 |
+
sqrt_xx, sqrt_yy = math.sqrt(xx), math.sqrt(yy)
|
46 |
+
aa = 1. - sqrt_xx
|
47 |
+
bb = sqrt_xx * (1. - yy)
|
48 |
+
cc = yy * sqrt_xx
|
49 |
+
cur_pos = v_a * aa + v_b * bb + v_c * cc
|
50 |
+
sampled_pcts.append(cur_pos)
|
51 |
+
|
52 |
+
tot_indices.append(triangles[i]) # tot_indices for triangles # # vertices indices
|
53 |
+
tot_factors.append([aa, bb, cc])
|
54 |
+
|
55 |
+
tot_indices = np.array(tot_indices, dtype=np.long)
|
56 |
+
tot_factors = np.array(tot_factors, dtype=np.float32)
|
57 |
+
|
58 |
+
sampled_ptcs = np.array(sampled_pcts)
|
59 |
+
print("sampled points from surface:", sampled_ptcs.shape)
|
60 |
+
# sampled_pcts = np.concatenate([sampled_pcts, vertices], axis=0)
|
61 |
+
return sampled_ptcs, tot_indices, tot_factors
|
62 |
+
|
63 |
+
|
64 |
+
def read_obj_file_ours(obj_fn, sub_one=False):
|
65 |
+
vertices = []
|
66 |
+
faces = []
|
67 |
+
with open(obj_fn, "r") as rf:
|
68 |
+
for line in rf:
|
69 |
+
items = line.strip().split(" ")
|
70 |
+
if items[0] == 'v':
|
71 |
+
cur_verts = items[1:]
|
72 |
+
cur_verts = [float(vv) for vv in cur_verts]
|
73 |
+
vertices.append(cur_verts)
|
74 |
+
elif items[0] == 'f':
|
75 |
+
cur_faces = items[1:] # faces
|
76 |
+
cur_face_idxes = []
|
77 |
+
for cur_f in cur_faces:
|
78 |
+
try:
|
79 |
+
cur_f_idx = int(cur_f.split("/")[0])
|
80 |
+
except:
|
81 |
+
cur_f_idx = int(cur_f.split("//")[0])
|
82 |
+
cur_face_idxes.append(cur_f_idx if not sub_one else cur_f_idx - 1)
|
83 |
+
faces.append(cur_face_idxes)
|
84 |
+
rf.close()
|
85 |
+
vertices = np.array(vertices, dtype=np.float)
|
86 |
+
return vertices, faces
|
87 |
+
|
88 |
+
def clamp_gradient(model, clip):
|
89 |
+
for p in model.parameters():
|
90 |
+
torch.nn.utils.clip_grad_value_(p, clip)
|
91 |
+
|
92 |
+
def clamp_gradient_norm(model, max_norm, norm_type=2):
|
93 |
+
for p in model.parameters():
|
94 |
+
torch.nn.utils.clip_grad_norm_(p, max_norm, norm_type=2)
|
95 |
+
|
96 |
+
|
97 |
+
def save_network(net, directory, network_label, epoch_label=None, **kwargs):
|
98 |
+
"""
|
99 |
+
save model to directory with name {network_label}_{epoch_label}.pth
|
100 |
+
Args:
|
101 |
+
net: pytorch model
|
102 |
+
directory: output directory
|
103 |
+
network_label: str
|
104 |
+
epoch_label: convertible to str
|
105 |
+
kwargs: additional value to be included
|
106 |
+
"""
|
107 |
+
save_filename = "_".join((network_label, str(epoch_label))) + ".pth"
|
108 |
+
save_path = os.path.join(directory, save_filename)
|
109 |
+
merge_states = OrderedDict()
|
110 |
+
merge_states["states"] = net.cpu().state_dict()
|
111 |
+
for k in kwargs:
|
112 |
+
merge_states[k] = kwargs[k]
|
113 |
+
torch.save(merge_states, save_path)
|
114 |
+
net = net.cuda()
|
115 |
+
|
116 |
+
|
117 |
+
def load_network(net, path):
|
118 |
+
"""
|
119 |
+
load network parameters whose name exists in the pth file.
|
120 |
+
return:
|
121 |
+
INT trained step
|
122 |
+
"""
|
123 |
+
# warnings.DeprecationWarning("load_network is deprecated. Use module.load_state_dict(strict=False) instead.")
|
124 |
+
if isinstance(path, str):
|
125 |
+
logger.info("loading network from {}".format(path))
|
126 |
+
if path[-3:] == "pth":
|
127 |
+
loaded_state = torch.load(path)
|
128 |
+
if "states" in loaded_state:
|
129 |
+
loaded_state = loaded_state["states"]
|
130 |
+
else:
|
131 |
+
loaded_state = np.load(path).item()
|
132 |
+
if "states" in loaded_state:
|
133 |
+
loaded_state = loaded_state["states"]
|
134 |
+
elif isinstance(path, dict):
|
135 |
+
loaded_state = path
|
136 |
+
|
137 |
+
network = net.module if isinstance(
|
138 |
+
net, torch.nn.DataParallel) else net
|
139 |
+
|
140 |
+
missingkeys, unexpectedkeys = network.load_state_dict(loaded_state, strict=False)
|
141 |
+
if len(missingkeys)>0:
|
142 |
+
logger.warn("load_network {} missing keys".format(len(missingkeys)), "\n".join(missingkeys))
|
143 |
+
if len(unexpectedkeys)>0:
|
144 |
+
logger.warn("load_network {} unexpected keys".format(len(unexpectedkeys)), "\n".join(unexpectedkeys))
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
def weights_init(m):
|
149 |
+
"""
|
150 |
+
initialize the weighs of the network for Convolutional layers and batchnorm layers
|
151 |
+
"""
|
152 |
+
if isinstance(m, (torch.nn.modules.conv._ConvNd, torch.nn.Linear)):
|
153 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
154 |
+
if m.bias is not None:
|
155 |
+
torch.nn.init.constant_(m.bias, 0.0)
|
156 |
+
elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
|
157 |
+
torch.nn.init.constant_(m.bias, 0.0)
|
158 |
+
torch.nn.init.constant_(m.weight, 1.0)
|
159 |
+
|
160 |
+
def seal(mesh_to_seal):
|
161 |
+
circle_v_id = np.array([108, 79, 78, 121, 214, 215, 279, 239, 234, 92, 38, 122, 118, 117, 119, 120], dtype = np.int32)
|
162 |
+
center = (mesh_to_seal.v[circle_v_id, :]).mean(0)
|
163 |
+
|
164 |
+
sealed_mesh = copy.copy(mesh_to_seal)
|
165 |
+
sealed_mesh.v = np.vstack([mesh_to_seal.v, center])
|
166 |
+
center_v_id = sealed_mesh.v.shape[0] - 1
|
167 |
+
|
168 |
+
for i in range(circle_v_id.shape[0]):
|
169 |
+
new_faces = [circle_v_id[i-1], circle_v_id[i], center_v_id]
|
170 |
+
sealed_mesh.f = np.vstack([sealed_mesh.f, new_faces])
|
171 |
+
return sealed_mesh
|
172 |
+
|
173 |
+
def read_pos_fr_txt(txt_fn):
|
174 |
+
pos_data = []
|
175 |
+
with open(txt_fn, "r") as rf:
|
176 |
+
for line in rf:
|
177 |
+
cur_pos = line.strip().split(" ")
|
178 |
+
cur_pos = [float(p) for p in cur_pos]
|
179 |
+
pos_data.append(cur_pos)
|
180 |
+
rf.close()
|
181 |
+
pos_data = np.array(pos_data, dtype=np.float32)
|
182 |
+
print(f"pos_data: {pos_data.shape}")
|
183 |
+
return pos_data
|
184 |
+
|
185 |
+
def read_field_data_fr_txt(field_fn):
|
186 |
+
field_data = []
|
187 |
+
with open(field_fn, "r") as rf:
|
188 |
+
for line in rf:
|
189 |
+
cur_field = line.strip().split(" ")
|
190 |
+
cur_field = [float(p) for p in cur_field]
|
191 |
+
field_data.append(cur_field)
|
192 |
+
rf.close()
|
193 |
+
field_data = np.array(field_data, dtype=np.float32)
|
194 |
+
print(f"filed_data: {field_data.shape}")
|
195 |
+
return field_data
|
196 |
+
|
197 |
+
def farthest_point_sampling(pos: torch.FloatTensor, n_sampling: int):
|
198 |
+
bz, N = pos.size(0), pos.size(1)
|
199 |
+
feat_dim = pos.size(-1)
|
200 |
+
device = pos.device
|
201 |
+
sampling_ratio = float(n_sampling / N)
|
202 |
+
pos_float = pos.float()
|
203 |
+
|
204 |
+
batch = torch.arange(bz, dtype=torch.long).view(bz, 1).to(device)
|
205 |
+
mult_one = torch.ones((N,), dtype=torch.long).view(1, N).to(device)
|
206 |
+
|
207 |
+
batch = batch * mult_one
|
208 |
+
batch = batch.view(-1)
|
209 |
+
pos_float = pos_float.contiguous().view(-1, feat_dim).contiguous() # (bz x N, 3)
|
210 |
+
# sampling_ratio = torch.tensor([sampling_ratio for _ in range(bz)], dtype=torch.float).to(device)
|
211 |
+
# batch = torch.zeros((N, ), dtype=torch.long, device=device)
|
212 |
+
sampled_idx = fps(pos_float, batch, ratio=sampling_ratio, random_start=False)
|
213 |
+
# shape of sampled_idx?
|
214 |
+
return sampled_idx
|
215 |
+
|
216 |
+
|
217 |
+
def batched_index_select_ours(values, indices, dim = 1):
|
218 |
+
value_dims = values.shape[(dim + 1):]
|
219 |
+
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
|
220 |
+
indices = indices[(..., *((None,) * len(value_dims)))]
|
221 |
+
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
|
222 |
+
value_expand_len = len(indices_shape) - (dim + 1)
|
223 |
+
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]
|
224 |
+
|
225 |
+
value_expand_shape = [-1] * len(values.shape)
|
226 |
+
expand_slice = slice(dim, (dim + value_expand_len))
|
227 |
+
value_expand_shape[expand_slice] = indices.shape[expand_slice]
|
228 |
+
values = values.expand(*value_expand_shape)
|
229 |
+
|
230 |
+
dim += value_expand_len
|
231 |
+
return values.gather(dim, indices)
|
232 |
+
|
233 |
+
def compute_nearest(query, verts):
|
234 |
+
# query: bsz x nn_q x 3
|
235 |
+
# verts: bsz x nn_q x 3
|
236 |
+
dists = torch.sum((query.unsqueeze(2) - verts.unsqueeze(1)) ** 2, dim=-1)
|
237 |
+
minn_dists, minn_dists_idx = torch.min(dists, dim=-1) # bsz x nn_q
|
238 |
+
minn_pts_pos = batched_index_select_ours(values=verts, indices=minn_dists_idx, dim=1)
|
239 |
+
minn_pts_pos = minn_pts_pos.unsqueeze(2)
|
240 |
+
minn_dists_idx = minn_dists_idx.unsqueeze(2)
|
241 |
+
return minn_dists, minn_dists_idx, minn_pts_pos
|
242 |
+
|
243 |
+
|
244 |
+
def batched_index_select(t, dim, inds):
|
245 |
+
"""
|
246 |
+
Helper function to extract batch-varying indicies along array
|
247 |
+
:param t: array to select from
|
248 |
+
:param dim: dimension to select along
|
249 |
+
:param inds: batch-vary indicies
|
250 |
+
:return:
|
251 |
+
"""
|
252 |
+
dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
|
253 |
+
out = t.gather(dim, dummy) # b x e x f
|
254 |
+
return out
|
255 |
+
|
256 |
+
|
257 |
+
def batched_get_rot_mtx_fr_vecs(normal_vecs):
|
258 |
+
# normal_vecs: nn_pts x 3 #
|
259 |
+
#
|
260 |
+
normal_vecs = normal_vecs / torch.clamp(torch.norm(normal_vecs, p=2, dim=-1, keepdim=True), min=1e-5)
|
261 |
+
sin_theta = normal_vecs[..., 0]
|
262 |
+
cos_theta = torch.sqrt(1. - sin_theta ** 2)
|
263 |
+
sin_phi = normal_vecs[..., 1] / torch.clamp(cos_theta, min=1e-5)
|
264 |
+
# cos_phi = torch.sqrt(1. - sin_phi ** 2)
|
265 |
+
cos_phi = normal_vecs[..., 2] / torch.clamp(cos_theta, min=1e-5)
|
266 |
+
|
267 |
+
sin_phi[cos_theta < 1e-5] = 1.
|
268 |
+
cos_phi[cos_theta < 1e-5] = 0.
|
269 |
+
|
270 |
+
#
|
271 |
+
y_rot_mtx = torch.stack(
|
272 |
+
[
|
273 |
+
torch.stack([cos_theta, torch.zeros_like(cos_theta), -sin_theta], dim=-1),
|
274 |
+
torch.stack([torch.zeros_like(cos_theta), torch.ones_like(cos_theta), torch.zeros_like(cos_theta)], dim=-1),
|
275 |
+
torch.stack([sin_theta, torch.zeros_like(cos_theta), cos_theta], dim=-1)
|
276 |
+
], dim=-1
|
277 |
+
)
|
278 |
+
x_rot_mtx = torch.stack(
|
279 |
+
[
|
280 |
+
torch.stack([torch.ones_like(cos_theta), torch.zeros_like(cos_theta), torch.zeros_like(cos_theta)], dim=-1),
|
281 |
+
torch.stack([torch.zeros_like(cos_phi), cos_phi, -sin_phi], dim=-1),
|
282 |
+
torch.stack([torch.zeros_like(cos_phi), sin_phi, cos_phi], dim=-1)
|
283 |
+
], dim=-1
|
284 |
+
)
|
285 |
+
rot_mtx = torch.matmul(x_rot_mtx, y_rot_mtx)
|
286 |
+
return rot_mtx
|
287 |
+
|
288 |
+
|
289 |
+
def batched_get_rot_mtx_fr_vecs_v2(normal_vecs):
|
290 |
+
# normal_vecs: nn_pts x 3 #
|
291 |
+
#
|
292 |
+
normal_vecs = normal_vecs / torch.clamp(torch.norm(normal_vecs, p=2, dim=-1, keepdim=True), min=1e-5)
|
293 |
+
sin_theta = normal_vecs[..., 0]
|
294 |
+
cos_theta = torch.sqrt(1. - sin_theta ** 2)
|
295 |
+
sin_phi = normal_vecs[..., 1] / torch.clamp(cos_theta, min=1e-5)
|
296 |
+
# cos_phi = torch.sqrt(1. - sin_phi ** 2)
|
297 |
+
cos_phi = normal_vecs[..., 2] / torch.clamp(cos_theta, min=1e-5)
|
298 |
+
|
299 |
+
sin_phi[cos_theta < 1e-5] = 1.
|
300 |
+
cos_phi[cos_theta < 1e-5] = 0.
|
301 |
+
|
302 |
+
# o: nn_pts x 3 #
|
303 |
+
o = torch.stack(
|
304 |
+
[torch.zeros_like(cos_phi), cos_phi, -sin_phi], dim=-1
|
305 |
+
)
|
306 |
+
nxo = torch.cross(o, normal_vecs)
|
307 |
+
# rot_mtx: nn_pts x 3 x 3 #
|
308 |
+
rot_mtx = torch.stack(
|
309 |
+
[nxo, o, normal_vecs], dim=-1
|
310 |
+
)
|
311 |
+
return rot_mtx
|
312 |
+
|
313 |
+
|
314 |
+
def batched_get_orientation_matrices(rot_vec):
|
315 |
+
rot_matrices = []
|
316 |
+
for i_w in range(rot_vec.shape[0]):
|
317 |
+
cur_rot_vec = rot_vec[i_w]
|
318 |
+
cur_rot_mtx = R.from_rotvec(cur_rot_vec).as_matrix()
|
319 |
+
rot_matrices.append(cur_rot_mtx)
|
320 |
+
rot_matrices = np.stack(rot_matrices, axis=0)
|
321 |
+
return rot_matrices
|
322 |
+
|
323 |
+
def batched_get_minn_dist_corresponding_pts(tips, obj_pcs):
|
324 |
+
dist_tips_to_obj_pc_minn_idx = np.argmin(
|
325 |
+
((tips.reshape(tips.shape[0], tips.shape[1], 1, 3) - obj_pcs.reshape(obj_pcs.shape[0], 1, obj_pcs.shape[1], 3)) ** 2).sum(axis=-1), axis=-1
|
326 |
+
)
|
327 |
+
obj_pcs_th = torch.from_numpy(obj_pcs).float()
|
328 |
+
dist_tips_to_obj_pc_minn_idx_th = torch.from_numpy(dist_tips_to_obj_pc_minn_idx).long()
|
329 |
+
nearest_pc_th = batched_index_select(obj_pcs_th, 1, dist_tips_to_obj_pc_minn_idx_th)
|
330 |
+
return nearest_pc_th, dist_tips_to_obj_pc_minn_idx_th
|
331 |
+
|
332 |
+
def get_affinity_fr_dist(dist, s=0.02):
|
333 |
+
### affinity scores ###
|
334 |
+
k = 0.5 * torch.cos(torch.pi / s * torch.abs(dist)) + 0.5
|
335 |
+
return k
|
336 |
+
|
337 |
+
def batched_reverse_transform(rot, transl, t_pc, trans=True):
|
338 |
+
# t_pc: ws x nn_obj x 3
|
339 |
+
# rot; ws x 3 x 3
|
340 |
+
# transl: ws x 1 x 3
|
341 |
+
if trans:
|
342 |
+
reverse_trans_pc = t_pc - transl
|
343 |
+
else:
|
344 |
+
reverse_trans_pc = t_pc
|
345 |
+
reverse_trans_pc = np.matmul(np.transpose(rot, (0, 2, 1)), np.transpose(reverse_trans_pc, (0, 2, 1)))
|
346 |
+
reverse_trans_pc = np.transpose(reverse_trans_pc, (0, 2, 1))
|
347 |
+
return reverse_trans_pc
|
348 |
+
|
349 |
+
|
350 |
+
def capsule_sdf(mesh_verts, mesh_normals, query_points, query_normals, caps_rad, caps_top, caps_bot, foreach_on_mesh):
|
351 |
+
# if caps on hand: mesh_verts = hand vert
|
352 |
+
"""
|
353 |
+
Find the SDF of query points to mesh verts
|
354 |
+
Capsule SDF formulation from https://iquilezles.org/www/articles/distfunctions/distfunctions.htm
|
355 |
+
|
356 |
+
:param mesh_verts: (batch, V, 3)
|
357 |
+
:param mesh_normals: (batch, V, 3)
|
358 |
+
:param query_points: (batch, Q, 3)
|
359 |
+
:param caps_rad: scalar, radius of capsules
|
360 |
+
:param caps_top: scalar, distance from mesh to top of capsule
|
361 |
+
:param caps_bot: scalar, distance from mesh to bottom of capsule
|
362 |
+
:param foreach_on_mesh: boolean, foreach point on mesh find closest query (V), or foreach query find closest mesh (Q)
|
363 |
+
:return: normalized sdsf + 1 (batch, V or Q)
|
364 |
+
"""
|
365 |
+
# TODO implement normal check?
|
366 |
+
if foreach_on_mesh: # Foreach mesh vert, find closest query point
|
367 |
+
# knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(mesh_verts, query_points, K=1, return_nn=True) # TODO should attract capsule middle?
|
368 |
+
# knn_dists, nearest_idx, nearest_pos = compute_nearest(query_points, mesh_verts)
|
369 |
+
knn_dists, nearest_idx, nearest_pos = compute_nearest(mesh_verts, query_points)
|
370 |
+
|
371 |
+
capsule_tops = mesh_verts + mesh_normals * caps_top
|
372 |
+
capsule_bots = mesh_verts + mesh_normals * caps_bot
|
373 |
+
delta_top = nearest_pos[:, :, 0, :] - capsule_tops
|
374 |
+
normal_dot = torch.sum(mesh_normals * batched_index_select(query_normals, 1, nearest_idx.squeeze(2)), dim=2)
|
375 |
+
|
376 |
+
rt_nearest_verts = mesh_verts
|
377 |
+
rt_nearest_normals = mesh_normals
|
378 |
+
|
379 |
+
else: # Foreach query vert, find closest mesh point
|
380 |
+
# knn_dists, nearest_idx, nearest_pos = pytorch3d.ops.knn_points(query_points, mesh_verts, K=1, return_nn=True) # TODO should attract capsule middle?
|
381 |
+
st_time = time.time()
|
382 |
+
knn_dists, nearest_idx, nearest_pos = compute_nearest(query_points, mesh_verts)
|
383 |
+
ed_time = time.time()
|
384 |
+
# print(f"Time for computing nearest: {ed_time - st_time}")
|
385 |
+
|
386 |
+
closest_mesh_verts = batched_index_select(mesh_verts, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
|
387 |
+
closest_mesh_normals = batched_index_select(mesh_normals, 1, nearest_idx.squeeze(2)) # Shape (batch, V, 3)
|
388 |
+
|
389 |
+
capsule_tops = closest_mesh_verts + closest_mesh_normals * caps_top # Coordinates of the top focii of the capsules (batch, V, 3)
|
390 |
+
capsule_bots = closest_mesh_verts + closest_mesh_normals * caps_bot
|
391 |
+
delta_top = query_points - capsule_tops
|
392 |
+
# normal_dot = torch.sum(query_normals * closest_mesh_normals, dim=2)
|
393 |
+
normal_dot = None
|
394 |
+
|
395 |
+
rt_nearest_verts = closest_mesh_verts
|
396 |
+
rt_nearest_normals = closest_mesh_normals
|
397 |
+
|
398 |
+
# (top -> bot) #!!#
|
399 |
+
bot_to_top = capsule_bots - capsule_tops # Vector from capsule bottom to top
|
400 |
+
along_axis = torch.sum(delta_top * bot_to_top, dim=2) # Dot product
|
401 |
+
top_to_bot_square = torch.sum(bot_to_top * bot_to_top, dim=2)
|
402 |
+
|
403 |
+
# print(f"top_to_bot_square: {top_to_bot_square[..., :10]}")
|
404 |
+
h = torch.clamp(along_axis / top_to_bot_square, 0, 1) # Could avoid NaNs with offset in division here
|
405 |
+
dist_to_axis = torch.norm(delta_top - bot_to_top * h.unsqueeze(2), dim=2) # Distance to capsule centerline
|
406 |
+
|
407 |
+
# two endpoints; edge of the capsule #
|
408 |
+
return dist_to_axis / caps_rad, normal_dot, rt_nearest_verts, rt_nearest_normals # (Normalized SDF)+1 0 on endpoint, 1 on edge of capsule
|
409 |
+
|
410 |
+
|
411 |
+
|
412 |
+
def reparameterize_gaussian(mean, logvar):
|
413 |
+
std = torch.exp(0.5 * logvar) ### std and eps -->
|
414 |
+
eps = torch.randn(std.size()).to(mean.device)
|
415 |
+
return mean + std * eps
|
416 |
+
|
417 |
+
|
418 |
+
def gaussian_entropy(logvar):
|
419 |
+
const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
|
420 |
+
ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
|
421 |
+
return ent
|
422 |
+
|
423 |
+
|
424 |
+
def standard_normal_logprob(z): # feature dim
|
425 |
+
dim = z.size(-1)
|
426 |
+
log_z = -0.5 * dim * np.log(2 * np.pi)
|
427 |
+
return log_z - z.pow(2) / 2
|
428 |
+
|
429 |
+
|
430 |
+
def truncated_normal_(tensor, mean=0, std=1, trunc_std=2):
|
431 |
+
"""
|
432 |
+
Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
|
433 |
+
"""
|
434 |
+
size = tensor.shape
|
435 |
+
tmp = tensor.new_empty(size + (4,)).normal_()
|
436 |
+
valid = (tmp < trunc_std) & (tmp > -trunc_std)
|
437 |
+
ind = valid.max(-1, keepdim=True)[1]
|
438 |
+
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
|
439 |
+
tensor.data.mul_(std).add_(mean)
|
440 |
+
return tensor
|
441 |
+
|
442 |
+
|
443 |
+
def makepath(desired_path, isfile = False):
|
444 |
+
'''
|
445 |
+
if the path does not exist make it
|
446 |
+
:param desired_path: can be path to a file or a folder name
|
447 |
+
:return:
|
448 |
+
'''
|
449 |
+
import os
|
450 |
+
if isfile:
|
451 |
+
if not os.path.exists(os.path.dirname(desired_path)):os.makedirs(os.path.dirname(desired_path))
|
452 |
+
else:
|
453 |
+
if not os.path.exists(desired_path): os.makedirs(desired_path)
|
454 |
+
return desired_path
|
455 |
+
|
456 |
+
|
457 |
+
def batch_gather(arr, ind):
|
458 |
+
"""
|
459 |
+
:param arr: B x N x D
|
460 |
+
:param ind: B x M
|
461 |
+
:return: B x M x D
|
462 |
+
"""
|
463 |
+
dummy = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), arr.size(2))
|
464 |
+
out = torch.gather(arr, 1, dummy)
|
465 |
+
return out
|
466 |
+
|
467 |
+
|
468 |
+
def random_rotate_np(x):
|
469 |
+
aa = np.random.randn(3)
|
470 |
+
theta = np.sqrt(np.sum(aa**2))
|
471 |
+
k = aa / np.maximum(theta, 1e-6)
|
472 |
+
K = np.array([[0, -k[2], k[1]],
|
473 |
+
[k[2], 0, -k[0]],
|
474 |
+
[-k[1], k[0], 0]])
|
475 |
+
R = np.eye(3) + np.sin(theta)*K + (1-np.cos(theta))*np.matmul(K, K)
|
476 |
+
R = R.astype(np.float32)
|
477 |
+
return np.matmul(x, R), R
|
478 |
+
|
479 |
+
|
480 |
+
def rotate_x(x, rad):
|
481 |
+
rad = -rad
|
482 |
+
rotmat = np.array([
|
483 |
+
[1, 0, 0],
|
484 |
+
[0, np.cos(rad), -np.sin(rad)],
|
485 |
+
[0, np.sin(rad), np.cos(rad)]
|
486 |
+
])
|
487 |
+
return np.dot(x, rotmat)
|
488 |
+
|
489 |
+
def rotate_y(x, rad):
|
490 |
+
rad = -rad
|
491 |
+
rotmat = np.array([
|
492 |
+
[np.cos(rad), 0, np.sin(rad)],
|
493 |
+
[0, 1, 0],
|
494 |
+
[-np.sin(rad), 0, np.cos(rad)]
|
495 |
+
])
|
496 |
+
return np.dot(x, rotmat)
|
497 |
+
|
498 |
+
def rotate_z(x, rad):
|
499 |
+
rad = -rad
|
500 |
+
rotmat = np.array([
|
501 |
+
[np.cos(rad), -np.sin(rad), 0],
|
502 |
+
[np.sin(rad), np.cos(rad), 0],
|
503 |
+
[0, 0, 1]
|
504 |
+
])
|
505 |
+
return np.dot(x, rotmat)
|
506 |
+
|
507 |
+
|
data_loaders/humanml/motion_loaders/__init__.py
ADDED
File without changes
|
data_loaders/humanml/motion_loaders/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (182 Bytes). View file
|
|