Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import imageio | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import glob | |
import random | |
from sklearn.decomposition import PCA | |
from src import const | |
from src.molecule_builder import get_bond_order | |
def save_xyz_file(path, one_hot, positions, node_mask, names, is_geom, suffix=''): | |
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
for batch_i in range(one_hot.size(0)): | |
mask = node_mask[batch_i].squeeze() | |
n_atoms = mask.sum() | |
atom_idx = torch.where(mask)[0] | |
f = open(os.path.join(path, f'{names[batch_i]}_{suffix}.xyz'), "w") | |
f.write("%d\n\n" % n_atoms) | |
atoms = torch.argmax(one_hot[batch_i], dim=1) | |
for atom_i in atom_idx: | |
atom = atoms[atom_i].item() | |
atom = idx2atom[atom] | |
f.write("%s %.9f %.9f %.9f\n" % ( | |
atom, positions[batch_i, atom_i, 0], positions[batch_i, atom_i, 1], positions[batch_i, atom_i, 2] | |
)) | |
f.close() | |
def load_xyz_files(path, suffix=''): | |
files = [] | |
for fname in os.listdir(path): | |
if fname.endswith(f'_{suffix}.xyz'): | |
files.append(fname) | |
files = sorted(files, key=lambda f: -int(f.replace(f'_{suffix}.xyz', '').split('_')[-1])) | |
return [os.path.join(path, fname) for fname in files] | |
def load_molecule_xyz(file, is_geom): | |
atom2idx = const.GEOM_ATOM2IDX if is_geom else const.ATOM2IDX | |
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
with open(file, encoding='utf8') as f: | |
n_atoms = int(f.readline()) | |
one_hot = torch.zeros(n_atoms, len(idx2atom)) | |
charges = torch.zeros(n_atoms, 1) | |
positions = torch.zeros(n_atoms, 3) | |
f.readline() | |
atoms = f.readlines() | |
for i in range(n_atoms): | |
atom = atoms[i].split(' ') | |
atom_type = atom[0] | |
one_hot[i, atom2idx[atom_type]] = 1 | |
position = torch.Tensor([float(e) for e in atom[1:]]) | |
positions[i, :] = position | |
return positions, one_hot, charges | |
def draw_sphere(ax, x, y, z, size, color, alpha): | |
u = np.linspace(0, 2 * np.pi, 100) | |
v = np.linspace(0, np.pi, 100) | |
xs = size * np.outer(np.cos(u), np.sin(v)) | |
ys = size * np.outer(np.sin(u), np.sin(v)) #* 0.8 | |
zs = size * np.outer(np.ones(np.size(u)), np.cos(v)) | |
ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, alpha=alpha) | |
def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom, fragment_mask=None): | |
x = positions[:, 0] | |
y = positions[:, 1] | |
z = positions[:, 2] | |
# Hydrogen, Carbon, Nitrogen, Oxygen, Flourine | |
idx2atom = const.GEOM_IDX2ATOM if is_geom else const.IDX2ATOM | |
colors_dic = np.array(const.COLORS) | |
radius_dic = np.array(const.RADII) | |
area_dic = 1500 * radius_dic ** 2 | |
areas = area_dic[atom_type] | |
radii = radius_dic[atom_type] | |
colors = colors_dic[atom_type] | |
if fragment_mask is None: | |
fragment_mask = torch.ones(len(x)) | |
for i in range(len(x)): | |
for j in range(i + 1, len(x)): | |
p1 = np.array([x[i], y[i], z[i]]) | |
p2 = np.array([x[j], y[j], z[j]]) | |
dist = np.sqrt(np.sum((p1 - p2) ** 2)) | |
atom1, atom2 = idx2atom[atom_type[i]], idx2atom[atom_type[j]] | |
draw_edge_int = get_bond_order(atom1, atom2, dist) | |
line_width = (3 - 2) * 2 * 2 | |
draw_edge = draw_edge_int > 0 | |
if draw_edge: | |
if draw_edge_int == 4: | |
linewidth_factor = 1.5 | |
else: | |
linewidth_factor = 1 | |
linewidth_factor *= 0.5 | |
ax.plot( | |
[x[i], x[j]], [y[i], y[j]], [z[i], z[j]], | |
linewidth=line_width * linewidth_factor * 2, | |
c=hex_bg_color, | |
alpha=alpha | |
) | |
# from pdb import set_trace | |
# set_trace() | |
if spheres_3d: | |
# idx = torch.where(fragment_mask[:len(x)] == 0)[0] | |
# ax.scatter( | |
# x[idx], | |
# y[idx], | |
# z[idx], | |
# alpha=0.9 * alpha, | |
# edgecolors='#FCBA03', | |
# facecolors='none', | |
# linewidths=2, | |
# s=900 | |
# ) | |
for i, j, k, s, c, f in zip(x, y, z, radii, colors, fragment_mask): | |
if f == 1: | |
alpha = 1.0 | |
draw_sphere(ax, i.item(), j.item(), k.item(), 0.5 * s, c, alpha) | |
else: | |
ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors) | |
def plot_data3d(positions, atom_type, is_geom, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False, | |
bg='black', alpha=1., fragment_mask=None): | |
black = (0, 0, 0) | |
white = (1, 1, 1) | |
hex_bg_color = '#FFFFFF' if bg == 'black' else '#000000' #'#666666' | |
fig = plt.figure(figsize=(10, 10)) | |
ax = fig.add_subplot(projection='3d') | |
ax.set_aspect('auto') | |
ax.view_init(elev=camera_elev, azim=camera_azim) | |
if bg == 'black': | |
ax.set_facecolor(black) | |
else: | |
ax.set_facecolor(white) | |
ax.xaxis.pane.set_alpha(0) | |
ax.yaxis.pane.set_alpha(0) | |
ax.zaxis.pane.set_alpha(0) | |
ax._axis3don = False | |
if bg == 'black': | |
ax.w_xaxis.line.set_color("black") | |
else: | |
ax.w_xaxis.line.set_color("white") | |
plot_molecule( | |
ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, is_geom=is_geom, fragment_mask=fragment_mask | |
) | |
max_value = positions.abs().max().item() | |
axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) | |
ax.set_xlim(-axis_lim, axis_lim) | |
ax.set_ylim(-axis_lim, axis_lim) | |
ax.set_zlim(-axis_lim, axis_lim) | |
dpi = 120 if spheres_3d else 50 | |
if save_path is not None: | |
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) | |
# plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi, transparent=True) | |
if spheres_3d: | |
img = imageio.imread(save_path) | |
img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') | |
imageio.imsave(save_path, img_brighter) | |
else: | |
plt.show() | |
plt.close() | |
def visualize_chain( | |
path, spheres_3d=False, bg="black", alpha=1.0, wandb=None, mode="chain", is_geom=False, fragment_mask=None | |
): | |
files = load_xyz_files(path) | |
save_paths = [] | |
# Fit PCA to the final molecule β to obtain the best orientation for visualization | |
positions, one_hot, charges = load_molecule_xyz(files[-1], is_geom=is_geom) | |
pca = PCA(n_components=3) | |
pca.fit(positions) | |
for i in range(len(files)): | |
file = files[i] | |
positions, one_hot, charges = load_molecule_xyz(file, is_geom=is_geom) | |
atom_type = torch.argmax(one_hot, dim=1).numpy() | |
# Transform positions of each frame according to the best orientation of the last frame | |
positions = pca.transform(positions) | |
positions = torch.tensor(positions) | |
fn = file[:-4] + '.png' | |
plot_data3d( | |
positions, atom_type, | |
save_path=fn, | |
spheres_3d=spheres_3d, | |
alpha=alpha, | |
bg=bg, | |
camera_elev=90, | |
camera_azim=90, | |
is_geom=is_geom, | |
fragment_mask=fragment_mask, | |
) | |
save_paths.append(fn) | |
imgs = [imageio.imread(fn) for fn in save_paths] | |
dirname = os.path.dirname(save_paths[0]) | |
gif_path = dirname + '/output.gif' | |
imageio.mimsave(gif_path, imgs, subrectangles=True) | |
if wandb is not None: | |
wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) | |