colbyford commited on
Commit
6a5443d
·
1 Parent(s): 66b44f0

Add initial app code

Browse files
Files changed (3) hide show
  1. README.md +5 -5
  2. app.py +281 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Esm3
3
- emoji:
4
- colorFrom: red
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.37.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ESM3
3
+ emoji: 🧬
4
+ colorFrom: gray
5
+ colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.37.1
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
+ # ESM3 HF Spaces Application
app.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ import py3Dmol
6
+ from huggingface_hub import login
7
+
8
+ from esm.utils.structure.protein_chain import ProteinChain
9
+ from esm.models.esm3 import ESM3
10
+ from esm.sdk.api import (
11
+ ESMProtein,
12
+ GenerationConfig,
13
+ )
14
+
15
+ theme = gr.themes.Monochrome(
16
+ primary_hue="gray",
17
+ )
18
+
19
+ ## Function to get model from Hugging Face using token
20
+ def get_model(model_name, token):
21
+ login(token=token)
22
+
23
+ # if torch.cuda.is_available():
24
+ # model = ESM3.from_pretrained(model_name, device=torch.device("cuda"))
25
+ # else:
26
+ # model = ESM3.from_pretrained(model_name, device=torch.device("cpu"))
27
+
28
+ model = ESM3.from_pretrained(model_name, device=torch.device("cpu"))
29
+ return model
30
+
31
+ ## Function to render 3D structure using py3Dmol
32
+ def render_pdb(pdb_string, motif_start=None, motif_end=None):
33
+ view = py3Dmol.view(width=800, height=800)
34
+ view.addModel(pdb_string, "pdb")
35
+ view.setStyle({"cartoon": {"color": "spectrum"}})
36
+ if motif_start is not None and motif_end is not None:
37
+ motif_inds = np.arange(motif_start, motif_end)
38
+ view.setStyle({"cartoon": {"color": "lightgrey"}})
39
+ motif_res_inds = (motif_inds + 1).tolist()
40
+ view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
41
+ view.zoomTo()
42
+ return view
43
+
44
+ ## Function to get PDB data
45
+ def get_pdb(pdb_id, chain_id):
46
+ pdb = ProteinChain.from_rcsb(pdb_id, chain_id)
47
+ # return [pdb.sequence, render_pdb(pdb.to_pdb_string())]
48
+ return pdb
49
+
50
+
51
+ # def select_motif(pdb, motif_start, motif_end):
52
+ # motif_inds = np.arange(motif_start, motif_end)
53
+ # motif_sequence = pdb[motif_inds].sequence
54
+ # motif_atom37_positions = pdb[motif_inds].atom37_positions
55
+ # return [motif_sequence, motif_atom37_positions]
56
+
57
+ # def setup_prompt(prompt_length, motif_sequence, motif_atom37_positions, insert_size):
58
+ # prompt_length = 200
59
+
60
+ # sequence_prompt = ["_"]*prompt_length
61
+ # sequence_prompt[insert_size:insert_size+len(motif_sequence)] = list(motif_sequence)
62
+ # sequence_prompt = "".join(sequence_prompt)
63
+
64
+ # structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
65
+ # structure_prompt[insert_size:insert_size+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions)
66
+
67
+ # protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)
68
+
69
+ # return [sequence_prompt, structure_prompt, protein_prompt]
70
+
71
+
72
+ # def generate_scaffold_sequence(model_name, token, sequence_prompt, protein_prompt):
73
+ # sequence_generation_config = GenerationConfig(track="sequence",
74
+ # num_steps=sequence_prompt.count("_") // 2,
75
+ # temperature=0.5)
76
+ # model = get_model(model_name, token)
77
+ # sequence_generation = model.generate(protein_prompt, sequence_generation_config)
78
+ # return sequence_generation
79
+
80
+
81
+ def scaffold(model_name, token, pdb_id, chain_id, motif_start, motif_end, prompt_length, insert_size):
82
+ pdb = get_pdb(pdb_id, chain_id)
83
+ # motif_sequence, motif_atom37_positions = select_motif(pdb, motif_start, motif_end)
84
+
85
+ motif_inds = np.arange(motif_start, motif_end)
86
+ motif_sequence = pdb[motif_inds].sequence
87
+ motif_atom37_positions = pdb[motif_inds].atom37_positions
88
+
89
+ # sequence_prompt, structure_prompt, protein_prompt = setup_prompt(prompt_length, motif_sequence, motif_atom37_positions, insert_size)
90
+
91
+ ## Create sequence prompt
92
+ sequence_prompt = ["_"]*prompt_length
93
+ sequence_prompt[insert_size:insert_size+len(motif_sequence)] = list(motif_sequence)
94
+ sequence_prompt = "".join(sequence_prompt)
95
+
96
+ ## Create structure prompt
97
+ structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
98
+ structure_prompt[insert_size:insert_size+len(motif_atom37_positions)] = torch.tensor(motif_atom37_positions)
99
+
100
+ ## Create protein prompt
101
+ protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)
102
+
103
+ # sequence_generation = generate_scaffold_sequence(model_name, token, sequence_prompt, protein_prompt)
104
+ sequence_generation_config = GenerationConfig(track="sequence",
105
+ num_steps=sequence_prompt.count("_") // 2,
106
+ temperature=0.5)
107
+ ## Generate sequence
108
+ model = get_model(model_name, token)
109
+ sequence_generation = model.generate(protein_prompt, sequence_generation_config)
110
+ generated_sequence = sequence_generation.sequence
111
+
112
+ return [
113
+ pdb.sequence,
114
+ motif_sequence,
115
+ # motif_atom37_positions,
116
+ sequence_prompt,
117
+ # structure_prompt,
118
+ # protein_prompt
119
+ generated_sequence
120
+ ]
121
+
122
+ def ss_edit(model_name, token, pdb_id, chain_id, region_start, region_end, shortened_region_length, shortening_ss8):
123
+ pdb = get_pdb(pdb_id, chain_id)
124
+ edit_region = np.arange(region_start, region_end)
125
+
126
+ ## Construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked
127
+ sequence_prompt = pdb.sequence[:edit_region[0]] + "_" * shortened_region_length + pdb.sequence[edit_region[-1] + 1:]
128
+
129
+ ## Construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region
130
+ ss8_prompt = shortening_ss8[:edit_region[0]] + (((shortened_region_length - 3) // 2) * "H" + "C"*3 + ((shortened_region_length - 3) // 2) * "H") + shortening_ss8[edit_region[-1] + 1:]
131
+
132
+ ## Save original sequence and secondary structure
133
+ original_sequence = pdb.sequence
134
+ original_ss8 = shortening_ss8
135
+ original_ss8_region = " "*edit_region[0] + shortening_ss8[edit_region[0]:edit_region[-1]+1]
136
+
137
+ proposed_ss8_region = " "*edit_region[0] + ss8_prompt[edit_region[0]:edit_region[0]+shortened_region_length]
138
+
139
+ ## Create protein prompt
140
+ protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)
141
+
142
+ ## Generatre sequence
143
+ model = get_model(model_name, token)
144
+ sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=protein_prompt.sequence.count("_") // 2, temperature=0.5))
145
+
146
+ return [
147
+ original_sequence,
148
+ original_ss8,
149
+ original_ss8_region,
150
+ sequence_prompt,
151
+ ss8_prompt,
152
+ proposed_ss8_region,
153
+ # protein_prompt,
154
+ sequence_generation
155
+ ]
156
+
157
+ def sasa_edit(model_name, token, pdb_id, chain_id, span_start, span_end, n_samples):
158
+ pdb = get_pdb(pdb_id, chain_id)
159
+
160
+ structure_prompt = torch.full((len(pdb), 37, 3), torch.nan)
161
+ structure_prompt[span_start:span_end] = torch.tensor(pdb[span_start:span_end].atom37_positions, dtype=torch.float32)
162
+
163
+ sasa_prompt = [None]*len(pdb)
164
+ sasa_prompt[span_start:span_end] = [40.0]*(span_end - span_start)
165
+
166
+ protein_prompt = ESMProtein(sequence="_"*len(pdb), coordinates=structure_prompt, sasa=sasa_prompt)
167
+
168
+ model = get_model(model_name, token)
169
+
170
+ generated_proteins = []
171
+ for i in range(n_samples):
172
+ ## Generate sequence
173
+ sequence_generation = model.generate(protein_prompt, GenerationConfig(track="sequence", num_steps=len(protein_prompt) // 8, temperature=0.7))
174
+ ## Fold Protein
175
+ structure_prediction = model.generate(ESMProtein(sequence=sequence_generation.sequence), GenerationConfig(track="structure", num_steps=len(protein_prompt) // 32))
176
+ generated_proteins.append(structure_prediction)
177
+
178
+ ## Sort generations by ptm
179
+ generated_proteins = sorted(generated_proteins, key=lambda x: x.ptm.item(), reverse=True)
180
+
181
+ return [
182
+ protein_prompt,
183
+ sequence_generation,
184
+ generated_proteins
185
+ ]
186
+
187
+
188
+ ## Interface for main Scaffolding Example
189
+ scaffold_app = gr.Interface(
190
+ fn=scaffold,
191
+ inputs=[
192
+ gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True),
193
+ gr.Textbox(value = "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", label="Hugging Face Token", type="password"),
194
+ gr.Textbox(value="1ITU", label = "PDB Code"),
195
+ gr.Textbox(value="A", label = "Chain"),
196
+ gr.Number(value=123, label="Motif Start"),
197
+ gr.Number(value=146, label="Motif End"),
198
+ gr.Number(value=200, label="Prompt Length"),
199
+ gr.Number(value=72, label="Insert Size")
200
+ ],
201
+ outputs=[
202
+ gr.Textbox(label="Sequence"),
203
+ # gr.Plot(label="3D Structure")
204
+ gr.Textbox(label="Motif Sequence"),
205
+ # gr.Textbox(label="Motif Positions")
206
+ gr.Textbox(label="Sequence Prompt"),
207
+ # gr.Textbox(label="Structure Prompt"),
208
+ # gr.Textbox(label="Protein Prompt"),
209
+ gr.Textbox(label="Generated Sequence")
210
+ ]
211
+ )
212
+
213
+ ## Interface for "Secondary Structure Editing Example: Helix Shortening"
214
+ ss_app = gr.Interface(
215
+ fn=ss_edit,
216
+ inputs=[
217
+ gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True),
218
+ gr.Textbox(value = "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", label="Hugging Face Token", type="password"),
219
+ gr.Textbox(value = "7XBQ", label="PDB ID"),
220
+ gr.Textbox(value = "A", label="Chain ID"),
221
+ gr.Number(value=38, label="Edit Region Start"),
222
+ gr.Number(value=111, label="Edit Region End"),
223
+ gr.Number(value=45, label="Shortened Region Length"),
224
+ gr.Textbox(value="CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC", label="SS8 Shortening")
225
+ ],
226
+ outputs=[
227
+ gr.Textbox(label="Original Sequence"),
228
+ gr.Textbox(label="Original SS8"),
229
+ gr.Textbox(label="Original SS8 Edit Region"),
230
+ gr.Textbox(label="Sequence Prompt"),
231
+ gr.Textbox(label="Edited SS8 Prompt"),
232
+ gr.Textbox(label="Proposed SS8 of Edit Region"),
233
+ # gr.Textbox(label="Protein Prompt"),
234
+ gr.Textbox(label="Generated Sequence")
235
+ ]
236
+ )
237
+
238
+ ## Interface for "SASA Editing Example: Exposing a buried helix"
239
+ sasa_app = gr.Interface(
240
+ fn=sasa_edit,
241
+ inputs=[
242
+ gr.Dropdown(label="Model Name", choices=["esm3_sm_open_v1"], value="esm3_sm_open_v1", allow_custom_value=True),
243
+ gr.Textbox(value = "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", label="Hugging Face Token", type="password"),
244
+ gr.Textbox(value = "1LBS", label="PDB ID"),
245
+ gr.Textbox(value = "A", label="Chain ID"),
246
+ gr.Number(value=105, label="Span Start"),
247
+ gr.Number(value=116, label="Span End"),
248
+ # gr.Textbox(value="CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC", label="SS8 String")
249
+ gr.Number(value=4, label="Number of Samples")
250
+ ],
251
+ outputs = [
252
+ gr.Textbox(label="Protein Prompt"),
253
+ gr.Textbox(label="Generated Sequences"),
254
+ gr.Textbox(label="Generated Proteins")
255
+ ]
256
+ )
257
+
258
+ ## Main Interface
259
+ with gr.Blocks(theme=theme) as esm_app:
260
+ with gr.Row():
261
+ gr.Markdown(
262
+ """
263
+ # ESM3: A frontier language model for biology.
264
+ - Created By: [EvolutionaryScale](https://www.evolutionaryscale.ai/blog/esm3-release)
265
+ - Spaces App By: [Tuple, The Cloud Genomics Company](https://tuple.xyz) [[Colby T. Ford](https://colbyford.com)]
266
+ """
267
+ )
268
+ with gr.Row():
269
+ gr.TabbedInterface([
270
+ scaffold_app,
271
+ ss_app,
272
+ sasa_app
273
+ ],
274
+ [
275
+ "Scaffolding Example",
276
+ "Secondary Structure Editing Example",
277
+ "SASA Editing Example"
278
+ ])
279
+
280
+ if __name__ == "__main__":
281
+ esm_app.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ esm
2
+ numpy
3
+ torch>=2.3.0
4
+ py3Dmol
5
+ huggingface_hub