gnina-torch / app.py
Rocco Meli
better UI
aed59e3
raw
history blame
7.38 kB
import os
def load_html(html_file: str) -> str:
"""
Load file from HTML directory.
Parameters
----------
html_file: str
HTML file name
Returns
-------
str
HTML file content
"""
with open(os.path.join("html", html_file), "r") as f:
return f.read()
def load_protein_from_file(protein_file) -> str:
"""
Parameters
----------
protein_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
Protein PDB file content
"""
with open(protein_file.name, "r") as f:
return f.read()
def load_ligand_from_file(ligand_file) -> str:
"""
Load ligand from file.
Parameters
----------
ligand_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
Ligand SDF file content
"""
with open(ligand_file.name, "r") as f:
return f.read()
def protein_html_from_file(protein_file) -> str:
"""
Wrap 3Dmol.js code around protein PDB file.
Parameters
----------
protein_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
3Dmol.js HTML code for displaying a PDB file
"""
protein = load_protein_from_file(protein_file)
protein_html = load_html("protein.html")
html = protein_html.replace("%%%PDB%%%", protein)
wrapper = load_html("wrapper.html")
return wrapper.replace("%%%HTML%%%", html)
def ligand_html_from_file(ligand_file) -> str:
"""
Wrap 3Dmol.js code around ligand SDF file.
Parameters
----------
ligand_file: _TemporaryFileWrapper
GradIO file object
Returns
-------
str
3Dmol.js HTML code for displaying a SDF file
"""
ligand = load_ligand_from_file(ligand_file)
ligand_html = load_html("ligand.html")
html = ligand_html.replace("%%%SDF%%%", ligand)
wrapper = load_html("wrapper.html")
return wrapper.replace("%%%HTML%%%", html)
def protein_ligand_html_from_file(protein_file, ligand_file):
protein = load_protein_from_file(protein_file)
ligand = load_ligand_from_file(ligand_file)
protein_ligand_html = load_html("pl.html")
html = protein_ligand_html.replace("%%%PDB%%%", protein)
html = html.replace("%%%SDF%%%", ligand)
wrapper = load_html("wrapper.html")
return wrapper.replace("%%%HTML%%%", html)
def predict(protein_file, ligand_file, cnn: str = "default"):
"""
Run gnina-torch on protein-ligand complex.
Parameters
----------
protein_file: _TemporaryFileWrapper
GradIO file object
ligand_file: _TemporaryFileWrapper
GradIO file object
cnn: str
CNN model to use
Returns
-------
dict[str, float]
CNNscore, CNNaffinity, and CNNvariance
"""
import molgrid
from gninatorch import gnina, dataloaders
import torch
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model, ensemble = gnina.setup_gnina_model(cnn, 23.5, 0.5)
model.eval()
model.to(device)
example_provider = molgrid.ExampleProvider(
data_root="",
balanced=False,
shuffle=False,
default_batch_size=1,
iteration_scheme=molgrid.IterationScheme.SmallEpoch,
)
# FIXME: Do this properly... =( [Might require light gnina-torch refactoring]
with open("data.in", "w") as f:
f.write(protein_file.name)
f.write(" ")
f.write(ligand_file.name)
print("Populating example provider... ", end="")
example_provider.populate("data.in")
print("done")
grid_maker = molgrid.GridMaker(resolution=0.5, dimension=23.5)
# TODO: Allow average over different rotations
loader = dataloaders.GriddedExamplesLoader(
example_provider=example_provider,
grid_maker=grid_maker,
random_translation=0.0, # No random translations for inference
random_rotation=False, # No random rotations for inference
grids_only=True,
device=device,
)
print("Loading and gridding data... ", end="")
batch = next(loader)
print("done")
print("Predicting... ", end="")
with torch.no_grad():
log_pose, affinity, affinity_var = model(batch)
print("done")
return pd.DataFrame(
{
"CNNscore": [torch.exp(log_pose[:, -1]).item()],
"CNNaffinity": [affinity.item()],
"CNNvariance": [affinity_var.item()],
}
).round(6)
if __name__ == "__main__":
import gradio as gr
demo = gr.Blocks()
with demo:
gr.Markdown("# Gnina-Torch")
gr.Markdown(
"Score your protein-ligand compex and predict the binding affinity with [Gnina]"
+ "(https://github.com/gnina/gnina)'s scoring function. Poewerd by [gnina-torch]"
+ "(https://github.com/RMeli/gnina-torch), a PyTorch implementation of Gnina's"
+ " scoring function."
)
gr.Markdown("## Protein and Ligand")
gr.Markdown(
"Upload your protein and ligand files in PDB and SDF format, respectively."
+ "Optionally, you can visualise the protein, the ligand, and the"
+ "protein-ligand complex."
)
with gr.Row():
with gr.Box():
pfile = gr.File(file_count="single", label="Protein file (PDB)")
pbtn = gr.Button("View Protein")
protein = gr.HTML()
pbtn.click(fn=protein_html_from_file, inputs=[pfile], outputs=protein)
with gr.Box():
lfile = gr.File(file_count="single", label="Ligand file (SDF)")
lbtn = gr.Button("View Ligand")
ligand = gr.HTML()
lbtn.click(fn=ligand_html_from_file, inputs=[lfile], outputs=ligand)
with gr.Box():
with gr.Column():
plcomplex = gr.HTML()
# TODO: Automatically display complex when both files are uploaded
plbtn = gr.Button("View Protein-Ligand Complex")
plbtn.click(
fn=protein_ligand_html_from_file,
inputs=[pfile, lfile],
outputs=plcomplex,
)
gr.Markdown("## Gnina-Torch")
gr.Markdown(
"Score your protein-ligand complex with [gnina-torch]"
+ "(https://github.com/RMeli/gnina-torch). You can choose between different"
+ "pre-trained Gnina models. See [McNutte et al. (2021)]"
+ "(https://doi.org/10.1186/s13321-021-00522-2) for more details."
)
with gr.Row():
df = gr.Dataframe()
with gr.Column():
dd = gr.Dropdown(
choices=[
"default",
"redock_default2018_ensemble",
"general_default2018_ensemble",
"crossdock_default2018_ensemble",
],
value="default",
label="CNN model",
)
with gr.Row():
btn = gr.Button("Score!")
btn.click(fn=predict, inputs=[pfile, lfile, dd], outputs=df)
demo.launch()