DiffLinker / src /delinker.py
igashov
DiffLinker code
95ba5bc
raw
history blame
8.69 kB
import csv
import numpy as np
from rdkit import Chem
from rdkit.Chem import MolStandardize
from src import metrics
from src.delinker_utils import sascorer, calc_SC_RDKit
from tqdm import tqdm
from pdb import set_trace
def get_valid_as_in_delinker(data, progress=False):
valid = []
generator = tqdm(enumerate(data), total=len(data)) if progress else enumerate(data)
for i, m in generator:
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=False)
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=False)
frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=False)
pred_mol_frags = Chem.GetMolFrags(pred_mol, asMols=True, sanitizeFrags=False)
pred_mol_filtered = max(pred_mol_frags, default=pred_mol, key=lambda mol: mol.GetNumAtoms())
try:
Chem.SanitizeMol(pred_mol_filtered)
Chem.SanitizeMol(true_mol)
Chem.SanitizeMol(frag)
except:
continue
if len(pred_mol_filtered.GetSubstructMatch(frag)) > 0:
valid.append({
'pred_mol': m['pred_mol'],
'true_mol': m['true_mol'],
'pred_mol_smi': Chem.MolToSmiles(pred_mol_filtered),
'true_mol_smi': Chem.MolToSmiles(true_mol),
'frag_smi': Chem.MolToSmiles(frag)
})
return valid
def extract_linker_smiles(molecule, fragments):
match = molecule.GetSubstructMatch(fragments)
elinker = Chem.EditableMol(molecule)
for atom_id in sorted(match, reverse=True):
elinker.RemoveAtom(atom_id)
linker = elinker.GetMol()
Chem.RemoveStereochemistry(linker)
try:
linker = MolStandardize.canonicalize_tautomer_smiles(Chem.MolToSmiles(linker))
except:
linker = Chem.MolToSmiles(linker)
return linker
def compute_and_add_linker_smiles(data, progress=False):
data_with_linkers = []
generator = tqdm(data) if progress else data
for m in generator:
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
frag = Chem.MolFromSmiles(m['frag_smi'], sanitize=True)
pred_linker = extract_linker_smiles(pred_mol, frag)
true_linker = extract_linker_smiles(true_mol, frag)
data_with_linkers.append({
**m,
'pred_linker': pred_linker,
'true_linker': true_linker,
})
return data_with_linkers
def compute_uniqueness(data, progress=False):
mol_dictionary = {}
generator = tqdm(data) if progress else data
for m in generator:
frag = m['frag_smi']
pred_mol = m['pred_mol_smi']
true_mol = m['true_mol_smi']
key = f'{true_mol}.{frag}'
mol_dictionary.setdefault(key, []).append(pred_mol)
total_mol = 0
unique_mol = 0
for molecules in mol_dictionary.values():
total_mol += len(molecules)
unique_mol += len(set(molecules))
return unique_mol / total_mol
def compute_novelty(data, progress=False):
novel = 0
true_linkers = set([m['true_linker'] for m in data])
generator = tqdm(data) if progress else data
for m in generator:
pred_linker = m['pred_linker']
if pred_linker in true_linkers:
continue
else:
novel += 1
return novel / len(data)
def compute_recovery_rate(data, progress=False):
total = set()
recovered = set()
generator = tqdm(data) if progress else data
for m in generator:
pred_mol = Chem.MolFromSmiles(m['pred_mol_smi'], sanitize=True)
Chem.RemoveStereochemistry(pred_mol)
pred_mol = Chem.MolToSmiles(Chem.RemoveHs(pred_mol))
true_mol = Chem.MolFromSmiles(m['true_mol_smi'], sanitize=True)
Chem.RemoveStereochemistry(true_mol)
true_mol = Chem.MolToSmiles(Chem.RemoveHs(true_mol))
true_link = m['true_linker']
total.add(f'{true_mol}.{true_link}')
if pred_mol == true_mol:
recovered.add(f'{true_mol}.{true_link}')
return len(recovered) / len(total)
def calc_sa_score_mol(mol):
if mol is None:
return None
return sascorer.calculateScore(mol)
def check_ring_filter(linker):
check = True
# Get linker rings
ssr = Chem.GetSymmSSSR(linker)
# Check rings
for ring in ssr:
for atom_idx in ring:
for bond in linker.GetAtomWithIdx(atom_idx).GetBonds():
if bond.GetBondType() == 2 and bond.GetBeginAtomIdx() in ring and bond.GetEndAtomIdx() in ring:
check = False
return check
def check_pains(mol, pains_smarts):
for pain in pains_smarts:
if mol.HasSubstructMatch(pain):
return False
return True
def calc_2d_filters(toks, pains_smarts):
pred_mol = Chem.MolFromSmiles(toks['pred_mol_smi'])
frag = Chem.MolFromSmiles(toks['frag_smi'])
linker = Chem.MolFromSmiles(toks['pred_linker'])
result = [False, False, False]
if len(pred_mol.GetSubstructMatch(frag)) > 0:
sa_score = False
ra_score = False
pains_score = False
try:
sa_score = calc_sa_score_mol(pred_mol) < calc_sa_score_mol(frag)
except Exception as e:
print(f'Could not compute SA score: {e}')
try:
ra_score = check_ring_filter(linker)
except Exception as e:
print(f'Could not compute RA score: {e}')
try:
pains_score = check_pains(pred_mol, pains_smarts)
except Exception as e:
print(f'Could not compute PAINS score: {e}')
result = [sa_score, ra_score, pains_score]
return result
def calc_filters_2d_dataset(data):
with open('models/wehi_pains.csv', 'r') as f:
pains_smarts = [Chem.MolFromSmarts(line[0], mergeHs=True) for line in csv.reader(f)]
pass_all = pass_SA = pass_RA = pass_PAINS = 0
for m in data:
filters_2d = calc_2d_filters(m, pains_smarts)
pass_all += filters_2d[0] & filters_2d[1] & filters_2d[2]
pass_SA += filters_2d[0]
pass_RA += filters_2d[1]
pass_PAINS += filters_2d[2]
return pass_all / len(data), pass_SA / len(data), pass_RA / len(data), pass_PAINS / len(data)
def calc_sc_rdkit_full_mol(gen_mol, ref_mol):
try:
score = calc_SC_RDKit.calc_SC_RDKit_score(gen_mol, ref_mol)
return score
except:
return -0.5
def sc_rdkit_score(data):
scores = []
for m in data:
score = calc_sc_rdkit_full_mol(m['pred_mol'], m['true_mol'])
scores.append(score)
return np.mean(scores)
def get_delinker_metrics(pred_molecules, true_molecules, true_fragments):
default_values = {
'DeLinker/validity': 0,
'DeLinker/uniqueness': 0,
'DeLinker/novelty': 0,
'DeLinker/recovery': 0,
'DeLinker/2D_filters': 0,
'DeLinker/2D_filters_SA': 0,
'DeLinker/2D_filters_RA': 0,
'DeLinker/2D_filters_PAINS': 0,
'DeLinker/SC_RDKit': 0,
}
if len(pred_molecules) == 0:
return default_values
data = []
for pred_mol, true_mol, true_frag in zip(pred_molecules, true_molecules, true_fragments):
data.append({
'pred_mol': pred_mol,
'true_mol': true_mol,
'pred_mol_smi': Chem.MolToSmiles(pred_mol),
'true_mol_smi': Chem.MolToSmiles(true_mol),
'frag_smi': Chem.MolToSmiles(true_frag)
})
# Validity according to DeLinker paper:
# Passing rdkit.Chem.Sanitize and the biggest fragment contains both fragments
valid_data = get_valid_as_in_delinker(data)
validity_as_in_delinker = len(valid_data) / len(data)
if len(valid_data) == 0:
return default_values
# Compute linkers and add to results
valid_data = compute_and_add_linker_smiles(valid_data)
# Compute uniqueness
uniqueness = compute_uniqueness(valid_data)
# Compute novelty
novelty = compute_novelty(valid_data)
# Compute recovered molecules
recovery_rate = compute_recovery_rate(valid_data)
# 2D filters
pass_all, pass_SA, pass_RA, pass_PAINS = calc_filters_2d_dataset(valid_data)
# 3D Filters
sc_rdkit = sc_rdkit_score(valid_data)
return {
'DeLinker/validity': validity_as_in_delinker,
'DeLinker/uniqueness': uniqueness,
'DeLinker/novelty': novelty,
'DeLinker/recovery': recovery_rate,
'DeLinker/2D_filters': pass_all,
'DeLinker/2D_filters_SA': pass_SA,
'DeLinker/2D_filters_RA': pass_RA,
'DeLinker/2D_filters_PAINS': pass_PAINS,
'DeLinker/SC_RDKit': sc_rdkit,
}