Upload 8 files
Browse files- AESKConv_240_100.bin +3 -0
- __init__.py +0 -0
- decoders.py +56 -0
- mean_vel_smplxflame_30.npy +3 -0
- mertic.py +357 -0
- motion_encoder.py +193 -0
- skeleton.py +298 -0
- skeleton_DME.py +473 -0
AESKConv_240_100.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cd9566b24264f34d44003b3de62cdfd50aa85b7cdde2d369214599023c40f55
|
3 |
+
size 17558653
|
__init__.py
ADDED
File without changes
|
decoders.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is modified from https://github.com/EricGuo5513/TM2T
|
2 |
+
# Licensed under: https://github.com/EricGuo5513/TM2T/blob/main/LICENSE
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class VQDecoderV3(nn.Module):
|
8 |
+
def __init__(self, args):
|
9 |
+
super(VQDecoderV3, self).__init__()
|
10 |
+
n_up = args.vae_layer
|
11 |
+
channels = []
|
12 |
+
for i in range(n_up - 1):
|
13 |
+
channels.append(args.vae_length)
|
14 |
+
channels.append(args.vae_length)
|
15 |
+
channels.append(args.vae_test_dim)
|
16 |
+
input_size = args.vae_length
|
17 |
+
n_resblk = 2
|
18 |
+
assert len(channels) == n_up + 1
|
19 |
+
if input_size == channels[0]:
|
20 |
+
layers = []
|
21 |
+
else:
|
22 |
+
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
|
23 |
+
|
24 |
+
for i in range(n_resblk):
|
25 |
+
layers += [ResBlock(channels[0])]
|
26 |
+
# channels = channels
|
27 |
+
for i in range(n_up):
|
28 |
+
layers += [
|
29 |
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
30 |
+
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
|
31 |
+
nn.LeakyReLU(0.2, inplace=True),
|
32 |
+
]
|
33 |
+
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
|
34 |
+
self.main = nn.Sequential(*layers)
|
35 |
+
# self.main.apply(init_weight)
|
36 |
+
|
37 |
+
def forward(self, inputs):
|
38 |
+
inputs = inputs.permute(0, 2, 1)
|
39 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
40 |
+
return outputs
|
41 |
+
|
42 |
+
|
43 |
+
class ResBlock(nn.Module):
|
44 |
+
def __init__(self, channel):
|
45 |
+
super(ResBlock, self).__init__()
|
46 |
+
self.model = nn.Sequential(
|
47 |
+
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
|
48 |
+
nn.LeakyReLU(0.2, inplace=True),
|
49 |
+
nn.Conv1d(channel, channel, kernel_size=3, stride=1, padding=1),
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
residual = x
|
54 |
+
out = self.model(x)
|
55 |
+
out += residual
|
56 |
+
return out
|
mean_vel_smplxflame_30.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53b5e48f2a7bf78c41a6de6395d6bb4f29018465ca5d0ee2820a2be3eebb7137
|
3 |
+
size 348
|
mertic.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import wget
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import librosa
|
6 |
+
import librosa.display
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from scipy.signal import argrelextrema
|
9 |
+
from scipy import linalg
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .motion_encoder import VAESKConv
|
13 |
+
|
14 |
+
class LVDFace(object):
|
15 |
+
def __init__(self):
|
16 |
+
self.counter = 0
|
17 |
+
self.sum = 0
|
18 |
+
|
19 |
+
def compute(self, pred_vertices, target_vertices):
|
20 |
+
t, c = pred_vertices.shape
|
21 |
+
diff_pred = pred_vertices[1:, :] - pred_vertices[:-1, :]
|
22 |
+
diff_target = target_vertices[1:, :] - target_vertices[:-1, :]
|
23 |
+
loss = np.abs(diff_pred - diff_target)
|
24 |
+
loss = np.sum(loss)
|
25 |
+
self.counter += t * c
|
26 |
+
self.sum += loss
|
27 |
+
|
28 |
+
def avg(self):
|
29 |
+
return self.sum / self.counter
|
30 |
+
|
31 |
+
def reset(self):
|
32 |
+
self.counter = 0
|
33 |
+
self.sum = 0
|
34 |
+
|
35 |
+
|
36 |
+
class MSEFace(object):
|
37 |
+
def __init__(self):
|
38 |
+
self.counter = 0
|
39 |
+
self.sum = 0
|
40 |
+
|
41 |
+
def compute(self, pred_vertices, target_vertices):
|
42 |
+
t, c = pred_vertices.shape
|
43 |
+
loss = np.square(pred_vertices - target_vertices)
|
44 |
+
self.sum += np.sum(loss)
|
45 |
+
self.counter += t * c
|
46 |
+
|
47 |
+
def avg(self):
|
48 |
+
if self.counter == 0:
|
49 |
+
return 0
|
50 |
+
return self.sum / self.counter
|
51 |
+
|
52 |
+
def reset(self):
|
53 |
+
self.counter = 0
|
54 |
+
self.sum = 0
|
55 |
+
|
56 |
+
|
57 |
+
class L1div(object):
|
58 |
+
def __init__(self):
|
59 |
+
self.counter = 0
|
60 |
+
self.sum = 0
|
61 |
+
|
62 |
+
def compute(self, results):
|
63 |
+
self.counter += results.shape[0]
|
64 |
+
mean = np.mean(results, axis=0)
|
65 |
+
sum_l1 = np.sum(np.abs(results - mean), axis=None)
|
66 |
+
self.sum += sum_l1
|
67 |
+
|
68 |
+
def avg(self):
|
69 |
+
if self.counter == 0:
|
70 |
+
return 0
|
71 |
+
return self.sum / self.counter
|
72 |
+
|
73 |
+
def reset(self):
|
74 |
+
self.counter = 0
|
75 |
+
self.sum = 0
|
76 |
+
|
77 |
+
|
78 |
+
class SRGR(object):
|
79 |
+
def __init__(self, threshold=0.1, joints=47, joint_dim=3):
|
80 |
+
self.threshold = threshold
|
81 |
+
self.pose_dimes = joints
|
82 |
+
self.joint_dim = joint_dim
|
83 |
+
self.counter = 0
|
84 |
+
self.sum = 0
|
85 |
+
|
86 |
+
def run(self, results, targets, semantic=None, verbose=False):
|
87 |
+
if semantic is None:
|
88 |
+
semantic = np.ones(results.shape[0])
|
89 |
+
avg_weight = 1.0
|
90 |
+
else:
|
91 |
+
# srgr == 0.165 when all success, scale range to [0, 1]
|
92 |
+
avg_weight = 0.165
|
93 |
+
results = results.reshape(-1, self.pose_dimes, self.joint_dim)
|
94 |
+
targets = targets.reshape(-1, self.pose_dimes, self.joint_dim)
|
95 |
+
semantic = semantic.reshape(-1)
|
96 |
+
diff = np.linalg.norm(results - targets, axis=2) # T, J
|
97 |
+
if verbose:
|
98 |
+
print(diff)
|
99 |
+
success = np.where(diff < self.threshold, 1.0, 0.0)
|
100 |
+
for i in range(success.shape[0]):
|
101 |
+
success[i, :] *= semantic[i] * (1 / avg_weight)
|
102 |
+
rate = np.sum(success) / (success.shape[0] * success.shape[1])
|
103 |
+
self.counter += success.shape[0]
|
104 |
+
self.sum += rate * success.shape[0]
|
105 |
+
return rate
|
106 |
+
|
107 |
+
def avg(self):
|
108 |
+
return self.sum / self.counter
|
109 |
+
|
110 |
+
def reset(self):
|
111 |
+
self.counter = 0
|
112 |
+
self.sum = 0
|
113 |
+
|
114 |
+
|
115 |
+
class BC(object):
|
116 |
+
def __init__(self, download_path=None, sigma=0.3, order=7, upper_body=[3, 6, 9, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]):
|
117 |
+
self.sigma = sigma
|
118 |
+
self.order = order
|
119 |
+
self.upper_body = upper_body
|
120 |
+
self.pose_data = []
|
121 |
+
if download_path is not None:
|
122 |
+
os.makedirs(download_path, exist_ok=True)
|
123 |
+
model_file_path = os.path.join(download_path, "mean_vel_smplxflame_30.npy")
|
124 |
+
if not os.path.exists(model_file_path):
|
125 |
+
print(f"Downloading {model_file_path}")
|
126 |
+
wget.download(
|
127 |
+
"https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/test_sequences/weights/mean_vel_smplxflame_30.npy",
|
128 |
+
model_file_path,
|
129 |
+
)
|
130 |
+
self.mmae = np.load(os.path.join(download_path, "mean_vel_smplxflame_30.npy")) if download_path is not None else None
|
131 |
+
self.threshold = 0.10
|
132 |
+
self.counter = 0
|
133 |
+
self.sum = 0
|
134 |
+
|
135 |
+
def load_audio(self, wave, t_start=None, t_end=None, without_file=False, sr_audio=16000):
|
136 |
+
hop_length = 512
|
137 |
+
if without_file:
|
138 |
+
y = wave
|
139 |
+
else:
|
140 |
+
y, sr = librosa.load(wave, sr=sr_audio)
|
141 |
+
|
142 |
+
short_y = y[t_start:t_end] if t_start is not None else y
|
143 |
+
short_y = short_y.astype(np.float32)
|
144 |
+
onset_t = librosa.onset.onset_detect(y=short_y, sr=sr_audio, hop_length=hop_length, units="time")
|
145 |
+
return onset_t
|
146 |
+
|
147 |
+
def load_motion(self, pose, t_start, t_end, pose_fps, without_file=False):
|
148 |
+
data_each_file = []
|
149 |
+
if without_file:
|
150 |
+
data_each_file = pose
|
151 |
+
else:
|
152 |
+
with open(pose, "r") as f:
|
153 |
+
for i, line_data in enumerate(f.readlines()):
|
154 |
+
if i < 432:
|
155 |
+
continue
|
156 |
+
line_data_np = np.fromstring(line_data, sep=" ")
|
157 |
+
if pose_fps == 15 and i % 2 == 0:
|
158 |
+
continue
|
159 |
+
data_each_file.append(np.concatenate([line_data_np[30:39], line_data_np[112:121]], 0))
|
160 |
+
data_each_file = np.array(data_each_file) # T*165
|
161 |
+
# print(data_each_file.shape)
|
162 |
+
joints = data_each_file.transpose(1, 0)
|
163 |
+
dt = 1 / pose_fps
|
164 |
+
init_vel = (joints[:, 1:2] - joints[:, :1]) / dt
|
165 |
+
middle_vel = (joints[:, 2:] - joints[:, 0:-2]) / (2 * dt)
|
166 |
+
final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
|
167 |
+
vel = np.concatenate([init_vel, middle_vel, final_vel], 1).transpose(1, 0).reshape(data_each_file.shape[0], -1, 3)
|
168 |
+
# print(vel.shape)
|
169 |
+
|
170 |
+
if self.mmae is not None:
|
171 |
+
vel = np.linalg.norm(vel, axis=2) / self.mmae
|
172 |
+
else:
|
173 |
+
print("Warning: mmae is not provided, using max value of vel as mmae")
|
174 |
+
self.mmae = np.linalg.norm(vel, axis=2).max()
|
175 |
+
vel = np.linalg.norm(vel, axis=2) / self.mmae
|
176 |
+
# print(vel.shape) # T*J
|
177 |
+
|
178 |
+
beat_vel_all = []
|
179 |
+
for i in range(vel.shape[1]):
|
180 |
+
vel_mask = np.where(vel[:, i] > self.threshold)
|
181 |
+
beat_vel = argrelextrema(vel[t_start:t_end, i], np.less, order=self.order)
|
182 |
+
beat_vel_list = [j for j in beat_vel[0] if j in vel_mask[0]]
|
183 |
+
beat_vel_all.append(np.array(beat_vel_list))
|
184 |
+
return beat_vel_all
|
185 |
+
|
186 |
+
def eval_random_pose(self, wave, pose, t_start, t_end, pose_fps, num_random=60):
|
187 |
+
onset_raw = self.load_audio(wave, t_start, t_end)
|
188 |
+
dur = t_end - t_start
|
189 |
+
for i in range(num_random):
|
190 |
+
beat_vel_all = self.load_motion(pose, i, i + dur, pose_fps)
|
191 |
+
dis_all_b2a = self.compute(onset_raw, beat_vel_all)
|
192 |
+
print(f"{i}s: ", dis_all_b2a)
|
193 |
+
|
194 |
+
@staticmethod
|
195 |
+
def plot_onsets(audio, sr, onset_times_1, onset_times_2):
|
196 |
+
fig, axarr = plt.subplots(2, 1, figsize=(10, 10), sharex=True)
|
197 |
+
librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[0])
|
198 |
+
librosa.display.waveshow(audio, sr=sr, alpha=0.7, ax=axarr[1])
|
199 |
+
|
200 |
+
for onset in onset_times_1:
|
201 |
+
axarr[0].axvline(onset, color="r", linestyle="--", alpha=0.9, label="Onset Method 1")
|
202 |
+
axarr[0].legend()
|
203 |
+
axarr[0].set(title="Onset Method 1", xlabel="", ylabel="Amplitude")
|
204 |
+
|
205 |
+
for onset in onset_times_2:
|
206 |
+
axarr[1].axvline(onset, color="b", linestyle="-", alpha=0.7, label="Onset Method 2")
|
207 |
+
axarr[1].legend()
|
208 |
+
axarr[1].set(title="Onset Method 2", xlabel="Time (s)", ylabel="Amplitude")
|
209 |
+
|
210 |
+
handles, labels = plt.gca().get_legend_handles_labels()
|
211 |
+
by_label = dict(zip(labels, handles))
|
212 |
+
plt.legend(by_label.values(), by_label.keys())
|
213 |
+
plt.title("Audio waveform with Onsets")
|
214 |
+
plt.savefig("./onset.png", dpi=500)
|
215 |
+
|
216 |
+
def audio_beat_vis(self, onset_raw, onset_bt, onset_bt_rms):
|
217 |
+
fig, ax = plt.subplots(nrows=4, sharex=True)
|
218 |
+
librosa.display.specshow(librosa.amplitude_to_db(self.S, ref=np.max), y_axis="log", x_axis="time", ax=ax[0])
|
219 |
+
ax[1].plot(self.times, self.oenv, label="Onset strength")
|
220 |
+
ax[1].vlines(librosa.frames_to_time(onset_raw), 0, self.oenv.max(), label="Raw onsets", color="r")
|
221 |
+
ax[1].legend()
|
222 |
+
ax[2].vlines(librosa.frames_to_time(onset_bt), 0, self.oenv.max(), label="Backtracked", color="r")
|
223 |
+
ax[2].legend()
|
224 |
+
ax[3].vlines(librosa.frames_to_time(onset_bt_rms), 0, self.oenv.max(), label="Backtracked (RMS)", color="r")
|
225 |
+
ax[3].legend()
|
226 |
+
fig.savefig("./onset.png", dpi=500)
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def motion_frames2time(vel, offset, pose_fps):
|
230 |
+
return vel / pose_fps + offset
|
231 |
+
|
232 |
+
@staticmethod
|
233 |
+
def GAHR(a, b, sigma):
|
234 |
+
dis_all_b2a = 0
|
235 |
+
for b_each in b:
|
236 |
+
l2_min = min(abs(a_each - b_each) for a_each in a)
|
237 |
+
dis_all_b2a += math.exp(-(l2_min**2) / (2 * sigma**2))
|
238 |
+
return dis_all_b2a / len(b)
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def fix_directed_GAHR(a, b, sigma):
|
242 |
+
a = BC.motion_frames2time(a, 0, 30)
|
243 |
+
b = BC.motion_frames2time(b, 0, 30)
|
244 |
+
a = [0] + a + [len(a) / 30]
|
245 |
+
b = [0] + b + [len(b) / 30]
|
246 |
+
return BC.GAHR(a, b, sigma)
|
247 |
+
|
248 |
+
def compute(self, onset_bt_rms, beat_vel, length=1, pose_fps=30):
|
249 |
+
avg_dis_all_b2a_list = []
|
250 |
+
for its, beat_vel_each in enumerate(beat_vel):
|
251 |
+
if its not in self.upper_body:
|
252 |
+
continue
|
253 |
+
if beat_vel_each.size == 0:
|
254 |
+
avg_dis_all_b2a_list.append(0)
|
255 |
+
continue
|
256 |
+
pose_bt = self.motion_frames2time(beat_vel_each, 0, pose_fps)
|
257 |
+
avg_dis_all_b2a_list.append(self.GAHR(pose_bt, onset_bt_rms, self.sigma))
|
258 |
+
self.sum += (sum(avg_dis_all_b2a_list) / len(self.upper_body)) * length
|
259 |
+
self.counter += length
|
260 |
+
|
261 |
+
def avg(self):
|
262 |
+
return self.sum / self.counter
|
263 |
+
|
264 |
+
def reset(self):
|
265 |
+
self.counter = 0
|
266 |
+
self.sum = 0
|
267 |
+
|
268 |
+
|
269 |
+
class Arg(object):
|
270 |
+
def __init__(self):
|
271 |
+
self.vae_length = 240
|
272 |
+
self.vae_test_dim = 330
|
273 |
+
self.vae_test_len = 32
|
274 |
+
self.vae_layer = 4
|
275 |
+
self.vae_test_stride = 20
|
276 |
+
self.vae_grow = [1, 1, 2, 1]
|
277 |
+
self.variational = False
|
278 |
+
|
279 |
+
|
280 |
+
class FGD(object):
|
281 |
+
def __init__(self, download_path="./emage/", device="cuda"):
|
282 |
+
if download_path is not None:
|
283 |
+
os.makedirs(download_path, exist_ok=True)
|
284 |
+
model_file_path = os.path.join(download_path, "AESKConv_240_100.bin")
|
285 |
+
smplx_model_dir = os.path.join(download_path, "smplx_models", "smplx")
|
286 |
+
smplx_model_file_path = os.path.join(smplx_model_dir, "SMPLX_NEUTRAL_2020.npz")
|
287 |
+
if not os.path.exists(model_file_path):
|
288 |
+
print(f"Downloading {model_file_path}")
|
289 |
+
wget.download(
|
290 |
+
"https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/test_sequences/weights/AESKConv_240_100.bin",
|
291 |
+
model_file_path,
|
292 |
+
)
|
293 |
+
|
294 |
+
os.makedirs(smplx_model_dir, exist_ok=True)
|
295 |
+
if not os.path.exists(smplx_model_file_path):
|
296 |
+
print(f"Downloading {smplx_model_file_path}")
|
297 |
+
wget.download(
|
298 |
+
"https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz",
|
299 |
+
smplx_model_file_path,
|
300 |
+
)
|
301 |
+
args = Arg()
|
302 |
+
self.eval_model = VAESKConv(args, model_save_path=download_path) # Assumes LocalEncoder is defined elsewhere
|
303 |
+
old_stat = torch.load(download_path + "AESKConv_240_100.bin")["model_state"]
|
304 |
+
new_stat = {}
|
305 |
+
for k, v in old_stat.items():
|
306 |
+
# If 'module.' is in the key, remove it
|
307 |
+
new_key = k.replace("module.", "") if "module." in k else k
|
308 |
+
new_stat[new_key] = v
|
309 |
+
self.eval_model.load_state_dict(new_stat)
|
310 |
+
|
311 |
+
self.eval_model.eval()
|
312 |
+
if torch.cuda.is_available():
|
313 |
+
self.eval_model.to(device)
|
314 |
+
|
315 |
+
self.pred_features = []
|
316 |
+
self.target_features = []
|
317 |
+
self.device = device
|
318 |
+
|
319 |
+
def reset(self):
|
320 |
+
self.pred_features = []
|
321 |
+
self.target_features = []
|
322 |
+
|
323 |
+
def get_feature(self, data):
|
324 |
+
assert len(data.shape) == 3
|
325 |
+
if data.shape[1] % 32 != 0:
|
326 |
+
drop_len = data.shape[1] % 32
|
327 |
+
data = data[:, :-drop_len]
|
328 |
+
# print(data.shape)
|
329 |
+
with torch.no_grad():
|
330 |
+
if torch.cuda.is_available():
|
331 |
+
data = data.to(self.device)
|
332 |
+
feature = self.eval_model.map2latent(data).cpu().numpy()
|
333 |
+
# print(feature.shape)
|
334 |
+
return feature
|
335 |
+
|
336 |
+
def update(self, pred, target):
|
337 |
+
self.pred_features.append(self.get_feature(pred))
|
338 |
+
self.target_features.append(self.get_feature(target))
|
339 |
+
|
340 |
+
def compute(self):
|
341 |
+
pred_features = np.concatenate([x.reshape(-1, x.shape[-1]) for x in self.pred_features], axis=0)
|
342 |
+
target_features = np.concatenate([x.reshape(-1, x.shape[-1]) for x in self.target_features], axis=0)
|
343 |
+
# print(pred_features.shape, target_features.shape)
|
344 |
+
return self.frechet_distance(pred_features, target_features)
|
345 |
+
|
346 |
+
@staticmethod
|
347 |
+
def frechet_distance(samples_A, samples_B, eps=1e-6):
|
348 |
+
mu1 = np.mean(samples_A, axis=0)
|
349 |
+
sigma1 = np.cov(samples_A, rowvar=False)
|
350 |
+
mu2 = np.mean(samples_B, axis=0)
|
351 |
+
sigma2 = np.cov(samples_B, rowvar=False)
|
352 |
+
diff = mu1 - mu2
|
353 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
354 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
355 |
+
if np.iscomplexobj(covmean):
|
356 |
+
covmean = covmean.real
|
357 |
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
|
motion_encoder.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from .skeleton_DME import SkeletonConv, SkeletonPool, find_neighbor, build_edge_topology
|
5 |
+
from .skeleton import SkeletonResidual
|
6 |
+
from .decoders import VQDecoderV3
|
7 |
+
|
8 |
+
|
9 |
+
class LocalEncoder(nn.Module):
|
10 |
+
def __init__(self, args, topology):
|
11 |
+
super(LocalEncoder, self).__init__()
|
12 |
+
args.channel_base = 6
|
13 |
+
args.activation = "tanh"
|
14 |
+
args.use_residual_blocks = True
|
15 |
+
args.z_dim = 1024
|
16 |
+
args.temporal_scale = 8
|
17 |
+
args.kernel_size = 4
|
18 |
+
args.num_layers = args.vae_layer
|
19 |
+
args.skeleton_dist = 2
|
20 |
+
args.extra_conv = 0
|
21 |
+
# check how to reflect in 1d
|
22 |
+
args.padding_mode = "constant"
|
23 |
+
args.skeleton_pool = "mean"
|
24 |
+
args.upsampling = "linear"
|
25 |
+
|
26 |
+
self.topologies = [topology]
|
27 |
+
self.channel_base = [args.channel_base]
|
28 |
+
|
29 |
+
self.channel_list = []
|
30 |
+
self.edge_num = [len(topology)]
|
31 |
+
self.pooling_list = []
|
32 |
+
self.layers = nn.ModuleList()
|
33 |
+
self.args = args
|
34 |
+
# self.convs = []
|
35 |
+
|
36 |
+
kernel_size = args.kernel_size
|
37 |
+
kernel_even = False if kernel_size % 2 else True
|
38 |
+
padding = (kernel_size - 1) // 2
|
39 |
+
bias = True
|
40 |
+
self.grow = args.vae_grow
|
41 |
+
for i in range(args.num_layers):
|
42 |
+
self.channel_base.append(self.channel_base[-1] * self.grow[i])
|
43 |
+
|
44 |
+
for i in range(args.num_layers):
|
45 |
+
seq = []
|
46 |
+
neighbour_list = find_neighbor(self.topologies[i], args.skeleton_dist)
|
47 |
+
in_channels = self.channel_base[i] * self.edge_num[i]
|
48 |
+
out_channels = self.channel_base[i + 1] * self.edge_num[i]
|
49 |
+
if i == 0:
|
50 |
+
self.channel_list.append(in_channels)
|
51 |
+
self.channel_list.append(out_channels)
|
52 |
+
last_pool = True if i == args.num_layers - 1 else False
|
53 |
+
|
54 |
+
# (T, J, D) => (T, J', D)
|
55 |
+
pool = SkeletonPool(
|
56 |
+
edges=self.topologies[i],
|
57 |
+
pooling_mode=args.skeleton_pool,
|
58 |
+
channels_per_edge=out_channels // len(neighbour_list),
|
59 |
+
last_pool=last_pool,
|
60 |
+
)
|
61 |
+
|
62 |
+
if args.use_residual_blocks:
|
63 |
+
# (T, J, D) => (T/2, J', 2D)
|
64 |
+
seq.append(
|
65 |
+
SkeletonResidual(
|
66 |
+
self.topologies[i],
|
67 |
+
neighbour_list,
|
68 |
+
joint_num=self.edge_num[i],
|
69 |
+
in_channels=in_channels,
|
70 |
+
out_channels=out_channels,
|
71 |
+
kernel_size=kernel_size,
|
72 |
+
stride=2,
|
73 |
+
padding=padding,
|
74 |
+
padding_mode=args.padding_mode,
|
75 |
+
bias=bias,
|
76 |
+
extra_conv=args.extra_conv,
|
77 |
+
pooling_mode=args.skeleton_pool,
|
78 |
+
activation=args.activation,
|
79 |
+
last_pool=last_pool,
|
80 |
+
)
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
for _ in range(args.extra_conv):
|
84 |
+
# (T, J, D) => (T, J, D)
|
85 |
+
seq.append(
|
86 |
+
SkeletonConv(
|
87 |
+
neighbour_list,
|
88 |
+
in_channels=in_channels,
|
89 |
+
out_channels=in_channels,
|
90 |
+
joint_num=self.edge_num[i],
|
91 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
92 |
+
stride=1,
|
93 |
+
padding=padding,
|
94 |
+
padding_mode=args.padding_mode,
|
95 |
+
bias=bias,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
|
99 |
+
# (T, J, D) => (T/2, J, 2D)
|
100 |
+
seq.append(
|
101 |
+
SkeletonConv(
|
102 |
+
neighbour_list,
|
103 |
+
in_channels=in_channels,
|
104 |
+
out_channels=out_channels,
|
105 |
+
joint_num=self.edge_num[i],
|
106 |
+
kernel_size=kernel_size,
|
107 |
+
stride=2,
|
108 |
+
padding=padding,
|
109 |
+
padding_mode=args.padding_mode,
|
110 |
+
bias=bias,
|
111 |
+
add_offset=False,
|
112 |
+
in_offset_channel=3 * self.channel_base[i] // self.channel_base[0],
|
113 |
+
)
|
114 |
+
)
|
115 |
+
# self.convs.append(seq[-1])
|
116 |
+
|
117 |
+
seq.append(pool)
|
118 |
+
seq.append(nn.PReLU() if args.activation == "relu" else nn.Tanh())
|
119 |
+
self.layers.append(nn.Sequential(*seq))
|
120 |
+
|
121 |
+
self.topologies.append(pool.new_edges)
|
122 |
+
self.pooling_list.append(pool.pooling_list)
|
123 |
+
self.edge_num.append(len(self.topologies[-1]))
|
124 |
+
|
125 |
+
# in_features = self.channel_base[-1] * len(self.pooling_list[-1])
|
126 |
+
# in_features *= int(args.temporal_scale / 2)
|
127 |
+
# self.reduce = nn.Linear(in_features, args.z_dim)
|
128 |
+
# self.mu = nn.Linear(in_features, args.z_dim)
|
129 |
+
# self.logvar = nn.Linear(in_features, args.z_dim)
|
130 |
+
|
131 |
+
def forward(self, input):
|
132 |
+
# bs, n, c = input.shape[0], input.shape[1], input.shape[2]
|
133 |
+
output = input.permute(0, 2, 1) # input.reshape(bs, n, -1, 6)
|
134 |
+
for layer in self.layers:
|
135 |
+
output = layer(output)
|
136 |
+
# output = output.view(output.shape[0], -1)
|
137 |
+
output = output.permute(0, 2, 1)
|
138 |
+
return output
|
139 |
+
|
140 |
+
|
141 |
+
def reparameterize(mu, logvar):
|
142 |
+
std = torch.exp(0.5 * logvar)
|
143 |
+
eps = torch.randn_like(std)
|
144 |
+
return mu + eps * std
|
145 |
+
|
146 |
+
|
147 |
+
class VAEConv(nn.Module):
|
148 |
+
def __init__(self, args):
|
149 |
+
super(VAEConv, self).__init__()
|
150 |
+
# self.encoder = VQEncoderV3(args)
|
151 |
+
# self.decoder = VQDecoderV3(args)
|
152 |
+
self.fc_mu = nn.Linear(args.vae_length, args.vae_length)
|
153 |
+
self.fc_logvar = nn.Linear(args.vae_length, args.vae_length)
|
154 |
+
self.variational = args.variational
|
155 |
+
|
156 |
+
def forward(self, inputs):
|
157 |
+
pre_latent = self.encoder(inputs)
|
158 |
+
mu, logvar = None, None
|
159 |
+
if self.variational:
|
160 |
+
mu = self.fc_mu(pre_latent)
|
161 |
+
logvar = self.fc_logvar(pre_latent)
|
162 |
+
pre_latent = reparameterize(mu, logvar)
|
163 |
+
rec_pose = self.decoder(pre_latent)
|
164 |
+
return {
|
165 |
+
"poses_feat": pre_latent,
|
166 |
+
"rec_pose": rec_pose,
|
167 |
+
"pose_mu": mu,
|
168 |
+
"pose_logvar": logvar,
|
169 |
+
}
|
170 |
+
|
171 |
+
def map2latent(self, inputs):
|
172 |
+
pre_latent = self.encoder(inputs)
|
173 |
+
if self.variational:
|
174 |
+
mu = self.fc_mu(pre_latent)
|
175 |
+
logvar = self.fc_logvar(pre_latent)
|
176 |
+
pre_latent = reparameterize(mu, logvar)
|
177 |
+
return pre_latent
|
178 |
+
|
179 |
+
def decode(self, pre_latent):
|
180 |
+
rec_pose = self.decoder(pre_latent)
|
181 |
+
return rec_pose
|
182 |
+
|
183 |
+
|
184 |
+
class VAESKConv(VAEConv):
|
185 |
+
def __init__(self, args, model_save_path="./emage/"):
|
186 |
+
# args = args()
|
187 |
+
super(VAESKConv, self).__init__(args)
|
188 |
+
smpl_fname = model_save_path + "smplx_models/smplx/SMPLX_NEUTRAL_2020.npz"
|
189 |
+
smpl_data = np.load(smpl_fname, encoding="latin1")
|
190 |
+
parents = smpl_data["kintree_table"][0].astype(np.int32)
|
191 |
+
edges = build_edge_topology(parents)
|
192 |
+
self.encoder = LocalEncoder(args, edges)
|
193 |
+
self.decoder = VQDecoderV3(args)
|
skeleton.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .skeleton_DME import SkeletonConv, SkeletonPool, SkeletonUnpool
|
5 |
+
|
6 |
+
|
7 |
+
def calc_node_depth(topology):
|
8 |
+
def dfs(node, topology):
|
9 |
+
if topology[node] < 0:
|
10 |
+
return 0
|
11 |
+
return 1 + dfs(topology[node], topology)
|
12 |
+
|
13 |
+
depth = []
|
14 |
+
for i in range(len(topology)):
|
15 |
+
depth.append(dfs(i, topology))
|
16 |
+
|
17 |
+
return depth
|
18 |
+
|
19 |
+
|
20 |
+
def residual_ratio(k):
|
21 |
+
return 1 / (k + 1)
|
22 |
+
|
23 |
+
|
24 |
+
class Affine(nn.Module):
|
25 |
+
def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0):
|
26 |
+
super(Affine, self).__init__()
|
27 |
+
if scale:
|
28 |
+
self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init)
|
29 |
+
else:
|
30 |
+
self.register_parameter("scale", None)
|
31 |
+
|
32 |
+
if bias:
|
33 |
+
self.bias = nn.Parameter(torch.zeros(num_parameters))
|
34 |
+
else:
|
35 |
+
self.register_parameter("bias", None)
|
36 |
+
|
37 |
+
def forward(self, input):
|
38 |
+
output = input
|
39 |
+
if self.scale is not None:
|
40 |
+
scale = self.scale.unsqueeze(0)
|
41 |
+
while scale.dim() < input.dim():
|
42 |
+
scale = scale.unsqueeze(2)
|
43 |
+
output = output.mul(scale)
|
44 |
+
|
45 |
+
if self.bias is not None:
|
46 |
+
bias = self.bias.unsqueeze(0)
|
47 |
+
while bias.dim() < input.dim():
|
48 |
+
bias = bias.unsqueeze(2)
|
49 |
+
output += bias
|
50 |
+
|
51 |
+
return output
|
52 |
+
|
53 |
+
|
54 |
+
class BatchStatistics(nn.Module):
|
55 |
+
def __init__(self, affine=-1):
|
56 |
+
super(BatchStatistics, self).__init__()
|
57 |
+
self.affine = nn.Sequential() if affine == -1 else Affine(affine)
|
58 |
+
self.loss = 0
|
59 |
+
|
60 |
+
def clear_loss(self):
|
61 |
+
self.loss = 0
|
62 |
+
|
63 |
+
def compute_loss(self, input):
|
64 |
+
input_flat = input.view(input.size(1), input.numel() // input.size(1))
|
65 |
+
mu = input_flat.mean(1)
|
66 |
+
logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log()
|
67 |
+
|
68 |
+
self.loss = mu.pow(2).mean() + logvar.pow(2).mean()
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
self.compute_loss(input)
|
72 |
+
return self.affine(input)
|
73 |
+
|
74 |
+
|
75 |
+
class ResidualBlock(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False
|
78 |
+
):
|
79 |
+
super(ResidualBlock, self).__init__()
|
80 |
+
|
81 |
+
self.residual_ratio = residual_ratio
|
82 |
+
self.shortcut_ratio = 1 - residual_ratio
|
83 |
+
|
84 |
+
residual = []
|
85 |
+
residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding))
|
86 |
+
if batch_statistics:
|
87 |
+
residual.append(BatchStatistics(out_channels))
|
88 |
+
if not last_layer:
|
89 |
+
residual.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
90 |
+
self.residual = nn.Sequential(*residual)
|
91 |
+
|
92 |
+
self.shortcut = nn.Sequential(
|
93 |
+
nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(),
|
94 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
|
95 |
+
BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential(),
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, input):
|
99 |
+
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
|
100 |
+
|
101 |
+
|
102 |
+
class ResidualBlockTranspose(nn.Module):
|
103 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation):
|
104 |
+
super(ResidualBlockTranspose, self).__init__()
|
105 |
+
|
106 |
+
self.residual_ratio = residual_ratio
|
107 |
+
self.shortcut_ratio = 1 - residual_ratio
|
108 |
+
|
109 |
+
self.residual = nn.Sequential(
|
110 |
+
nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), nn.PReLU() if activation == "relu" else nn.Tanh()
|
111 |
+
)
|
112 |
+
|
113 |
+
self.shortcut = nn.Sequential(
|
114 |
+
nn.Upsample(scale_factor=2, mode="linear", align_corners=False) if stride == 2 else nn.Sequential(),
|
115 |
+
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
|
116 |
+
)
|
117 |
+
|
118 |
+
def forward(self, input):
|
119 |
+
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio)
|
120 |
+
|
121 |
+
|
122 |
+
class SkeletonResidual(nn.Module):
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
topology,
|
126 |
+
neighbour_list,
|
127 |
+
joint_num,
|
128 |
+
in_channels,
|
129 |
+
out_channels,
|
130 |
+
kernel_size,
|
131 |
+
stride,
|
132 |
+
padding,
|
133 |
+
padding_mode,
|
134 |
+
bias,
|
135 |
+
extra_conv,
|
136 |
+
pooling_mode,
|
137 |
+
activation,
|
138 |
+
last_pool,
|
139 |
+
):
|
140 |
+
super(SkeletonResidual, self).__init__()
|
141 |
+
|
142 |
+
kernel_even = False if kernel_size % 2 else True
|
143 |
+
|
144 |
+
seq = []
|
145 |
+
for _ in range(extra_conv):
|
146 |
+
# (T, J, D) => (T, J, D)
|
147 |
+
seq.append(
|
148 |
+
SkeletonConv(
|
149 |
+
neighbour_list,
|
150 |
+
in_channels=in_channels,
|
151 |
+
out_channels=in_channels,
|
152 |
+
joint_num=joint_num,
|
153 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
154 |
+
stride=1,
|
155 |
+
padding=padding,
|
156 |
+
padding_mode=padding_mode,
|
157 |
+
bias=bias,
|
158 |
+
)
|
159 |
+
)
|
160 |
+
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
161 |
+
# (T, J, D) => (T/2, J, 2D)
|
162 |
+
seq.append(
|
163 |
+
SkeletonConv(
|
164 |
+
neighbour_list,
|
165 |
+
in_channels=in_channels,
|
166 |
+
out_channels=out_channels,
|
167 |
+
joint_num=joint_num,
|
168 |
+
kernel_size=kernel_size,
|
169 |
+
stride=stride,
|
170 |
+
padding=padding,
|
171 |
+
padding_mode=padding_mode,
|
172 |
+
bias=bias,
|
173 |
+
add_offset=False,
|
174 |
+
)
|
175 |
+
)
|
176 |
+
seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!!
|
177 |
+
self.residual = nn.Sequential(*seq)
|
178 |
+
|
179 |
+
# (T, J, D) => (T/2, J, 2D)
|
180 |
+
self.shortcut = SkeletonConv(
|
181 |
+
neighbour_list,
|
182 |
+
in_channels=in_channels,
|
183 |
+
out_channels=out_channels,
|
184 |
+
joint_num=joint_num,
|
185 |
+
kernel_size=1,
|
186 |
+
stride=stride,
|
187 |
+
padding=0,
|
188 |
+
bias=True,
|
189 |
+
add_offset=False,
|
190 |
+
)
|
191 |
+
|
192 |
+
seq = []
|
193 |
+
# (T/2, J, 2D) => (T/2, J', 2D)
|
194 |
+
pool = SkeletonPool(
|
195 |
+
edges=topology, pooling_mode=pooling_mode, channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool
|
196 |
+
)
|
197 |
+
if len(pool.pooling_list) != pool.edge_num:
|
198 |
+
seq.append(pool)
|
199 |
+
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
200 |
+
self.common = nn.Sequential(*seq)
|
201 |
+
|
202 |
+
def forward(self, input):
|
203 |
+
output = self.residual(input) + self.shortcut(input)
|
204 |
+
|
205 |
+
return self.common(output)
|
206 |
+
|
207 |
+
|
208 |
+
class SkeletonResidualTranspose(nn.Module):
|
209 |
+
def __init__(
|
210 |
+
self,
|
211 |
+
neighbour_list,
|
212 |
+
joint_num,
|
213 |
+
in_channels,
|
214 |
+
out_channels,
|
215 |
+
kernel_size,
|
216 |
+
padding,
|
217 |
+
padding_mode,
|
218 |
+
bias,
|
219 |
+
extra_conv,
|
220 |
+
pooling_list,
|
221 |
+
upsampling,
|
222 |
+
activation,
|
223 |
+
last_layer,
|
224 |
+
):
|
225 |
+
super(SkeletonResidualTranspose, self).__init__()
|
226 |
+
|
227 |
+
kernel_even = False if kernel_size % 2 else True
|
228 |
+
|
229 |
+
seq = []
|
230 |
+
# (T, J, D) => (2T, J, D)
|
231 |
+
if upsampling is not None:
|
232 |
+
seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False))
|
233 |
+
# (2T, J, D) => (2T, J', D)
|
234 |
+
unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list))
|
235 |
+
if unpool.input_edge_num != unpool.output_edge_num:
|
236 |
+
seq.append(unpool)
|
237 |
+
self.common = nn.Sequential(*seq)
|
238 |
+
|
239 |
+
seq = []
|
240 |
+
for _ in range(extra_conv):
|
241 |
+
# (2T, J', D) => (2T, J', D)
|
242 |
+
seq.append(
|
243 |
+
SkeletonConv(
|
244 |
+
neighbour_list,
|
245 |
+
in_channels=in_channels,
|
246 |
+
out_channels=in_channels,
|
247 |
+
joint_num=joint_num,
|
248 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
249 |
+
stride=1,
|
250 |
+
padding=padding,
|
251 |
+
padding_mode=padding_mode,
|
252 |
+
bias=bias,
|
253 |
+
)
|
254 |
+
)
|
255 |
+
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh())
|
256 |
+
# (2T, J', D) => (2T, J', D/2)
|
257 |
+
seq.append(
|
258 |
+
SkeletonConv(
|
259 |
+
neighbour_list,
|
260 |
+
in_channels=in_channels,
|
261 |
+
out_channels=out_channels,
|
262 |
+
joint_num=joint_num,
|
263 |
+
kernel_size=kernel_size - 1 if kernel_even else kernel_size,
|
264 |
+
stride=1,
|
265 |
+
padding=padding,
|
266 |
+
padding_mode=padding_mode,
|
267 |
+
bias=bias,
|
268 |
+
add_offset=False,
|
269 |
+
)
|
270 |
+
)
|
271 |
+
self.residual = nn.Sequential(*seq)
|
272 |
+
|
273 |
+
# (2T, J', D) => (2T, J', D/2)
|
274 |
+
self.shortcut = SkeletonConv(
|
275 |
+
neighbour_list,
|
276 |
+
in_channels=in_channels,
|
277 |
+
out_channels=out_channels,
|
278 |
+
joint_num=joint_num,
|
279 |
+
kernel_size=1,
|
280 |
+
stride=1,
|
281 |
+
padding=0,
|
282 |
+
bias=True,
|
283 |
+
add_offset=False,
|
284 |
+
)
|
285 |
+
|
286 |
+
if activation == "relu":
|
287 |
+
self.activation = nn.PReLU() if not last_layer else None
|
288 |
+
else:
|
289 |
+
self.activation = nn.Tanh() if not last_layer else None
|
290 |
+
|
291 |
+
def forward(self, input):
|
292 |
+
output = self.common(input)
|
293 |
+
output = self.residual(output) + self.shortcut(output)
|
294 |
+
|
295 |
+
if self.activation is not None:
|
296 |
+
return self.activation(output)
|
297 |
+
else:
|
298 |
+
return output
|
skeleton_DME.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This script is modified from https://github.com/DeepMotionEditing/deep-motion-editing
|
2 |
+
# Licensed under:
|
3 |
+
"""
|
4 |
+
Copyright (c) 2020, Kfir Aberman, Peizhuo Li, Yijia Weng, Dani Lischinski, Olga Sorkine-Hornung, Daniel Cohen-Or and Baoquan Chen.
|
5 |
+
All rights reserved.
|
6 |
+
|
7 |
+
Redistribution and use in source and binary forms, with or without
|
8 |
+
modification, are permitted provided that the following conditions are met:
|
9 |
+
|
10 |
+
* Redistributions of source code must retain the above copyright notice, this
|
11 |
+
list of conditions and the following disclaimer.
|
12 |
+
|
13 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
14 |
+
this list of conditions and the following disclaimer in the documentation
|
15 |
+
and/or other materials provided with the distribution.
|
16 |
+
|
17 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
18 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
19 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
20 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
21 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
22 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
23 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
24 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
25 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
26 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
27 |
+
"""
|
28 |
+
|
29 |
+
import math
|
30 |
+
import numpy as np
|
31 |
+
import torch
|
32 |
+
import torch.nn as nn
|
33 |
+
import torch.nn.functional as F
|
34 |
+
|
35 |
+
|
36 |
+
class SkeletonConv(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
neighbour_list,
|
40 |
+
in_channels,
|
41 |
+
out_channels,
|
42 |
+
kernel_size,
|
43 |
+
joint_num,
|
44 |
+
stride=1,
|
45 |
+
padding=0,
|
46 |
+
bias=True,
|
47 |
+
padding_mode="zeros",
|
48 |
+
add_offset=False,
|
49 |
+
in_offset_channel=0,
|
50 |
+
):
|
51 |
+
self.in_channels_per_joint = in_channels // joint_num
|
52 |
+
self.out_channels_per_joint = out_channels // joint_num
|
53 |
+
if in_channels % joint_num != 0 or out_channels % joint_num != 0:
|
54 |
+
raise Exception("BAD")
|
55 |
+
super(SkeletonConv, self).__init__()
|
56 |
+
|
57 |
+
if padding_mode == "zeros":
|
58 |
+
padding_mode = "constant"
|
59 |
+
if padding_mode == "reflection":
|
60 |
+
padding_mode = "reflect"
|
61 |
+
|
62 |
+
self.expanded_neighbour_list = []
|
63 |
+
self.expanded_neighbour_list_offset = []
|
64 |
+
self.neighbour_list = neighbour_list
|
65 |
+
self.add_offset = add_offset
|
66 |
+
self.joint_num = joint_num
|
67 |
+
|
68 |
+
self.stride = stride
|
69 |
+
self.dilation = 1
|
70 |
+
self.groups = 1
|
71 |
+
self.padding = padding
|
72 |
+
self.padding_mode = padding_mode
|
73 |
+
self._padding_repeated_twice = (padding, padding)
|
74 |
+
|
75 |
+
for neighbour in neighbour_list:
|
76 |
+
expanded = []
|
77 |
+
for k in neighbour:
|
78 |
+
for i in range(self.in_channels_per_joint):
|
79 |
+
expanded.append(k * self.in_channels_per_joint + i)
|
80 |
+
self.expanded_neighbour_list.append(expanded)
|
81 |
+
|
82 |
+
if self.add_offset:
|
83 |
+
self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
|
84 |
+
|
85 |
+
for neighbour in neighbour_list:
|
86 |
+
expanded = []
|
87 |
+
for k in neighbour:
|
88 |
+
for i in range(add_offset):
|
89 |
+
expanded.append(k * in_offset_channel + i)
|
90 |
+
self.expanded_neighbour_list_offset.append(expanded)
|
91 |
+
|
92 |
+
self.weight = torch.zeros(out_channels, in_channels, kernel_size)
|
93 |
+
if bias:
|
94 |
+
self.bias = torch.zeros(out_channels)
|
95 |
+
else:
|
96 |
+
self.register_parameter("bias", None)
|
97 |
+
|
98 |
+
self.mask = torch.zeros_like(self.weight)
|
99 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
100 |
+
self.mask[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
|
101 |
+
self.mask = nn.Parameter(self.mask, requires_grad=False)
|
102 |
+
|
103 |
+
self.description = (
|
104 |
+
"SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, "
|
105 |
+
"joint_num={}, stride={}, padding={}, bias={})".format(
|
106 |
+
in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
self.reset_parameters()
|
111 |
+
|
112 |
+
def reset_parameters(self):
|
113 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
114 |
+
""" Use temporary variable to avoid assign to copy of slice, which might lead to unexpected result """
|
115 |
+
tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...])
|
116 |
+
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
|
117 |
+
self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...] = tmp
|
118 |
+
if self.bias is not None:
|
119 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
|
120 |
+
self.weight[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1), neighbour, ...]
|
121 |
+
)
|
122 |
+
bound = 1 / math.sqrt(fan_in)
|
123 |
+
tmp = torch.zeros_like(self.bias[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1)])
|
124 |
+
nn.init.uniform_(tmp, -bound, bound)
|
125 |
+
self.bias[self.out_channels_per_joint * i : self.out_channels_per_joint * (i + 1)] = tmp
|
126 |
+
|
127 |
+
self.weight = nn.Parameter(self.weight)
|
128 |
+
if self.bias is not None:
|
129 |
+
self.bias = nn.Parameter(self.bias)
|
130 |
+
|
131 |
+
def set_offset(self, offset):
|
132 |
+
if not self.add_offset:
|
133 |
+
raise Exception("Wrong Combination of Parameters")
|
134 |
+
self.offset = offset.reshape(offset.shape[0], -1)
|
135 |
+
|
136 |
+
def forward(self, input):
|
137 |
+
# print('SkeletonConv')
|
138 |
+
weight_masked = self.weight * self.mask
|
139 |
+
# print(f'input: {input.size()}')
|
140 |
+
res = F.conv1d(
|
141 |
+
F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
|
142 |
+
weight_masked,
|
143 |
+
self.bias,
|
144 |
+
self.stride,
|
145 |
+
0,
|
146 |
+
self.dilation,
|
147 |
+
self.groups,
|
148 |
+
)
|
149 |
+
|
150 |
+
if self.add_offset:
|
151 |
+
offset_res = self.offset_enc(self.offset)
|
152 |
+
offset_res = offset_res.reshape(offset_res.shape + (1,))
|
153 |
+
res += offset_res / 100
|
154 |
+
# print(f'res: {res.size()}')
|
155 |
+
return res
|
156 |
+
|
157 |
+
|
158 |
+
class SkeletonLinear(nn.Module):
|
159 |
+
def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
|
160 |
+
super(SkeletonLinear, self).__init__()
|
161 |
+
self.neighbour_list = neighbour_list
|
162 |
+
self.in_channels = in_channels
|
163 |
+
self.out_channels = out_channels
|
164 |
+
self.in_channels_per_joint = in_channels // len(neighbour_list)
|
165 |
+
self.out_channels_per_joint = out_channels // len(neighbour_list)
|
166 |
+
self.extra_dim1 = extra_dim1
|
167 |
+
self.expanded_neighbour_list = []
|
168 |
+
|
169 |
+
for neighbour in neighbour_list:
|
170 |
+
expanded = []
|
171 |
+
for k in neighbour:
|
172 |
+
for i in range(self.in_channels_per_joint):
|
173 |
+
expanded.append(k * self.in_channels_per_joint + i)
|
174 |
+
self.expanded_neighbour_list.append(expanded)
|
175 |
+
|
176 |
+
self.weight = torch.zeros(out_channels, in_channels)
|
177 |
+
self.mask = torch.zeros(out_channels, in_channels)
|
178 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
179 |
+
|
180 |
+
self.reset_parameters()
|
181 |
+
|
182 |
+
def reset_parameters(self):
|
183 |
+
for i, neighbour in enumerate(self.expanded_neighbour_list):
|
184 |
+
tmp = torch.zeros_like(self.weight[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour])
|
185 |
+
self.mask[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour] = 1
|
186 |
+
nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
|
187 |
+
self.weight[i * self.out_channels_per_joint : (i + 1) * self.out_channels_per_joint, neighbour] = tmp
|
188 |
+
|
189 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
190 |
+
bound = 1 / math.sqrt(fan_in)
|
191 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
192 |
+
|
193 |
+
self.weight = nn.Parameter(self.weight)
|
194 |
+
self.mask = nn.Parameter(self.mask, requires_grad=False)
|
195 |
+
|
196 |
+
def forward(self, input):
|
197 |
+
input = input.reshape(input.shape[0], -1)
|
198 |
+
weight_masked = self.weight * self.mask
|
199 |
+
res = F.linear(input, weight_masked, self.bias)
|
200 |
+
if self.extra_dim1:
|
201 |
+
res = res.reshape(res.shape + (1,))
|
202 |
+
return res
|
203 |
+
|
204 |
+
|
205 |
+
class SkeletonPool(nn.Module):
|
206 |
+
def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
|
207 |
+
super(SkeletonPool, self).__init__()
|
208 |
+
|
209 |
+
if pooling_mode != "mean":
|
210 |
+
raise Exception("Unimplemented pooling mode in matrix_implementation")
|
211 |
+
|
212 |
+
self.channels_per_edge = channels_per_edge
|
213 |
+
self.pooling_mode = pooling_mode
|
214 |
+
self.edge_num = len(edges)
|
215 |
+
# self.edge_num = len(edges) + 1
|
216 |
+
self.seq_list = []
|
217 |
+
self.pooling_list = []
|
218 |
+
self.new_edges = []
|
219 |
+
degree = [0] * 100 # each element represents the degree of the corresponding joint
|
220 |
+
|
221 |
+
for edge in edges:
|
222 |
+
degree[edge[0]] += 1
|
223 |
+
degree[edge[1]] += 1
|
224 |
+
|
225 |
+
# seq_list contains multiple sub-lists where each sub-list is an edge chain from the joint whose degree > 2 to the end effectors or joints whose degree > 2.
|
226 |
+
def find_seq(j, seq):
|
227 |
+
nonlocal self, degree, edges
|
228 |
+
|
229 |
+
if degree[j] > 2 and j != 0:
|
230 |
+
self.seq_list.append(seq)
|
231 |
+
seq = []
|
232 |
+
|
233 |
+
if degree[j] == 1:
|
234 |
+
self.seq_list.append(seq)
|
235 |
+
return
|
236 |
+
|
237 |
+
for idx, edge in enumerate(edges):
|
238 |
+
if edge[0] == j:
|
239 |
+
find_seq(edge[1], seq + [idx])
|
240 |
+
|
241 |
+
find_seq(0, [])
|
242 |
+
# print(f'self.seq_list: {self.seq_list}')
|
243 |
+
|
244 |
+
for seq in self.seq_list:
|
245 |
+
if last_pool:
|
246 |
+
self.pooling_list.append(seq)
|
247 |
+
continue
|
248 |
+
if len(seq) % 2 == 1:
|
249 |
+
self.pooling_list.append([seq[0]])
|
250 |
+
self.new_edges.append(edges[seq[0]])
|
251 |
+
seq = seq[1:]
|
252 |
+
for i in range(0, len(seq), 2):
|
253 |
+
self.pooling_list.append([seq[i], seq[i + 1]])
|
254 |
+
self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
|
255 |
+
# print(f'self.pooling_list: {self.pooling_list}')
|
256 |
+
# print(f'self.new_egdes: {self.new_edges}')
|
257 |
+
|
258 |
+
# add global position
|
259 |
+
# self.pooling_list.append([self.edge_num - 1])
|
260 |
+
|
261 |
+
self.description = "SkeletonPool(in_edge_num={}, out_edge_num={})".format(len(edges), len(self.pooling_list))
|
262 |
+
|
263 |
+
self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
|
264 |
+
|
265 |
+
for i, pair in enumerate(self.pooling_list):
|
266 |
+
for j in pair:
|
267 |
+
for c in range(channels_per_edge):
|
268 |
+
self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
|
269 |
+
|
270 |
+
self.weight = nn.Parameter(self.weight, requires_grad=False)
|
271 |
+
|
272 |
+
def forward(self, input: torch.Tensor):
|
273 |
+
# print('SkeletonPool')
|
274 |
+
# print(f'input: {input.size()}')
|
275 |
+
# print(f'self.weight: {self.weight.size()}')
|
276 |
+
return torch.matmul(self.weight, input)
|
277 |
+
|
278 |
+
|
279 |
+
class SkeletonUnpool(nn.Module):
|
280 |
+
def __init__(self, pooling_list, channels_per_edge):
|
281 |
+
super(SkeletonUnpool, self).__init__()
|
282 |
+
self.pooling_list = pooling_list
|
283 |
+
self.input_edge_num = len(pooling_list)
|
284 |
+
self.output_edge_num = 0
|
285 |
+
self.channels_per_edge = channels_per_edge
|
286 |
+
for t in self.pooling_list:
|
287 |
+
self.output_edge_num += len(t)
|
288 |
+
|
289 |
+
self.description = "SkeletonUnpool(in_edge_num={}, out_edge_num={})".format(
|
290 |
+
self.input_edge_num,
|
291 |
+
self.output_edge_num,
|
292 |
+
)
|
293 |
+
|
294 |
+
self.weight = torch.zeros(self.output_edge_num * channels_per_edge, self.input_edge_num * channels_per_edge)
|
295 |
+
|
296 |
+
for i, pair in enumerate(self.pooling_list):
|
297 |
+
for j in pair:
|
298 |
+
for c in range(channels_per_edge):
|
299 |
+
self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
|
300 |
+
|
301 |
+
self.weight = nn.Parameter(self.weight)
|
302 |
+
self.weight.requires_grad_(False)
|
303 |
+
|
304 |
+
def forward(self, input: torch.Tensor):
|
305 |
+
# print('SkeletonUnpool')
|
306 |
+
# print(f'input: {input.size()}')
|
307 |
+
# print(f'self.weight: {self.weight.size()}')
|
308 |
+
return torch.matmul(self.weight, input)
|
309 |
+
|
310 |
+
|
311 |
+
"""
|
312 |
+
Helper functions for skeleton operation
|
313 |
+
"""
|
314 |
+
|
315 |
+
|
316 |
+
def dfs(x, fa, vis, dist):
|
317 |
+
vis[x] = 1
|
318 |
+
for y in range(len(fa)):
|
319 |
+
if (fa[y] == x or fa[x] == y) and vis[y] == 0:
|
320 |
+
dist[y] = dist[x] + 1
|
321 |
+
dfs(y, fa, vis, dist)
|
322 |
+
|
323 |
+
|
324 |
+
"""
|
325 |
+
def find_neighbor_joint(fa, threshold):
|
326 |
+
neighbor_list = [[]]
|
327 |
+
for x in range(1, len(fa)):
|
328 |
+
vis = [0 for _ in range(len(fa))]
|
329 |
+
dist = [0 for _ in range(len(fa))]
|
330 |
+
dist[0] = 10000
|
331 |
+
dfs(x, fa, vis, dist)
|
332 |
+
neighbor = []
|
333 |
+
for j in range(1, len(fa)):
|
334 |
+
if dist[j] <= threshold:
|
335 |
+
neighbor.append(j)
|
336 |
+
neighbor_list.append(neighbor)
|
337 |
+
|
338 |
+
neighbor = [0]
|
339 |
+
for i, x in enumerate(neighbor_list):
|
340 |
+
if i == 0: continue
|
341 |
+
if 1 in x:
|
342 |
+
neighbor.append(i)
|
343 |
+
neighbor_list[i] = [0] + neighbor_list[i]
|
344 |
+
neighbor_list[0] = neighbor
|
345 |
+
return neighbor_list
|
346 |
+
|
347 |
+
|
348 |
+
def build_edge_topology(topology, offset):
|
349 |
+
# get all edges (pa, child, offset)
|
350 |
+
edges = []
|
351 |
+
joint_num = len(topology)
|
352 |
+
for i in range(1, joint_num):
|
353 |
+
edges.append((topology[i], i, offset[i]))
|
354 |
+
return edges
|
355 |
+
"""
|
356 |
+
|
357 |
+
|
358 |
+
def build_edge_topology(topology):
|
359 |
+
# get all edges (pa, child)
|
360 |
+
edges = []
|
361 |
+
joint_num = len(topology)
|
362 |
+
edges.append((0, joint_num)) # add an edge between the root joint and a virtual joint
|
363 |
+
for i in range(1, joint_num):
|
364 |
+
edges.append((topology[i], i))
|
365 |
+
return edges
|
366 |
+
|
367 |
+
|
368 |
+
def build_joint_topology(edges, origin_names):
|
369 |
+
parent = []
|
370 |
+
offset = []
|
371 |
+
names = []
|
372 |
+
edge2joint = []
|
373 |
+
joint_from_edge = [] # -1 means virtual joint
|
374 |
+
joint_cnt = 0
|
375 |
+
out_degree = [0] * (len(edges) + 10)
|
376 |
+
for edge in edges:
|
377 |
+
out_degree[edge[0]] += 1
|
378 |
+
|
379 |
+
# add root joint
|
380 |
+
joint_from_edge.append(-1)
|
381 |
+
parent.append(0)
|
382 |
+
offset.append(np.array([0, 0, 0]))
|
383 |
+
names.append(origin_names[0])
|
384 |
+
joint_cnt += 1
|
385 |
+
|
386 |
+
def make_topology(edge_idx, pa):
|
387 |
+
nonlocal edges, parent, offset, names, edge2joint, joint_from_edge, joint_cnt
|
388 |
+
edge = edges[edge_idx]
|
389 |
+
if out_degree[edge[0]] > 1:
|
390 |
+
parent.append(pa)
|
391 |
+
offset.append(np.array([0, 0, 0]))
|
392 |
+
names.append(origin_names[edge[1]] + "_virtual")
|
393 |
+
edge2joint.append(-1)
|
394 |
+
pa = joint_cnt
|
395 |
+
joint_cnt += 1
|
396 |
+
|
397 |
+
parent.append(pa)
|
398 |
+
offset.append(edge[2])
|
399 |
+
names.append(origin_names[edge[1]])
|
400 |
+
edge2joint.append(edge_idx)
|
401 |
+
pa = joint_cnt
|
402 |
+
joint_cnt += 1
|
403 |
+
|
404 |
+
for idx, e in enumerate(edges):
|
405 |
+
if e[0] == edge[1]:
|
406 |
+
make_topology(idx, pa)
|
407 |
+
|
408 |
+
for idx, e in enumerate(edges):
|
409 |
+
if e[0] == 0:
|
410 |
+
make_topology(idx, 0)
|
411 |
+
|
412 |
+
return parent, offset, names, edge2joint
|
413 |
+
|
414 |
+
|
415 |
+
def calc_edge_mat(edges):
|
416 |
+
edge_num = len(edges)
|
417 |
+
# edge_mat[i][j] = distance between edge(i) and edge(j)
|
418 |
+
edge_mat = [[100000] * edge_num for _ in range(edge_num)]
|
419 |
+
for i in range(edge_num):
|
420 |
+
edge_mat[i][i] = 0
|
421 |
+
|
422 |
+
# initialize edge_mat with direct neighbor
|
423 |
+
for i, a in enumerate(edges):
|
424 |
+
for j, b in enumerate(edges):
|
425 |
+
link = 0
|
426 |
+
for x in range(2):
|
427 |
+
for y in range(2):
|
428 |
+
if a[x] == b[y]:
|
429 |
+
link = 1
|
430 |
+
if link:
|
431 |
+
edge_mat[i][j] = 1
|
432 |
+
|
433 |
+
# calculate all the pairs distance
|
434 |
+
for k in range(edge_num):
|
435 |
+
for i in range(edge_num):
|
436 |
+
for j in range(edge_num):
|
437 |
+
edge_mat[i][j] = min(edge_mat[i][j], edge_mat[i][k] + edge_mat[k][j])
|
438 |
+
return edge_mat
|
439 |
+
|
440 |
+
|
441 |
+
def find_neighbor(edges, d):
|
442 |
+
"""
|
443 |
+
Args:
|
444 |
+
edges: The list contains N elements, each element represents (parent, child).
|
445 |
+
d: Distance between edges (the distance of the same edge is 0 and the distance of adjacent edges is 1).
|
446 |
+
|
447 |
+
Returns:
|
448 |
+
The list contains N elements, each element is a list of edge indices whose distance <= d.
|
449 |
+
"""
|
450 |
+
edge_mat = calc_edge_mat(edges)
|
451 |
+
neighbor_list = []
|
452 |
+
edge_num = len(edge_mat)
|
453 |
+
for i in range(edge_num):
|
454 |
+
neighbor = []
|
455 |
+
for j in range(edge_num):
|
456 |
+
if edge_mat[i][j] <= d:
|
457 |
+
neighbor.append(j)
|
458 |
+
neighbor_list.append(neighbor)
|
459 |
+
|
460 |
+
# # add neighbor for global part
|
461 |
+
# global_part_neighbor = neighbor_list[0].copy()
|
462 |
+
# """
|
463 |
+
# Line #373 is buggy. Thanks @crissallan!!
|
464 |
+
# See issue #30 (https://github.com/DeepMotionEditing/deep-motion-editing/issues/30)
|
465 |
+
# However, fixing this bug will make it unable to load the pretrained model and
|
466 |
+
# affect the reproducibility of quantitative error reported in the paper.
|
467 |
+
# It is not a fatal bug so we didn't touch it and we are looking for possible solutions.
|
468 |
+
# """
|
469 |
+
# for i in global_part_neighbor:
|
470 |
+
# neighbor_list[i].append(edge_num)
|
471 |
+
# neighbor_list.append(global_part_neighbor)
|
472 |
+
|
473 |
+
return neighbor_list
|