|
import os |
|
from rdkit import Chem |
|
from rdkit.Chem import Draw, AllChem |
|
from rdkit.Geometry import Point3D |
|
from rdkit import RDLogger |
|
import numpy as np |
|
import rdkit.Chem |
|
|
|
class MolecularVisualization: |
|
def __init__(self, atom_decoder): |
|
self.atom_decoder = atom_decoder |
|
|
|
def mol_from_graphs(self, node_list, adjacency_matrix): |
|
""" |
|
Convert graphs to rdkit molecules |
|
node_list: the nodes of a batch of nodes (bs x n) |
|
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) |
|
""" |
|
|
|
atom_decoder = self.atom_decoder |
|
|
|
|
|
mol = Chem.RWMol() |
|
|
|
|
|
node_to_idx = {} |
|
for i in range(len(node_list)): |
|
if node_list[i] == -1: |
|
continue |
|
a = Chem.Atom(atom_decoder[int(node_list[i])]) |
|
molIdx = mol.AddAtom(a) |
|
node_to_idx[i] = molIdx |
|
|
|
for ix, row in enumerate(adjacency_matrix): |
|
for iy, bond in enumerate(row): |
|
|
|
if iy <= ix: |
|
continue |
|
if bond == 1: |
|
bond_type = Chem.rdchem.BondType.SINGLE |
|
elif bond == 2: |
|
bond_type = Chem.rdchem.BondType.DOUBLE |
|
elif bond == 3: |
|
bond_type = Chem.rdchem.BondType.TRIPLE |
|
elif bond == 4: |
|
bond_type = Chem.rdchem.BondType.AROMATIC |
|
else: |
|
continue |
|
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) |
|
|
|
try: |
|
mol = mol.GetMol() |
|
except rdkit.Chem.KekulizeException: |
|
print("Can't kekulize molecule") |
|
mol = None |
|
return mol |
|
|
|
def visualize_chain(self, nodes_list, adjacency_matrix): |
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] |
|
|
|
|
|
final_molecule = mols[-1] |
|
AllChem.Compute2DCoords(final_molecule) |
|
|
|
coords = [] |
|
for i, atom in enumerate(final_molecule.GetAtoms()): |
|
positions = final_molecule.GetConformer().GetAtomPosition(i) |
|
coords.append((positions.x, positions.y, positions.z)) |
|
|
|
|
|
for i, mol in enumerate(mols): |
|
AllChem.Compute2DCoords(mol) |
|
conf = mol.GetConformer() |
|
for j, atom in enumerate(mol.GetAtoms()): |
|
x, y, z = coords[j] |
|
conf.SetAtomPosition(j, Point3D(x, y, z)) |
|
|
|
|
|
mol_images = [] |
|
for frame, mol in enumerate(mols): |
|
img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}") |
|
mol_images.append(img) |
|
|
|
return mol_images |