Spaces:
Sleeping
Sleeping
import streamlit as st | |
import streamlit.components.v1 as components | |
import pandas as pd | |
import mols2grid | |
from ipywidgets import interact, widgets | |
import textwrap | |
import moses | |
from transformers import EncoderDecoderModel, RobertaTokenizer | |
from moses.metrics.utils import QED, SA, logP, NP, weight, get_n_rings | |
from moses.utils import mapper, get_mol | |
# @st.cache(allow_output_mutation=False, hash_funcs={Tokenizer: str}) | |
from typing import List | |
from util import filter_dataframe | |
def load_models(): | |
# protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") | |
# mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") | |
model1 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenOne") | |
model2 = EncoderDecoderModel.from_pretrained("gokceuludogan/WarmMolGenTwo") | |
return model1, model2 # , protein_tokenizer, mol_tokenizer | |
def count(smiles_list: List[str]): | |
counts = [] | |
for smiles in smiles_list: | |
counts.append(len(smiles)) | |
return counts | |
def remove_none_elements(mol_list, smiles_list): | |
filtered_mol_list = [] | |
filtered_smiles_list = [] | |
indices = [] | |
for i, element in enumerate(mol_list): | |
if element is not None: | |
filtered_mol_list.append(element) | |
else: | |
indices.append(i) | |
removed_len = len(indices) | |
for i in range(len(smiles_list)): | |
if i not in indices: | |
filtered_smiles_list.append(smiles_list.__getitem__(i)) | |
return filtered_mol_list, filtered_smiles_list, removed_len | |
def format_list_numbers(lst): | |
for i, value in enumerate(lst): | |
lst[i] = float("{:.3f}".format(value)) | |
def generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, target, pool): | |
protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/WarmMolGenTwo") | |
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") | |
# model1, model2, protein_tokenizer, mol_tokenizer = load_models() | |
model1, model2 = load_models() | |
inputs = protein_tokenizer(target, return_tensors="pt") | |
model = model1 if model_name == 'WarmMolGenOne' else model2 | |
outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, | |
eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, | |
max_length=int(max_new_tokens), num_return_sequences=int(num_mols), | |
do_sample=do_sample, num_beams=num_beams) | |
output_smiles = mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
st.write("### Generated Molecules") | |
# mol_list = list(map(MolFromSmiles, output_smiles)) | |
# print(mol_list) | |
# QED_scores = list(map(QED.qed, mol_list)) | |
# print(QED_scores) | |
# st.write(output_smiles) | |
mol_list = mapper(pool)(get_mol, output_smiles) | |
mol_list, output_smiles, removed_len = remove_none_elements(mol_list, output_smiles) | |
if removed_len != 0: | |
st.write(f"#### Note that: {removed_len} numbers of generated invalid molecules are discarded.") | |
QED_scores = mapper(pool)(QED, mol_list) | |
SA_scores = mapper(pool)(SA, mol_list) | |
logP_scores = mapper(pool)(logP, mol_list) | |
NP_scores = mapper(pool)(NP, mol_list) | |
weight_scores = mapper(pool)(weight, mol_list) | |
format_list_numbers(QED_scores) | |
format_list_numbers(SA_scores) | |
format_list_numbers(logP_scores) | |
format_list_numbers(NP_scores) | |
format_list_numbers(weight_scores) | |
df_smiles = pd.DataFrame( | |
{'SMILES': output_smiles, "QED": QED_scores, "SA": SA_scores, "logP": logP_scores, "NP": NP_scores, | |
"Weight": weight_scores}) | |
return df_smiles | |
def warm_molgen_demo(): | |
with st.form("my_form"): | |
with st.sidebar: | |
st.sidebar.subheader("Configurable parameters") | |
model_name = st.sidebar.selectbox( | |
"Model Selector", | |
options=[ | |
"WarmMolGenOne", | |
"WarmMolGenTwo", | |
], | |
index=0, | |
) | |
num_mols = st.sidebar.number_input( | |
"Number of generated molecules", | |
min_value=0, | |
max_value=20, | |
value=10, | |
help="The number of molecules to be generated.", | |
) | |
max_new_tokens = st.sidebar.number_input( | |
"Maximum length", | |
min_value=0, | |
max_value=1024, | |
value=128, | |
help="The maximum length of the sequence to be generated.", | |
) | |
do_sample = st.sidebar.selectbox( | |
"Sampling?", | |
(True, False), | |
help="Whether or not to use sampling; use beam decoding otherwise.", | |
) | |
target = st.text_area( | |
"Target Sequence", | |
"MENTENSVDSKSIKNLEPKIIHGSESMDSGISLDNSYKMDYPEMGLCIIINNKNFHKSTG", | |
) | |
generate_new_molecules = st.form_submit_button("Generate Molecules") | |
num_beams = None if do_sample is True else int(num_mols) | |
pool = 1 | |
if generate_new_molecules: | |
st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, | |
target, pool) | |
if 'df' not in st.session_state: | |
st.session_state.df = generate_molecules(model_name, num_mols, max_new_tokens, do_sample, num_beams, | |
target, pool) | |
df = st.session_state.df | |
filtered_df = filter_dataframe(df) | |
if filtered_df.empty: | |
st.markdown( | |
""" | |
<span style='color: blue; font-size: 30px;'>No molecules were found with specified properties.</span> | |
""", | |
unsafe_allow_html=True | |
) | |
else: | |
raw_html = mols2grid.display(filtered_df, height=1000)._repr_html_() | |
components.html(raw_html, width=900, height=450, scrolling=True) | |
st.markdown("## How to Generate") | |
generation_code = f""" | |
from transformers import EncoderDecoderModel, RobertaTokenizer | |
protein_tokenizer = RobertaTokenizer.from_pretrained("gokceuludogan/{model_name}") | |
mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/PubChem10M_SMILES_BPE_450k") | |
model = EncoderDecoderModel.from_pretrained("gokceuludogan/{model_name}") | |
inputs = protein_tokenizer("{target}", return_tensors="pt") | |
outputs = model.generate(**inputs, decoder_start_token_id=mol_tokenizer.bos_token_id, | |
eos_token_id=mol_tokenizer.eos_token_id, pad_token_id=mol_tokenizer.eos_token_id, | |
max_length={max_new_tokens}, num_return_sequences={num_mols}, do_sample={do_sample}, num_beams={num_beams}) | |
mol_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
""" | |
st.code(textwrap.dedent(generation_code)) # textwrap.dedent("".join("Halletcez"))) | |
st.set_page_config(page_title="WarmMolGen Demo", page_icon="🔥", layout='wide') | |
st.markdown("# WarmMolGen Demo") | |
st.sidebar.header("WarmMolGen Demo") | |
st.markdown( | |
""" | |
This demo illustrates WarmMolGen models' generation capabilities. | |
Given a target sequence and a set of parameters, the models generate molecules targeting the given protein sequence. | |
Please enter an input sequence below 👇 and configure parameters from the sidebar 👈 to generate molecules! | |
See below for saving the output molecules and the code snippet generating them! | |
""" | |
) | |
warm_molgen_demo() | |