File size: 7,691 Bytes
514d010
 
 
 
 
 
 
aec17f8
514d010
 
 
 
 
 
 
a785154
dd64b2b
514d010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ca5378
000e1c0
 
227b864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366aa69
227b864
514d010
 
 
 
227b864
 
e6f430b
03d5987
e6f430b
1685117
227b864
 
1685117
227b864
 
1685117
227b864
514d010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227b864
 
 
8d62ec9
227b864
8d62ec9
 
514d010
227b864
 
 
 
 
514d010
227b864
 
 
 
 
 
8d62ec9
227b864
 
be6385e
227b864
 
 
 
 
 
 
 
 
 
8d62ec9
227b864
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd8953a
227b864
 
 
8d62ec9
3ef0664
227b864
 
a8ee13b
227b864
 
 
 
 
 
 
 
514d010
227b864
 
514d010
4d0c7ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem import Descriptors
from rdkit.Chem import Draw
import selfies as sf
from rdkit.Chem import RDConfig
import os
import sys
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer



def get_largest_ring_size(mol):
    cycle_list = mol.GetRingInfo().AtomRings()
    if cycle_list:
        cycle_length = max([len(j) for j in cycle_list])
    else:
        cycle_length = 0
    return cycle_length

def plogp(smile):
    if smile:
        mol = Chem.MolFromSmiles(smile)
        if mol:
            log_p = Descriptors.MolLogP(mol)
            sas_score = sascorer.calculateScore(mol)
            largest_ring_size = get_largest_ring_size(mol)
            cycle_score = max(largest_ring_size - 6, 0)
            if log_p and sas_score and largest_ring_size:
                p_logp = log_p - sas_score - cycle_score
                return p_logp
            else: 
                return -100
        else:
            return -100
    else:
        return -100
    
def sf_decode(selfies):
    try:
        decode = sf.decoder(selfies)
        return decode
    except sf.DecoderError:
        return ''
    
def sim(input_smile, output_smile):
    if input_smile and output_smile:
        input_mol = Chem.MolFromSmiles(input_smile)
        output_mol = Chem.MolFromSmiles(output_smile)
        if input_mol and output_mol:
            input_fp = AllChem.GetMorganFingerprint(input_mol, 2)
            output_fp = AllChem.GetMorganFingerprint(output_mol, 2)
            sim = DataStructs.TanimotoSimilarity(input_fp, output_fp)
            return sim
        else: return None
    else: return None 


def gen_process(gen_input):
    tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large")
    model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large")
    
    sf_input = tokenizer(gen_input, return_tensors="pt")
    
    # beam search
    molecules = model.generate(input_ids=sf_input["input_ids"],
                              attention_mask=sf_input["attention_mask"],
                              max_length=15,
                              min_length=5,
                              num_return_sequences=4,
                              num_beams=5)

    gen_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
    smis = [sf.decoder(i) for i in gen_output]
    mols = []
    for smi in smis:
        mol = Chem.MolFromSmiles(smi)
        mols.append(mol)
    
    gen_output_image = Draw.MolsToGridImage(
        mols,
        molsPerRow=4,
        subImgSize=(200,200),
        legends=['' for x in mols]
    )
    
    return "\n".join(gen_output), gen_output_image
 
def opt_process(opt_input):

    tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large-opt")
    model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large-opt")
    
    input = opt_input

    smis_input = sf.decoder(input)
    mol_input = []
    mol = Chem.MolFromSmiles(smis_input)
    mol_input.append(mol)
    
    opt_input_img = Draw.MolsToGridImage(
        mol_input,
        molsPerRow=4,
        subImgSize=(200,200),
        legends=['' for x in mol_input]
    )
    
    sf_input = tokenizer(input, return_tensors="pt")
    molecules = model.generate(
                    input_ids=sf_input["input_ids"],
                    attention_mask=sf_input["attention_mask"],
                    do_sample=True,
                    max_length=100,
                    min_length=5,
                    top_k=30,
                    top_p=1,
                    num_return_sequences=10
                    )
    sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules]
    sf_output = list(set(sf_output))
    input_sm = sf_decode(input)
    sm_output = [sf_decode(sf) for sf in sf_output]
    
    
    input_plogp = plogp(input_sm)
    plogp_improve = [plogp(i)-input_plogp for i in sm_output]
    
    
    simm = [sim(i,input_sm) for i in sm_output]
    
    candidate_selfies = {"candidates": sf_output, "improvement": plogp_improve, "sim": simm}
    data = pd.DataFrame(candidate_selfies)
    
    results = data[(data['improvement']> 0) & (data['sim']>0.4)]
    opt_output  = results["candidates"].tolist()
    opt_output_imp = results["improvement"].tolist()
    opt_output_imp = [str(i) for i in opt_output_imp]
    opt_output_sim = results["sim"].tolist()
    opt_output_sim = [str(i) for i in opt_output_sim]
    

    smis = [sf.decoder(i) for i in opt_output]
    mols = []
    for smi in smis:
        mol = Chem.MolFromSmiles(smi)
        mols.append(mol)
    
    opt_output_img = Draw.MolsToGridImage(
        mols,
        molsPerRow=4,
        subImgSize=(200,200),
        legends=['' for x in mols]
    )
    return opt_input_img, "\n".join(opt_output), "\n".join(opt_output_imp), "\n".join(opt_output_sim), opt_output_img

with gr.Blocks() as demo:
    gr.Markdown("# MolGen: Domain-Agnostic Molecular Generation with Self-feedback")

    with gr.Tabs():
        with gr.TabItem("Molecular Generation"):
            with gr.Row():
                with gr.Column():
                    gen_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
                    gen_button = gr.Button("Generate")

                with gr.Column():
                    gen_output = gr.Textbox(label="Generation Results", lines=5, placeholder="")
                    gen_output_image = gr.Image(label="Visualization")
       
            gr.Examples(
                examples=[["[C][=C][C][=C][C][=C][Ring1][=Branch1]"], 
                          ["[C]"]
                          ],
                inputs=[gen_input],
                outputs=[gen_output, gen_output_image],
                fn=gen_process,
                cache_examples=True,
            )

        with gr.TabItem("Constrained Molecular Property Optimization"):
            with gr.Row():
                with gr.Column():
                    opt_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input")
                    opt_button = gr.Button("Optimize")

                with gr.Column():
                    opt_input_img = gr.Image(label="Input Visualization")
                    opt_output = gr.Textbox(label="Optimization Results", lines=3, placeholder="")
                    opt_output_imp = gr.Textbox(label="Optimization Property Improvements", lines=3, placeholder="")
                    opt_output_sim = gr.Textbox(label="Similarity", lines=3, placeholder="")
                    opt_output_img = gr.Image(label="Output Visualization")
                    

            gr.Examples(
                examples=[["[C][C][=Branch1][C][=O][N][C][C][O][C][C][O][C][C][O][C][C][Ring1][N]"], 
                          ["[C][C][S][C][C][S][C][C][C][S][C][C][S][C][Ring1][=C]"],
                          ["[N][#C][C][C][C@@H1][C][C][C][C][C][C][C][C][C][C][C][Ring1][N][=O]"]
                          ],
                inputs=[opt_input],
                outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img],
                fn=opt_process,
                cache_examples=True,
            )

    gen_button.click(fn=gen_process, inputs=[gen_input], outputs=[gen_output, gen_output_image])
    opt_button.click(fn=opt_process, inputs=[opt_input], outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img])

demo.launch()