import gradio as gr import numpy as np import torch import py3Dmol from huggingface_hub import login from esm.utils.structure.protein_chain import ProteinChain from esm.models.esm3 import ESM3 from esm.sdk.api import ( ESMProtein, GenerationConfig, ) theme = gr.themes.Monochrome( primary_hue="gray", ) ## Function to get model from Hugging Face using token def get_model(model_name, token): login(token=token) if torch.cuda.is_available(): model = ESM3.from_pretrained(model_name, device=torch.device("cuda")) else: model = ESM3.from_pretrained(model_name, device=torch.device("cpu")) # model = ESM3.from_pretrained(model_name, device=torch.device("cpu")) return model ## Function to render 3D structure using py3Dmol def render_pdb(pdb_string, motif_start=None, motif_end=None): view = py3Dmol.view(width=800, height=800) view.addModel(pdb_string, "pdb") view.setStyle({"cartoon": {"color": "spectrum"}}) if motif_start is not None and motif_end is not None: motif_inds = np.arange(motif_start, motif_end) view.setStyle({"cartoon": {"color": "lightgrey"}}) motif_res_inds = (motif_inds + 1).tolist() view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}}) view.zoomTo() return view ## Function to get PDB data def get_pdb(pdb_id, chain_id): pdb = ProteinChain.from_rcsb(pdb_id, chain_id) # return [pdb.sequence, render_pdb(pdb.to_pdb_string())] return pdb def scaffold(model_name, token, pdb_id, chain_id, motif_start, motif_end, prompt_length, insert_size): pdb = get_pdb(pdb_id, chain_id) ## Get motif sequence and atom37 positions motif_inds = np.arange(motif_start, motif_end) motif_sequence = pdb[motif_inds].sequence motif_atom37_positions = pdb[motif_inds].atom37_positions ## Create sequence prompt sequence_prompt = ["_"]*prompt_length sequence_prompt[insert_size:insert_size+len(motif_sequence)] = list(motif_sequence) sequence_prompt = "".join(sequence_prompt) ## Create structure prompt structure_prompt = torch.full((prompt_length, 37, 3), np.nan) structure_prompt[insert_size:insert_size+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions) ## Create protein prompt and sequence generation config protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt) sequence_generation_config = GenerationConfig(track="sequence", num_steps=sequence_prompt.count("_") // 2, temperature=0.5) ## Generate sequence model = get_model(model_name, token) sequence_generation = model.generate(protein_prompt, sequence_generation_config) generated_sequence = sequence_generation.sequence return [ pdb.sequence, motif_sequence, # motif_atom37_positions, sequence_prompt, # structure_prompt, # protein_prompt generated_sequence ] def ss_edit(model_name, token, pdb_id, chain_id, region_start, region_end, shortened_region_length, shortening_ss8): pdb = get_pdb(pdb_id, chain_id) edit_region = np.arange(region_start, region_end) ## Construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked sequence_prompt = pdb.sequence[:edit_region[0]] + "_" * shortened_region_length + pdb.sequence[edit_region[-1] + 1:] ## Construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region ss8_prompt = shortening_ss8[:edit_region[0]] + (((shortened_region_length - 3) // 2) * "H" + "C"*3 + ((shortened_region_length - 3) // 2) * "H") + shortening_ss8[edit_region[-1] + 1:] ## Save original sequence and secondary structure original_sequence = pdb.sequence original_ss8 = shortening_ss8 original_ss8_region = " "*edit_region[0] + shortening_ss8[edit_region[0]:edit_region[-1]+1] proposed_ss8_region = " "*edit_region[0] + ss8_prompt[edit_region[0]:edit_region[0]+shortened_region_length] ## Create protein prompt protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt) ## Generatre sequence model = get_model(model_name, token) sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=protein_prompt.sequence.count("_") // 2, temperature=0.5)) return [ original_sequence, original_ss8, original_ss8_region, sequence_prompt, ss8_prompt, proposed_ss8_region, # protein_prompt, sequence_generation ] def sasa_edit(model_name, token, pdb_id, chain_id, span_start, span_end, n_samples): pdb = get_pdb(pdb_id, chain_id) structure_prompt = torch.full((len(pdb), 37, 3), torch.nan) structure_prompt[span_start:span_end] = torch.tensor(pdb[span_start:span_end].atom37_positions, dtype=torch.float32) sasa_prompt = [None]*len(pdb) sasa_prompt[span_start:span_end] = [40.0]*(span_end - span_start) protein_prompt = ESMProtein(sequence="_"*len(pdb), coordinates=structure_prompt, sasa=sasa_prompt) model = get_model(model_name, token) generated_proteins = [] for i in range(n_samples): ## Generate sequence sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=len(protein_prompt) // 8, temperature=0.7)) ## Fold Protein structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track="structure", num_steps=len(protein_prompt) // 32)) generated_proteins.append(structure_prediction) ## Sort generations by ptm generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True) return [ protein_prompt, sequence_generation, generated_proteins ] ## Interface for main Scaffolding Example scaffold_app = gr.Interface( fn=scaffold, inputs=[ gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True), gr.Textbox(value = "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", label="Hugging Face Token", type="password"), gr.Textbox(value="1ITU", label = "PDB Code"), gr.Textbox(value="A", label = "Chain"), gr.Number(value=123, label="Motif Start"), gr.Number(value=146, label="Motif End"), gr.Number(value=200, label="Prompt Length"), gr.Number(value=72, label="Insert Size") ], outputs=[ gr.Textbox(label="Sequence"), # gr.Plot(label="3D Structure") gr.Textbox(label="Motif Sequence"), # gr.Textbox(label="Motif Positions") gr.Textbox(label="Sequence Prompt"), # gr.Textbox(label="Structure Prompt"), # gr.Textbox(label="Protein Prompt"), gr.Textbox(label="Generated Sequence") ] ) ## Interface for "Secondary Structure Editing Example: Helix Shortening" ss_app = gr.Interface( fn=ss_edit, inputs=[ gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True), gr.Textbox(value = "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", label="Hugging Face Token", type="password"), gr.Textbox(value = "7XBQ", label="PDB ID"), gr.Textbox(value = "A", label="Chain ID"), gr.Number(value=38, label="Edit Region Start"), gr.Number(value=111, label="Edit Region End"), gr.Number(value=45, label="Shortened Region Length"), gr.Textbox(value="CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC", label="SS8 Shortening") ], outputs=[ gr.Textbox(label="Original Sequence"), gr.Textbox(label="Original SS8"), gr.Textbox(label="Original SS8 Edit Region"), gr.Textbox(label="Sequence Prompt"), gr.Textbox(label="Edited SS8 Prompt"), gr.Textbox(label="Proposed SS8 of Edit Region"), # gr.Textbox(label="Protein Prompt"), gr.Textbox(label="Generated Sequence") ] ) ## Interface for "SASA Editing Example: Exposing a buried helix" sasa_app = gr.Interface( fn=sasa_edit, inputs=[ gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True), gr.Textbox(value = "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", label="Hugging Face Token", type="password"), gr.Textbox(value = "1LBS", label="PDB ID"), gr.Textbox(value = "A", label="Chain ID"), gr.Number(value=105, label="Span Start"), gr.Number(value=116, label="Span End"), # gr.Textbox(value="CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC", label="SS8 String") gr.Number(value=4, label="Number of Samples") ], outputs = [ gr.Textbox(label="Protein Prompt"), gr.Textbox(label="Generated Sequences"), gr.Textbox(label="Generated Proteins") ] ) ## Main Interface with gr.Blocks(theme=theme) as esm_app: with gr.Row(): gr.Markdown( """ # ESM3: A frontier language model for biology. Model Created By: [EvolutionaryScale](https://www.evolutionaryscale.ai) - Press Release: https://www.evolutionaryscale.ai/blog/esm3-release - GitHub: https://github.com/evolutionaryscale/esm - HuggingFace Model: https://huggingface.co/EvolutionaryScale/esm3-sm-open-v1 Spaces App By: [Tuple, The Cloud Genomics Company](https://tuple.xyz) [[Colby T. Ford](https://colbyford.com)] """ ) with gr.Row(): gr.TabbedInterface([ scaffold_app, ss_app, sasa_app ], [ "Scaffolding Example", "Secondary Structure Editing Example", "SASA Editing Example" ]) if __name__ == "__main__": esm_app.launch()