|
from rdkit import Chem, RDLogger |
|
|
|
RDLogger.DisableLog("rdApp.*") |
|
|
|
import re |
|
import random |
|
import logging |
|
from rdkit import Chem |
|
from typing import List, Tuple, Optional |
|
random.seed(0) |
|
import torch |
|
|
|
bond_dict = [ |
|
None, |
|
Chem.rdchem.BondType.SINGLE, |
|
Chem.rdchem.BondType.DOUBLE, |
|
Chem.rdchem.BondType.TRIPLE, |
|
Chem.rdchem.BondType.AROMATIC, |
|
] |
|
|
|
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def check_polymer(smiles): |
|
if "*" in smiles: |
|
monomer = smiles.replace("*", "[H]") |
|
if mol2smiles(get_mol(monomer)) is None: |
|
logger.warning(f"Invalid polymerization point") |
|
return False |
|
else: |
|
return True |
|
return True |
|
|
|
def graph_to_smiles(molecule_list: List[Tuple], atom_decoder: list) -> List[Optional[str]]: |
|
|
|
smiles_list = [] |
|
for index, graph in enumerate(molecule_list): |
|
try: |
|
atom_types, edge_types = graph |
|
mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) |
|
|
|
|
|
for connection in (True, False): |
|
mol_conn, _ = correct_mol(mol_init, connection=connection) |
|
if mol_conn is not None: |
|
break |
|
else: |
|
logger.warning(f"Failed to correct molecule {index}") |
|
mol_conn = mol_init |
|
|
|
|
|
smiles = mol2smiles(mol_conn) |
|
if not smiles: |
|
logger.warning(f"Failed to convert molecule {index} to SMILES, falling back to RDKit MolToSmiles") |
|
smiles = Chem.MolToSmiles(mol_conn) |
|
|
|
if smiles: |
|
mol = get_mol(smiles) |
|
if mol is not None: |
|
|
|
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False) |
|
largest_mol = max(mol_frags, key=lambda m: m.GetNumAtoms()) |
|
|
|
largest_smiles = mol2smiles(largest_mol) |
|
if largest_smiles and len(largest_smiles) > 1: |
|
if check_polymer(largest_smiles): |
|
smiles_list.append(largest_smiles) |
|
else: |
|
smiles_list.append(None) |
|
elif check_polymer(smiles): |
|
smiles_list.append(smiles) |
|
else: |
|
smiles_list.append(None) |
|
else: |
|
logger.warning(f"Failed to convert SMILES back to molecule for index {index}") |
|
smiles_list.append(None) |
|
else: |
|
logger.warning(f"Failed to generate SMILES for molecule {index}, appending None") |
|
smiles_list.append(None) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing molecule {index}: {str(e)}") |
|
try: |
|
|
|
fallback_smiles = Chem.MolToSmiles(mol_init) |
|
if fallback_smiles: |
|
smiles_list.append(fallback_smiles) |
|
logger.warning(f"Used RDKit MolToSmiles fallback for molecule {index}") |
|
else: |
|
smiles_list.append(None) |
|
logger.warning(f"RDKit MolToSmiles fallback failed for molecule {index}, appending None") |
|
except Exception as e2: |
|
logger.error(f"All attempts failed for molecule {index}: {str(e2)}") |
|
smiles_list.append(None) |
|
|
|
return smiles_list |
|
|
|
def build_molecule_with_partial_charges( |
|
atom_types, edge_types, atom_decoder, verbose=False |
|
): |
|
if verbose: |
|
print("\nbuilding new molecule") |
|
|
|
mol = Chem.RWMol() |
|
for atom in atom_types: |
|
a = Chem.Atom(atom_decoder[atom.item()]) |
|
mol.AddAtom(a) |
|
if verbose: |
|
print("Atom added: ", atom.item(), atom_decoder[atom.item()]) |
|
|
|
edge_types = torch.triu(edge_types) |
|
all_bonds = torch.nonzero(edge_types) |
|
|
|
for i, bond in enumerate(all_bonds): |
|
if bond[0].item() != bond[1].item(): |
|
mol.AddBond( |
|
bond[0].item(), |
|
bond[1].item(), |
|
bond_dict[edge_types[bond[0], bond[1]].item()], |
|
) |
|
if verbose: |
|
print( |
|
"bond added:", |
|
bond[0].item(), |
|
bond[1].item(), |
|
edge_types[bond[0], bond[1]].item(), |
|
bond_dict[edge_types[bond[0], bond[1]].item()], |
|
) |
|
|
|
|
|
flag, atomid_valence = check_valency(mol) |
|
if verbose: |
|
print("flag, valence", flag, atomid_valence) |
|
if flag: |
|
continue |
|
else: |
|
if len(atomid_valence) == 2: |
|
idx = atomid_valence[0] |
|
v = atomid_valence[1] |
|
an = mol.GetAtomWithIdx(idx).GetAtomicNum() |
|
if verbose: |
|
print("atomic num of atom with a large valence", an) |
|
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: |
|
mol.GetAtomWithIdx(idx).SetFormalCharge(1) |
|
|
|
else: |
|
continue |
|
return mol |
|
|
|
|
|
def correct_mol(mol, connection=False): |
|
|
|
no_correct = False |
|
flag, _ = check_valency(mol) |
|
if flag: |
|
no_correct = True |
|
|
|
while True: |
|
if connection: |
|
mol_conn = connect_fragments(mol) |
|
mol = mol_conn |
|
if mol is None: |
|
return None, no_correct |
|
flag, atomid_valence = check_valency(mol) |
|
if flag: |
|
break |
|
else: |
|
try: |
|
assert len(atomid_valence) == 2 |
|
idx = atomid_valence[0] |
|
v = atomid_valence[1] |
|
queue = [] |
|
check_idx = 0 |
|
for b in mol.GetAtomWithIdx(idx).GetBonds(): |
|
type = int(b.GetBondType()) |
|
queue.append( |
|
(b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()) |
|
) |
|
if type == 12: |
|
check_idx += 1 |
|
queue.sort(key=lambda tup: tup[1], reverse=True) |
|
|
|
if queue[-1][1] == 12: |
|
return None, no_correct |
|
elif len(queue) > 0: |
|
start = queue[check_idx][2] |
|
end = queue[check_idx][3] |
|
t = queue[check_idx][1] - 1 |
|
mol.RemoveBond(start, end) |
|
if t >= 1: |
|
mol.AddBond(start, end, bond_dict[t]) |
|
except Exception as e: |
|
|
|
return None, no_correct |
|
return mol, no_correct |
|
|
|
def check_valid(smiles): |
|
mol = get_mol(smiles) |
|
if mol is None: |
|
return False |
|
smiles = mol2smiles(mol) |
|
if smiles is None: |
|
return False |
|
return True |
|
|
|
def get_mol(smiles_or_mol): |
|
""" |
|
Loads SMILES/molecule into RDKit's object |
|
""" |
|
if isinstance(smiles_or_mol, str): |
|
if len(smiles_or_mol) == 0: |
|
return None |
|
mol = Chem.MolFromSmiles(smiles_or_mol) |
|
if mol is None: |
|
return None |
|
try: |
|
Chem.SanitizeMol(mol) |
|
except ValueError: |
|
return None |
|
return mol |
|
return smiles_or_mol |
|
|
|
|
|
def mol2smiles(mol): |
|
if mol is None: |
|
return None |
|
try: |
|
Chem.SanitizeMol(mol) |
|
except ValueError: |
|
return None |
|
return Chem.MolToSmiles(mol) |
|
|
|
|
|
def check_valency(mol): |
|
try: |
|
|
|
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) |
|
return True, None |
|
except ValueError as e: |
|
e = str(e) |
|
p = e.find("#") |
|
e_sub = e[p:] |
|
atomid_valence = list(map(int, re.findall(r"\d+", e_sub))) |
|
return False, atomid_valence |
|
except Exception as e: |
|
|
|
return False, [] |
|
|
|
|
|
|
|
def select_atom_with_available_valency(frag): |
|
atoms = list(frag.GetAtoms()) |
|
random.shuffle(atoms) |
|
for atom in atoms: |
|
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0: |
|
return atom |
|
return None |
|
|
|
|
|
def select_atoms_with_available_valency(frag): |
|
return [ |
|
atom |
|
for atom in frag.GetAtoms() |
|
if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0 |
|
] |
|
|
|
|
|
def try_to_connect_fragments(combined_mol, frag, atom1, atom2): |
|
|
|
trial_combined_mol = Chem.RWMol(combined_mol) |
|
trial_frag = Chem.RWMol(frag) |
|
|
|
|
|
new_indices = { |
|
atom.GetIdx(): trial_combined_mol.AddAtom(atom) |
|
for atom in trial_frag.GetAtoms() |
|
} |
|
|
|
|
|
trial_combined_mol.AddBond( |
|
atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE |
|
) |
|
|
|
|
|
for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]: |
|
atom = trial_combined_mol.GetAtomWithIdx(atom_idx) |
|
num_h = atom.GetTotalNumHs() |
|
atom.SetNumExplicitHs(max(0, num_h - 1)) |
|
|
|
|
|
for bond in trial_frag.GetBonds(): |
|
trial_combined_mol.AddBond( |
|
new_indices[bond.GetBeginAtomIdx()], |
|
new_indices[bond.GetEndAtomIdx()], |
|
bond.GetBondType(), |
|
) |
|
|
|
|
|
new_mol = Chem.Mol(trial_combined_mol) |
|
try: |
|
Chem.SanitizeMol(new_mol) |
|
return new_mol |
|
except Chem.MolSanitizeException: |
|
return None |
|
|
|
|
|
def connect_fragments(mol): |
|
|
|
frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) |
|
if len(frags) < 2: |
|
return mol |
|
|
|
combined_mol = Chem.RWMol(frags[0]) |
|
|
|
for frag in frags[1:]: |
|
|
|
atoms1 = select_atoms_with_available_valency(combined_mol) |
|
atoms2 = select_atoms_with_available_valency(frag) |
|
|
|
|
|
for atom1 in atoms1: |
|
for atom2 in atoms2: |
|
new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2) |
|
if new_mol is not None: |
|
|
|
combined_mol = new_mol |
|
break |
|
else: |
|
|
|
continue |
|
|
|
break |
|
else: |
|
|
|
return None |
|
|
|
return combined_mol |
|
|
|
|
|
|
|
|