Spaces:
Sleeping
Sleeping
Simon Duerr
commited on
Commit
·
22e3abd
1
Parent(s):
8a361d8
fix: train path, update draw samples
Browse files- README.md +9 -5
- app.py +7 -4
- checkpoints/allatom.yml +1 -1
- checkpoints/backbone.yml +1 -1
- configs/allatom.yml +1 -1
- configs/backbone.yml +1 -1
- configs/seqdes.yml +1 -1
- core/protein_mpnn.py +3 -3
- draw_samples.py +12 -9
- protpardelle_pymol.py +12 -2
README.md
CHANGED
@@ -10,15 +10,15 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
-
# protpardelle
|
14 |
|
15 |
Code for the paper: [An all-atom protein generative model](https://www.biorxiv.org/content/10.1101/2023.05.24.542194v1.full).
|
16 |
|
17 |
The code is under active development and we welcome contributions, feature requests, issues, corrections, and any questions! Where we have used or adapted code from others we have tried to give proper attribution, but please let us know if anything should be corrected.
|
18 |
|
19 |
-
## Environment
|
20 |
|
21 |
-
To set up the conda environment, run `conda env create -f configs/environment.yml`.
|
22 |
|
23 |
## Inference
|
24 |
|
@@ -28,12 +28,16 @@ To draw 8 samples per length for lengths in `range(70, 150, 5)` from the backbon
|
|
28 |
|
29 |
`python draw_samples.py --type backbone --param n_steps --paramval 100 --minlen 70 --maxlen 150 --steplen 5 --perlen 8`
|
30 |
|
31 |
-
We have also added the ability to provide an input PDB file and a list of (zero-indexed) indices to condition on from the PDB file. We can expect it to do better or worse depending on the problem (better on easier problems such as inpainting, worse on difficult problems such as discontiguous
|
32 |
|
33 |
-
`python draw_samples.py --input_pdb --
|
|
|
|
|
34 |
|
35 |
## Training
|
36 |
|
|
|
|
|
37 |
Pretrained model weights are provided, but if you are interested in training your own models, we have provided training code together with some basic online evaluation. You will need to create a Weights & Biases account.
|
38 |
|
39 |
The dataset can be downloaded from [CATH](http://download.cathdb.info/cath/releases/all-releases/v4_3_0/non-redundant-data-sets/), and the train/validation/test splits used can be downloaded with
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
# protpardelle
|
14 |
|
15 |
Code for the paper: [An all-atom protein generative model](https://www.biorxiv.org/content/10.1101/2023.05.24.542194v1.full).
|
16 |
|
17 |
The code is under active development and we welcome contributions, feature requests, issues, corrections, and any questions! Where we have used or adapted code from others we have tried to give proper attribution, but please let us know if anything should be corrected.
|
18 |
|
19 |
+
## Environment and setup
|
20 |
|
21 |
+
To set up the conda environment, run `conda env create -f configs/environment.yml` then `conda activate delle`. You will also need to clone the [ProteinMPNN repository](https://github.com/dauparas/ProteinMPNN) to the same directory that contains the `protpardelle/` repository. You may also need to set the `home_dir` variable in the configs you use to the path to the directory containing the `protpardelle/` directory.
|
22 |
|
23 |
## Inference
|
24 |
|
|
|
28 |
|
29 |
`python draw_samples.py --type backbone --param n_steps --paramval 100 --minlen 70 --maxlen 150 --steplen 5 --perlen 8`
|
30 |
|
31 |
+
We have also added the ability to provide an input PDB file and a list of (zero-indexed) indices to condition on from the PDB file. Note also that current models are single-chain only, so multi-chain PDBs will be treated as single chains (we intend to release multi-chain models in a later update). We can expect it to do better or worse depending on the problem (better on easier problems such as inpainting, worse on difficult problems such as discontiguous scaffolding). Use this command to resample the first 25 and 71st to 80th residues of `my_pdb.pdb`.
|
32 |
|
33 |
+
`python draw_samples.py --input_pdb my_pdb.pdb --resample_idxs 0-25,70-80`
|
34 |
+
|
35 |
+
For more control over the sampling process, including tweaking the sampling hyperparameters and more specific methods of conditioning, you can directly interface with the `model.sample()` function; we have provided examples of how to configure and run these commands in `sampling.py`.
|
36 |
|
37 |
## Training
|
38 |
|
39 |
+
Note (Sep 2023): the lab has decided to collect usage statistics on people interested in training their own versions of Protpardelle (for funding and other purposes). To obtain a copy of the repository with training code, please complete [this Google Form](https://docs.google.com/forms/d/1WKMVbydLh6LIegc3HfwMQhgL2_qnrY7ks9FM_ylo4ts) - you will receive a link to a Google Drive zip which contains the repository with training code. After publication, the plan is to include the full training code directly in this repository.
|
40 |
+
|
41 |
Pretrained model weights are provided, but if you are interested in training your own models, we have provided training code together with some basic online evaluation. You will need to create a Weights & Biases account.
|
42 |
|
43 |
The dataset can be downloaded from [CATH](http://download.cathdb.info/cath/releases/all-releases/v4_3_0/non-redundant-data-sets/), and the train/validation/test splits used can be downloaded with
|
app.py
CHANGED
@@ -303,15 +303,15 @@ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, step
|
|
303 |
if args.type == "backbone":
|
304 |
if args.model_checkpoint:
|
305 |
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
|
306 |
-
cfg_path = f"{args.model_checkpoint}/
|
307 |
else:
|
308 |
checkpoint = (
|
309 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
310 |
)
|
311 |
cfg_path = f"{model_directory}/configs/backbone.yml"
|
312 |
-
|
313 |
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
314 |
-
model = models.Protpardelle(
|
315 |
model.load_state_dict(weights)
|
316 |
model.to(device)
|
317 |
model.eval()
|
@@ -319,7 +319,7 @@ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, step
|
|
319 |
elif args.type == "allatom":
|
320 |
if args.model_checkpoint:
|
321 |
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
|
322 |
-
cfg_path = f"{args.model_checkpoint}/
|
323 |
else:
|
324 |
checkpoint = (
|
325 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
@@ -345,6 +345,9 @@ def protpardelle(path_to_file, m, resample_idx, modeltype, minlen, maxlen, step
|
|
345 |
for k, v in sampling_kwargs_readme:
|
346 |
f.write(f"{k}\t{v}\n")
|
347 |
|
|
|
|
|
|
|
348 |
# Draw samples
|
349 |
output_files = draw_and_save_samples(
|
350 |
model,
|
|
|
303 |
if args.type == "backbone":
|
304 |
if args.model_checkpoint:
|
305 |
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
|
306 |
+
cfg_path = f"{args.model_checkpoint}/backbone_pretrained.yml"
|
307 |
else:
|
308 |
checkpoint = (
|
309 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
310 |
)
|
311 |
cfg_path = f"{model_directory}/configs/backbone.yml"
|
312 |
+
config = utils.load_config(cfg_path)
|
313 |
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
314 |
+
model = models.Protpardelle(config, device=device)
|
315 |
model.load_state_dict(weights)
|
316 |
model.to(device)
|
317 |
model.eval()
|
|
|
319 |
elif args.type == "allatom":
|
320 |
if args.model_checkpoint:
|
321 |
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
|
322 |
+
cfg_path = f"{args.model_checkpoint}/allatom_pretrained.yml"
|
323 |
else:
|
324 |
checkpoint = (
|
325 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
|
|
345 |
for k, v in sampling_kwargs_readme:
|
346 |
f.write(f"{k}\t{v}\n")
|
347 |
|
348 |
+
print(f"Model loaded from {checkpoint}")
|
349 |
+
print(f"Beginning sampling for {date_string}...")
|
350 |
+
|
351 |
# Draw samples
|
352 |
output_files = draw_and_save_samples(
|
353 |
model,
|
checkpoints/allatom.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
train:
|
2 |
-
home_dir: '
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
|
|
1 |
train:
|
2 |
+
home_dir: ''
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
checkpoints/backbone.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
train:
|
2 |
-
home_dir: '
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
|
|
1 |
train:
|
2 |
+
home_dir: ''
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
configs/allatom.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
train:
|
2 |
-
home_dir: '
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
|
|
1 |
train:
|
2 |
+
home_dir: ''
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
configs/backbone.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
train:
|
2 |
-
home_dir: '
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
|
|
1 |
train:
|
2 |
+
home_dir: ''
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
configs/seqdes.yml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
train:
|
2 |
-
home_dir: '
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
|
|
1 |
train:
|
2 |
+
home_dir: ''
|
3 |
seed: 0
|
4 |
checkpoint: ['', 0]
|
5 |
batch_size: 32
|
core/protein_mpnn.py
CHANGED
@@ -55,10 +55,11 @@ def get_mpnn_model(model_name='v_48_020', path_to_model_weights='', ca_only=Fals
|
|
55 |
else:
|
56 |
file_path = os.path.realpath(__file__)
|
57 |
k = file_path.rfind("/")
|
|
|
58 |
if ca_only:
|
59 |
-
model_folder_path = file_path[:k] + '/ca_model_weights/'
|
60 |
else:
|
61 |
-
model_folder_path = file_path[:k] + '/vanilla_model_weights/'
|
62 |
|
63 |
checkpoint_path = model_folder_path + f'{model_name}.pt'
|
64 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
@@ -450,7 +451,6 @@ def run_proteinmpnn(model=None, pdb_path='', pdb_path_chains='', path_to_model_w
|
|
450 |
print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
|
451 |
if write_output_files:
|
452 |
f.close()
|
453 |
-
|
454 |
return new_mpnn_seqs
|
455 |
|
456 |
|
|
|
55 |
else:
|
56 |
file_path = os.path.realpath(__file__)
|
57 |
k = file_path.rfind("/")
|
58 |
+
k = file_path[:k].rfind("/")
|
59 |
if ca_only:
|
60 |
+
model_folder_path = file_path[:k] + '/ProteinMPNN/ca_model_weights/'
|
61 |
else:
|
62 |
+
model_folder_path = file_path[:k] + '/ProteinMPNN/vanilla_model_weights/'
|
63 |
|
64 |
checkpoint_path = model_folder_path + f'{model_name}.pt'
|
65 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
451 |
print(f'{num_seqs} sequences of length {total_length} generated in {dt} seconds')
|
452 |
if write_output_files:
|
453 |
f.close()
|
|
|
454 |
return new_mpnn_seqs
|
455 |
|
456 |
|
draw_samples.py
CHANGED
@@ -122,18 +122,18 @@ class Manager(object):
|
|
122 |
"--perlen", type=int, default=2, help="How many samples per sequence length"
|
123 |
)
|
124 |
self.parser.add_argument(
|
125 |
-
"--minlen", type=int,
|
126 |
)
|
127 |
self.parser.add_argument(
|
128 |
"--maxlen",
|
129 |
type=int,
|
130 |
-
|
131 |
help="Maximum sequence length, not inclusive",
|
132 |
)
|
133 |
self.parser.add_argument(
|
134 |
"--steplen",
|
135 |
type=int,
|
136 |
-
|
137 |
help="How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
|
138 |
)
|
139 |
self.parser.add_argument(
|
@@ -279,15 +279,15 @@ def main():
|
|
279 |
if args.type == "backbone":
|
280 |
if args.model_checkpoint:
|
281 |
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
|
282 |
-
cfg_path = f"{args.model_checkpoint}/
|
283 |
else:
|
284 |
checkpoint = (
|
285 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
286 |
)
|
287 |
cfg_path = f"{model_directory}/configs/backbone.yml"
|
288 |
-
|
289 |
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
290 |
-
model = models.Protpardelle(
|
291 |
model.load_state_dict(weights)
|
292 |
model.to(device)
|
293 |
model.eval()
|
@@ -295,7 +295,7 @@ def main():
|
|
295 |
elif args.type == "allatom":
|
296 |
if args.model_checkpoint:
|
297 |
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
|
298 |
-
cfg_path = f"{args.model_checkpoint}/
|
299 |
else:
|
300 |
checkpoint = (
|
301 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
@@ -310,8 +310,11 @@ def main():
|
|
310 |
model.eval()
|
311 |
model.device = device
|
312 |
|
|
|
|
|
|
|
313 |
# Sampling
|
314 |
-
with open(
|
315 |
f.write(f"Sampling run for {date_string}\n")
|
316 |
f.write(f"Random seed {seed}\n")
|
317 |
f.write(f"Model checkpoint: {checkpoint}\n")
|
@@ -341,7 +344,7 @@ def main():
|
|
341 |
print(f"Of this, {sampling_time} seconds were for actual sampling.")
|
342 |
print(f"{total_num_samples} total samples were drawn.")
|
343 |
|
344 |
-
with open(
|
345 |
f.write(f"Total job time: {time_elapsed} seconds\n")
|
346 |
f.write(f"Model run time: {sampling_time} seconds\n")
|
347 |
f.write(f"Total samples drawn: {total_num_samples}\n")
|
|
|
122 |
"--perlen", type=int, default=2, help="How many samples per sequence length"
|
123 |
)
|
124 |
self.parser.add_argument(
|
125 |
+
"--minlen", type=int, default=50, help="Minimum sequence length"
|
126 |
)
|
127 |
self.parser.add_argument(
|
128 |
"--maxlen",
|
129 |
type=int,
|
130 |
+
default=60,
|
131 |
help="Maximum sequence length, not inclusive",
|
132 |
)
|
133 |
self.parser.add_argument(
|
134 |
"--steplen",
|
135 |
type=int,
|
136 |
+
default=5,
|
137 |
help="How frequently to select sequence length, for steplen 2, would be 50, 52, 54, etc",
|
138 |
)
|
139 |
self.parser.add_argument(
|
|
|
279 |
if args.type == "backbone":
|
280 |
if args.model_checkpoint:
|
281 |
checkpoint = f"{args.model_checkpoint}/backbone_state_dict.pth"
|
282 |
+
cfg_path = f"{args.model_checkpoint}/backbone_pretrained.yml"
|
283 |
else:
|
284 |
checkpoint = (
|
285 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
286 |
)
|
287 |
cfg_path = f"{model_directory}/configs/backbone.yml"
|
288 |
+
config = utils.load_config(cfg_path)
|
289 |
weights = torch.load(checkpoint, map_location=device)["model_state_dict"]
|
290 |
+
model = models.Protpardelle(config, device=device)
|
291 |
model.load_state_dict(weights)
|
292 |
model.to(device)
|
293 |
model.eval()
|
|
|
295 |
elif args.type == "allatom":
|
296 |
if args.model_checkpoint:
|
297 |
checkpoint = f"{args.model_checkpoint}/allatom_state_dict.pth"
|
298 |
+
cfg_path = f"{args.model_checkpoint}/allatom_pretrained.yml"
|
299 |
else:
|
300 |
checkpoint = (
|
301 |
f"{model_directory}/checkpoints/epoch{epoch}_training_state.pth"
|
|
|
310 |
model.eval()
|
311 |
model.device = device
|
312 |
|
313 |
+
if config.train.home_dir == '':
|
314 |
+
config.train.home_dir = os.getcwd()
|
315 |
+
|
316 |
# Sampling
|
317 |
+
with open(save_dir + "/readme.txt", "w") as f:
|
318 |
f.write(f"Sampling run for {date_string}\n")
|
319 |
f.write(f"Random seed {seed}\n")
|
320 |
f.write(f"Model checkpoint: {checkpoint}\n")
|
|
|
344 |
print(f"Of this, {sampling_time} seconds were for actual sampling.")
|
345 |
print(f"{total_num_samples} total samples were drawn.")
|
346 |
|
347 |
+
with open(save_dir + "/readme.txt", "a") as f:
|
348 |
f.write(f"Total job time: {time_elapsed} seconds\n")
|
349 |
f.write(f"Model run time: {sampling_time} seconds\n")
|
350 |
f.write(f"Total samples drawn: {total_num_samples}\n")
|
protpardelle_pymol.py
CHANGED
@@ -15,9 +15,9 @@ except ImportError:
|
|
15 |
|
16 |
|
17 |
if os.environ.get("GRADIO_LOCAL") != None:
|
18 |
-
public_link = "http://127.0.0.1:
|
19 |
else:
|
20 |
-
public_link = "
|
21 |
|
22 |
|
23 |
|
@@ -140,6 +140,16 @@ def query_protpardelle_uncond(
|
|
140 |
|
141 |
|
142 |
def setprotpardellelink(link:str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
global public_link
|
144 |
try:
|
145 |
client = Client(link)
|
|
|
15 |
|
16 |
|
17 |
if os.environ.get("GRADIO_LOCAL") != None:
|
18 |
+
public_link = "http://127.0.0.1:7860"
|
19 |
else:
|
20 |
+
public_link = "ProteinDesignLab/protpardelle"
|
21 |
|
22 |
|
23 |
|
|
|
140 |
|
141 |
|
142 |
def setprotpardellelink(link:str):
|
143 |
+
"""
|
144 |
+
AUTHOR
|
145 |
+
Simon Duerr
|
146 |
+
https://twitter.com/simonduerr
|
147 |
+
DESCRIPTION
|
148 |
+
Set a public link to use a locally hosted version of this space
|
149 |
+
USAGE
|
150 |
+
protpardelle_setlink link_or_username/spacename
|
151 |
+
"""
|
152 |
+
|
153 |
global public_link
|
154 |
try:
|
155 |
client = Client(link)
|