esm3 / app.py
colbyford's picture
Reduce need for scaffold functions
620584c
raw
history blame
10.6 kB
import gradio as gr
import numpy as np
import torch
import py3Dmol
from huggingface_hub import login
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.esm3 import ESM3
from esm.sdk.api import (
ESMProtein,
GenerationConfig,
)
theme = gr.themes.Monochrome(
primary_hue="gray",
)
## Function to get model from Hugging Face using token
def get_model(model_name, token):
login(token=token)
if torch.cuda.is_available():
model = ESM3.from_pretrained(model_name, device=torch.device("cuda"))
else:
model = ESM3.from_pretrained(model_name, device=torch.device("cpu"))
# model = ESM3.from_pretrained(model_name, device=torch.device("cpu"))
return model
## Function to render 3D structure using py3Dmol
def render_pdb(pdb_string, motif_start=None, motif_end=None):
view = py3Dmol.view(width=800, height=800)
view.addModel(pdb_string, "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
if motif_start is not None and motif_end is not None:
motif_inds = np.arange(motif_start, motif_end)
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (motif_inds + 1).tolist()
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
return view
## Function to get PDB data
def get_pdb(pdb_id, chain_id):
pdb = ProteinChain.from_rcsb(pdb_id, chain_id)
# return [pdb.sequence, render_pdb(pdb.to_pdb_string())]
return pdb
def scaffold(model_name, token, pdb_id, chain_id, motif_start, motif_end, prompt_length, insert_size):
pdb = get_pdb(pdb_id, chain_id)
## Get motif sequence and atom37 positions
motif_inds = np.arange(motif_start, motif_end)
motif_sequence = pdb[motif_inds].sequence
motif_atom37_positions = pdb[motif_inds].atom37_positions
## Create sequence prompt
sequence_prompt = ["_"]*prompt_length
sequence_prompt[insert_size:insert_size+len(motif_sequence)] = list(motif_sequence)
sequence_prompt = "".join(sequence_prompt)
## Create structure prompt
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
structure_prompt[insert_size:insert_size+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions)
## Create protein prompt and sequence generation config
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)
sequence_generation_config = GenerationConfig(track="sequence",
num_steps=sequence_prompt.count("_") // 2,
temperature=0.5)
## Generate sequence
model = get_model(model_name, token)
sequence_generation = model.generate(protein_prompt, sequence_generation_config)
generated_sequence = sequence_generation.sequence
return [
pdb.sequence,
motif_sequence,
# motif_atom37_positions,
sequence_prompt,
# structure_prompt,
# protein_prompt
generated_sequence
]
def ss_edit(model_name, token, pdb_id, chain_id, region_start, region_end, shortened_region_length, shortening_ss8):
pdb = get_pdb(pdb_id, chain_id)
edit_region = np.arange(region_start, region_end)
## Construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked
sequence_prompt = pdb.sequence[:edit_region[0]] + "_" * shortened_region_length + pdb.sequence[edit_region[-1] + 1:]
## Construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region
ss8_prompt = shortening_ss8[:edit_region[0]] + (((shortened_region_length - 3) // 2) * "H" + "C"*3 + ((shortened_region_length - 3) // 2) * "H") + shortening_ss8[edit_region[-1] + 1:]
## Save original sequence and secondary structure
original_sequence = pdb.sequence
original_ss8 = shortening_ss8
original_ss8_region = " "*edit_region[0] + shortening_ss8[edit_region[0]:edit_region[-1]+1]
proposed_ss8_region = " "*edit_region[0] + ss8_prompt[edit_region[0]:edit_region[0]+shortened_region_length]
## Create protein prompt
protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)
## Generatre sequence
model = get_model(model_name, token)
sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=protein_prompt.sequence.count("_") // 2, temperature=0.5))
return [
original_sequence,
original_ss8,
original_ss8_region,
sequence_prompt,
ss8_prompt,
proposed_ss8_region,
# protein_prompt,
sequence_generation
]
def sasa_edit(model_name, token, pdb_id, chain_id, span_start, span_end, n_samples):
pdb = get_pdb(pdb_id, chain_id)
structure_prompt = torch.full((len(pdb), 37, 3), torch.nan)
structure_prompt[span_start:span_end] = torch.tensor(pdb[span_start:span_end].atom37_positions, dtype=torch.float32)
sasa_prompt = [None]*len(pdb)
sasa_prompt[span_start:span_end] = [40.0]*(span_end - span_start)
protein_prompt = ESMProtein(sequence="_"*len(pdb), coordinates=structure_prompt, sasa=sasa_prompt)
model = get_model(model_name, token)
generated_proteins = []
for i in range(n_samples):
## Generate sequence
sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=len(protein_prompt) // 8, temperature=0.7))
## Fold Protein
structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track="structure", num_steps=len(protein_prompt) // 32))
generated_proteins.append(structure_prediction)
## Sort generations by ptm
generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True)
return [
protein_prompt,
sequence_generation,
generated_proteins
]
## Interface for main Scaffolding Example
scaffold_app = gr.Interface(
fn=scaffold,
inputs=[
gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True),
gr.Textbox(value = "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", label="Hugging Face Token", type="password"),
gr.Textbox(value="1ITU", label = "PDB Code"),
gr.Textbox(value="A", label = "Chain"),
gr.Number(value=123, label="Motif Start"),
gr.Number(value=146, label="Motif End"),
gr.Number(value=200, label="Prompt Length"),
gr.Number(value=72, label="Insert Size")
],
outputs=[
gr.Textbox(label="Sequence"),
# gr.Plot(label="3D Structure")
gr.Textbox(label="Motif Sequence"),
# gr.Textbox(label="Motif Positions")
gr.Textbox(label="Sequence Prompt"),
# gr.Textbox(label="Structure Prompt"),
# gr.Textbox(label="Protein Prompt"),
gr.Textbox(label="Generated Sequence")
]
)
## Interface for "Secondary Structure Editing Example: Helix Shortening"
ss_app = gr.Interface(
fn=ss_edit,
inputs=[
gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True),
gr.Textbox(value = "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", label="Hugging Face Token", type="password"),
gr.Textbox(value = "7XBQ", label="PDB ID"),
gr.Textbox(value = "A", label="Chain ID"),
gr.Number(value=38, label="Edit Region Start"),
gr.Number(value=111, label="Edit Region End"),
gr.Number(value=45, label="Shortened Region Length"),
gr.Textbox(value="CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC", label="SS8 Shortening")
],
outputs=[
gr.Textbox(label="Original Sequence"),
gr.Textbox(label="Original SS8"),
gr.Textbox(label="Original SS8 Edit Region"),
gr.Textbox(label="Sequence Prompt"),
gr.Textbox(label="Edited SS8 Prompt"),
gr.Textbox(label="Proposed SS8 of Edit Region"),
# gr.Textbox(label="Protein Prompt"),
gr.Textbox(label="Generated Sequence")
]
)
## Interface for "SASA Editing Example: Exposing a buried helix"
sasa_app = gr.Interface(
fn=sasa_edit,
inputs=[
gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True),
gr.Textbox(value = "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", label="Hugging Face Token", type="password"),
gr.Textbox(value = "1LBS", label="PDB ID"),
gr.Textbox(value = "A", label="Chain ID"),
gr.Number(value=105, label="Span Start"),
gr.Number(value=116, label="Span End"),
# gr.Textbox(value="CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC", label="SS8 String")
gr.Number(value=4, label="Number of Samples")
],
outputs = [
gr.Textbox(label="Protein Prompt"),
gr.Textbox(label="Generated Sequences"),
gr.Textbox(label="Generated Proteins")
]
)
## Main Interface
with gr.Blocks(theme=theme) as esm_app:
with gr.Row():
gr.Markdown(
"""
# ESM3: A frontier language model for biology.
Model Created By: [EvolutionaryScale](https://www.evolutionaryscale.ai)
- Press Release: https://www.evolutionaryscale.ai/blog/esm3-release
- GitHub: https://github.com/evolutionaryscale/esm
- HuggingFace Model: https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1
Spaces App By: [Tuple, The Cloud Genomics Company](https://tuple.xyz) [[Colby T. Ford](https://colbyford.com)]
"""
)
with gr.Row():
gr.TabbedInterface([
scaffold_app,
ss_app,
sasa_app
],
[
"Scaffolding Example",
"Secondary Structure Editing Example",
"SASA Editing Example"
])
if __name__ == "__main__":
esm_app.launch()