Spaces:
Runtime error
Runtime error
init test project
Browse files- .DS_Store +0 -0
- GPS.py +324 -0
- NN/losses.py +51 -0
- NN/utils.py +103 -0
- README.md +4 -4
- app.py +54 -0
- configs/random_synthesis.yaml +36 -0
- dataset/.DS_Store +0 -0
- dataset/tracks_motion.py +183 -0
- requirements.txt +15 -0
- utils/.DS_Store +0 -0
- utils/base.py +148 -0
- utils/contact.py +103 -0
- utils/kinematics.py +203 -0
- utils/skeleton.py +347 -0
- utils/transforms.py +399 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
GPS.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import itertools
|
7 |
+
from tensorboardX import SummaryWriter
|
8 |
+
|
9 |
+
from NN.losses import make_criteria
|
10 |
+
from utils.base import logger
|
11 |
+
|
12 |
+
class GPS:
|
13 |
+
def __init__(self,
|
14 |
+
init_mode: str = 'random_synthesis',
|
15 |
+
noise_sigma: float = 1.0,
|
16 |
+
coarse_ratio: float = 0.2,
|
17 |
+
coarse_ratio_factor: float = 6,
|
18 |
+
pyr_factor: float = 0.75,
|
19 |
+
num_stages_limit: int = -1,
|
20 |
+
device: str = 'cuda:0',
|
21 |
+
silent: bool = False
|
22 |
+
):
|
23 |
+
'''
|
24 |
+
Args:
|
25 |
+
init_mode:
|
26 |
+
- 'random_synthesis': init with random seed
|
27 |
+
- 'random': init with random seed
|
28 |
+
noise_sigma: float = 1.0, random noise.
|
29 |
+
coarse_ratio: float = 0.2, ratio at the coarse level.
|
30 |
+
pyr_factor: float = 0.75, pyramid factor.
|
31 |
+
num_stages_limit: int = -1, no limit.
|
32 |
+
device: str = 'cuda:0', default device.
|
33 |
+
silent: bool = False, mute the output.
|
34 |
+
'''
|
35 |
+
self.init_mode = init_mode
|
36 |
+
self.noise_sigma = noise_sigma
|
37 |
+
self.coarse_ratio = coarse_ratio
|
38 |
+
self.coarse_ratio_factor = coarse_ratio_factor
|
39 |
+
self.pyr_factor = pyr_factor
|
40 |
+
self.num_stages_limit = num_stages_limit
|
41 |
+
self.device = torch.device(device)
|
42 |
+
self.silent = silent
|
43 |
+
|
44 |
+
def _get_pyramid_lengths(self, dest, ext=None):
|
45 |
+
"""Get a list of pyramid lengths"""
|
46 |
+
if self.coarse_ratio == -1:
|
47 |
+
self.coarse_ratio = np.around(ext['criteria']['patch_size'] * self.coarse_ratio_factor / dest, 2)
|
48 |
+
|
49 |
+
lengths = [int(np.round(dest * self.coarse_ratio))]
|
50 |
+
while lengths[-1] < dest:
|
51 |
+
lengths.append(int(np.round(lengths[-1] / self.pyr_factor)))
|
52 |
+
if lengths[-1] == lengths[-2]:
|
53 |
+
lengths[-1] += 1
|
54 |
+
lengths[-1] = dest
|
55 |
+
|
56 |
+
return lengths
|
57 |
+
|
58 |
+
def _get_target_pyramid(self, target, ext=None):
|
59 |
+
"""Reads a target motion(s) and create a pyraimd out of it. Ordered in increatorch.sing size"""
|
60 |
+
self._num_target = len(target)
|
61 |
+
lengths = []
|
62 |
+
min_len = 10000
|
63 |
+
for i in range(len(target)):
|
64 |
+
new_length = self._get_pyramid_lengths(len(target[i]), ext)
|
65 |
+
min_len = min(min_len, len(new_length))
|
66 |
+
if self.num_stages_limit != -1:
|
67 |
+
new_length = new_length[:self.num_stages_limit]
|
68 |
+
lengths.append(new_length)
|
69 |
+
for i in range(len(target)):
|
70 |
+
lengths[i] = lengths[i][-min_len:]
|
71 |
+
self.pyraimd_lengths = lengths
|
72 |
+
|
73 |
+
target_pyramid = [[] for _ in range(len(lengths[0]))]
|
74 |
+
for step in range(len(lengths[0])):
|
75 |
+
for i in range(len(target)):
|
76 |
+
length = lengths[i][step]
|
77 |
+
motion = target[i]
|
78 |
+
target_pyramid[step].append(motion.sample(size=length).to(self.device))
|
79 |
+
# target_pyramid[step].append(motion.pos2velo(motion.sample(size=length)))
|
80 |
+
# motion.motion_data = motion.pos2velo(motion.motion_data)
|
81 |
+
# target_pyramid[step].append(motion.sample(size=length))
|
82 |
+
# motion.motion_data = motion.velo2pos(motion.motion_data)
|
83 |
+
|
84 |
+
if not self.silent:
|
85 |
+
print('Levels:', lengths)
|
86 |
+
for i in range(len(target_pyramid)):
|
87 |
+
print(f'Number of clips in target pyramid {i} is {len(target_pyramid[i])}: {[[tgt.min(), tgt.max()] for tgt in target_pyramid[i]]}')
|
88 |
+
|
89 |
+
return target_pyramid
|
90 |
+
|
91 |
+
def _get_initial_motion(self):
|
92 |
+
"""Prepare the initial motion for optimization"""
|
93 |
+
if 'random_synthesis' in str(self.init_mode):
|
94 |
+
m = self.init_mode.split('/')[-1]
|
95 |
+
if m =='random_synthesis':
|
96 |
+
final_length = sum([i[-1] for i in self.pyraimd_lengths])
|
97 |
+
elif 'x' in m:
|
98 |
+
final_length = int(m.replace('x', '')) * sum([i[-1] for i in self.pyraimd_lengths])
|
99 |
+
elif (self.init_mode.split('/')[-1]).isdigit():
|
100 |
+
final_length = int(self.init_mode.split('/')[-1])
|
101 |
+
else:
|
102 |
+
raise ValueError(f'incorrect init_mode: {self.init_mode}')
|
103 |
+
|
104 |
+
self.synthesized_lengths = self._get_pyramid_lengths(final_length)
|
105 |
+
|
106 |
+
else:
|
107 |
+
raise ValueError(f'Unsupported init_mode {self.init_mode}')
|
108 |
+
|
109 |
+
initial_motion = F.interpolate(torch.cat([self.target_pyramid[0][i] for i in range(self._num_target)], dim=-1),
|
110 |
+
size=self.synthesized_lengths[0], mode='linear', align_corners=True)
|
111 |
+
if self.noise_sigma > 0:
|
112 |
+
initial_motion_w_noise = initial_motion + torch.randn_like(initial_motion) * self.noise_sigma
|
113 |
+
initial_motion_w_noise = torch.fmod(initial_motion_w_noise, 1.0)
|
114 |
+
else:
|
115 |
+
initial_motion_w_noise = initial_motion
|
116 |
+
|
117 |
+
if not self.silent:
|
118 |
+
print('Synthesized lengths:', self.synthesized_lengths)
|
119 |
+
print('Initial motion:', initial_motion.min(), initial_motion.max())
|
120 |
+
print('Initial motion with noise:', initial_motion_w_noise.min(), initial_motion_w_noise.max())
|
121 |
+
|
122 |
+
return initial_motion_w_noise
|
123 |
+
|
124 |
+
def run(self, target, mode="backpropagate", ext=None, debug_dir=None):
|
125 |
+
'''
|
126 |
+
Run the patch-based motion synthesis.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
target (torch.Tensor): Target data.
|
130 |
+
mode (str): Optimization mode. Support ['backpropagate', 'match_and_blend']
|
131 |
+
ext (dict): extra data or constrain.
|
132 |
+
debug_dir (str): Debug directory.
|
133 |
+
'''
|
134 |
+
# preprare data
|
135 |
+
self.target_pyramid = self._get_target_pyramid(target, ext)
|
136 |
+
self.synthesized = self._get_initial_motion()
|
137 |
+
if debug_dir is not None:
|
138 |
+
writer = SummaryWriter(log_dir=debug_dir)
|
139 |
+
|
140 |
+
# prepare configuration
|
141 |
+
if mode == "backpropagate":
|
142 |
+
self.synthesized.requires_grad_(True)
|
143 |
+
assert 'criteria' in ext.keys(), 'Please specify a criteria for synthsis.'
|
144 |
+
criteria = make_criteria(ext['criteria']).to(self.device)
|
145 |
+
elif mode == "match_and_blend":
|
146 |
+
self.synthesized.requires_grad_(False)
|
147 |
+
assert 'criteria' in ext.keys(), 'Please specify a criteria for synthsis.'
|
148 |
+
criteria = make_criteria(ext['criteria']).to(self.device)
|
149 |
+
else:
|
150 |
+
raise ValueError(f'Unsupported mode: {mode}')
|
151 |
+
|
152 |
+
# perform synthsis
|
153 |
+
self.pbar = logger(ext['num_itrs'], len(self.target_pyramid))
|
154 |
+
ext['pbar'] = self.pbar
|
155 |
+
for lvl, lvl_target in enumerate(self.target_pyramid):
|
156 |
+
self.pbar.new_lvl()
|
157 |
+
if lvl > 0:
|
158 |
+
with torch.no_grad():
|
159 |
+
self.synthesized = F.interpolate(self.synthesized.detach(), size=self.synthesized_lengths[lvl], mode='linear')
|
160 |
+
if mode == "backpropagate":
|
161 |
+
self.synthesized.requires_grad_(True)
|
162 |
+
|
163 |
+
if mode == "backpropagate": # direct optimize the synthesized motion
|
164 |
+
self.synthesized, losses = GPS.backpropagate(self.synthesized, lvl_target, criteria, ext=ext)
|
165 |
+
elif mode == "match_and_blend":
|
166 |
+
self.synthesized, losses = GPS.match_and_blend(self.synthesized, lvl_target, criteria, ext=ext)
|
167 |
+
|
168 |
+
criteria.clean_cache()
|
169 |
+
if debug_dir:
|
170 |
+
for itr in range(len(losses)):
|
171 |
+
writer.add_scalar(f'optimize/losses_lvl{lvl}', losses[itr], itr)
|
172 |
+
self.pbar.pbar.close()
|
173 |
+
|
174 |
+
|
175 |
+
return self.synthesized.detach()
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def backpropagate(synthesized, targets, criteria=None, ext=None):
|
179 |
+
"""
|
180 |
+
Minimizes criteria(synthesized, target) for num_steps SGD steps
|
181 |
+
Args:
|
182 |
+
targets (torch.Tensor): Target data.
|
183 |
+
ext (dict): extra configurations.
|
184 |
+
"""
|
185 |
+
if criteria is None:
|
186 |
+
assert 'criteria' in ext.keys(), 'Criteria is not set'
|
187 |
+
criteria = make_criteria(ext['criteria']).to(synthesized.device)
|
188 |
+
|
189 |
+
optim = None
|
190 |
+
if 'optimizer' in ext.keys():
|
191 |
+
if ext['optimizer'] == 'Adam':
|
192 |
+
optim = torch.optim.Adam([synthesized], lr=ext['lr'])
|
193 |
+
elif ext['optimizer'] == 'SGD':
|
194 |
+
optim = torch.optim.SGD([synthesized], lr=ext['lr'])
|
195 |
+
elif ext['optimizer'] == 'RMSprop':
|
196 |
+
optim = torch.optim.RMSprop([synthesized], lr=ext['lr'])
|
197 |
+
else:
|
198 |
+
print(f'use default RMSprop optimizer')
|
199 |
+
optim = torch.optim.RMSprop([synthesized], lr=ext['lr']) if optim is None else optim
|
200 |
+
# optim = torch.optim.Adam([synthesized], lr=ext['lr']) if optim is None else optim
|
201 |
+
lr_decay = np.exp(np.log(0.333) / ext['num_itrs'])
|
202 |
+
|
203 |
+
# other constraints
|
204 |
+
trajectory = ext['trajectory'] if 'trajectory' in ext.keys() else None
|
205 |
+
|
206 |
+
losses = []
|
207 |
+
for _i in range(ext['num_itrs']):
|
208 |
+
optim.zero_grad()
|
209 |
+
|
210 |
+
loss = criteria(synthesized, targets)
|
211 |
+
|
212 |
+
if trajectory is not None: ## velo constrain
|
213 |
+
target_traj = F.interpolate(trajectory, size=synthesized.shape[-1], mode='linear')
|
214 |
+
# target_traj = F.interpolate(trajectory, size=synthesized.shape[-1], mode='linear', align_corners=False)
|
215 |
+
target_velo = ext['pos2velo'](target_traj)
|
216 |
+
|
217 |
+
velo_mask = [-3, -1]
|
218 |
+
loss += 1 * F.l1_loss(synthesized[:, velo_mask, :], target_velo[:, velo_mask, :])
|
219 |
+
|
220 |
+
loss.backward()
|
221 |
+
optim.step()
|
222 |
+
|
223 |
+
# Update staus
|
224 |
+
losses.append(loss.item())
|
225 |
+
if 'pbar' in ext.keys():
|
226 |
+
ext['pbar'].step()
|
227 |
+
ext['pbar'].print()
|
228 |
+
|
229 |
+
return synthesized, losses
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
@torch.no_grad()
|
233 |
+
def match_and_blend(synthesized, targets, criteria, ext):
|
234 |
+
"""
|
235 |
+
Minimizes criteria(synthesized, target)
|
236 |
+
Args:
|
237 |
+
targets (torch.Tensor): Target data.
|
238 |
+
ext (dict): extra configurations.
|
239 |
+
"""
|
240 |
+
losses = []
|
241 |
+
for _i in range(ext['num_itrs']):
|
242 |
+
if 'parts_list' in ext.keys():
|
243 |
+
def extract_part_motions(motion, parts_list):
|
244 |
+
part_motions = []
|
245 |
+
n_frames = motion.shape[-1]
|
246 |
+
rot, pos = motion[:, :-3, :].reshape(-1, 6, n_frames), motion[:, -3:, :]
|
247 |
+
|
248 |
+
for part in parts_list:
|
249 |
+
# part -= 1
|
250 |
+
part = [i -1 for i in part]
|
251 |
+
|
252 |
+
# print(part)
|
253 |
+
if 0 in part:
|
254 |
+
part_motions += [torch.cat([rot[part].view(1, -1, n_frames), pos.view(1, -1, n_frames)], dim=1)]
|
255 |
+
else:
|
256 |
+
part_motions += [rot[part].view(1, -1, n_frames)]
|
257 |
+
|
258 |
+
return part_motions
|
259 |
+
def combine_part_motions(part_motions, parts_list):
|
260 |
+
assert len(part_motions) == len(parts_list)
|
261 |
+
n_frames = part_motions[0].shape[-1]
|
262 |
+
l = max(list(itertools.chain(*parts_list)))
|
263 |
+
# print(l, n_frames)
|
264 |
+
# motion = torch.zeros((1, (l+1)*6 + 3, n_frames), device=part_motions[0].device)
|
265 |
+
rot = torch.zeros(((l+1), 6, n_frames), device=part_motions[0].device)
|
266 |
+
pos = torch.zeros((1, 3, n_frames), device=part_motions[0].device)
|
267 |
+
div_rot = torch.zeros((l+1), device=part_motions[0].device)
|
268 |
+
div_pos = torch.zeros(1, device=part_motions[0].device)
|
269 |
+
|
270 |
+
for part_motion, part in zip(part_motions, parts_list):
|
271 |
+
part = [i -1 for i in part]
|
272 |
+
|
273 |
+
if 0 in part:
|
274 |
+
# print(part_motion.shape)
|
275 |
+
pos += part_motion[:, -3:, :]
|
276 |
+
div_pos += 1
|
277 |
+
rot[part] += part_motion[:, :-3, :].view(-1, 6, n_frames)
|
278 |
+
div_rot[part] += 1
|
279 |
+
else:
|
280 |
+
rot[part] += part_motion.view(-1, 6, n_frames)
|
281 |
+
div_rot[part] += 1
|
282 |
+
|
283 |
+
# print(div_rot, div_pos)
|
284 |
+
# print(rot.shape)
|
285 |
+
rot = (rot.permute(1, 2, 0) / div_rot).permute(2, 0, 1)
|
286 |
+
pos = pos / div_pos
|
287 |
+
|
288 |
+
return torch.cat([rot.view(1, -1, n_frames), pos.view(1, 3, n_frames)], dim=1)
|
289 |
+
|
290 |
+
# raw_synthesized = synthesized
|
291 |
+
# print(synthesized, synthesized.shape)
|
292 |
+
synthesized_part_motions = extract_part_motions(synthesized, ext['parts_list'])
|
293 |
+
targets_part_motions = [extract_part_motions(target, ext['parts_list']) for target in targets]
|
294 |
+
|
295 |
+
synthesized = []
|
296 |
+
for _j in range(len(synthesized_part_motions)):
|
297 |
+
synthesized_part_motion = synthesized_part_motions[_j]
|
298 |
+
# synthesized += [synthesized_part_motion]
|
299 |
+
targets_part_motion = [target[_j] for target in targets_part_motions]
|
300 |
+
# # print(synthesized_part_motion.shape, targets_part_motion[0].shape)
|
301 |
+
synthesized += [criteria(synthesized_part_motion, targets_part_motion, ext=ext, return_blended_results=True)[0]]
|
302 |
+
|
303 |
+
# print(len(synthesized))
|
304 |
+
|
305 |
+
synthesized = combine_part_motions(synthesized, ext['parts_list'])
|
306 |
+
# print(synthesized, synthesized.shape)
|
307 |
+
# print((raw_synthesized-synthesized > 0.00001).sum())
|
308 |
+
# exit()
|
309 |
+
# print(synthesized.shape)
|
310 |
+
losses = 0
|
311 |
+
|
312 |
+
# exit()
|
313 |
+
|
314 |
+
else:
|
315 |
+
synthesized, loss = criteria(synthesized, targets, ext=ext, return_blended_results=True)
|
316 |
+
|
317 |
+
# Update staus
|
318 |
+
losses.append(loss.item())
|
319 |
+
if 'pbar' in ext.keys():
|
320 |
+
ext['pbar'].step()
|
321 |
+
ext['pbar'].print()
|
322 |
+
|
323 |
+
return synthesized, losses
|
324 |
+
|
NN/losses.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Dists
|
5 |
+
|
6 |
+
def make_criteria(conf):
|
7 |
+
if conf['type'] == 'PatchCoherentLoss':
|
8 |
+
return PatchCoherentLoss(conf['patch_size'], stride=conf['stride'], loop=conf['loop'], coherent_alpha=conf['coherent_alpha'])
|
9 |
+
elif conf['type'] == 'SWDLoss':
|
10 |
+
raise NotImplementedError('SWDLoss is not implemented')
|
11 |
+
else:
|
12 |
+
raise ValueError('Invalid criteria: {}'.format(conf['criteria']))
|
13 |
+
|
14 |
+
class PatchCoherentLoss(torch.nn.Module):
|
15 |
+
def __init__(self, patch_size=7, stride=1, loop=False, coherent_alpha=None, cache=False):
|
16 |
+
super(PatchCoherentLoss, self).__init__()
|
17 |
+
self.patch_size = patch_size
|
18 |
+
self.stride = stride
|
19 |
+
self.loop = loop
|
20 |
+
self.coherent_alpha = coherent_alpha
|
21 |
+
assert self.stride == 1, "Only support stride of 1"
|
22 |
+
# assert self.patch_size % 2 == 1, "Only support odd patch size"
|
23 |
+
self.cache = cache
|
24 |
+
if cache:
|
25 |
+
self.cached_data = None
|
26 |
+
|
27 |
+
def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_results=False):
|
28 |
+
"""For each patch in input X find its NN in target Y and sum the their distances"""
|
29 |
+
assert X.shape[0] == 1, "Only support batch size of 1"
|
30 |
+
dist_fn = lambda X, Y: dist_wrapper(efficient_cdist, X, Y) if dist_wrapper is not None else efficient_cdist(X, Y)
|
31 |
+
|
32 |
+
x_patches = extract_patches(X, self.patch_size, self.stride, loop=self.loop)
|
33 |
+
|
34 |
+
if not self.cache or self.cached_data is None:
|
35 |
+
y_patches = []
|
36 |
+
for y in Ys:
|
37 |
+
y_patches += [extract_patches(y, self.patch_size, self.stride, loop=False)]
|
38 |
+
y_patches = torch.cat(y_patches, dim=1)
|
39 |
+
self.cached_data = y_patches
|
40 |
+
else:
|
41 |
+
y_patches = self.cached_data
|
42 |
+
|
43 |
+
nnf, dist = get_NNs_Dists(dist_fn, x_patches.squeeze(0), y_patches.squeeze(0), self.coherent_alpha)
|
44 |
+
|
45 |
+
if return_blended_results:
|
46 |
+
return combine_patches(X.shape, y_patches[:, nnf, :], self.patch_size, self.stride, loop=self.loop), dist.mean()
|
47 |
+
else:
|
48 |
+
return dist.mean()
|
49 |
+
|
50 |
+
def clean_cache(self):
|
51 |
+
self.cached_data = None
|
NN/utils.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import unfoldNd
|
4 |
+
|
5 |
+
def extract_patches(x, patch_size, stride, loop=False):
|
6 |
+
"""Extract patches from a motion sequence"""
|
7 |
+
b, c, _t = x.shape
|
8 |
+
|
9 |
+
# manually padding to loop
|
10 |
+
if loop:
|
11 |
+
half = patch_size // 2
|
12 |
+
front, tail = x[:,:,:half], x[:,:,-half:]
|
13 |
+
x = torch.concat([tail, x, front], dim=-1)
|
14 |
+
|
15 |
+
x_patches = unfoldNd.unfoldNd(x, kernel_size=patch_size, stride=stride).transpose(1, 2).reshape(b, -1, c, patch_size)
|
16 |
+
|
17 |
+
return x_patches.view(b, -1, c * patch_size)
|
18 |
+
|
19 |
+
def combine_patches(x_shape, ys, patch_size, stride, loop=False):
|
20 |
+
"""Combine motion patches"""
|
21 |
+
# manually handle to loop
|
22 |
+
out_shape = [*x_shape]
|
23 |
+
if loop:
|
24 |
+
padding = patch_size // 2
|
25 |
+
out_shape[-1] = out_shape[-1] + padding * 2
|
26 |
+
|
27 |
+
combined = unfoldNd.foldNd(ys.permute(0, 2, 1), output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)
|
28 |
+
|
29 |
+
# normal fold matrix
|
30 |
+
input_ones = torch.ones(tuple(out_shape), dtype=ys.dtype, device=ys.device)
|
31 |
+
divisor = unfoldNd.unfoldNd(input_ones, kernel_size=patch_size, stride=stride)
|
32 |
+
divisor = unfoldNd.foldNd(divisor, output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)
|
33 |
+
combined = (combined / divisor).squeeze(dim=0).unsqueeze(0)
|
34 |
+
|
35 |
+
if loop:
|
36 |
+
half = patch_size // 2
|
37 |
+
front, tail = combined[:,:,:half], combined[:,:,-half:]
|
38 |
+
combined[:, :, half:2 * half] = (combined[:, :, half:2 * half] + tail) / 2
|
39 |
+
combined[:, :, - 2 * half:-half] = (front + combined[:, :, - 2 * half:-half]) / 2
|
40 |
+
combined = combined[:, :, half:-half]
|
41 |
+
|
42 |
+
return combined
|
43 |
+
|
44 |
+
|
45 |
+
def efficient_cdist(X, Y):
|
46 |
+
"""
|
47 |
+
Pytorch efficient way of computing distances between all vectors in X and Y, i.e (X[:, None] - Y[None, :])**2
|
48 |
+
Get the nearest neighbor index from Y for each X
|
49 |
+
:param X: (n1, d) tensor
|
50 |
+
:param Y: (n2, d) tensor
|
51 |
+
Returns a n2 n1 of indices
|
52 |
+
"""
|
53 |
+
dist = (X * X).sum(1)[:, None] + (Y * Y).sum(1)[None, :] - 2.0 * torch.mm(X, torch.transpose(Y, 0, 1))
|
54 |
+
d = X.shape[1]
|
55 |
+
dist /= d # normalize by size of vector to make dists independent of the size of d ( use same alpha for all patche-sizes)
|
56 |
+
return dist # DO NOT use torch.sqrt
|
57 |
+
|
58 |
+
|
59 |
+
def get_col_mins_efficient(dist_fn, X, Y, b=1024):
|
60 |
+
"""
|
61 |
+
Computes the l2 distance to the closest x or each y.
|
62 |
+
:param X: (n1, d) tensor
|
63 |
+
:param Y: (n2, d) tensor
|
64 |
+
Returns n1 long array of L2 distances
|
65 |
+
"""
|
66 |
+
n_batches = len(Y) // b
|
67 |
+
mins = torch.zeros(Y.shape[0], dtype=X.dtype, device=X.device)
|
68 |
+
for i in range(n_batches):
|
69 |
+
mins[i * b:(i + 1) * b] = dist_fn(X, Y[i * b:(i + 1) * b]).min(0)[0]
|
70 |
+
if len(Y) % b != 0:
|
71 |
+
mins[n_batches * b:] = dist_fn(X, Y[n_batches * b:]).min(0)[0]
|
72 |
+
|
73 |
+
return mins
|
74 |
+
|
75 |
+
|
76 |
+
def get_NNs_Dists(dist_fn, X, Y, alpha=None, b=1024):
|
77 |
+
"""
|
78 |
+
Get the nearest neighbor index from Y for each X.
|
79 |
+
Avoids holding a (n1 * n2) amtrix in order to reducing memory footprint to (b * max(n1,n2)).
|
80 |
+
:param X: (n1, d) tensor
|
81 |
+
:param Y: (n2, d) tensor
|
82 |
+
Returns a n2 n1 of indices amd distances
|
83 |
+
"""
|
84 |
+
if alpha is not None:
|
85 |
+
normalizing_row = get_col_mins_efficient(dist_fn, X, Y, b=b)
|
86 |
+
normalizing_row = alpha + normalizing_row[None, :]
|
87 |
+
else:
|
88 |
+
normalizing_row = 1
|
89 |
+
|
90 |
+
NNs = torch.zeros(X.shape[0], dtype=torch.long, device=X.device)
|
91 |
+
Dists = torch.zeros(X.shape[0], dtype=torch.float, device=X.device)
|
92 |
+
|
93 |
+
n_batches = len(X) // b
|
94 |
+
for i in range(n_batches):
|
95 |
+
dists = dist_fn(X[i * b:(i + 1) * b], Y) / normalizing_row
|
96 |
+
NNs[i * b:(i + 1) * b] = dists.min(1)[1]
|
97 |
+
Dists[i * b:(i + 1) * b] = dists.min(1)[0]
|
98 |
+
if len(X) % b != 0:
|
99 |
+
dists = dist_fn(X[n_batches * b:], Y) / normalizing_row
|
100 |
+
NNs[n_batches * b:] = dists.min(1)[1]
|
101 |
+
Dists[n_batches * b: ] = dists.min(1)[0]
|
102 |
+
|
103 |
+
return NNs, Dists
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
-
title: GenMM
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.33.1
|
8 |
app_file: app.py
|
|
|
1 |
---
|
2 |
+
title: GenMM
|
3 |
+
emoji: 🌍
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.33.1
|
8 |
app_file: app.py
|
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import time
|
3 |
+
|
4 |
+
from dataset.tracks_motion import TracksMotion
|
5 |
+
from GPS import GPS
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
def _synthesis(synthesis_setting, motion_data):
|
9 |
+
model = GPS(
|
10 |
+
init_mode = f"random_synthesis/{synthesis_setting['frames']}",
|
11 |
+
noise_sigma = synthesis_setting['noise_sigma'],
|
12 |
+
coarse_ratio = 0.2,
|
13 |
+
pyr_factor = synthesis_setting['pyr_factor'],
|
14 |
+
num_stages_limit = -1,
|
15 |
+
silent=True,
|
16 |
+
device='cpu'
|
17 |
+
)
|
18 |
+
|
19 |
+
synthesized_motion = model.run(
|
20 |
+
motion_data,
|
21 |
+
mode="match_and_blend",
|
22 |
+
ext={
|
23 |
+
'criteria': {
|
24 |
+
'type': 'PatchCoherentLoss',
|
25 |
+
'patch_size': synthesis_setting['patch_size'],
|
26 |
+
'stride': synthesis_setting['stride'] if 'stride' in synthesis_setting.keys() else 1,
|
27 |
+
'loop': synthesis_setting['loop'],
|
28 |
+
'coherent_alpha': synthesis_setting['alpha'] if synthesis_setting['completeness'] else None,
|
29 |
+
},
|
30 |
+
'optimizer': "match_and_blend",
|
31 |
+
'num_itrs': synthesis_setting['num_steps'],
|
32 |
+
}
|
33 |
+
)
|
34 |
+
|
35 |
+
return synthesized_motion
|
36 |
+
|
37 |
+
def synthesis(data):
|
38 |
+
data = json.loads(data)
|
39 |
+
# create track object
|
40 |
+
data['setting']['coarse_ratio'] = -1
|
41 |
+
motion_data = TracksMotion(data['tracks'], scale=data['scale'])
|
42 |
+
start = time.time()
|
43 |
+
synthesized_motion = _synthesis(
|
44 |
+
data['setting'],
|
45 |
+
[motion_data]
|
46 |
+
)
|
47 |
+
end = time.time()
|
48 |
+
data['time'] = end - start
|
49 |
+
data['tracks'] = motion_data.parse(synthesized_motion)
|
50 |
+
|
51 |
+
return data
|
52 |
+
|
53 |
+
demo = gr.Interface(fn=synthesis, inputs="json", outputs="json")
|
54 |
+
demo.launch()
|
configs/random_synthesis.yaml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
outout_dir: './output/random_synthesis'
|
2 |
+
|
3 |
+
# for GANimator BVH data
|
4 |
+
skeleton_aware: true
|
5 |
+
use_velo: true
|
6 |
+
repr: 'repr6d'
|
7 |
+
contact: true
|
8 |
+
keep_y_pos: true
|
9 |
+
joint_reduction: true
|
10 |
+
|
11 |
+
|
12 |
+
# for synthesis
|
13 |
+
coarse_ratio: -1
|
14 |
+
coarse_ratio_factor: 10
|
15 |
+
pyr_factor: 0.75
|
16 |
+
num_stages_limit: -1
|
17 |
+
noise_sigma: 10.0
|
18 |
+
patch_size: 11
|
19 |
+
loop: false
|
20 |
+
loss_type: 'PatchCoherent'
|
21 |
+
coherent_alpha: 0.01
|
22 |
+
optimizer: 'RMSprop'
|
23 |
+
lr: 0.01
|
24 |
+
num_steps: 3
|
25 |
+
decay_rate: 0.9
|
26 |
+
decay_steps: 0.9
|
27 |
+
|
28 |
+
# for visualization (only for blender render)
|
29 |
+
visualization: true
|
30 |
+
fbx_path: null
|
31 |
+
reso: '[1920, 1080]'
|
32 |
+
samples: 64
|
33 |
+
fps: 30
|
34 |
+
frame_end: -1
|
35 |
+
camera_pos: '[0, -8, 2.5]'
|
36 |
+
target_pos: '[0, 2, 0.5]'
|
dataset/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
dataset/tracks_motion.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from os.path import join as pjoin
|
3 |
+
import numpy as np
|
4 |
+
import copy
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from utils.transforms import quat2repr6d, quat2euler, repr6d2quat
|
8 |
+
|
9 |
+
class TracksParser():
|
10 |
+
def __init__(self, tracks_json, scale=1.0, requires_contact=False, joint_reduction=False):
|
11 |
+
assert requires_contact==False, 'contact is not implemented for tracks data yet!!!'
|
12 |
+
|
13 |
+
self.tracks_json = tracks_json
|
14 |
+
self.scale = scale
|
15 |
+
self.requires_contact = requires_contact
|
16 |
+
self.joint_reduction = joint_reduction
|
17 |
+
|
18 |
+
self.skeleton_names = []
|
19 |
+
self.rotations = []
|
20 |
+
for i, track in enumerate(self.tracks_json):
|
21 |
+
# print(i, track['name'])
|
22 |
+
self.skeleton_names.append(track['name'])
|
23 |
+
if i == 0:
|
24 |
+
assert track['type'] == 'vector'
|
25 |
+
self.position = np.array(track['values']).reshape(-1, 3) * self.scale
|
26 |
+
self.num_frames = self.position.shape[0]
|
27 |
+
else:
|
28 |
+
assert track['type'] == 'quaternion' # DEAFULT: quaternion
|
29 |
+
rotation = np.array(track['values']).reshape(-1, 4)
|
30 |
+
if rotation.shape[0] == 0:
|
31 |
+
rotation = np.zeros((self.num_frames, 4))
|
32 |
+
elif rotation.shape[0] < self.num_frames:
|
33 |
+
rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0)
|
34 |
+
elif rotation.shape[0] > self.num_frames:
|
35 |
+
rotation = rotation[:self.num_frames]
|
36 |
+
self.rotations += [rotation]
|
37 |
+
self.rotations = np.array(self.rotations, dtype=np.float32)
|
38 |
+
|
39 |
+
def to_tensor(self, repr='euler', rot_only=False):
|
40 |
+
if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
|
41 |
+
raise Exception('Unknown rotation representation')
|
42 |
+
rotations = self.get_rotation(repr=repr)
|
43 |
+
positions = self.get_position()
|
44 |
+
|
45 |
+
if rot_only:
|
46 |
+
return rotations.reshape(rotations.shape[0], -1)
|
47 |
+
|
48 |
+
if self.requires_contact:
|
49 |
+
virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)])
|
50 |
+
virtual_contact[..., 0] = self.contact_label
|
51 |
+
rotations = torch.cat([rotations, virtual_contact], dim=1)
|
52 |
+
|
53 |
+
rotations = rotations.reshape(rotations.shape[0], -1)
|
54 |
+
return torch.cat((rotations, positions), dim=-1)
|
55 |
+
|
56 |
+
def get_rotation(self, repr='quat'):
|
57 |
+
if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':
|
58 |
+
rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1)
|
59 |
+
if repr == 'repr6d':
|
60 |
+
rotations = quat2repr6d(rotations)
|
61 |
+
if repr == 'euler':
|
62 |
+
rotations = quat2euler(rotations)
|
63 |
+
return rotations
|
64 |
+
|
65 |
+
def get_position(self):
|
66 |
+
return torch.tensor(self.position, dtype=torch.float32)
|
67 |
+
|
68 |
+
class TracksMotion:
|
69 |
+
def __init__(self, tracks_json, scale=1.0, repr='repr6d', padding=False,
|
70 |
+
use_velo=True, contact=False, keep_y_pos=True, joint_reduction=False):
|
71 |
+
self.scale = scale
|
72 |
+
self.tracks = TracksParser(tracks_json, scale, requires_contact=contact, joint_reduction=joint_reduction)
|
73 |
+
self.raw_motion = self.tracks.to_tensor(repr=repr)
|
74 |
+
self.extra = {
|
75 |
+
|
76 |
+
}
|
77 |
+
|
78 |
+
self.repr = repr
|
79 |
+
if repr == 'quat':
|
80 |
+
self.n_rot = 4
|
81 |
+
elif repr == 'repr6d':
|
82 |
+
self.n_rot = 6
|
83 |
+
elif repr == 'euler':
|
84 |
+
self.n_rot = 3
|
85 |
+
self.padding = padding
|
86 |
+
self.use_velo = use_velo
|
87 |
+
self.contact = contact
|
88 |
+
self.keep_y_pos = keep_y_pos
|
89 |
+
self.joint_reduction = joint_reduction
|
90 |
+
|
91 |
+
self.raw_motion = self.raw_motion.permute(1, 0).unsqueeze_(0) # Shape = (1, n_channel, n_frames)
|
92 |
+
self.extra['global_pos'] = self.raw_motion[:, -3:, :]
|
93 |
+
|
94 |
+
if padding:
|
95 |
+
self.n_pad = self.n_rot - 3 # pad position channels
|
96 |
+
paddings = torch.zeros_like(self.raw_motion[:, :self.n_pad])
|
97 |
+
self.raw_motion = torch.cat((self.raw_motion, paddings), dim=1)
|
98 |
+
else:
|
99 |
+
self.n_pad = 0
|
100 |
+
self.raw_motion = torch.cat((self.raw_motion[:, :-3-self.n_pad], self.raw_motion[:, -3-self.n_pad:]), dim=1)
|
101 |
+
|
102 |
+
if self.use_velo:
|
103 |
+
self.msk = [-3, -2, -1] if not keep_y_pos else [-3, -1]
|
104 |
+
self.raw_motion = self.pos2velo(self.raw_motion)
|
105 |
+
|
106 |
+
self.n_contact = len(self.tracks.skeleton.contact_id) if contact else 0
|
107 |
+
|
108 |
+
@property
|
109 |
+
def n_channels(self):
|
110 |
+
return self.raw_motion.shape[1]
|
111 |
+
|
112 |
+
def __len__(self):
|
113 |
+
return self.raw_motion.shape[-1]
|
114 |
+
|
115 |
+
def pos2velo(self, pos):
|
116 |
+
msk = [i - self.n_pad for i in self.msk]
|
117 |
+
velo = pos.detach().clone().to(pos.device)
|
118 |
+
velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1]
|
119 |
+
self.begin_pos = pos[:, msk, 0].clone()
|
120 |
+
velo[:, msk, 0] = pos[:, msk, 1]
|
121 |
+
return velo
|
122 |
+
|
123 |
+
def velo2pos(self, velo):
|
124 |
+
msk = [i - self.n_pad for i in self.msk]
|
125 |
+
pos = velo.detach().clone().to(velo.device)
|
126 |
+
pos[:, msk, 0] = self.begin_pos.to(velo.device)
|
127 |
+
pos[:, msk] = torch.cumsum(velo[:, msk], dim=-1)
|
128 |
+
return pos
|
129 |
+
|
130 |
+
def motion2pos(self, motion):
|
131 |
+
if not self.use_velo:
|
132 |
+
return motion
|
133 |
+
else:
|
134 |
+
self.velo2pos(motion.clone())
|
135 |
+
|
136 |
+
def sample(self, size=None, slerp=False, align_corners=False):
|
137 |
+
if size is None:
|
138 |
+
return {'motion': self.raw_motion, 'extra': self.extra}
|
139 |
+
else:
|
140 |
+
if slerp:
|
141 |
+
raise NotImplementedError('slerp is not not implemented yet!!!')
|
142 |
+
else:
|
143 |
+
motion = F.interpolate(self.raw_motion, size=size, mode='linear', align_corners=align_corners)
|
144 |
+
extra = {}
|
145 |
+
if 'global_pos' in self.extra.keys():
|
146 |
+
extra['global_pos'] = F.interpolate(self.extra['global_pos'], size=size, mode='linear', align_corners=align_corners)
|
147 |
+
|
148 |
+
return motion
|
149 |
+
# return {'motion': motion, 'extra': extra}
|
150 |
+
|
151 |
+
def parse(self, motion, keep_velo=False,):
|
152 |
+
"""
|
153 |
+
No batch support here!!!
|
154 |
+
:returns tracks_json
|
155 |
+
"""
|
156 |
+
motion = motion.clone()
|
157 |
+
|
158 |
+
if self.use_velo and not keep_velo:
|
159 |
+
motion = self.velo2pos(motion)
|
160 |
+
if self.n_pad:
|
161 |
+
motion = motion[:, :-self.n_pad]
|
162 |
+
if self.contact:
|
163 |
+
raise NotImplementedError('contact is not implemented yet!!!')
|
164 |
+
|
165 |
+
motion = motion.squeeze().permute(1, 0)
|
166 |
+
pos = motion[..., -3:] / self.scale
|
167 |
+
rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)
|
168 |
+
if self.repr == 'repr6d':
|
169 |
+
rot = repr6d2quat(rot)
|
170 |
+
elif self.repr == 'euler':
|
171 |
+
raise NotImplementedError('parse "euler is not implemented yet!!!')
|
172 |
+
|
173 |
+
times = []
|
174 |
+
out_tracks_json = copy.deepcopy(self.tracks.tracks_json)
|
175 |
+
for i, _track in enumerate(out_tracks_json):
|
176 |
+
if i == 0:
|
177 |
+
times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])]
|
178 |
+
out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist()
|
179 |
+
else:
|
180 |
+
out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist()
|
181 |
+
out_tracks_json[i]['times'] = times
|
182 |
+
|
183 |
+
return out_tracks_json
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
filterpy==1.4.5
|
2 |
+
torchvision==0.12.0
|
3 |
+
tensorboardX==2.5
|
4 |
+
protobuf==3.20.1
|
5 |
+
scipy==1.7.3
|
6 |
+
tqdm==4.62.3
|
7 |
+
unfoldNd
|
8 |
+
flask==2.1.3
|
9 |
+
flask-cors==3.0.10
|
10 |
+
pyyaml>=5.3.1
|
11 |
+
requests
|
12 |
+
tensorboard
|
13 |
+
transforms3d
|
14 |
+
imageio
|
15 |
+
imageio-ffmpeg
|
utils/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
utils/base.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import yaml
|
6 |
+
import imageio
|
7 |
+
import random
|
8 |
+
import shutil
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from tqdm import tqdm
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
class ConfigParser():
|
16 |
+
def __init__(self, args):
|
17 |
+
"""
|
18 |
+
class to parse configuration.
|
19 |
+
"""
|
20 |
+
args = args.parse_args()
|
21 |
+
self.cfg = self.merge_config_file(args)
|
22 |
+
|
23 |
+
# set random seed
|
24 |
+
self.set_seed()
|
25 |
+
|
26 |
+
def __str__(self):
|
27 |
+
return str(self.cfg.__dict__)
|
28 |
+
|
29 |
+
def __getattr__(self, name):
|
30 |
+
"""
|
31 |
+
Access items use dot.notation.
|
32 |
+
"""
|
33 |
+
return self.cfg.__dict__[name]
|
34 |
+
|
35 |
+
def __getitem__(self, name):
|
36 |
+
"""
|
37 |
+
Access items like ordinary dict.
|
38 |
+
"""
|
39 |
+
return self.cfg.__dict__[name]
|
40 |
+
|
41 |
+
def merge_config_file(self, args, allow_invalid=True):
|
42 |
+
"""
|
43 |
+
Load json config file and merge the arguments
|
44 |
+
"""
|
45 |
+
assert args.config is not None
|
46 |
+
with open(args.config, 'r') as f:
|
47 |
+
cfg = yaml.safe_load(f)
|
48 |
+
if 'config' in cfg.keys():
|
49 |
+
del cfg['config']
|
50 |
+
f.close()
|
51 |
+
invalid_args = list(set(cfg.keys()) - set(dir(args)))
|
52 |
+
if invalid_args and not allow_invalid:
|
53 |
+
raise ValueError(f"Invalid args {invalid_args} in {args.config}.")
|
54 |
+
|
55 |
+
for k in list(cfg.keys()):
|
56 |
+
if k in args.__dict__.keys() and args.__dict__[k] is not None:
|
57 |
+
print('=========> overwrite config: {} = {}'.format(k, args.__dict__[k]))
|
58 |
+
del cfg[k]
|
59 |
+
|
60 |
+
args.__dict__.update(cfg)
|
61 |
+
|
62 |
+
return args
|
63 |
+
|
64 |
+
def set_seed(self):
|
65 |
+
''' set random seed for random, numpy and torch. '''
|
66 |
+
if 'seed' not in self.cfg.__dict__.keys():
|
67 |
+
return
|
68 |
+
if self.cfg.seed is None:
|
69 |
+
self.cfg.seed = int(time.time()) % 1000000
|
70 |
+
print('=========> set random seed: {}'.format(self.cfg.seed))
|
71 |
+
# fix random seeds for reproducibility
|
72 |
+
random.seed(self.cfg.seed)
|
73 |
+
np.random.seed(self.cfg.seed)
|
74 |
+
torch.manual_seed(self.cfg.seed)
|
75 |
+
torch.cuda.manual_seed(self.cfg.seed)
|
76 |
+
|
77 |
+
def save_codes_and_config(self, save_path):
|
78 |
+
"""
|
79 |
+
save codes and config to $save_path.
|
80 |
+
"""
|
81 |
+
cur_codes_path = osp.dirname(osp.dirname(os.path.abspath(__file__)))
|
82 |
+
if os.path.exists(save_path):
|
83 |
+
shutil.rmtree(save_path)
|
84 |
+
shutil.copytree(cur_codes_path, osp.join(save_path, 'codes'), \
|
85 |
+
ignore=shutil.ignore_patterns('*debug*', '*data*', '*output*', '*exps*', '*.txt', '*.json', '*.mp4', '*.png', '*.jpg', '*.bvh', '*.csv', '*.pth', '*.tar', '*.npz'))
|
86 |
+
|
87 |
+
with open(osp.join(save_path, 'config.yaml'), 'w') as f:
|
88 |
+
f.write(yaml.dump(self.cfg.__dict__))
|
89 |
+
f.close()
|
90 |
+
|
91 |
+
|
92 |
+
# other utils
|
93 |
+
class logger:
|
94 |
+
"""Keeps track of the levels and steps of optimization. Logs it via TQDM"""
|
95 |
+
def __init__(self, n_steps, n_lvls):
|
96 |
+
self.n_steps = n_steps
|
97 |
+
self.n_lvls = n_lvls
|
98 |
+
self.lvl = -1
|
99 |
+
self.lvl_step = 0
|
100 |
+
self.steps = 0
|
101 |
+
self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting')
|
102 |
+
|
103 |
+
def step(self):
|
104 |
+
self.pbar.update(1)
|
105 |
+
self.steps += 1
|
106 |
+
self.lvl_step += 1
|
107 |
+
|
108 |
+
def new_lvl(self):
|
109 |
+
self.lvl += 1
|
110 |
+
self.lvl_step = 0
|
111 |
+
|
112 |
+
def print(self):
|
113 |
+
self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls-1}, step {self.lvl_step}/{self.n_steps}')
|
114 |
+
|
115 |
+
|
116 |
+
def set_seed(seed):
|
117 |
+
if seed is not None:
|
118 |
+
random.seed(seed)
|
119 |
+
np.random.seed(seed)
|
120 |
+
torch.manual_seed(seed)
|
121 |
+
torch.cuda.manual_seed(seed)
|
122 |
+
|
123 |
+
|
124 |
+
# debug utils
|
125 |
+
def draw_trajectory(trajectory, save_path=None, anim=True):
|
126 |
+
r = max(abs(trajectory.min()), trajectory.max())
|
127 |
+
if anim:
|
128 |
+
imgs = []
|
129 |
+
for i in tqdm(range(1, trajectory.shape[0])):
|
130 |
+
plt.plot(trajectory[:i, 0], trajectory[:i, 2], color='red')
|
131 |
+
plt.xlim(-r-1, r+1)
|
132 |
+
plt.ylim(-r-1, r+1)
|
133 |
+
plt.savefig(save_path + '.png')
|
134 |
+
imgs += [imageio.imread(save_path + '.png')]
|
135 |
+
imageio.mimwrite(save_path + '.mp4', imgs)
|
136 |
+
plt.close()
|
137 |
+
else:
|
138 |
+
# plt.scatter(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2])
|
139 |
+
plt.plot(trajectory[:, 0], trajectory[:, 2], color='red')
|
140 |
+
plt.xlim(-r*1.5, r*1.5)
|
141 |
+
plt.ylim(-r*1.5, r*1.5)
|
142 |
+
if save_path is not None:
|
143 |
+
plt.savefig(save_path + '.png')
|
144 |
+
plt.close()
|
145 |
+
|
146 |
+
# velo = self.raw_motion[0, self.mask, :].numpy()
|
147 |
+
# print(velo.shape)
|
148 |
+
# imgs = []
|
utils/contact.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def foot_contact_by_height(pos):
|
5 |
+
eps = 0.25
|
6 |
+
return (-eps < pos[..., 1]) * (pos[..., 1] < eps)
|
7 |
+
|
8 |
+
|
9 |
+
def velocity(pos, padding=False):
|
10 |
+
velo = pos[1:, ...] - pos[:-1, ...]
|
11 |
+
velo_norm = torch.norm(velo, dim=-1)
|
12 |
+
if padding:
|
13 |
+
pad = torch.zeros_like(velo_norm[:1, :])
|
14 |
+
velo_norm = torch.cat([pad, velo_norm], dim=0)
|
15 |
+
return velo_norm
|
16 |
+
|
17 |
+
|
18 |
+
def foot_contact(pos, ref_height=1., threshold=0.018):
|
19 |
+
velo_norm = velocity(pos)
|
20 |
+
contact = velo_norm < threshold
|
21 |
+
contact = contact.int()
|
22 |
+
padding = torch.zeros_like(contact)
|
23 |
+
contact = torch.cat([padding[:1, :], contact], dim=0)
|
24 |
+
return contact
|
25 |
+
|
26 |
+
|
27 |
+
def alpha(t):
|
28 |
+
return 2.0 * t * t * t - 3.0 * t * t + 1
|
29 |
+
|
30 |
+
|
31 |
+
def lerp(a, l, r):
|
32 |
+
return (1 - a) * l + a * r
|
33 |
+
|
34 |
+
|
35 |
+
def constrain_from_contact(contact, glb, fid='TBD', L=5):
|
36 |
+
"""
|
37 |
+
:param contact: contact label
|
38 |
+
:param glb: original global position
|
39 |
+
:param fid: joint id to fix, corresponding to the order in contact
|
40 |
+
:param L: frame to look forward/backward
|
41 |
+
:return:
|
42 |
+
"""
|
43 |
+
T = glb.shape[0]
|
44 |
+
|
45 |
+
for i, fidx in enumerate(fid): # fidx: index of the foot joint
|
46 |
+
fixed = contact[:, i] # [T]
|
47 |
+
s = 0
|
48 |
+
while s < T:
|
49 |
+
while s < T and fixed[s] == 0:
|
50 |
+
s += 1
|
51 |
+
if s >= T:
|
52 |
+
break
|
53 |
+
t = s
|
54 |
+
avg = glb[t, fidx].clone()
|
55 |
+
while t + 1 < T and fixed[t + 1] == 1:
|
56 |
+
t += 1
|
57 |
+
avg += glb[t, fidx].clone()
|
58 |
+
avg /= (t - s + 1)
|
59 |
+
|
60 |
+
for j in range(s, t + 1):
|
61 |
+
glb[j, fidx] = avg.clone()
|
62 |
+
s = t + 1
|
63 |
+
|
64 |
+
for s in range(T):
|
65 |
+
if fixed[s] == 1:
|
66 |
+
continue
|
67 |
+
l, r = None, None
|
68 |
+
consl, consr = False, False
|
69 |
+
for k in range(L):
|
70 |
+
if s - k - 1 < 0:
|
71 |
+
break
|
72 |
+
if fixed[s - k - 1]:
|
73 |
+
l = s - k - 1
|
74 |
+
consl = True
|
75 |
+
break
|
76 |
+
for k in range(L):
|
77 |
+
if s + k + 1 >= T:
|
78 |
+
break
|
79 |
+
if fixed[s + k + 1]:
|
80 |
+
r = s + k + 1
|
81 |
+
consr = True
|
82 |
+
break
|
83 |
+
if not consl and not consr:
|
84 |
+
continue
|
85 |
+
if consl and consr:
|
86 |
+
litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
|
87 |
+
glb[s, fidx], glb[l, fidx])
|
88 |
+
ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
|
89 |
+
glb[s, fidx], glb[r, fidx])
|
90 |
+
itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)),
|
91 |
+
ritp, litp)
|
92 |
+
glb[s, fidx] = itp.clone()
|
93 |
+
continue
|
94 |
+
if consl:
|
95 |
+
litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
|
96 |
+
glb[s, fidx], glb[l, fidx])
|
97 |
+
glb[s, fidx] = litp.clone()
|
98 |
+
continue
|
99 |
+
if consr:
|
100 |
+
ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
|
101 |
+
glb[s, fidx], glb[r, fidx])
|
102 |
+
glb[s, fidx] = ritp.clone()
|
103 |
+
return glb
|
utils/kinematics.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils.transforms import quat2mat, repr6d2mat, euler2mat
|
3 |
+
|
4 |
+
|
5 |
+
class ForwardKinematics:
|
6 |
+
def __init__(self, parents, offsets=None):
|
7 |
+
self.parents = parents
|
8 |
+
if offsets is not None and len(offsets.shape) == 2:
|
9 |
+
offsets = offsets.unsqueeze(0)
|
10 |
+
self.offsets = offsets
|
11 |
+
|
12 |
+
def forward(self, rots, offsets=None, global_pos=None):
|
13 |
+
"""
|
14 |
+
Forward Kinematics: returns a per-bone transformation
|
15 |
+
@param rots: local joint rotations (batch_size, bone_num, 3, 3)
|
16 |
+
@param offsets: (batch_size, bone_num, 3) or None
|
17 |
+
@param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default)
|
18 |
+
@return: (batch_szie, bone_num, 3, 4)
|
19 |
+
"""
|
20 |
+
rots = rots.clone()
|
21 |
+
if offsets is None:
|
22 |
+
offsets = self.offsets.to(rots.device)
|
23 |
+
if global_pos is None:
|
24 |
+
global_pos = offsets[:, 0]
|
25 |
+
|
26 |
+
pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device)
|
27 |
+
rest_pos = torch.zeros_like(pos)
|
28 |
+
res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device)
|
29 |
+
|
30 |
+
pos[:, 0] = global_pos
|
31 |
+
rest_pos[:, 0] = offsets[:, 0]
|
32 |
+
|
33 |
+
for i, p in enumerate(self.parents):
|
34 |
+
if i != 0:
|
35 |
+
rots[:, i] = torch.matmul(rots[:, p], rots[:, i])
|
36 |
+
pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p]
|
37 |
+
rest_pos[:, i] = rest_pos[:, p] + offsets[:, i]
|
38 |
+
|
39 |
+
res[:, i, :3, :3] = rots[:, i]
|
40 |
+
res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i]
|
41 |
+
|
42 |
+
return res
|
43 |
+
|
44 |
+
def accumulate(self, local_rots):
|
45 |
+
"""
|
46 |
+
Get global joint rotation from local rotations
|
47 |
+
@param local_rots: (batch_size, n_bone, 3, 3)
|
48 |
+
@return: global_rotations
|
49 |
+
"""
|
50 |
+
res = torch.empty_like(local_rots)
|
51 |
+
for i, p in enumerate(self.parents):
|
52 |
+
if i == 0:
|
53 |
+
res[:, i] = local_rots[:, i]
|
54 |
+
else:
|
55 |
+
res[:, i] = torch.matmul(res[:, p], local_rots[:, i])
|
56 |
+
return res
|
57 |
+
|
58 |
+
def unaccumulate(self, global_rots):
|
59 |
+
"""
|
60 |
+
Get local joint rotation from global rotations
|
61 |
+
@param global_rots: (batch_size, n_bone, 3, 3)
|
62 |
+
@return: local_rotations
|
63 |
+
"""
|
64 |
+
res = torch.empty_like(global_rots)
|
65 |
+
inv = torch.empty_like(global_rots)
|
66 |
+
|
67 |
+
for i, p in enumerate(self.parents):
|
68 |
+
if i == 0:
|
69 |
+
inv[:, i] = global_rots[:, i].transpose(-2, -1)
|
70 |
+
res[:, i] = global_rots[:, i]
|
71 |
+
continue
|
72 |
+
res[:, i] = torch.matmul(inv[:, p], global_rots[:, i])
|
73 |
+
inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p])
|
74 |
+
|
75 |
+
return res
|
76 |
+
|
77 |
+
|
78 |
+
class ForwardKinematicsJoint:
|
79 |
+
def __init__(self, parents, offset):
|
80 |
+
self.parents = parents
|
81 |
+
self.offset = offset
|
82 |
+
|
83 |
+
'''
|
84 |
+
rotation should have shape batch_size * Joint_num * (3/4) * Time
|
85 |
+
position should have shape batch_size * 3 * Time
|
86 |
+
offset should have shape batch_size * Joint_num * 3
|
87 |
+
output have shape batch_size * Time * Joint_num * 3
|
88 |
+
'''
|
89 |
+
|
90 |
+
def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None,
|
91 |
+
world=True):
|
92 |
+
'''
|
93 |
+
if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation')
|
94 |
+
if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation')
|
95 |
+
rotation = rotation.permute(0, 3, 1, 2)
|
96 |
+
position = position.permute(0, 2, 1)
|
97 |
+
'''
|
98 |
+
if rotation.shape[-1] == 6:
|
99 |
+
transform = repr6d2mat(rotation)
|
100 |
+
elif rotation.shape[-1] == 4:
|
101 |
+
norm = torch.norm(rotation, dim=-1, keepdim=True)
|
102 |
+
rotation = rotation / norm
|
103 |
+
transform = quat2mat(rotation)
|
104 |
+
elif rotation.shape[-1] == 3:
|
105 |
+
transform = euler2mat(rotation)
|
106 |
+
else:
|
107 |
+
raise Exception('Only accept quaternion rotation input')
|
108 |
+
result = torch.empty(transform.shape[:-2] + (3,), device=position.device)
|
109 |
+
|
110 |
+
if offset is None:
|
111 |
+
offset = self.offset
|
112 |
+
offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1))
|
113 |
+
|
114 |
+
result[..., 0, :] = position
|
115 |
+
for i, pi in enumerate(self.parents):
|
116 |
+
if pi == -1:
|
117 |
+
assert i == 0
|
118 |
+
continue
|
119 |
+
|
120 |
+
result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze()
|
121 |
+
transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone())
|
122 |
+
if world: result[..., i, :] += result[..., pi, :]
|
123 |
+
return result
|
124 |
+
|
125 |
+
|
126 |
+
class InverseKinematicsJoint:
|
127 |
+
def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains):
|
128 |
+
self.rotations = rotations.detach().clone()
|
129 |
+
self.rotations.requires_grad_(True)
|
130 |
+
self.position = positions.detach().clone()
|
131 |
+
self.position.requires_grad_(True)
|
132 |
+
|
133 |
+
self.parents = parents
|
134 |
+
self.offset = offset
|
135 |
+
self.constrains = constrains
|
136 |
+
|
137 |
+
self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
|
138 |
+
self.criteria = torch.nn.MSELoss()
|
139 |
+
|
140 |
+
self.fk = ForwardKinematicsJoint(parents, offset)
|
141 |
+
|
142 |
+
self.glb = None
|
143 |
+
|
144 |
+
def step(self):
|
145 |
+
self.optimizer.zero_grad()
|
146 |
+
glb = self.fk.forward(self.rotations, self.position)
|
147 |
+
loss = self.criteria(glb, self.constrains)
|
148 |
+
loss.backward()
|
149 |
+
self.optimizer.step()
|
150 |
+
self.glb = glb
|
151 |
+
return loss.item()
|
152 |
+
|
153 |
+
|
154 |
+
class InverseKinematicsJoint2:
|
155 |
+
def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid,
|
156 |
+
lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False):
|
157 |
+
self.use_velo = use_velo
|
158 |
+
self.rotations_ori = rotations.detach().clone()
|
159 |
+
self.rotations = rotations.detach().clone()
|
160 |
+
self.rotations.requires_grad_(True)
|
161 |
+
self.position_ori = positions.detach().clone()
|
162 |
+
self.position = positions.detach().clone()
|
163 |
+
if self.use_velo:
|
164 |
+
self.position[1:] = self.position[1:] - self.position[:-1]
|
165 |
+
self.position.requires_grad_(True)
|
166 |
+
|
167 |
+
self.parents = parents
|
168 |
+
self.offset = offset
|
169 |
+
self.constrains = constrains.detach().clone()
|
170 |
+
self.cid = cid
|
171 |
+
|
172 |
+
self.lambda_rec_rot = lambda_rec_rot
|
173 |
+
self.lambda_rec_pos = lambda_rec_pos
|
174 |
+
|
175 |
+
self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
|
176 |
+
self.criteria = torch.nn.MSELoss()
|
177 |
+
|
178 |
+
self.fk = ForwardKinematicsJoint(parents, offset)
|
179 |
+
|
180 |
+
self.glb = None
|
181 |
+
|
182 |
+
def step(self):
|
183 |
+
self.optimizer.zero_grad()
|
184 |
+
if self.use_velo:
|
185 |
+
position = torch.cumsum(self.position, dim=0)
|
186 |
+
else:
|
187 |
+
position = self.position
|
188 |
+
glb = self.fk.forward(self.rotations, position)
|
189 |
+
self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains)
|
190 |
+
self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori)
|
191 |
+
self.rec_loss_pos = self.criteria(self.position, self.position_ori)
|
192 |
+
loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos
|
193 |
+
loss.backward()
|
194 |
+
self.optimizer.step()
|
195 |
+
self.glb = glb
|
196 |
+
return loss.item()
|
197 |
+
|
198 |
+
def get_position(self):
|
199 |
+
if self.use_velo:
|
200 |
+
position = torch.cumsum(self.position.detach(), dim=0)
|
201 |
+
else:
|
202 |
+
position = self.position.detach()
|
203 |
+
return position
|
utils/skeleton.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class SkeletonConv(nn.Module):
|
9 |
+
def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
|
10 |
+
bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
|
11 |
+
super(SkeletonConv, self).__init__()
|
12 |
+
|
13 |
+
if in_channels % joint_num != 0 or out_channels % joint_num != 0:
|
14 |
+
raise Exception('in/out channels should be divided by joint_num')
|
15 |
+
self.in_channels_per_joint = in_channels // joint_num
|
16 |
+
self.out_channels_per_joint = out_channels // joint_num
|
17 |
+
|
18 |
+
if padding_mode == 'zeros': padding_mode = 'constant'
|
19 |
+
|
20 |
+
self.expanded_neighbour_list = []
|
21 |
+
self.expanded_neighbour_list_offset = []
|
22 |
+
self.neighbour_list = neighbour_list
|
23 |
+
self.add_offset = add_offset
|
24 |
+
self.joint_num = joint_num
|
25 |
+
|
26 |
+
self.stride = stride
|
27 |
+
self.dilation = 1
|
28 |
+
self.groups = 1
|
29 |
+
self.padding = padding
|
30 |
+
self.padding_mode = padding_mode
|
31 |
+
self._padding_repeated_twice = (padding, padding)
|
32 |
+
|
33 |
+
for neighbour in neighbour_list:
|
34 |
+
expanded = []
|
35 |
+
for k in neighbour:
|
36 |
+
for i in range(self.in_channels_per_joint):
|
37 |
+
expanded.append(k * self.in_channels_per_joint + i)
|
38 |
+
self.expanded_neighbour_list.append(expanded)
|
39 |
+
|
40 |
+
if self.add_offset:
|
41 |
+
self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
|
42 |
+
|
43 |
+
for neighbour in neighbour_list:
|
44 |
+
expanded = []
|
45 |
+
for k in neighbour:
|
46 |
+
for i in range(add_offset):
|
47 |
+
expanded.append(k * in_offset_channel + i)
|
48 |
+
self.expanded_neighbour_list_offset.append(expanded)
|
49 |
+
|
50 |
+
self.weight = torch.zeros(out_channels, in_channels, kernel_size)
|
51 |
+
if bias:
|
52 |
+
self.bias = torch.zeros(out_channels)
|
53 |
+
else:
|
54 |
+
self.register_parameter('bias', None)
|
55 |
+
|
56 |
+
self.mask = torch.zeros_like(self.weight)
|
57 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
58 |
+
self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
|
59 |
+
self.mask = nn.Parameter(self.mask, requires_grad=False)
|
60 |
+
|
61 |
+
self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
|
62 |
+
'joint_num={}, stride={}, padding={}, bias={})'.format(
|
63 |
+
in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
|
64 |
+
)
|
65 |
+
|
66 |
+
self.reset_parameters()
|
67 |
+
|
68 |
+
def reset_parameters(self):
|
69 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
70 |
+
""" Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """
|
71 |
+
tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
|
72 |
+
neighbour, ...])
|
73 |
+
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
|
74 |
+
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
|
75 |
+
neighbour, ...] = tmp
|
76 |
+
if self.bias is not None:
|
77 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
|
78 |
+
self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
|
79 |
+
bound = 1 / math.sqrt(fan_in)
|
80 |
+
tmp = torch.zeros_like(
|
81 |
+
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
|
82 |
+
nn.init.uniform_(tmp, -bound, bound)
|
83 |
+
self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp
|
84 |
+
|
85 |
+
self.weight = nn.Parameter(self.weight)
|
86 |
+
if self.bias is not None:
|
87 |
+
self.bias = nn.Parameter(self.bias)
|
88 |
+
|
89 |
+
def set_offset(self, offset):
|
90 |
+
if not self.add_offset: raise Exception('Wrong Combination of Parameters')
|
91 |
+
self.offset = offset.reshape(offset.shape[0], -1)
|
92 |
+
|
93 |
+
def forward(self, input):
|
94 |
+
weight_masked = self.weight * self.mask
|
95 |
+
res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
|
96 |
+
weight_masked, self.bias, self.stride,
|
97 |
+
0, self.dilation, self.groups)
|
98 |
+
|
99 |
+
if self.add_offset:
|
100 |
+
offset_res = self.offset_enc(self.offset)
|
101 |
+
offset_res = offset_res.reshape(offset_res.shape + (1, ))
|
102 |
+
res += offset_res / 100
|
103 |
+
return res
|
104 |
+
|
105 |
+
def __repr__(self):
|
106 |
+
return self.description
|
107 |
+
|
108 |
+
|
109 |
+
class SkeletonLinear(nn.Module):
|
110 |
+
def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
|
111 |
+
super(SkeletonLinear, self).__init__()
|
112 |
+
self.neighbour_list = neighbour_list
|
113 |
+
self.in_channels = in_channels
|
114 |
+
self.out_channels = out_channels
|
115 |
+
self.in_channels_per_joint = in_channels // len(neighbour_list)
|
116 |
+
self.out_channels_per_joint = out_channels // len(neighbour_list)
|
117 |
+
self.extra_dim1 = extra_dim1
|
118 |
+
self.expanded_neighbour_list = []
|
119 |
+
|
120 |
+
for neighbour in neighbour_list:
|
121 |
+
expanded = []
|
122 |
+
for k in neighbour:
|
123 |
+
for i in range(self.in_channels_per_joint):
|
124 |
+
expanded.append(k * self.in_channels_per_joint + i)
|
125 |
+
self.expanded_neighbour_list.append(expanded)
|
126 |
+
|
127 |
+
self.weight = torch.zeros(out_channels, in_channels)
|
128 |
+
self.mask = torch.zeros(out_channels, in_channels)
|
129 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
130 |
+
|
131 |
+
self.reset_parameters()
|
132 |
+
|
133 |
+
def reset_parameters(self):
|
134 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
135 |
+
tmp = torch.zeros_like(
|
136 |
+
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
|
137 |
+
)
|
138 |
+
self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
|
139 |
+
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
|
140 |
+
self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp
|
141 |
+
|
142 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
143 |
+
bound = 1 / math.sqrt(fan_in)
|
144 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
145 |
+
|
146 |
+
self.weight = nn.Parameter(self.weight)
|
147 |
+
self.mask = nn.Parameter(self.mask, requires_grad=False)
|
148 |
+
|
149 |
+
def forward(self, input):
|
150 |
+
input = input.reshape(input.shape[0], -1)
|
151 |
+
weight_masked = self.weight * self.mask
|
152 |
+
res = F.linear(input, weight_masked, self.bias)
|
153 |
+
if self.extra_dim1: res = res.reshape(res.shape + (1,))
|
154 |
+
return res
|
155 |
+
|
156 |
+
|
157 |
+
class SkeletonPoolJoint(nn.Module):
|
158 |
+
def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False):
|
159 |
+
super(SkeletonPoolJoint, self).__init__()
|
160 |
+
|
161 |
+
if pooling_mode != 'mean':
|
162 |
+
raise Exception('Unimplemented pooling mode in matrix_implementation')
|
163 |
+
|
164 |
+
self.joint_num = len(topology)
|
165 |
+
self.parent = topology
|
166 |
+
self.pooling_list = []
|
167 |
+
self.pooling_mode = pooling_mode
|
168 |
+
|
169 |
+
self.pooling_map = [-1 for _ in range(len(self.parent))]
|
170 |
+
self.child = [-1 for _ in range(len(self.parent))]
|
171 |
+
children_cnt = [0 for _ in range(len(self.parent))]
|
172 |
+
for x, pa in enumerate(self.parent):
|
173 |
+
if pa < 0: continue
|
174 |
+
children_cnt[pa] += 1
|
175 |
+
self.child[pa] = x
|
176 |
+
self.pooling_map[0] = 0
|
177 |
+
for x in range(len(self.parent)):
|
178 |
+
if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1):
|
179 |
+
while children_cnt[x] <= 1:
|
180 |
+
pa = self.parent[x]
|
181 |
+
if last_pool:
|
182 |
+
seq = [x]
|
183 |
+
while pa != -1 and children_cnt[pa] == 1:
|
184 |
+
seq = [pa] + seq
|
185 |
+
x = pa
|
186 |
+
pa = self.parent[x]
|
187 |
+
self.pooling_list.append(seq)
|
188 |
+
break
|
189 |
+
else:
|
190 |
+
if pa != -1 and children_cnt[pa] == 1:
|
191 |
+
self.pooling_list.append([pa, x])
|
192 |
+
x = self.parent[pa]
|
193 |
+
else:
|
194 |
+
self.pooling_list.append([x, ])
|
195 |
+
break
|
196 |
+
elif children_cnt[x] > 1:
|
197 |
+
self.pooling_list.append([x, ])
|
198 |
+
|
199 |
+
self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format(
|
200 |
+
len(topology), len(self.pooling_list),
|
201 |
+
)
|
202 |
+
|
203 |
+
self.pooling_list.sort(key=lambda x:x[0])
|
204 |
+
for i, a in enumerate(self.pooling_list):
|
205 |
+
for j in a:
|
206 |
+
self.pooling_map[j] = i
|
207 |
+
|
208 |
+
self.output_joint_num = len(self.pooling_list)
|
209 |
+
self.new_topology = [-1 for _ in range(len(self.pooling_list))]
|
210 |
+
for i, x in enumerate(self.pooling_list):
|
211 |
+
if i < 1: continue
|
212 |
+
self.new_topology[i] = self.pooling_map[self.parent[x[0]]]
|
213 |
+
|
214 |
+
self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint)
|
215 |
+
|
216 |
+
for i, pair in enumerate(self.pooling_list):
|
217 |
+
for j in pair:
|
218 |
+
for c in range(channels_per_joint):
|
219 |
+
self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair)
|
220 |
+
|
221 |
+
self.weight = nn.Parameter(self.weight, requires_grad=False)
|
222 |
+
|
223 |
+
def forward(self, input: torch.Tensor):
|
224 |
+
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
|
225 |
+
|
226 |
+
|
227 |
+
class SkeletonPool(nn.Module):
|
228 |
+
def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
|
229 |
+
super(SkeletonPool, self).__init__()
|
230 |
+
|
231 |
+
if pooling_mode != 'mean':
|
232 |
+
raise Exception('Unimplemented pooling mode in matrix_implementation')
|
233 |
+
|
234 |
+
self.channels_per_edge = channels_per_edge
|
235 |
+
self.pooling_mode = pooling_mode
|
236 |
+
self.edge_num = len(edges) + 1
|
237 |
+
self.seq_list = []
|
238 |
+
self.pooling_list = []
|
239 |
+
self.new_edges = []
|
240 |
+
degree = [0] * 100
|
241 |
+
|
242 |
+
for edge in edges:
|
243 |
+
degree[edge[0]] += 1
|
244 |
+
degree[edge[1]] += 1
|
245 |
+
|
246 |
+
def find_seq(j, seq):
|
247 |
+
nonlocal self, degree, edges
|
248 |
+
|
249 |
+
if degree[j] > 2 and j != 0:
|
250 |
+
self.seq_list.append(seq)
|
251 |
+
seq = []
|
252 |
+
|
253 |
+
if degree[j] == 1:
|
254 |
+
self.seq_list.append(seq)
|
255 |
+
return
|
256 |
+
|
257 |
+
for idx, edge in enumerate(edges):
|
258 |
+
if edge[0] == j:
|
259 |
+
find_seq(edge[1], seq + [idx])
|
260 |
+
|
261 |
+
find_seq(0, [])
|
262 |
+
for seq in self.seq_list:
|
263 |
+
if last_pool:
|
264 |
+
self.pooling_list.append(seq)
|
265 |
+
continue
|
266 |
+
if len(seq) % 2 == 1:
|
267 |
+
self.pooling_list.append([seq[0]])
|
268 |
+
self.new_edges.append(edges[seq[0]])
|
269 |
+
seq = seq[1:]
|
270 |
+
for i in range(0, len(seq), 2):
|
271 |
+
self.pooling_list.append([seq[i], seq[i + 1]])
|
272 |
+
self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
|
273 |
+
|
274 |
+
# add global position
|
275 |
+
self.pooling_list.append([self.edge_num - 1])
|
276 |
+
|
277 |
+
self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
|
278 |
+
len(edges), len(self.pooling_list)
|
279 |
+
)
|
280 |
+
|
281 |
+
self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
|
282 |
+
|
283 |
+
for i, pair in enumerate(self.pooling_list):
|
284 |
+
for j in pair:
|
285 |
+
for c in range(channels_per_edge):
|
286 |
+
self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
|
287 |
+
|
288 |
+
self.weight = nn.Parameter(self.weight, requires_grad=False)
|
289 |
+
|
290 |
+
def forward(self, input: torch.Tensor):
|
291 |
+
return torch.matmul(self.weight, input)
|
292 |
+
|
293 |
+
|
294 |
+
class SkeletonUnpool(nn.Module):
|
295 |
+
def __init__(self, pooling_list, channels_per_edge):
|
296 |
+
super(SkeletonUnpool, self).__init__()
|
297 |
+
self.pooling_list = pooling_list
|
298 |
+
self.input_joint_num = len(pooling_list)
|
299 |
+
self.output_joint_num = 0
|
300 |
+
self.channels_per_edge = channels_per_edge
|
301 |
+
for t in self.pooling_list:
|
302 |
+
self.output_joint_num += len(t)
|
303 |
+
|
304 |
+
self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format(
|
305 |
+
self.input_joint_num, self.output_joint_num,
|
306 |
+
)
|
307 |
+
|
308 |
+
self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge)
|
309 |
+
|
310 |
+
for i, pair in enumerate(self.pooling_list):
|
311 |
+
for j in pair:
|
312 |
+
for c in range(channels_per_edge):
|
313 |
+
self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
|
314 |
+
|
315 |
+
self.weight = nn.Parameter(self.weight)
|
316 |
+
self.weight.requires_grad_(False)
|
317 |
+
|
318 |
+
def forward(self, input: torch.Tensor):
|
319 |
+
return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
|
320 |
+
|
321 |
+
|
322 |
+
def find_neighbor_joint(parents, threshold):
|
323 |
+
n_joint = len(parents)
|
324 |
+
dist_mat = np.empty((n_joint, n_joint), dtype=np.int)
|
325 |
+
dist_mat[:, :] = 100000
|
326 |
+
for i, p in enumerate(parents):
|
327 |
+
dist_mat[i, i] = 0
|
328 |
+
if i != 0:
|
329 |
+
dist_mat[i, p] = dist_mat[p, i] = 1
|
330 |
+
|
331 |
+
"""
|
332 |
+
Floyd's algorithm
|
333 |
+
"""
|
334 |
+
for k in range(n_joint):
|
335 |
+
for i in range(n_joint):
|
336 |
+
for j in range(n_joint):
|
337 |
+
dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j])
|
338 |
+
|
339 |
+
neighbor_list = []
|
340 |
+
for i in range(n_joint):
|
341 |
+
neighbor = []
|
342 |
+
for j in range(n_joint):
|
343 |
+
if dist_mat[i, j] <= threshold:
|
344 |
+
neighbor.append(j)
|
345 |
+
neighbor_list.append(neighbor)
|
346 |
+
|
347 |
+
return neighbor_list
|
utils/transforms.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def batch_mm(matrix, matrix_batch):
|
6 |
+
"""
|
7 |
+
https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242
|
8 |
+
:param matrix: Sparse or dense matrix, size (m, n).
|
9 |
+
:param matrix_batch: Batched dense matrices, size (b, n, k).
|
10 |
+
:return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k).
|
11 |
+
"""
|
12 |
+
batch_size = matrix_batch.shape[0]
|
13 |
+
# Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k)
|
14 |
+
vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1)
|
15 |
+
|
16 |
+
# A matrix-matrix product is a batched matrix-vector product of the columns.
|
17 |
+
# And then reverse the reshaping. (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k)
|
18 |
+
return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0)
|
19 |
+
|
20 |
+
|
21 |
+
def aa2quat(rots, form='wxyz', unified_orient=True):
|
22 |
+
"""
|
23 |
+
Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0)
|
24 |
+
@param rots: angle-axis rotations, (*, 3)
|
25 |
+
@param form: quaternion format, either 'wxyz' or 'xyzw'
|
26 |
+
@param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3)
|
27 |
+
:return:
|
28 |
+
"""
|
29 |
+
angles = rots.norm(dim=-1, keepdim=True)
|
30 |
+
norm = angles.clone()
|
31 |
+
norm[norm < 1e-8] = 1
|
32 |
+
axis = rots / norm
|
33 |
+
quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype)
|
34 |
+
angles = angles * 0.5
|
35 |
+
if form == 'wxyz':
|
36 |
+
quats[..., 0] = torch.cos(angles.squeeze(-1))
|
37 |
+
quats[..., 1:] = torch.sin(angles) * axis
|
38 |
+
elif form == 'xyzw':
|
39 |
+
quats[..., :3] = torch.sin(angles) * axis
|
40 |
+
quats[..., 3] = torch.cos(angles.squeeze(-1))
|
41 |
+
|
42 |
+
if unified_orient:
|
43 |
+
idx = quats[..., 0] < 0
|
44 |
+
quats[idx, :] *= -1
|
45 |
+
|
46 |
+
return quats
|
47 |
+
|
48 |
+
|
49 |
+
def quat2aa(quats):
|
50 |
+
"""
|
51 |
+
Convert wxyz quaternions to angle-axis representation
|
52 |
+
:param quats:
|
53 |
+
:return:
|
54 |
+
"""
|
55 |
+
_cos = quats[..., 0]
|
56 |
+
xyz = quats[..., 1:]
|
57 |
+
_sin = xyz.norm(dim=-1)
|
58 |
+
norm = _sin.clone()
|
59 |
+
norm[norm < 1e-7] = 1
|
60 |
+
axis = xyz / norm.unsqueeze(-1)
|
61 |
+
angle = torch.atan2(_sin, _cos) * 2
|
62 |
+
return axis * angle.unsqueeze(-1)
|
63 |
+
|
64 |
+
|
65 |
+
def quat2mat(quats: torch.Tensor):
|
66 |
+
"""
|
67 |
+
Convert (w, x, y, z) quaternions to 3x3 rotation matrix
|
68 |
+
:param quats: quaternions of shape (..., 4)
|
69 |
+
:return: rotation matrices of shape (..., 3, 3)
|
70 |
+
"""
|
71 |
+
qw = quats[..., 0]
|
72 |
+
qx = quats[..., 1]
|
73 |
+
qy = quats[..., 2]
|
74 |
+
qz = quats[..., 3]
|
75 |
+
|
76 |
+
x2 = qx + qx
|
77 |
+
y2 = qy + qy
|
78 |
+
z2 = qz + qz
|
79 |
+
xx = qx * x2
|
80 |
+
yy = qy * y2
|
81 |
+
wx = qw * x2
|
82 |
+
xy = qx * y2
|
83 |
+
yz = qy * z2
|
84 |
+
wy = qw * y2
|
85 |
+
xz = qx * z2
|
86 |
+
zz = qz * z2
|
87 |
+
wz = qw * z2
|
88 |
+
|
89 |
+
m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype)
|
90 |
+
m[..., 0, 0] = 1.0 - (yy + zz)
|
91 |
+
m[..., 0, 1] = xy - wz
|
92 |
+
m[..., 0, 2] = xz + wy
|
93 |
+
m[..., 1, 0] = xy + wz
|
94 |
+
m[..., 1, 1] = 1.0 - (xx + zz)
|
95 |
+
m[..., 1, 2] = yz - wx
|
96 |
+
m[..., 2, 0] = xz - wy
|
97 |
+
m[..., 2, 1] = yz + wx
|
98 |
+
m[..., 2, 2] = 1.0 - (xx + yy)
|
99 |
+
|
100 |
+
return m
|
101 |
+
|
102 |
+
|
103 |
+
def quat2euler(q, order='xyz', degrees=True):
|
104 |
+
"""
|
105 |
+
Convert (w, x, y, z) quaternions to xyz euler angles. This is used for bvh output.
|
106 |
+
"""
|
107 |
+
q0 = q[..., 0]
|
108 |
+
q1 = q[..., 1]
|
109 |
+
q2 = q[..., 2]
|
110 |
+
q3 = q[..., 3]
|
111 |
+
es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype)
|
112 |
+
|
113 |
+
if order == 'xyz':
|
114 |
+
es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
|
115 |
+
es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1))
|
116 |
+
es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
|
117 |
+
else:
|
118 |
+
raise NotImplementedError('Cannot convert to ordering %s' % order)
|
119 |
+
|
120 |
+
if degrees:
|
121 |
+
es = es * 180 / np.pi
|
122 |
+
|
123 |
+
return es
|
124 |
+
|
125 |
+
|
126 |
+
def euler2mat(rots, order='xyz'):
|
127 |
+
axis = {'x': torch.tensor((1, 0, 0), device=rots.device),
|
128 |
+
'y': torch.tensor((0, 1, 0), device=rots.device),
|
129 |
+
'z': torch.tensor((0, 0, 1), device=rots.device)}
|
130 |
+
|
131 |
+
rots = rots / 180 * np.pi
|
132 |
+
mats = []
|
133 |
+
for i in range(3):
|
134 |
+
aa = axis[order[i]] * rots[..., i].unsqueeze(-1)
|
135 |
+
mats.append(aa2mat(aa))
|
136 |
+
return mats[0] @ (mats[1] @ mats[2])
|
137 |
+
|
138 |
+
|
139 |
+
def aa2mat(rots):
|
140 |
+
"""
|
141 |
+
Convert angle-axis representation to rotation matrix
|
142 |
+
:param rots: angle-axis representation
|
143 |
+
:return:
|
144 |
+
"""
|
145 |
+
quat = aa2quat(rots)
|
146 |
+
mat = quat2mat(quat)
|
147 |
+
return mat
|
148 |
+
|
149 |
+
|
150 |
+
def mat2quat(R) -> torch.Tensor:
|
151 |
+
'''
|
152 |
+
https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py
|
153 |
+
Convert a rotation matrix to a unit quaternion.
|
154 |
+
|
155 |
+
This uses the Shepperd’s method for numerical stability.
|
156 |
+
'''
|
157 |
+
|
158 |
+
# The rotation matrix must be orthonormal
|
159 |
+
|
160 |
+
w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2])
|
161 |
+
x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])
|
162 |
+
y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2])
|
163 |
+
z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2])
|
164 |
+
|
165 |
+
yz = (R[..., 1, 2] + R[..., 2, 1])
|
166 |
+
xz = (R[..., 2, 0] + R[..., 0, 2])
|
167 |
+
xy = (R[..., 0, 1] + R[..., 1, 0])
|
168 |
+
|
169 |
+
wx = (R[..., 2, 1] - R[..., 1, 2])
|
170 |
+
wy = (R[..., 0, 2] - R[..., 2, 0])
|
171 |
+
wz = (R[..., 1, 0] - R[..., 0, 1])
|
172 |
+
|
173 |
+
w = torch.empty_like(x2)
|
174 |
+
x = torch.empty_like(x2)
|
175 |
+
y = torch.empty_like(x2)
|
176 |
+
z = torch.empty_like(x2)
|
177 |
+
|
178 |
+
flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1])
|
179 |
+
flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1])
|
180 |
+
flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1])
|
181 |
+
flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1])
|
182 |
+
|
183 |
+
x[flagA] = torch.sqrt(x2[flagA])
|
184 |
+
w[flagA] = wx[flagA] / x[flagA]
|
185 |
+
y[flagA] = xy[flagA] / x[flagA]
|
186 |
+
z[flagA] = xz[flagA] / x[flagA]
|
187 |
+
|
188 |
+
y[flagB] = torch.sqrt(y2[flagB])
|
189 |
+
w[flagB] = wy[flagB] / y[flagB]
|
190 |
+
x[flagB] = xy[flagB] / y[flagB]
|
191 |
+
z[flagB] = yz[flagB] / y[flagB]
|
192 |
+
|
193 |
+
z[flagC] = torch.sqrt(z2[flagC])
|
194 |
+
w[flagC] = wz[flagC] / z[flagC]
|
195 |
+
x[flagC] = xz[flagC] / z[flagC]
|
196 |
+
y[flagC] = yz[flagC] / z[flagC]
|
197 |
+
|
198 |
+
w[flagD] = torch.sqrt(w2[flagD])
|
199 |
+
x[flagD] = wx[flagD] / w[flagD]
|
200 |
+
y[flagD] = wy[flagD] / w[flagD]
|
201 |
+
z[flagD] = wz[flagD] / w[flagD]
|
202 |
+
|
203 |
+
# if R[..., 2, 2] < 0:
|
204 |
+
#
|
205 |
+
# if R[..., 0, 0] > R[..., 1, 1]:
|
206 |
+
#
|
207 |
+
# x = torch.sqrt(x2)
|
208 |
+
# w = wx / x
|
209 |
+
# y = xy / x
|
210 |
+
# z = xz / x
|
211 |
+
#
|
212 |
+
# else:
|
213 |
+
#
|
214 |
+
# y = torch.sqrt(y2)
|
215 |
+
# w = wy / y
|
216 |
+
# x = xy / y
|
217 |
+
# z = yz / y
|
218 |
+
#
|
219 |
+
# else:
|
220 |
+
#
|
221 |
+
# if R[..., 0, 0] < -R[..., 1, 1]:
|
222 |
+
#
|
223 |
+
# z = torch.sqrt(z2)
|
224 |
+
# w = wz / z
|
225 |
+
# x = xz / z
|
226 |
+
# y = yz / z
|
227 |
+
#
|
228 |
+
# else:
|
229 |
+
#
|
230 |
+
# w = torch.sqrt(w2)
|
231 |
+
# x = wx / w
|
232 |
+
# y = wy / w
|
233 |
+
# z = wz / w
|
234 |
+
|
235 |
+
res = [w, x, y, z]
|
236 |
+
res = [z.unsqueeze(-1) for z in res]
|
237 |
+
|
238 |
+
return torch.cat(res, dim=-1) / 2
|
239 |
+
|
240 |
+
|
241 |
+
def quat2repr6d(quat):
|
242 |
+
mat = quat2mat(quat)
|
243 |
+
res = mat[..., :2, :]
|
244 |
+
res = res.reshape(res.shape[:-2] + (6, ))
|
245 |
+
return res
|
246 |
+
|
247 |
+
|
248 |
+
def repr6d2mat(repr):
|
249 |
+
x = repr[..., :3]
|
250 |
+
y = repr[..., 3:]
|
251 |
+
x = x / x.norm(dim=-1, keepdim=True)
|
252 |
+
z = torch.cross(x, y)
|
253 |
+
z = z / z.norm(dim=-1, keepdim=True)
|
254 |
+
y = torch.cross(z, x)
|
255 |
+
res = [x, y, z]
|
256 |
+
res = [v.unsqueeze(-2) for v in res]
|
257 |
+
mat = torch.cat(res, dim=-2)
|
258 |
+
return mat
|
259 |
+
|
260 |
+
|
261 |
+
def repr6d2quat(repr) -> torch.Tensor:
|
262 |
+
x = repr[..., :3]
|
263 |
+
y = repr[..., 3:]
|
264 |
+
x = x / x.norm(dim=-1, keepdim=True)
|
265 |
+
z = torch.cross(x, y)
|
266 |
+
z = z / z.norm(dim=-1, keepdim=True)
|
267 |
+
y = torch.cross(z, x)
|
268 |
+
res = [x, y, z]
|
269 |
+
res = [v.unsqueeze(-2) for v in res]
|
270 |
+
mat = torch.cat(res, dim=-2)
|
271 |
+
return mat2quat(mat)
|
272 |
+
|
273 |
+
|
274 |
+
def inv_affine(mat):
|
275 |
+
"""
|
276 |
+
Calculate the inverse of any affine transformation
|
277 |
+
"""
|
278 |
+
affine = torch.zeros((mat.shape[:2] + (1, 4)))
|
279 |
+
affine[..., 3] = 1
|
280 |
+
vert_mat = torch.cat((mat, affine), dim=2)
|
281 |
+
vert_mat_inv = torch.inverse(vert_mat)
|
282 |
+
return vert_mat_inv[..., :3, :]
|
283 |
+
|
284 |
+
|
285 |
+
def inv_rigid_affine(mat):
|
286 |
+
"""
|
287 |
+
Calculate the inverse of a rigid affine transformation
|
288 |
+
"""
|
289 |
+
res = mat.clone()
|
290 |
+
res[..., :3] = mat[..., :3].transpose(-2, -1)
|
291 |
+
res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1)
|
292 |
+
return res
|
293 |
+
|
294 |
+
|
295 |
+
def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None):
|
296 |
+
if n_bone is None: n_bone = 24
|
297 |
+
if ee is not None:
|
298 |
+
if root_rot:
|
299 |
+
ee.append(0)
|
300 |
+
n_bone_ = n_bone
|
301 |
+
n_bone = len(ee)
|
302 |
+
axis = torch.randn((batch_size, n_bone, 3), device=device)
|
303 |
+
axis /= axis.norm(dim=-1, keepdim=True)
|
304 |
+
if uniform:
|
305 |
+
angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi
|
306 |
+
else:
|
307 |
+
angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor
|
308 |
+
angle.clamp(-np.pi, np.pi)
|
309 |
+
poses = axis * angle
|
310 |
+
if ee is not None:
|
311 |
+
res = torch.zeros((batch_size, n_bone_, 3), device=device)
|
312 |
+
for i, id in enumerate(ee):
|
313 |
+
res[:, id] = poses[:, i]
|
314 |
+
poses = res
|
315 |
+
poses = poses.reshape(batch_size, -1)
|
316 |
+
if not root_rot:
|
317 |
+
poses[..., :3] = 0
|
318 |
+
return poses
|
319 |
+
|
320 |
+
|
321 |
+
def slerp(l, r, t, unit=True):
|
322 |
+
"""
|
323 |
+
:param l: shape = (*, n)
|
324 |
+
:param r: shape = (*, n)
|
325 |
+
:param t: shape = (*)
|
326 |
+
:param unit: If l and h are unit vectors
|
327 |
+
:return:
|
328 |
+
"""
|
329 |
+
eps = 1e-8
|
330 |
+
if not unit:
|
331 |
+
l_n = l / torch.norm(l, dim=-1, keepdim=True)
|
332 |
+
r_n = r / torch.norm(r, dim=-1, keepdim=True)
|
333 |
+
else:
|
334 |
+
l_n = l
|
335 |
+
r_n = r
|
336 |
+
omega = torch.acos((l_n * r_n).sum(dim=-1).clamp(-1, 1))
|
337 |
+
dom = torch.sin(omega)
|
338 |
+
|
339 |
+
flag = dom < eps
|
340 |
+
|
341 |
+
res = torch.empty_like(l_n)
|
342 |
+
t_t = t[flag].unsqueeze(-1)
|
343 |
+
res[flag] = (1 - t_t) * l_n[flag] + t_t * r_n[flag]
|
344 |
+
|
345 |
+
flag = ~ flag
|
346 |
+
|
347 |
+
t_t = t[flag]
|
348 |
+
d_t = dom[flag]
|
349 |
+
va = torch.sin((1 - t_t) * omega[flag]) / d_t
|
350 |
+
vb = torch.sin(t_t * omega[flag]) / d_t
|
351 |
+
res[flag] = (va.unsqueeze(-1) * l_n[flag] + vb.unsqueeze(-1) * r_n[flag])
|
352 |
+
return res
|
353 |
+
|
354 |
+
|
355 |
+
def slerp_quat(l, r, t):
|
356 |
+
"""
|
357 |
+
slerp for unit quaternions
|
358 |
+
:param l: (*, 4) unit quaternion
|
359 |
+
:param r: (*, 4) unit quaternion
|
360 |
+
:param t: (*) scalar between 0 and 1
|
361 |
+
"""
|
362 |
+
t = t.expand(l.shape[:-1])
|
363 |
+
flag = (l * r).sum(dim=-1) >= 0
|
364 |
+
res = torch.empty_like(l)
|
365 |
+
res[flag] = slerp(l[flag], r[flag], t[flag])
|
366 |
+
flag = ~ flag
|
367 |
+
res[flag] = slerp(-l[flag], r[flag], t[flag])
|
368 |
+
return res
|
369 |
+
|
370 |
+
|
371 |
+
# def slerp_6d(l, r, t):
|
372 |
+
# l_q = repr6d2quat(l)
|
373 |
+
# r_q = repr6d2quat(r)
|
374 |
+
# res_q = slerp_quat(l_q, r_q, t)
|
375 |
+
# return quat2repr6d(res_q)
|
376 |
+
|
377 |
+
|
378 |
+
def interpolate_6d(input, size):
|
379 |
+
"""
|
380 |
+
:param input: (batch_size, n_channels, length)
|
381 |
+
:param size: required output size for temporal axis
|
382 |
+
:return:
|
383 |
+
"""
|
384 |
+
batch = input.shape[0]
|
385 |
+
length = input.shape[-1]
|
386 |
+
input = input.reshape((batch, -1, 6, length))
|
387 |
+
input = input.permute(0, 1, 3, 2) # (batch_size, n_joint, length, 6)
|
388 |
+
input_q = repr6d2quat(input)
|
389 |
+
idx = torch.tensor(list(range(size)), device=input_q.device, dtype=torch.float) / size * (length - 1)
|
390 |
+
idx_l = torch.floor(idx)
|
391 |
+
t = idx - idx_l
|
392 |
+
idx_l = idx_l.long()
|
393 |
+
idx_r = idx_l + 1
|
394 |
+
t = t.reshape((1, 1, -1))
|
395 |
+
res_q = slerp_quat(input_q[..., idx_l, :], input_q[..., idx_r, :], t)
|
396 |
+
res = quat2repr6d(res_q) # shape = (batch_size, n_joint, t, 6)
|
397 |
+
res = res.permute(0, 1, 3, 2)
|
398 |
+
res = res.reshape((batch, -1, size))
|
399 |
+
return res
|