H-Liu1997 commited on
Commit
a92b3a0
·
verified ·
1 Parent(s): 5c3e860

Upload 8 files

Browse files
Files changed (8) hide show
  1. AESKConv_240_100.bin +3 -0
  2. __init__.py +0 -0
  3. decoders.py +56 -0
  4. mean_vel_smplxflame_30.npy +3 -0
  5. mertic.py +357 -0
  6. motion_encoder.py +193 -0
  7. skeleton.py +298 -0
  8. skeleton_DME.py +473 -0
AESKConv_240_100.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cd9566b24264f34d44003b3de62cdfd50aa85b7cdde2d369214599023c40f55
3
+ size 17558653
__init__.py ADDED
File without changes
decoders.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is modified from https://github.com/EricGuo5513/TM2T
2
+ # Licensed under: https://github.com/EricGuo5513/TM2T/blob/main/LICENSE
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class VQDecoderV3(nn.Module):
8
+ def __init__(self, args):
9
+ super(VQDecoderV3, self).__init__()
10
+ n_up = args.vae_layer
11
+ channels = []
12
+ for i in range(n_up - 1):
13
+ channels.append(args.vae_length)
14
+ channels.append(args.vae_length)
15
+ channels.append(args.vae_test_dim)
16
+ input_size = args.vae_length
17
+ n_resblk = 2
18
+ assert len(channels) == n_up + 1
19
+ if input_size == channels[0]:
20
+ layers = []
21
+ else:
22
+ layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
23
+
24
+ for i in range(n_resblk):
25
+ layers += [ResBlock(channels[0])]
26
+ # channels = channels
27
+ for i in range(n_up):
28
+ layers += [
29
+ nn.Upsample(scale_factor=2, mode="nearest"),
30
+ nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
31
+ nn.LeakyReLU(0.2, inplace=True),
32
+ ]
33
+ layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
34
+ self.main = nn.Sequential(*layers)
35
+ # self.main.apply(init_weight)
36
+
37
+ def forward(self, inputs):
38
+ inputs = inputs.permute(0, 2, 1)
39
+ outputs = self.main(inputs).permute(0, 2, 1)
40
+ return outputs
41
+
42
+
43
+ class ResBlock(nn.Module):
44
+ def __init__(self, channel):
45
+ super(ResBlock, self).__init__()
46
+ self.model = nn.Sequential(
47
+ nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
48
+ nn.LeakyReLU(0.2, inplace=True),
49
+ nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
50
+ )
51
+
52
+ def forward(self, x):
53
+ residual = x
54
+ out = self.model(x)
55
+ out += residual
56
+ return out
mean_vel_smplxflame_30.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53b5e48f2a7bf78c41a6de6395d6bb4f29018465ca5d0ee2820a2be3eebb7137
3
+ size 348
mertic.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wget
3
+ import math
4
+ import numpy as np
5
+ import librosa
6
+ import librosa.display
7
+ import matplotlib.pyplot as plt
8
+ from scipy.signal import argrelextrema
9
+ from scipy import linalg
10
+ import torch
11
+
12
+ from .motion_encoder import VAESKConv
13
+
14
+ class LVDFace(object):
15
+ def __init__(self):
16
+ self.counter = 0
17
+ self.sum = 0
18
+
19
+ def compute(self, pred_vertices, target_vertices):
20
+ t, c = pred_vertices.shape
21
+ diff_pred = pred_vertices[1:, :] - pred_vertices[:-1, :]
22
+ diff_target = target_vertices[1:, :] - target_vertices[:-1, :]
23
+ loss = np.abs(diff_pred - diff_target)
24
+ loss = np.sum(loss)
25
+ self.counter += t * c
26
+ self.sum += loss
27
+
28
+ def avg(self):
29
+ return self.sum / self.counter
30
+
31
+ def reset(self):
32
+ self.counter = 0
33
+ self.sum = 0
34
+
35
+
36
+ class MSEFace(object):
37
+ def __init__(self):
38
+ self.counter = 0
39
+ self.sum = 0
40
+
41
+ def compute(self, pred_vertices, target_vertices):
42
+ t, c = pred_vertices.shape
43
+ loss = np.square(pred_vertices - target_vertices)
44
+ self.sum += np.sum(loss)
45
+ self.counter += t * c
46
+
47
+ def avg(self):
48
+ if self.counter == 0:
49
+ return 0
50
+ return self.sum / self.counter
51
+
52
+ def reset(self):
53
+ self.counter = 0
54
+ self.sum = 0
55
+
56
+
57
+ class L1div(object):
58
+ def __init__(self):
59
+ self.counter = 0
60
+ self.sum = 0
61
+
62
+ def compute(self, results):
63
+ self.counter += results.shape[0]
64
+ mean = np.mean(results, axis=0)
65
+ sum_l1 = np.sum(np.abs(results - mean), axis=None)
66
+ self.sum += sum_l1
67
+
68
+ def avg(self):
69
+ if self.counter == 0:
70
+ return 0
71
+ return self.sum / self.counter
72
+
73
+ def reset(self):
74
+ self.counter = 0
75
+ self.sum = 0
76
+
77
+
78
+ class SRGR(object):
79
+ def __init__(self, threshold=0.1, joints=47, joint_dim=3):
80
+ self.threshold = threshold
81
+ self.pose_dimes = joints
82
+ self.joint_dim = joint_dim
83
+ self.counter = 0
84
+ self.sum = 0
85
+
86
+ def run(self, results, targets, semantic=None, verbose=False):
87
+ if semantic is None:
88
+ semantic = np.ones(results.shape[0])
89
+ avg_weight = 1.0
90
+ else:
91
+ # srgr == 0.165 when all success, scale range to [0, 1]
92
+ avg_weight = 0.165
93
+ results = results.reshape(-1, self.pose_dimes, self.joint_dim)
94
+ targets = targets.reshape(-1, self.pose_dimes, self.joint_dim)
95
+ semantic = semantic.reshape(-1)
96
+ diff = np.linalg.norm(results - targets, axis=2) # T, J
97
+ if verbose:
98
+ print(diff)
99
+ success = np.where(diff < self.threshold, 1.0, 0.0)
100
+ for i in range(success.shape[0]):
101
+ success[i, :] *= semantic[i] * (1 / avg_weight)
102
+ rate = np.sum(success) / (success.shape[0] * success.shape[1])
103
+ self.counter += success.shape[0]
104
+ self.sum += rate * success.shape[0]
105
+ return rate
106
+
107
+ def avg(self):
108
+ return self.sum / self.counter
109
+
110
+ def reset(self):
111
+ self.counter = 0
112
+ self.sum = 0
113
+
114
+
115
+ class BC(object):
116
+ def __init__(self, download_path=None, sigma=0.3, order=7, upper_body=[3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]):
117
+ self.sigma = sigma
118
+ self.order = order
119
+ self.upper_body = upper_body
120
+ self.pose_data = []
121
+ if download_path is not None:
122
+ os.makedirs(download_path, exist_ok=True)
123
+ model_file_path = os.path.join(download_path, "mean_vel_smplxflame_30.npy")
124
+ if not os.path.exists(model_file_path):
125
+ print(f"Downloading {model_file_path}")
126
+ wget.download(
127
+ "https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/test_sequences/weights/mean_vel_smplxflame_30.npy",
128
+ model_file_path,
129
+ )
130
+ self.mmae = np.load(os.path.join(download_path, "mean_vel_smplxflame_30.npy")) if download_path is not None else None
131
+ self.threshold = 0.10
132
+ self.counter = 0
133
+ self.sum = 0
134
+
135
+ def load_audio(self, wave, t_start=None, t_end=None, without_file=False, sr_audio=16000):
136
+ hop_length = 512
137
+ if without_file:
138
+ y = wave
139
+ else:
140
+ y, sr = librosa.load(wave, sr=sr_audio)
141
+
142
+ short_y = y[t_start:t_end] if t_start is not None else y
143
+ short_y = short_y.astype(np.float32)
144
+ onset_t = librosa.onset.onset_detect(y=short_y, sr=sr_audio, hop_length=hop_length, units="time")
145
+ return onset_t
146
+
147
+ def load_motion(self, pose, t_start, t_end, pose_fps, without_file=False):
148
+ data_each_file = []
149
+ if without_file:
150
+ data_each_file = pose
151
+ else:
152
+ with open(pose, "r") as f:
153
+ for i, line_data in enumerate(f.readlines()):
154
+ if i < 432:
155
+ continue
156
+ line_data_np = np.fromstring(line_data, sep=" ")
157
+ if pose_fps == 15 and i % 2 == 0:
158
+ continue
159
+ data_each_file.append(np.concatenate([line_data_np[30:39], line_data_np[112:121]], 0))
160
+ data_each_file = np.array(data_each_file) # T*165
161
+ # print(data_each_file.shape)
162
+ joints = data_each_file.transpose(1, 0)
163
+ dt = 1 / pose_fps
164
+ init_vel = (joints[:, 1:2] - joints[:, :1]) / dt
165
+ middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt)
166
+ final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
167
+ vel = np.concatenate([init_vel, middle_vel, final_vel], 1).transpose(1, 0).reshape(data_each_file.shape[0], -1, 3)
168
+ # print(vel.shape)
169
+
170
+ if self.mmae is not None:
171
+ vel = np.linalg.norm(vel, axis=2) / self.mmae
172
+ else:
173
+ print("Warning: mmae is not provided, using max value of vel as mmae")
174
+ self.mmae = np.linalg.norm(vel, axis=2).max()
175
+ vel = np.linalg.norm(vel, axis=2) / self.mmae
176
+ # print(vel.shape) # T*J
177
+
178
+ beat_vel_all = []
179
+ for i in range(vel.shape[1]):
180
+ vel_mask = np.where(vel[:, i] > self.threshold)
181
+ beat_vel = argrelextrema(vel[t_start:t_end, i], np.less, order=self.order)
182
+ beat_vel_list = [j for j in beat_vel[0] if j in vel_mask[0]]
183
+ beat_vel_all.append(np.array(beat_vel_list))
184
+ return beat_vel_all
185
+
186
+ def eval_random_pose(self, wave, pose, t_start, t_end, pose_fps, num_random=60):
187
+ onset_raw = self.load_audio(wave, t_start, t_end)
188
+ dur = t_end - t_start
189
+ for i in range(num_random):
190
+ beat_vel_all = self.load_motion(pose, i, i + dur, pose_fps)
191
+ dis_all_b2a = self.compute(onset_raw, beat_vel_all)
192
+ print(f"{i}s: ", dis_all_b2a)
193
+
194
+ @staticmethod
195
+ def plot_onsets(audio, sr, onset_times_1, onset_times_2):
196
+ fig, axarr = plt.subplots(2, 1, figsize=(10, 10), sharex=True)
197
+ librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[0])
198
+ librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[1])
199
+
200
+ for onset in onset_times_1:
201
+ axarr[0].axvline(onset, color="r", linestyle="--", alpha=0.9, label="Onset Method 1")
202
+ axarr[0].legend()
203
+ axarr[0].set(title="Onset Method 1", xlabel="", ylabel="Amplitude")
204
+
205
+ for onset in onset_times_2:
206
+ axarr[1].axvline(onset, color="b", linestyle="-", alpha=0.7, label="Onset Method 2")
207
+ axarr[1].legend()
208
+ axarr[1].set(title="Onset Method 2", xlabel="Time (s)", ylabel="Amplitude")
209
+
210
+ handles, labels = plt.gca().get_legend_handles_labels()
211
+ by_label = dict(zip(labels, handles))
212
+ plt.legend(by_label.values(), by_label.keys())
213
+ plt.title("Audio waveform with Onsets")
214
+ plt.savefig("./onset.png", dpi=500)
215
+
216
+ def audio_beat_vis(self, onset_raw, onset_bt, onset_bt_rms):
217
+ fig, ax = plt.subplots(nrows=4, sharex=True)
218
+ librosa.display.specshow(librosa.amplitude_to_db(self.S, ref=np.max), y_axis="log", x_axis="time", ax=ax[0])
219
+ ax[1].plot(self.times, self.oenv, label="Onset strength")
220
+ ax[1].vlines(librosa.frames_to_time(onset_raw), 0, self.oenv.max(), label="Raw onsets", color="r")
221
+ ax[1].legend()
222
+ ax[2].vlines(librosa.frames_to_time(onset_bt), 0, self.oenv.max(), label="Backtracked", color="r")
223
+ ax[2].legend()
224
+ ax[3].vlines(librosa.frames_to_time(onset_bt_rms), 0, self.oenv.max(), label="Backtracked (RMS)", color="r")
225
+ ax[3].legend()
226
+ fig.savefig("./onset.png", dpi=500)
227
+
228
+ @staticmethod
229
+ def motion_frames2time(vel, offset, pose_fps):
230
+ return vel / pose_fps + offset
231
+
232
+ @staticmethod
233
+ def GAHR(a, b, sigma):
234
+ dis_all_b2a = 0
235
+ for b_each in b:
236
+ l2_min = min(abs(a_each - b_each) for a_each in a)
237
+ dis_all_b2a += math.exp(-(l2_min**2) / (2 * sigma**2))
238
+ return dis_all_b2a / len(b)
239
+
240
+ @staticmethod
241
+ def fix_directed_GAHR(a, b, sigma):
242
+ a = BC.motion_frames2time(a, 0, 30)
243
+ b = BC.motion_frames2time(b, 0, 30)
244
+ a = [0] + a + [len(a) / 30]
245
+ b = [0] + b + [len(b) / 30]
246
+ return BC.GAHR(a, b, sigma)
247
+
248
+ def compute(self, onset_bt_rms, beat_vel, length=1, pose_fps=30):
249
+ avg_dis_all_b2a_list = []
250
+ for its, beat_vel_each in enumerate(beat_vel):
251
+ if its not in self.upper_body:
252
+ continue
253
+ if beat_vel_each.size == 0:
254
+ avg_dis_all_b2a_list.append(0)
255
+ continue
256
+ pose_bt = self.motion_frames2time(beat_vel_each, 0, pose_fps)
257
+ avg_dis_all_b2a_list.append(self.GAHR(pose_bt, onset_bt_rms, self.sigma))
258
+ self.sum += (sum(avg_dis_all_b2a_list) / len(self.upper_body)) * length
259
+ self.counter += length
260
+
261
+ def avg(self):
262
+ return self.sum / self.counter
263
+
264
+ def reset(self):
265
+ self.counter = 0
266
+ self.sum = 0
267
+
268
+
269
+ class Arg(object):
270
+ def __init__(self):
271
+ self.vae_length = 240
272
+ self.vae_test_dim = 330
273
+ self.vae_test_len = 32
274
+ self.vae_layer = 4
275
+ self.vae_test_stride = 20
276
+ self.vae_grow = [1, 1, 2, 1]
277
+ self.variational = False
278
+
279
+
280
+ class FGD(object):
281
+ def __init__(self, download_path="./emage/", device="cuda"):
282
+ if download_path is not None:
283
+ os.makedirs(download_path, exist_ok=True)
284
+ model_file_path = os.path.join(download_path, "AESKConv_240_100.bin")
285
+ smplx_model_dir = os.path.join(download_path, "smplx_models", "smplx")
286
+ smplx_model_file_path = os.path.join(smplx_model_dir, "SMPLX_NEUTRAL_2020.npz")
287
+ if not os.path.exists(model_file_path):
288
+ print(f"Downloading {model_file_path}")
289
+ wget.download(
290
+ "https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/test_sequences/weights/AESKConv_240_100.bin",
291
+ model_file_path,
292
+ )
293
+
294
+ os.makedirs(smplx_model_dir, exist_ok=True)
295
+ if not os.path.exists(smplx_model_file_path):
296
+ print(f"Downloading {smplx_model_file_path}")
297
+ wget.download(
298
+ "https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz",
299
+ smplx_model_file_path,
300
+ )
301
+ args = Arg()
302
+ self.eval_model = VAESKConv(args, model_save_path=download_path) # Assumes LocalEncoder is defined elsewhere
303
+ old_stat = torch.load(download_path + "AESKConv_240_100.bin")["model_state"]
304
+ new_stat = {}
305
+ for k, v in old_stat.items():
306
+ # If 'module.' is in the key, remove it
307
+ new_key = k.replace("module.", "") if "module." in k else k
308
+ new_stat[new_key] = v
309
+ self.eval_model.load_state_dict(new_stat)
310
+
311
+ self.eval_model.eval()
312
+ if torch.cuda.is_available():
313
+ self.eval_model.to(device)
314
+
315
+ self.pred_features = []
316
+ self.target_features = []
317
+ self.device = device
318
+
319
+ def reset(self):
320
+ self.pred_features = []
321
+ self.target_features = []
322
+
323
+ def get_feature(self, data):
324
+ assert len(data.shape) == 3
325
+ if data.shape[1] % 32 != 0:
326
+ drop_len = data.shape[1] % 32
327
+ data = data[:, :-drop_len]
328
+ # print(data.shape)
329
+ with torch.no_grad():
330
+ if torch.cuda.is_available():
331
+ data = data.to(self.device)
332
+ feature = self.eval_model.map2latent(data).cpu().numpy()
333
+ # print(feature.shape)
334
+ return feature
335
+
336
+ def update(self, pred, target):
337
+ self.pred_features.append(self.get_feature(pred))
338
+ self.target_features.append(self.get_feature(target))
339
+
340
+ def compute(self):
341
+ pred_features = np.concatenate([x.reshape(-1, x.shape[-1]) for x in self.pred_features], axis=0)
342
+ target_features = np.concatenate([x.reshape(-1, x.shape[-1]) for x in self.target_features], axis=0)
343
+ # print(pred_features.shape, target_features.shape)
344
+ return self.frechet_distance(pred_features, target_features)
345
+
346
+ @staticmethod
347
+ def frechet_distance(samples_A, samples_B, eps=1e-6):
348
+ mu1 = np.mean(samples_A, axis=0)
349
+ sigma1 = np.cov(samples_A, rowvar=False)
350
+ mu2 = np.mean(samples_B, axis=0)
351
+ sigma2 = np.cov(samples_B, rowvar=False)
352
+ diff = mu1 - mu2
353
+ offset = np.eye(sigma1.shape[0]) * eps
354
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
355
+ if np.iscomplexobj(covmean):
356
+ covmean = covmean.real
357
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
motion_encoder.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ from .skeleton_DME import SkeletonConv, SkeletonPool, find_neighbor, build_edge_topology
5
+ from .skeleton import SkeletonResidual
6
+ from .decoders import VQDecoderV3
7
+
8
+
9
+ class LocalEncoder(nn.Module):
10
+ def __init__(self, args, topology):
11
+ super(LocalEncoder, self).__init__()
12
+ args.channel_base = 6
13
+ args.activation = "tanh"
14
+ args.use_residual_blocks = True
15
+ args.z_dim = 1024
16
+ args.temporal_scale = 8
17
+ args.kernel_size = 4
18
+ args.num_layers = args.vae_layer
19
+ args.skeleton_dist = 2
20
+ args.extra_conv = 0
21
+ # check how to reflect in 1d
22
+ args.padding_mode = "constant"
23
+ args.skeleton_pool = "mean"
24
+ args.upsampling = "linear"
25
+
26
+ self.topologies = [topology]
27
+ self.channel_base = [args.channel_base]
28
+
29
+ self.channel_list = []
30
+ self.edge_num = [len(topology)]
31
+ self.pooling_list = []
32
+ self.layers = nn.ModuleList()
33
+ self.args = args
34
+ # self.convs = []
35
+
36
+ kernel_size = args.kernel_size
37
+ kernel_even = False if kernel_size % 2 else True
38
+ padding = (kernel_size - 1) // 2
39
+ bias = True
40
+ self.grow = args.vae_grow
41
+ for i in range(args.num_layers):
42
+ self.channel_base.append(self.channel_base[-1] * self.grow[i])
43
+
44
+ for i in range(args.num_layers):
45
+ seq = []
46
+ neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist)
47
+ in_channels = self.channel_base[i] * self.edge_num[i]
48
+ out_channels = self.channel_base[i + 1] * self.edge_num[i]
49
+ if i == 0:
50
+ self.channel_list.append(in_channels)
51
+ self.channel_list.append(out_channels)
52
+ last_pool = True if i == args.num_layers - 1 else False
53
+
54
+ # (T, J, D) => (T, J', D)
55
+ pool = SkeletonPool(
56
+ edges=self.topologies[i],
57
+ pooling_mode=args.skeleton_pool,
58
+ channels_per_edge=out_channels // len(neighbour_list),
59
+ last_pool=last_pool,
60
+ )
61
+
62
+ if args.use_residual_blocks:
63
+ # (T, J, D) => (T/2, J', 2D)
64
+ seq.append(
65
+ SkeletonResidual(
66
+ self.topologies[i],
67
+ neighbour_list,
68
+ joint_num=self.edge_num[i],
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ kernel_size=kernel_size,
72
+ stride=2,
73
+ padding=padding,
74
+ padding_mode=args.padding_mode,
75
+ bias=bias,
76
+ extra_conv=args.extra_conv,
77
+ pooling_mode=args.skeleton_pool,
78
+ activation=args.activation,
79
+ last_pool=last_pool,
80
+ )
81
+ )
82
+ else:
83
+ for _ in range(args.extra_conv):
84
+ # (T, J, D) => (T, J, D)
85
+ seq.append(
86
+ SkeletonConv(
87
+ neighbour_list,
88
+ in_channels=in_channels,
89
+ out_channels=in_channels,
90
+ joint_num=self.edge_num[i],
91
+ kernel_size=kernel_size - 1 if kernel_even else kernel_size,
92
+ stride=1,
93
+ padding=padding,
94
+ padding_mode=args.padding_mode,
95
+ bias=bias,
96
+ )
97
+ )
98
+ seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
99
+ # (T, J, D) => (T/2, J, 2D)
100
+ seq.append(
101
+ SkeletonConv(
102
+ neighbour_list,
103
+ in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ joint_num=self.edge_num[i],
106
+ kernel_size=kernel_size,
107
+ stride=2,
108
+ padding=padding,
109
+ padding_mode=args.padding_mode,
110
+ bias=bias,
111
+ add_offset=False,
112
+ in_offset_channel=3 * self.channel_base[i] // self.channel_base[0],
113
+ )
114
+ )
115
+ # self.convs.append(seq[-1])
116
+
117
+ seq.append(pool)
118
+ seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
119
+ self.layers.append(nn.Sequential(*seq))
120
+
121
+ self.topologies.append(pool.new_edges)
122
+ self.pooling_list.append(pool.pooling_list)
123
+ self.edge_num.append(len(self.topologies[-1]))
124
+
125
+ # in_features = self.channel_base[-1] * len(self.pooling_list[-1])
126
+ # in_features *= int(args.temporal_scale / 2)
127
+ # self.reduce = nn.Linear(in_features, args.z_dim)
128
+ # self.mu = nn.Linear(in_features, args.z_dim)
129
+ # self.logvar = nn.Linear(in_features, args.z_dim)
130
+
131
+ def forward(self, input):
132
+ # bs, n, c = input.shape[0], input.shape[1], input.shape[2]
133
+ output = input.permute(0, 2, 1) # input.reshape(bs, n, -1, 6)
134
+ for layer in self.layers:
135
+ output = layer(output)
136
+ # output = output.view(output.shape[0], -1)
137
+ output = output.permute(0, 2, 1)
138
+ return output
139
+
140
+
141
+ def reparameterize(mu, logvar):
142
+ std = torch.exp(0.5 * logvar)
143
+ eps = torch.randn_like(std)
144
+ return mu + eps * std
145
+
146
+
147
+ class VAEConv(nn.Module):
148
+ def __init__(self, args):
149
+ super(VAEConv, self).__init__()
150
+ # self.encoder = VQEncoderV3(args)
151
+ # self.decoder = VQDecoderV3(args)
152
+ self.fc_mu = nn.Linear(args.vae_length, args.vae_length)
153
+ self.fc_logvar = nn.Linear(args.vae_length, args.vae_length)
154
+ self.variational = args.variational
155
+
156
+ def forward(self, inputs):
157
+ pre_latent = self.encoder(inputs)
158
+ mu, logvar = None, None
159
+ if self.variational:
160
+ mu = self.fc_mu(pre_latent)
161
+ logvar = self.fc_logvar(pre_latent)
162
+ pre_latent = reparameterize(mu, logvar)
163
+ rec_pose = self.decoder(pre_latent)
164
+ return {
165
+ "poses_feat": pre_latent,
166
+ "rec_pose": rec_pose,
167
+ "pose_mu": mu,
168
+ "pose_logvar": logvar,
169
+ }
170
+
171
+ def map2latent(self, inputs):
172
+ pre_latent = self.encoder(inputs)
173
+ if self.variational:
174
+ mu = self.fc_mu(pre_latent)
175
+ logvar = self.fc_logvar(pre_latent)
176
+ pre_latent = reparameterize(mu, logvar)
177
+ return pre_latent
178
+
179
+ def decode(self, pre_latent):
180
+ rec_pose = self.decoder(pre_latent)
181
+ return rec_pose
182
+
183
+
184
+ class VAESKConv(VAEConv):
185
+ def __init__(self, args, model_save_path="./emage/"):
186
+ # args = args()
187
+ super(VAESKConv, self).__init__(args)
188
+ smpl_fname = model_save_path + "smplx_models/smplx/SMPLX_NEUTRAL_2020.npz"
189
+ smpl_data = np.load(smpl_fname, encoding="latin1")
190
+ parents = smpl_data["kintree_table"][0].astype(np.int32)
191
+ edges = build_edge_topology(parents)
192
+ self.encoder = LocalEncoder(args, edges)
193
+ self.decoder = VQDecoderV3(args)
skeleton.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .skeleton_DME import SkeletonConv, SkeletonPool, SkeletonUnpool
5
+
6
+
7
+ def calc_node_depth(topology):
8
+ def dfs(node, topology):
9
+ if topology[node] < 0:
10
+ return 0
11
+ return 1 + dfs(topology[node], topology)
12
+
13
+ depth = []
14
+ for i in range(len(topology)):
15
+ depth.append(dfs(i, topology))
16
+
17
+ return depth
18
+
19
+
20
+ def residual_ratio(k):
21
+ return 1 / (k + 1)
22
+
23
+
24
+ class Affine(nn.Module):
25
+ def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0):
26
+ super(Affine, self).__init__()
27
+ if scale:
28
+ self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init)
29
+ else:
30
+ self.register_parameter("scale", None)
31
+
32
+ if bias:
33
+ self.bias = nn.Parameter(torch.zeros(num_parameters))
34
+ else:
35
+ self.register_parameter("bias", None)
36
+
37
+ def forward(self, input):
38
+ output = input
39
+ if self.scale is not None:
40
+ scale = self.scale.unsqueeze(0)
41
+ while scale.dim() < input.dim():
42
+ scale = scale.unsqueeze(2)
43
+ output = output.mul(scale)
44
+
45
+ if self.bias is not None:
46
+ bias = self.bias.unsqueeze(0)
47
+ while bias.dim() < input.dim():
48
+ bias = bias.unsqueeze(2)
49
+ output += bias
50
+
51
+ return output
52
+
53
+
54
+ class BatchStatistics(nn.Module):
55
+ def __init__(self, affine=-1):
56
+ super(BatchStatistics, self).__init__()
57
+ self.affine = nn.Sequential() if affine == -1 else Affine(affine)
58
+ self.loss = 0
59
+
60
+ def clear_loss(self):
61
+ self.loss = 0
62
+
63
+ def compute_loss(self, input):
64
+ input_flat = input.view(input.size(1), input.numel() // input.size(1))
65
+ mu = input_flat.mean(1)
66
+ logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log()
67
+
68
+ self.loss = mu.pow(2).mean() + logvar.pow(2).mean()
69
+
70
+ def forward(self, input):
71
+ self.compute_loss(input)
72
+ return self.affine(input)
73
+
74
+
75
+ class ResidualBlock(nn.Module):
76
+ def __init__(
77
+ self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False
78
+ ):
79
+ super(ResidualBlock, self).__init__()
80
+
81
+ self.residual_ratio = residual_ratio
82
+ self.shortcut_ratio = 1 - residual_ratio
83
+
84
+ residual = []
85
+ residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))
86
+ if batch_statistics:
87
+ residual.append(BatchStatistics(out_channels))
88
+ if not last_layer:
89
+ residual.append(nn.PReLU() if activation == "relu" else nn.Tanh())
90
+ self.residual = nn.Sequential(*residual)
91
+
92
+ self.shortcut = nn.Sequential(
93
+ nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(),
94
+ nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
95
+ BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential(),
96
+ )
97
+
98
+ def forward(self, input):
99
+ return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
100
+
101
+
102
+ class ResidualBlockTranspose(nn.Module):
103
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation):
104
+ super(ResidualBlockTranspose, self).__init__()
105
+
106
+ self.residual_ratio = residual_ratio
107
+ self.shortcut_ratio = 1 - residual_ratio
108
+
109
+ self.residual = nn.Sequential(
110
+ nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), nn.PReLU() if activation == "relu" else nn.Tanh()
111
+ )
112
+
113
+ self.shortcut = nn.Sequential(
114
+ nn.Upsample(scale_factor=2, mode="linear", align_corners=False) if stride == 2 else nn.Sequential(),
115
+ nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
116
+ )
117
+
118
+ def forward(self, input):
119
+ return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
120
+
121
+
122
+ class SkeletonResidual(nn.Module):
123
+ def __init__(
124
+ self,
125
+ topology,
126
+ neighbour_list,
127
+ joint_num,
128
+ in_channels,
129
+ out_channels,
130
+ kernel_size,
131
+ stride,
132
+ padding,
133
+ padding_mode,
134
+ bias,
135
+ extra_conv,
136
+ pooling_mode,
137
+ activation,
138
+ last_pool,
139
+ ):
140
+ super(SkeletonResidual, self).__init__()
141
+
142
+ kernel_even = False if kernel_size % 2 else True
143
+
144
+ seq = []
145
+ for _ in range(extra_conv):
146
+ # (T, J, D) => (T, J, D)
147
+ seq.append(
148
+ SkeletonConv(
149
+ neighbour_list,
150
+ in_channels=in_channels,
151
+ out_channels=in_channels,
152
+ joint_num=joint_num,
153
+ kernel_size=kernel_size - 1 if kernel_even else kernel_size,
154
+ stride=1,
155
+ padding=padding,
156
+ padding_mode=padding_mode,
157
+ bias=bias,
158
+ )
159
+ )
160
+ seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
161
+ # (T, J, D) => (T/2, J, 2D)
162
+ seq.append(
163
+ SkeletonConv(
164
+ neighbour_list,
165
+ in_channels=in_channels,
166
+ out_channels=out_channels,
167
+ joint_num=joint_num,
168
+ kernel_size=kernel_size,
169
+ stride=stride,
170
+ padding=padding,
171
+ padding_mode=padding_mode,
172
+ bias=bias,
173
+ add_offset=False,
174
+ )
175
+ )
176
+ seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!!
177
+ self.residual = nn.Sequential(*seq)
178
+
179
+ # (T, J, D) => (T/2, J, 2D)
180
+ self.shortcut = SkeletonConv(
181
+ neighbour_list,
182
+ in_channels=in_channels,
183
+ out_channels=out_channels,
184
+ joint_num=joint_num,
185
+ kernel_size=1,
186
+ stride=stride,
187
+ padding=0,
188
+ bias=True,
189
+ add_offset=False,
190
+ )
191
+
192
+ seq = []
193
+ # (T/2, J, 2D) => (T/2, J', 2D)
194
+ pool = SkeletonPool(
195
+ edges=topology, pooling_mode=pooling_mode, channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool
196
+ )
197
+ if len(pool.pooling_list) != pool.edge_num:
198
+ seq.append(pool)
199
+ seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
200
+ self.common = nn.Sequential(*seq)
201
+
202
+ def forward(self, input):
203
+ output = self.residual(input) + self.shortcut(input)
204
+
205
+ return self.common(output)
206
+
207
+
208
+ class SkeletonResidualTranspose(nn.Module):
209
+ def __init__(
210
+ self,
211
+ neighbour_list,
212
+ joint_num,
213
+ in_channels,
214
+ out_channels,
215
+ kernel_size,
216
+ padding,
217
+ padding_mode,
218
+ bias,
219
+ extra_conv,
220
+ pooling_list,
221
+ upsampling,
222
+ activation,
223
+ last_layer,
224
+ ):
225
+ super(SkeletonResidualTranspose, self).__init__()
226
+
227
+ kernel_even = False if kernel_size % 2 else True
228
+
229
+ seq = []
230
+ # (T, J, D) => (2T, J, D)
231
+ if upsampling is not None:
232
+ seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False))
233
+ # (2T, J, D) => (2T, J', D)
234
+ unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list))
235
+ if unpool.input_edge_num != unpool.output_edge_num:
236
+ seq.append(unpool)
237
+ self.common = nn.Sequential(*seq)
238
+
239
+ seq = []
240
+ for _ in range(extra_conv):
241
+ # (2T, J', D) => (2T, J', D)
242
+ seq.append(
243
+ SkeletonConv(
244
+ neighbour_list,
245
+ in_channels=in_channels,
246
+ out_channels=in_channels,
247
+ joint_num=joint_num,
248
+ kernel_size=kernel_size - 1 if kernel_even else kernel_size,
249
+ stride=1,
250
+ padding=padding,
251
+ padding_mode=padding_mode,
252
+ bias=bias,
253
+ )
254
+ )
255
+ seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
256
+ # (2T, J', D) => (2T, J', D/2)
257
+ seq.append(
258
+ SkeletonConv(
259
+ neighbour_list,
260
+ in_channels=in_channels,
261
+ out_channels=out_channels,
262
+ joint_num=joint_num,
263
+ kernel_size=kernel_size - 1 if kernel_even else kernel_size,
264
+ stride=1,
265
+ padding=padding,
266
+ padding_mode=padding_mode,
267
+ bias=bias,
268
+ add_offset=False,
269
+ )
270
+ )
271
+ self.residual = nn.Sequential(*seq)
272
+
273
+ # (2T, J', D) => (2T, J', D/2)
274
+ self.shortcut = SkeletonConv(
275
+ neighbour_list,
276
+ in_channels=in_channels,
277
+ out_channels=out_channels,
278
+ joint_num=joint_num,
279
+ kernel_size=1,
280
+ stride=1,
281
+ padding=0,
282
+ bias=True,
283
+ add_offset=False,
284
+ )
285
+
286
+ if activation == "relu":
287
+ self.activation = nn.PReLU() if not last_layer else None
288
+ else:
289
+ self.activation = nn.Tanh() if not last_layer else None
290
+
291
+ def forward(self, input):
292
+ output = self.common(input)
293
+ output = self.residual(output) + self.shortcut(output)
294
+
295
+ if self.activation is not None:
296
+ return self.activation(output)
297
+ else:
298
+ return output
skeleton_DME.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is modified from https://github.com/DeepMotionEditing/deep-motion-editing
2
+ # Licensed under:
3
+ """
4
+ Copyright (c) 2020, Kfir Aberman, Peizhuo Li, Yijia Weng, Dani Lischinski, Olga Sorkine-Hornung, Daniel Cohen-Or and Baoquan Chen.
5
+ All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this
11
+ list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
21
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
+ """
28
+
29
+ import math
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+
35
+
36
+ class SkeletonConv(nn.Module):
37
+ def __init__(
38
+ self,
39
+ neighbour_list,
40
+ in_channels,
41
+ out_channels,
42
+ kernel_size,
43
+ joint_num,
44
+ stride=1,
45
+ padding=0,
46
+ bias=True,
47
+ padding_mode="zeros",
48
+ add_offset=False,
49
+ in_offset_channel=0,
50
+ ):
51
+ self.in_channels_per_joint = in_channels // joint_num
52
+ self.out_channels_per_joint = out_channels // joint_num
53
+ if in_channels % joint_num != 0 or out_channels % joint_num != 0:
54
+ raise Exception("BAD")
55
+ super(SkeletonConv, self).__init__()
56
+
57
+ if padding_mode == "zeros":
58
+ padding_mode = "constant"
59
+ if padding_mode == "reflection":
60
+ padding_mode = "reflect"
61
+
62
+ self.expanded_neighbour_list = []
63
+ self.expanded_neighbour_list_offset = []
64
+ self.neighbour_list = neighbour_list
65
+ self.add_offset = add_offset
66
+ self.joint_num = joint_num
67
+
68
+ self.stride = stride
69
+ self.dilation = 1
70
+ self.groups = 1
71
+ self.padding = padding
72
+ self.padding_mode = padding_mode
73
+ self._padding_repeated_twice = (padding, padding)
74
+
75
+ for neighbour in neighbour_list:
76
+ expanded = []
77
+ for k in neighbour:
78
+ for i in range(self.in_channels_per_joint):
79
+ expanded.append(k * self.in_channels_per_joint + i)
80
+ self.expanded_neighbour_list.append(expanded)
81
+
82
+ if self.add_offset:
83
+ self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
84
+
85
+ for neighbour in neighbour_list:
86
+ expanded = []
87
+ for k in neighbour:
88
+ for i in range(add_offset):
89
+ expanded.append(k * in_offset_channel + i)
90
+ self.expanded_neighbour_list_offset.append(expanded)
91
+
92
+ self.weight = torch.zeros(out_channels, in_channels, kernel_size)
93
+ if bias:
94
+ self.bias = torch.zeros(out_channels)
95
+ else:
96
+ self.register_parameter("bias", None)
97
+
98
+ self.mask = torch.zeros_like(self.weight)
99
+ for i, neighbour in enumerate(self.expanded_neighbour_list):
100
+ self.mask[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
101
+ self.mask = nn.Parameter(self.mask, requires_grad=False)
102
+
103
+ self.description = (
104
+ "SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, "
105
+ "joint_num={}, stride={}, padding={}, bias={})".format(
106
+ in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
107
+ )
108
+ )
109
+
110
+ self.reset_parameters()
111
+
112
+ def reset_parameters(self):
113
+ for i, neighbour in enumerate(self.expanded_neighbour_list):
114
+ """ Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """
115
+ tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...])
116
+ nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
117
+ self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...] = tmp
118
+ if self.bias is not None:
119
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
120
+ self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...]
121
+ )
122
+ bound = 1 / math.sqrt(fan_in)
123
+ tmp = torch.zeros_like(self.bias[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1)])
124
+ nn.init.uniform_(tmp, -bound, bound)
125
+ self.bias[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1)] = tmp
126
+
127
+ self.weight = nn.Parameter(self.weight)
128
+ if self.bias is not None:
129
+ self.bias = nn.Parameter(self.bias)
130
+
131
+ def set_offset(self, offset):
132
+ if not self.add_offset:
133
+ raise Exception("Wrong Combination of Parameters")
134
+ self.offset = offset.reshape(offset.shape[0], -1)
135
+
136
+ def forward(self, input):
137
+ # print('SkeletonConv')
138
+ weight_masked = self.weight * self.mask
139
+ # print(f'input: {input.size()}')
140
+ res = F.conv1d(
141
+ F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
142
+ weight_masked,
143
+ self.bias,
144
+ self.stride,
145
+ 0,
146
+ self.dilation,
147
+ self.groups,
148
+ )
149
+
150
+ if self.add_offset:
151
+ offset_res = self.offset_enc(self.offset)
152
+ offset_res = offset_res.reshape(offset_res.shape + (1,))
153
+ res += offset_res / 100
154
+ # print(f'res: {res.size()}')
155
+ return res
156
+
157
+
158
+ class SkeletonLinear(nn.Module):
159
+ def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
160
+ super(SkeletonLinear, self).__init__()
161
+ self.neighbour_list = neighbour_list
162
+ self.in_channels = in_channels
163
+ self.out_channels = out_channels
164
+ self.in_channels_per_joint = in_channels // len(neighbour_list)
165
+ self.out_channels_per_joint = out_channels // len(neighbour_list)
166
+ self.extra_dim1 = extra_dim1
167
+ self.expanded_neighbour_list = []
168
+
169
+ for neighbour in neighbour_list:
170
+ expanded = []
171
+ for k in neighbour:
172
+ for i in range(self.in_channels_per_joint):
173
+ expanded.append(k * self.in_channels_per_joint + i)
174
+ self.expanded_neighbour_list.append(expanded)
175
+
176
+ self.weight = torch.zeros(out_channels, in_channels)
177
+ self.mask = torch.zeros(out_channels, in_channels)
178
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
179
+
180
+ self.reset_parameters()
181
+
182
+ def reset_parameters(self):
183
+ for i, neighbour in enumerate(self.expanded_neighbour_list):
184
+ tmp = torch.zeros_like(self.weight[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour])
185
+ self.mask[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour] = 1
186
+ nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
187
+ self.weight[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour] = tmp
188
+
189
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
190
+ bound = 1 / math.sqrt(fan_in)
191
+ nn.init.uniform_(self.bias, -bound, bound)
192
+
193
+ self.weight = nn.Parameter(self.weight)
194
+ self.mask = nn.Parameter(self.mask, requires_grad=False)
195
+
196
+ def forward(self, input):
197
+ input = input.reshape(input.shape[0], -1)
198
+ weight_masked = self.weight * self.mask
199
+ res = F.linear(input, weight_masked, self.bias)
200
+ if self.extra_dim1:
201
+ res = res.reshape(res.shape + (1,))
202
+ return res
203
+
204
+
205
+ class SkeletonPool(nn.Module):
206
+ def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
207
+ super(SkeletonPool, self).__init__()
208
+
209
+ if pooling_mode != "mean":
210
+ raise Exception("Unimplemented pooling mode in matrix_implementation")
211
+
212
+ self.channels_per_edge = channels_per_edge
213
+ self.pooling_mode = pooling_mode
214
+ self.edge_num = len(edges)
215
+ # self.edge_num = len(edges) + 1
216
+ self.seq_list = []
217
+ self.pooling_list = []
218
+ self.new_edges = []
219
+ degree = [0] * 100 # each element represents the degree of the corresponding joint
220
+
221
+ for edge in edges:
222
+ degree[edge[0]] += 1
223
+ degree[edge[1]] += 1
224
+
225
+ # seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2.
226
+ def find_seq(j, seq):
227
+ nonlocal self, degree, edges
228
+
229
+ if degree[j] > 2 and j != 0:
230
+ self.seq_list.append(seq)
231
+ seq = []
232
+
233
+ if degree[j] == 1:
234
+ self.seq_list.append(seq)
235
+ return
236
+
237
+ for idx, edge in enumerate(edges):
238
+ if edge[0] == j:
239
+ find_seq(edge[1], seq + [idx])
240
+
241
+ find_seq(0, [])
242
+ # print(f'self.seq_list: {self.seq_list}')
243
+
244
+ for seq in self.seq_list:
245
+ if last_pool:
246
+ self.pooling_list.append(seq)
247
+ continue
248
+ if len(seq) % 2 == 1:
249
+ self.pooling_list.append([seq[0]])
250
+ self.new_edges.append(edges[seq[0]])
251
+ seq = seq[1:]
252
+ for i in range(0, len(seq), 2):
253
+ self.pooling_list.append([seq[i], seq[i + 1]])
254
+ self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
255
+ # print(f'self.pooling_list: {self.pooling_list}')
256
+ # print(f'self.new_egdes: {self.new_edges}')
257
+
258
+ # add global position
259
+ # self.pooling_list.append([self.edge_num - 1])
260
+
261
+ self.description = "SkeletonPool(in_edge_num={}, out_edge_num={})".format(len(edges), len(self.pooling_list))
262
+
263
+ self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
264
+
265
+ for i, pair in enumerate(self.pooling_list):
266
+ for j in pair:
267
+ for c in range(channels_per_edge):
268
+ self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
269
+
270
+ self.weight = nn.Parameter(self.weight, requires_grad=False)
271
+
272
+ def forward(self, input: torch.Tensor):
273
+ # print('SkeletonPool')
274
+ # print(f'input: {input.size()}')
275
+ # print(f'self.weight: {self.weight.size()}')
276
+ return torch.matmul(self.weight, input)
277
+
278
+
279
+ class SkeletonUnpool(nn.Module):
280
+ def __init__(self, pooling_list, channels_per_edge):
281
+ super(SkeletonUnpool, self).__init__()
282
+ self.pooling_list = pooling_list
283
+ self.input_edge_num = len(pooling_list)
284
+ self.output_edge_num = 0
285
+ self.channels_per_edge = channels_per_edge
286
+ for t in self.pooling_list:
287
+ self.output_edge_num += len(t)
288
+
289
+ self.description = "SkeletonUnpool(in_edge_num={}, out_edge_num={})".format(
290
+ self.input_edge_num,
291
+ self.output_edge_num,
292
+ )
293
+
294
+ self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge)
295
+
296
+ for i, pair in enumerate(self.pooling_list):
297
+ for j in pair:
298
+ for c in range(channels_per_edge):
299
+ self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
300
+
301
+ self.weight = nn.Parameter(self.weight)
302
+ self.weight.requires_grad_(False)
303
+
304
+ def forward(self, input: torch.Tensor):
305
+ # print('SkeletonUnpool')
306
+ # print(f'input: {input.size()}')
307
+ # print(f'self.weight: {self.weight.size()}')
308
+ return torch.matmul(self.weight, input)
309
+
310
+
311
+ """
312
+ Helper functions for skeleton operation
313
+ """
314
+
315
+
316
+ def dfs(x, fa, vis, dist):
317
+ vis[x] = 1
318
+ for y in range(len(fa)):
319
+ if (fa[y] == x or fa[x] == y) and vis[y] == 0:
320
+ dist[y] = dist[x] + 1
321
+ dfs(y, fa, vis, dist)
322
+
323
+
324
+ """
325
+ def find_neighbor_joint(fa, threshold):
326
+ neighbor_list = [[]]
327
+ for x in range(1, len(fa)):
328
+ vis = [0 for _ in range(len(fa))]
329
+ dist = [0 for _ in range(len(fa))]
330
+ dist[0] = 10000
331
+ dfs(x, fa, vis, dist)
332
+ neighbor = []
333
+ for j in range(1, len(fa)):
334
+ if dist[j] <= threshold:
335
+ neighbor.append(j)
336
+ neighbor_list.append(neighbor)
337
+
338
+ neighbor = [0]
339
+ for i, x in enumerate(neighbor_list):
340
+ if i == 0: continue
341
+ if 1 in x:
342
+ neighbor.append(i)
343
+ neighbor_list[i] = [0] + neighbor_list[i]
344
+ neighbor_list[0] = neighbor
345
+ return neighbor_list
346
+
347
+
348
+ def build_edge_topology(topology, offset):
349
+ # get all edges (pa, child, offset)
350
+ edges = []
351
+ joint_num = len(topology)
352
+ for i in range(1, joint_num):
353
+ edges.append((topology[i], i, offset[i]))
354
+ return edges
355
+ """
356
+
357
+
358
+ def build_edge_topology(topology):
359
+ # get all edges (pa, child)
360
+ edges = []
361
+ joint_num = len(topology)
362
+ edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint
363
+ for i in range(1, joint_num):
364
+ edges.append((topology[i], i))
365
+ return edges
366
+
367
+
368
+ def build_joint_topology(edges, origin_names):
369
+ parent = []
370
+ offset = []
371
+ names = []
372
+ edge2joint = []
373
+ joint_from_edge = [] # -1 means virtual joint
374
+ joint_cnt = 0
375
+ out_degree = [0] * (len(edges) + 10)
376
+ for edge in edges:
377
+ out_degree[edge[0]] += 1
378
+
379
+ # add root joint
380
+ joint_from_edge.append(-1)
381
+ parent.append(0)
382
+ offset.append(np.array([0, 0, 0]))
383
+ names.append(origin_names[0])
384
+ joint_cnt += 1
385
+
386
+ def make_topology(edge_idx, pa):
387
+ nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt
388
+ edge = edges[edge_idx]
389
+ if out_degree[edge[0]] > 1:
390
+ parent.append(pa)
391
+ offset.append(np.array([0, 0, 0]))
392
+ names.append(origin_names[edge[1]] + "_virtual")
393
+ edge2joint.append(-1)
394
+ pa = joint_cnt
395
+ joint_cnt += 1
396
+
397
+ parent.append(pa)
398
+ offset.append(edge[2])
399
+ names.append(origin_names[edge[1]])
400
+ edge2joint.append(edge_idx)
401
+ pa = joint_cnt
402
+ joint_cnt += 1
403
+
404
+ for idx, e in enumerate(edges):
405
+ if e[0] == edge[1]:
406
+ make_topology(idx, pa)
407
+
408
+ for idx, e in enumerate(edges):
409
+ if e[0] == 0:
410
+ make_topology(idx, 0)
411
+
412
+ return parent, offset, names, edge2joint
413
+
414
+
415
+ def calc_edge_mat(edges):
416
+ edge_num = len(edges)
417
+ # edge_mat[i][j] = distance between edge(i) and edge(j)
418
+ edge_mat = [[100000] * edge_num for _ in range(edge_num)]
419
+ for i in range(edge_num):
420
+ edge_mat[i][i] = 0
421
+
422
+ # initialize edge_mat with direct neighbor
423
+ for i, a in enumerate(edges):
424
+ for j, b in enumerate(edges):
425
+ link = 0
426
+ for x in range(2):
427
+ for y in range(2):
428
+ if a[x] == b[y]:
429
+ link = 1
430
+ if link:
431
+ edge_mat[i][j] = 1
432
+
433
+ # calculate all the pairs distance
434
+ for k in range(edge_num):
435
+ for i in range(edge_num):
436
+ for j in range(edge_num):
437
+ edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j])
438
+ return edge_mat
439
+
440
+
441
+ def find_neighbor(edges, d):
442
+ """
443
+ Args:
444
+ edges: The list contains N elements, each element represents (parent, child).
445
+ d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1).
446
+
447
+ Returns:
448
+ The list contains N elements, each element is a list of edge indices whose distance <= d.
449
+ """
450
+ edge_mat = calc_edge_mat(edges)
451
+ neighbor_list = []
452
+ edge_num = len(edge_mat)
453
+ for i in range(edge_num):
454
+ neighbor = []
455
+ for j in range(edge_num):
456
+ if edge_mat[i][j] <= d:
457
+ neighbor.append(j)
458
+ neighbor_list.append(neighbor)
459
+
460
+ # # add neighbor for global part
461
+ # global_part_neighbor = neighbor_list[0].copy()
462
+ # """
463
+ # Line #373 is buggy. Thanks @crissallan!!
464
+ # See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30)
465
+ # However, fixing this bug will make it unable to load the pretrained model and
466
+ # affect the reproducibility of quantitative error reported in the paper.
467
+ # It is not a fatal bug so we didn't touch it and we are looking for possible solutions.
468
+ # """
469
+ # for i in global_part_neighbor:
470
+ # neighbor_list[i].append(edge_num)
471
+ # neighbor_list.append(global_part_neighbor)
472
+
473
+ return neighbor_list