Spaces:
Sleeping
Sleeping
igashov
commited on
Commit
•
c438a2a
1
Parent(s):
bec2844
update COM
Browse files
app.py
CHANGED
@@ -160,6 +160,12 @@ def generate(input_file, n_steps):
|
|
160 |
print('Generated linker')
|
161 |
x = chain[0][:, :, :ddpm.n_dims]
|
162 |
h = chain[0][:, :, ddpm.n_dims:]
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
|
164 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
165 |
print('Saved XYZ files')
|
|
|
160 |
print('Generated linker')
|
161 |
x = chain[0][:, :, :ddpm.n_dims]
|
162 |
h = chain[0][:, :, ddpm.n_dims:]
|
163 |
+
|
164 |
+
pos_masked = data['positions'] * data['fragment_mask']
|
165 |
+
N = data['fragment_mask'].sum(1, keepdims=True)
|
166 |
+
mean = torch.sum(pos_masked, dim=1, keepdim=True) / N
|
167 |
+
x = x + mean * node_mask
|
168 |
+
|
169 |
names = [f'output_{i+1}_{name}' for i in range(N_SAMPLES)]
|
170 |
save_xyz_file('results', h, x, node_mask, names=names, is_geom=True, suffix='')
|
171 |
print('Saved XYZ files')
|