Simon Duerr commited on
Commit
a82b6a2
1 Parent(s): 5cc9777

fix default options

Browse files
Files changed (2) hide show
  1. app.py +17 -16
  2. rosettafold_pymol.py +19 -13
app.py CHANGED
@@ -240,11 +240,12 @@ def predict(
240
  order = int(order)
241
 
242
  max_extra_msa = max_msa * 8
 
243
  sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper())
244
  sequence = re.sub(":+", ":", sequence)
245
  sequence = re.sub("^[:]+", "", sequence)
246
  sequence = re.sub("[:]+$", "", sequence)
247
-
248
  if sym in ["X", "C"]:
249
  copies = int(order)
250
  elif sym in ["D"]:
@@ -270,7 +271,7 @@ def predict(
270
 
271
  print(f"jobname: {jobname}")
272
  print(f"lengths: {lengths}")
273
-
274
  os.makedirs(jobname, exist_ok=True)
275
  if msa_method == "mmseqs2":
276
  get_msa(u_sequences, jobname, mode=pair_mode, max_msa=max_extra_msa)
@@ -279,19 +280,17 @@ def predict(
279
  u_sequence = "/".join(u_sequences)
280
  with open(f"{jobname}/msa.a3m", "w") as a3m:
281
  a3m.write(f">{jobname}\n{u_sequence}\n")
282
-
283
- elif msa_method == "custom_a3m":
284
- print("upload custom a3m")
285
- # msa_dict = files.upload()
286
- lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()
287
- a3m_lines = []
288
- for line in lines:
289
- line = line.replace("\x00", "")
290
- if len(line) > 0 and not line.startswith("#"):
291
- a3m_lines.append(line)
292
-
293
- with open(f"{jobname}/msa.a3m", "w") as a3m:
294
- a3m.write("\n".join(a3m_lines))
295
 
296
  best_plddt = None
297
  best_seed = None
@@ -300,13 +299,15 @@ def predict(
300
  random.seed(seed)
301
  np.random.seed(seed)
302
  npz = f"{jobname}/rf2_seed{seed}_00.npz"
 
 
303
  pred.predict(
304
  inputs=[f"{jobname}/msa.a3m"],
305
  out_prefix=f"{jobname}/rf2_seed{seed}",
306
  symm=symm,
307
  ffdb=None, # TODO (templates),
308
  n_recycles=num_recycles,
309
- msa_mask=0.15 if use_mlm else 0.0,
310
  msa_concat_mode=msa_concat_mode,
311
  nseqs=max_msa,
312
  nseqs_full=max_extra_msa,
 
240
  order = int(order)
241
 
242
  max_extra_msa = max_msa * 8
243
+ print("sequence", sequence)
244
  sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper())
245
  sequence = re.sub(":+", ":", sequence)
246
  sequence = re.sub("^[:]+", "", sequence)
247
  sequence = re.sub("[:]+$", "", sequence)
248
+ print("sequence", sequence)
249
  if sym in ["X", "C"]:
250
  copies = int(order)
251
  elif sym in ["D"]:
 
271
 
272
  print(f"jobname: {jobname}")
273
  print(f"lengths: {lengths}")
274
+ print("final_sequence", u_sequences)
275
  os.makedirs(jobname, exist_ok=True)
276
  if msa_method == "mmseqs2":
277
  get_msa(u_sequences, jobname, mode=pair_mode, max_msa=max_extra_msa)
 
280
  u_sequence = "/".join(u_sequences)
281
  with open(f"{jobname}/msa.a3m", "w") as a3m:
282
  a3m.write(f">{jobname}\n{u_sequence}\n")
283
+ # elif msa_method == "custom_a3m":
284
+ # print("upload custom a3m")
285
+ # # msa_dict = files.upload()
286
+ # lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()
287
+ # a3m_lines = []
288
+ # for line in lines:
289
+ # line = line.replace("\x00", "")
290
+ # if len(line) > 0 and not line.startswith("#"):
291
+ # a3m_lines.append(line)
292
+ # with open(f"{jobname}/msa.a3m", "w") as a3m:
293
+ # a3m.write("\n".join(a3m_lines))
 
 
294
 
295
  best_plddt = None
296
  best_seed = None
 
299
  random.seed(seed)
300
  np.random.seed(seed)
301
  npz = f"{jobname}/rf2_seed{seed}_00.npz"
302
+ mlm = 0.15 if use_mlm else 0
303
+ print("MLM", mlm, use_mlm)
304
  pred.predict(
305
  inputs=[f"{jobname}/msa.a3m"],
306
  out_prefix=f"{jobname}/rf2_seed{seed}",
307
  symm=symm,
308
  ffdb=None, # TODO (templates),
309
  n_recycles=num_recycles,
310
+ msa_mask=0.15 if use_mlm else 0,
311
  msa_concat_mode=msa_concat_mode,
312
  nseqs=max_msa,
313
  nseqs_full=max_extra_msa,
rosettafold_pymol.py CHANGED
@@ -1,6 +1,7 @@
1
  from pymol import cmd
2
  import requests
3
 
 
4
 
5
  # from gradio_client import Client
6
 
@@ -68,13 +69,13 @@ def query_rosettafold2(
68
  msa_concat_mode: str = "diag",
69
  msa_method: str = "single_sequence",
70
  pair_mode: str = "unpaired_paired",
71
- collapse_identical: bool = True,
72
- num_recycles: int = 0,
73
- use_mlm: bool = True,
74
- use_dropout: bool = True,
75
  max_msa: int = 16,
76
  random_seed: int = 0,
77
- num_models: int = 0,
78
  ):
79
  """
80
  AUTHOR
@@ -115,13 +116,13 @@ def query_rosettafold2(
115
  Collapse identical sequences Default:True
116
 
117
  num_recycles:
118
- Number of recycles Default:0 Options: 0, 1, 3, 6, 12, 24
119
 
120
  use_mlm:
121
- Use MLM Default:True
122
 
123
  use_dropout:
124
- Use dropout Default:True
125
 
126
  max_msa:
127
  Max MSA Default:16
@@ -132,8 +133,13 @@ def query_rosettafold2(
132
  num_models:
133
  Number of models Default:0
134
  """
 
 
 
 
 
135
  response = requests.post(
136
- "https://simonduerr-rosettafold2.hf.space/run/rosettafold2/",
137
  json={
138
  "data": [
139
  sequence, # str in 'sequence' Textbox component
@@ -143,10 +149,10 @@ def query_rosettafold2(
143
  "diag", # str (Option from: ['diag', 'repeat', 'default']) in 'msa_concat_mode' Dropdown component
144
  "single_sequence", # str (Option from: ['mmseqs2', 'single_sequence', 'custom_a3m']) in 'msa_method' Dropdown component
145
  "unpaired_paired", # str (Option from: ['unpaired_paired', 'paired', 'unpaired']) in 'pair_mode' Dropdown component
146
- True, # bool in 'collapse_identical' Checkbox component
147
- 0, # int (Option from: ['0', '1', '3', '6', '12', '24']) in 'num_recycles' Dropdown component
148
- True, # bool in 'use_mlm' Checkbox component
149
- True, # bool in 'use_dropout' Checkbox component
150
  16, # int (Option from: ['16', '32', '64', '128', '256', '512']) in 'max_msa' Dropdown component
151
  0, # int in 'random_seed' Textbox component
152
  1, # int (Option from: ['1', '2', '4', '8', '16', '32']) in 'num_models' Dropdown component
 
1
  from pymol import cmd
2
  import requests
3
 
4
+ import os
5
 
6
  # from gradio_client import Client
7
 
 
69
  msa_concat_mode: str = "diag",
70
  msa_method: str = "single_sequence",
71
  pair_mode: str = "unpaired_paired",
72
+ collapse_identical: bool = False,
73
+ num_recycles: int = 1,
74
+ use_mlm: bool = False,
75
+ use_dropout: bool = False,
76
  max_msa: int = 16,
77
  random_seed: int = 0,
78
+ num_models: int = 1,
79
  ):
80
  """
81
  AUTHOR
 
116
  Collapse identical sequences Default:True
117
 
118
  num_recycles:
119
+ Number of recycles Default:6 Options: 0, 1, 3, 6, 12, 24
120
 
121
  use_mlm:
122
+ Use MLM Default:False
123
 
124
  use_dropout:
125
+ Use dropout Default:False
126
 
127
  max_msa:
128
  Max MSA Default:16
 
133
  num_models:
134
  Number of models Default:0
135
  """
136
+ if os.path.exists("/home/user/app/"):
137
+ url = "https://simonduerr-rosettafold2.hf.space"
138
+ else:
139
+ url = "http://localhost:7860"
140
+
141
  response = requests.post(
142
+ url + "/run/rosettafold2/",
143
  json={
144
  "data": [
145
  sequence, # str in 'sequence' Textbox component
 
149
  "diag", # str (Option from: ['diag', 'repeat', 'default']) in 'msa_concat_mode' Dropdown component
150
  "single_sequence", # str (Option from: ['mmseqs2', 'single_sequence', 'custom_a3m']) in 'msa_method' Dropdown component
151
  "unpaired_paired", # str (Option from: ['unpaired_paired', 'paired', 'unpaired']) in 'pair_mode' Dropdown component
152
+ False, # bool in 'collapse_identical' Checkbox component
153
+ 6, # int (Option from: ['0', '1', '3', '6', '12', '24']) in 'num_recycles' Dropdown component
154
+ False, # bool in 'use_mlm' Checkbox component
155
+ False, # bool in 'use_dropout' Checkbox component
156
  16, # int (Option from: ['16', '32', '64', '128', '256', '512']) in 'max_msa' Dropdown component
157
  0, # int in 'random_seed' Textbox component
158
  1, # int (Option from: ['1', '2', '4', '8', '16', '32']) in 'num_models' Dropdown component