Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_ketcher import st_ketcher | |
from SynTool.mcts.tree import Tree, TreeConfig | |
from SynTool.mcts.expansion import PolicyFunction | |
from SynTool.mcts.search import extract_tree_stats | |
from SynTool.utils.config import PolicyNetworkConfig | |
from SynTool.interfaces.visualisation import to_table, extract_routes | |
import pickle | |
import uuid | |
import base64 | |
import pandas as pd | |
import json | |
import re | |
def download_button(object_to_download, download_filename, button_text, pickle_it=False): | |
""" | |
Issued from | |
Generates a link to download the given object_to_download. | |
Params: | |
------ | |
object_to_download: The object to be downloaded. | |
download_filename (str): filename and extension of file. e.g. mydata.csv, | |
some_txt_output.txt download_link_text (str): Text to display for download | |
link. | |
button_text (str): Text to display on download button (e.g. 'click here to download file') | |
pickle_it (bool): If True, pickle file. | |
Returns: | |
------- | |
(str): the anchor tag to download object_to_download | |
Examples: | |
-------- | |
download_link(your_df, 'YOUR_DF.csv', 'Click to download data!') | |
download_link(your_str, 'YOUR_STRING.txt', 'Click to download text!') | |
""" | |
if pickle_it: | |
try: | |
object_to_download = pickle.dumps(object_to_download) | |
except pickle.PicklingError as e: | |
st.write(e) | |
return None | |
else: | |
if isinstance(object_to_download, bytes): | |
pass | |
elif isinstance(object_to_download, pd.DataFrame): | |
object_to_download = object_to_download.to_csv(index=False).encode('utf-8') | |
# Try JSON encode for everything else | |
# else: | |
# object_to_download = json.dumps(object_to_download) | |
try: | |
# some strings <-> bytes conversions necessary here | |
b64 = base64.b64encode(object_to_download.encode()).decode() | |
except AttributeError: | |
b64 = base64.b64encode(object_to_download).decode() | |
button_uuid = str(uuid.uuid4()).replace('-', '') | |
button_id = re.sub('\d+', '', button_uuid) | |
custom_css = f""" | |
<style> | |
#{button_id} {{ | |
background-color: rgb(255, 255, 255); | |
color: rgb(38, 39, 48); | |
text-decoration: none; | |
border-radius: 4px; | |
border-width: 1px; | |
border-style: solid; | |
border-color: rgb(230, 234, 241); | |
border-image: initial; | |
}} | |
#{button_id}:hover {{ | |
border-color: rgb(246, 51, 102); | |
color: rgb(246, 51, 102); | |
}} | |
#{button_id}:active {{ | |
box-shadow: none; | |
background-color: rgb(246, 51, 102); | |
color: white; | |
}} | |
</style> """ | |
dl_link = custom_css + f'<a download="{download_filename}" id="{button_id}" href="data:file/txt;base64,{b64}">{button_text}</a><br></br>' | |
return dl_link | |
st.set_page_config( # layout="wide", | |
page_title="SynTool GUI", | |
page_icon="🧪",) | |
st.title("`SynTool GUI`") | |
st.write("*{Introduction text to be inserted here}*") | |
st.header('Molecule input') | |
st.write("You can provide a molecular structure by either providing its SMILES string + Enter, either by drawing it + Apply.") | |
DEFAULT_MOL='NC(CCCCB(O)O)(CCN1CCC(CO)C1)C(=O)O' | |
molecule = st.text_input("Molecule", DEFAULT_MOL) | |
smile_code = st_ketcher(molecule) | |
st.header('Launch calculation') | |
st.write("If you modified the structure, please ensure you clicked on 'Apply' (bottom right of the molecular editor).") | |
st.markdown(f"The molecule SMILES is actually: ``{smile_code}``") | |
max_depth = st.slider('Maximal number of reaction steps', min_value=2, max_value=9, value=9) | |
run_default = st.button('Launch and search a reaction path',) | |
ranking_policy_weights_path = 'data/policy_network.ckpt' | |
reaction_rules_path = 'data/reaction_rules.pickle' | |
building_blocks_path = 'data/building_blocks.smi' | |
policy_config = PolicyNetworkConfig(weights_path=ranking_policy_weights_path) | |
policy_function = PolicyFunction(policy_config=policy_config) | |
if run_default: | |
st.toast('Optimisation is started. The progress will be printed below') | |
spinner = st.spinner(text="Running with default parameters...") | |
tree_config = TreeConfig( | |
search_strategy="expansion_first", | |
evaluation_type="rollout", | |
max_iterations=100, | |
max_depth=max_depth, | |
min_mol_size=0, | |
init_node_value=0.5, | |
ucb_type="uct", | |
c_ucb=0.1, | |
silent=True | |
) | |
with spinner: | |
tree = Tree( | |
target=smile_code, | |
tree_config=tree_config, | |
reaction_rules_path=reaction_rules_path, | |
building_blocks_path=building_blocks_path, | |
policy_function=policy_function, | |
value_function=None, | |
) | |
_ = list(tree) | |
res = extract_tree_stats(tree, smile_code) # extract_routes(tree) | |
st.header('Results') | |
if res['found_paths']: | |
st.write("Success!") | |
st.subheader("Retrosynthetic Routes Report") | |
st.markdown(to_table(tree, None, extended=True, integration=True), unsafe_allow_html=True) | |
st.subheader("Statistics") | |
st.write(pd.DataFrame(res, index=[0])) | |
st.subheader("Downloads") | |
dl_html = download_button(to_table(tree, None, extended=True, integration=False), | |
'results_syntool.html', | |
'Download results as a HTML file') | |
dl_csv = download_button(pd.DataFrame(res, index=[0]), | |
'results_syntool.csv', | |
'Download statistics as an Excel csv file') | |
st.markdown(dl_html+dl_csv, unsafe_allow_html=True) | |
else: | |
st.write("Found no reaction path.") | |
st.divider() | |
st.header('Restart from the beginning?') | |
if st.button("Restart"): | |
st.rerun() | |