File size: 7,691 Bytes
514d010 aec17f8 514d010 a785154 dd64b2b 514d010 1ca5378 000e1c0 227b864 366aa69 227b864 514d010 227b864 e6f430b 03d5987 e6f430b 1685117 227b864 1685117 227b864 1685117 227b864 514d010 227b864 8d62ec9 227b864 8d62ec9 514d010 227b864 514d010 227b864 8d62ec9 227b864 be6385e 227b864 8d62ec9 227b864 fd8953a 227b864 8d62ec9 3ef0664 227b864 a8ee13b 227b864 514d010 227b864 514d010 4d0c7ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import Descriptors
from rdkit.Chem import Draw
import selfies as sf
from rdkit.Chem import RDConfig
import os
import sys
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
def get_largest_ring_size(mol):
cycle_list = mol.GetRingInfo().AtomRings()
if cycle_list:
cycle_length = max([len(j) for j in cycle_list])
else:
cycle_length = 0
return cycle_length
def plogp(smile):
if smile:
mol = Chem.MolFromSmiles(smile)
if mol:
log_p = Descriptors.MolLogP(mol)
sas_score = sascorer.calculateScore(mol)
largest_ring_size = get_largest_ring_size(mol)
cycle_score = max(largest_ring_size - 6, 0)
if log_p and sas_score and largest_ring_size:
p_logp = log_p - sas_score - cycle_score
return p_logp
else:
return -100
else:
return -100
else:
return -100
def sf_decode(selfies):
try:
decode = sf.decoder(selfies)
return decode
except sf.DecoderError:
return ''
def sim(input_smile, output_smile):
if input_smile and output_smile:
input_mol = Chem.MolFromSmiles(input_smile)
output_mol = Chem.MolFromSmiles(output_smile)
if input_mol and output_mol:
input_fp = AllChem.GetMorganFingerprint(input_mol, 2)
output_fp = AllChem.GetMorganFingerprint(output_mol, 2)
sim = DataStructs.TanimotoSimilarity(input_fp, output_fp)
return sim
else: return None
else: return None
def gen_process(gen_input):
tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large")
model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large")
sf_input = tokenizer(gen_input, return_tensors="pt")
# beam search
molecules = model.generate(input_ids=sf_input["input_ids"],
attention_mask=sf_input["attention_mask"],
max_length=15,
min_length=5,
num_return_sequences=4,
num_beams=5)
gen_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
smis = [sf.decoder(i) for i in gen_output]
mols = []
for smi in smis:
mol = Chem.MolFromSmiles(smi)
mols.append(mol)
gen_output_image = Draw.MolsToGridImage(
mols,
molsPerRow=4,
subImgSize=(200,200),
legends=['' for x in mols]
)
return "\n".join(gen_output), gen_output_image
def opt_process(opt_input):
tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large-opt")
model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large-opt")
input = opt_input
smis_input = sf.decoder(input)
mol_input = []
mol = Chem.MolFromSmiles(smis_input)
mol_input.append(mol)
opt_input_img = Draw.MolsToGridImage(
mol_input,
molsPerRow=4,
subImgSize=(200,200),
legends=['' for x in mol_input]
)
sf_input = tokenizer(input, return_tensors="pt")
molecules = model.generate(
input_ids=sf_input["input_ids"],
attention_mask=sf_input["attention_mask"],
do_sample=True,
max_length=100,
min_length=5,
top_k=30,
top_p=1,
num_return_sequences=10
)
sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
sf_output = list(set(sf_output))
input_sm = sf_decode(input)
sm_output = [sf_decode(sf) for sf in sf_output]
input_plogp = plogp(input_sm)
plogp_improve = [plogp(i)-input_plogp for i in sm_output]
simm = [sim(i,input_sm) for i in sm_output]
candidate_selfies = {"candidates": sf_output, "improvement": plogp_improve, "sim": simm}
data = pd.DataFrame(candidate_selfies)
results = data[(data['improvement']> 0) & (data['sim']>0.4)]
opt_output = results["candidates"].tolist()
opt_output_imp = results["improvement"].tolist()
opt_output_imp = [str(i) for i in opt_output_imp]
opt_output_sim = results["sim"].tolist()
opt_output_sim = [str(i) for i in opt_output_sim]
smis = [sf.decoder(i) for i in opt_output]
mols = []
for smi in smis:
mol = Chem.MolFromSmiles(smi)
mols.append(mol)
opt_output_img = Draw.MolsToGridImage(
mols,
molsPerRow=4,
subImgSize=(200,200),
legends=['' for x in mols]
)
return opt_input_img, "\n".join(opt_output), "\n".join(opt_output_imp), "\n".join(opt_output_sim), opt_output_img
with gr.Blocks() as demo:
gr.Markdown("# MolGen: Domain-Agnostic Molecular Generation with Self-feedback")
with gr.Tabs():
with gr.TabItem("Molecular Generation"):
with gr.Row():
with gr.Column():
gen_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
gen_button = gr.Button("Generate")
with gr.Column():
gen_output = gr.Textbox(label="Generation Results", lines=5, placeholder="")
gen_output_image = gr.Image(label="Visualization")
gr.Examples(
examples=[["[C][=C][C][=C][C][=C][Ring1][=Branch1]"],
["[C]"]
],
inputs=[gen_input],
outputs=[gen_output, gen_output_image],
fn=gen_process,
cache_examples=True,
)
with gr.TabItem("Constrained Molecular Property Optimization"):
with gr.Row():
with gr.Column():
opt_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
opt_button = gr.Button("Optimize")
with gr.Column():
opt_input_img = gr.Image(label="Input Visualization")
opt_output = gr.Textbox(label="Optimization Results", lines=3, placeholder="")
opt_output_imp = gr.Textbox(label="Optimization Property Improvements", lines=3, placeholder="")
opt_output_sim = gr.Textbox(label="Similarity", lines=3, placeholder="")
opt_output_img = gr.Image(label="Output Visualization")
gr.Examples(
examples=[["[C][C][=Branch1][C][=O][N][C][C][O][C][C][O][C][C][O][C][C][Ring1][N]"],
["[C][C][S][C][C][S][C][C][C][S][C][C][S][C][Ring1][=C]"],
["[N][#C][C][C][C@@H1][C][C][C][C][C][C][C][C][C][C][C][Ring1][N][=O]"]
],
inputs=[opt_input],
outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img],
fn=opt_process,
cache_examples=True,
)
gen_button.click(fn=gen_process, inputs=[gen_input], outputs=[gen_output, gen_output_image])
opt_button.click(fn=opt_process, inputs=[opt_input], outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img])
demo.launch()
|