liuganghuggingface's picture
Update graph_decoder/molecule_utils.py
4bfe982 verified
from rdkit import Chem, RDLogger
RDLogger.DisableLog("rdApp.*")
import re
import random
import logging
from rdkit import Chem
from typing import List, Tuple, Optional
random.seed(0)
import torch
bond_dict = [
None,
Chem.rdchem.BondType.SINGLE,
Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC,
]
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}
logger = logging.getLogger(__name__)
def check_polymer(smiles):
if "*" in smiles:
monomer = smiles.replace("*", "[H]")
if mol2smiles(get_mol(monomer)) is None:
logger.warning(f"Invalid polymerization point")
return False
else:
return True
return True
def graph_to_smiles(molecule_list: List[Tuple], atom_decoder: list) -> List[Optional[str]]:
smiles_list = []
for index, graph in enumerate(molecule_list):
try:
atom_types, edge_types = graph
mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder)
# Try to correct the molecule with connection=True, then False if needed
for connection in (True, False):
mol_conn, _ = correct_mol(mol_init, connection=connection)
if mol_conn is not None:
break
else:
logger.warning(f"Failed to correct molecule {index}")
mol_conn = mol_init # Fallback to initial molecule
# Convert to SMILES
smiles = mol2smiles(mol_conn)
if not smiles:
logger.warning(f"Failed to convert molecule {index} to SMILES, falling back to RDKit MolToSmiles")
smiles = Chem.MolToSmiles(mol_conn)
if smiles:
mol = get_mol(smiles)
if mol is not None:
# Get the largest fragment
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
largest_mol = max(mol_frags, key=lambda m: m.GetNumAtoms())
largest_smiles = mol2smiles(largest_mol)
if largest_smiles and len(largest_smiles) > 1:
if check_polymer(largest_smiles):
smiles_list.append(largest_smiles)
else:
smiles_list.append(None)
elif check_polymer(smiles):
smiles_list.append(smiles)
else:
smiles_list.append(None)
else:
logger.warning(f"Failed to convert SMILES back to molecule for index {index}")
smiles_list.append(None)
else:
logger.warning(f"Failed to generate SMILES for molecule {index}, appending None")
smiles_list.append(None)
except Exception as e:
logger.error(f"Error processing molecule {index}: {str(e)}")
try:
# Fallback to RDKit's MolToSmiles if everything else fails
fallback_smiles = Chem.MolToSmiles(mol_init)
if fallback_smiles:
smiles_list.append(fallback_smiles)
logger.warning(f"Used RDKit MolToSmiles fallback for molecule {index}")
else:
smiles_list.append(None)
logger.warning(f"RDKit MolToSmiles fallback failed for molecule {index}, appending None")
except Exception as e2:
logger.error(f"All attempts failed for molecule {index}: {str(e2)}")
smiles_list.append(None)
return smiles_list
def build_molecule_with_partial_charges(
atom_types, edge_types, atom_decoder, verbose=False
):
if verbose:
print("\nbuilding new molecule")
mol = Chem.RWMol()
for atom in atom_types:
a = Chem.Atom(atom_decoder[atom.item()])
mol.AddAtom(a)
if verbose:
print("Atom added: ", atom.item(), atom_decoder[atom.item()])
edge_types = torch.triu(edge_types)
all_bonds = torch.nonzero(edge_types)
for i, bond in enumerate(all_bonds):
if bond[0].item() != bond[1].item():
mol.AddBond(
bond[0].item(),
bond[1].item(),
bond_dict[edge_types[bond[0], bond[1]].item()],
)
if verbose:
print(
"bond added:",
bond[0].item(),
bond[1].item(),
edge_types[bond[0], bond[1]].item(),
bond_dict[edge_types[bond[0], bond[1]].item()],
)
# add formal charge to atom: e.g. [O+], [N+], [S+]
# not support [O-], [N-], [S-], [NH+] etc.
flag, atomid_valence = check_valency(mol)
if verbose:
print("flag, valence", flag, atomid_valence)
if flag:
continue
else:
if len(atomid_valence) == 2:
idx = atomid_valence[0]
v = atomid_valence[1]
an = mol.GetAtomWithIdx(idx).GetAtomicNum()
if verbose:
print("atomic num of atom with a large valence", an)
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
mol.GetAtomWithIdx(idx).SetFormalCharge(1)
# print("Formal charge added")
else:
continue
return mol
def correct_mol(mol, connection=False):
#####
no_correct = False
flag, _ = check_valency(mol)
if flag:
no_correct = True
while True:
if connection:
mol_conn = connect_fragments(mol)
mol = mol_conn
if mol is None:
return None, no_correct
flag, atomid_valence = check_valency(mol)
if flag:
break
else:
try:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
queue = []
check_idx = 0
for b in mol.GetAtomWithIdx(idx).GetBonds():
type = int(b.GetBondType())
queue.append(
(b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())
)
if type == 12:
check_idx += 1
queue.sort(key=lambda tup: tup[1], reverse=True)
if queue[-1][1] == 12:
return None, no_correct
elif len(queue) > 0:
start = queue[check_idx][2]
end = queue[check_idx][3]
t = queue[check_idx][1] - 1
mol.RemoveBond(start, end)
if t >= 1:
mol.AddBond(start, end, bond_dict[t])
except Exception as e:
# print(f"An error occurred in correction: {e}")
return None, no_correct
return mol, no_correct
def check_valid(smiles):
mol = get_mol(smiles)
if mol is None:
return False
smiles = mol2smiles(mol)
if smiles is None:
return False
return True
def get_mol(smiles_or_mol):
"""
Loads SMILES/molecule into RDKit's object
"""
if isinstance(smiles_or_mol, str):
if len(smiles_or_mol) == 0:
return None
mol = Chem.MolFromSmiles(smiles_or_mol)
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return mol
return smiles_or_mol
def mol2smiles(mol):
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return Chem.MolToSmiles(mol)
def check_valency(mol):
try:
# First attempt to sanitize with specific properties
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True, None
except ValueError as e:
e = str(e)
p = e.find("#")
e_sub = e[p:]
atomid_valence = list(map(int, re.findall(r"\d+", e_sub)))
return False, atomid_valence
except Exception as e:
# print(f"An unexpected error occurred: {e}")
return False, []
##### connect fragements
def select_atom_with_available_valency(frag):
atoms = list(frag.GetAtoms())
random.shuffle(atoms)
for atom in atoms:
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0:
return atom
return None
def select_atoms_with_available_valency(frag):
return [
atom
for atom in frag.GetAtoms()
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0
]
def try_to_connect_fragments(combined_mol, frag, atom1, atom2):
# Make copies of the molecules to try the connection
trial_combined_mol = Chem.RWMol(combined_mol)
trial_frag = Chem.RWMol(frag)
# Add the new fragment to the combined molecule with new indices
new_indices = {
atom.GetIdx(): trial_combined_mol.AddAtom(atom)
for atom in trial_frag.GetAtoms()
}
# Add the bond between the suitable atoms from each fragment
trial_combined_mol.AddBond(
atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE
)
# Adjust the hydrogen count of the connected atoms
for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]:
atom = trial_combined_mol.GetAtomWithIdx(atom_idx)
num_h = atom.GetTotalNumHs()
atom.SetNumExplicitHs(max(0, num_h - 1))
# Add bonds for the new fragment
for bond in trial_frag.GetBonds():
trial_combined_mol.AddBond(
new_indices[bond.GetBeginAtomIdx()],
new_indices[bond.GetEndAtomIdx()],
bond.GetBondType(),
)
# Convert to a Mol object and try to sanitize it
new_mol = Chem.Mol(trial_combined_mol)
try:
Chem.SanitizeMol(new_mol)
return new_mol # Return the new valid molecule
except Chem.MolSanitizeException:
return None # If the molecule is not valid, return None
def connect_fragments(mol):
# Get the separate fragments
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
if len(frags) < 2:
return mol
combined_mol = Chem.RWMol(frags[0])
for frag in frags[1:]:
# Select all atoms with available valency from both molecules
atoms1 = select_atoms_with_available_valency(combined_mol)
atoms2 = select_atoms_with_available_valency(frag)
# Try to connect using all combinations of available valency atoms
for atom1 in atoms1:
for atom2 in atoms2:
new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2)
if new_mol is not None:
# If a valid connection is made, update the combined molecule and break
combined_mol = new_mol
break
else:
# Continue if the inner loop didn't break (no valid connection found for atom1)
continue
# Break if the inner loop did break (valid connection found)
break
else:
# If no valid connections could be made with any of the atoms, return None
return None
return combined_mol
#### connect fragements