rosettafold2 / app.py
Simon Duerr
fix default options
a82b6a2
import os, time, sys
if not os.path.isfile("RF2_apr23.pt"):
# send param download into background
os.system(
"(apt-get install aria2; aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/RF2_apr23.pt) &"
)
if not os.path.isdir("RoseTTAFold2"):
print("install RoseTTAFold2")
os.system("git clone https://github.com/sokrypton/RoseTTAFold2.git")
print(os.listdir("RoseTTAFold2"))
os.system(
"cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install ."
)
os.system(
"wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py"
)
# install hhsuite
print("install hhsuite")
os.makedirs("hhsuite", exist_ok=True)
os.system(
f"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/"
)
print(os.listdir("hhsuite"))
if os.path.isfile(f"RF2_apr23.pt.aria2"):
print("downloading RoseTTAFold2 params")
while os.path.isfile(f"RF2_apr23.pt.aria2"):
time.sleep(5)
os.environ["DGLBACKEND"] = "pytorch"
sys.path.append("RoseTTAFold2/network")
if "hhsuite" not in os.environ["PATH"]:
os.environ["PATH"] += ":hhsuite/bin:hhsuite/scripts"
import matplotlib.pyplot as plt
import numpy as np
from parsers import parse_a3m
from api import run_mmseqs2
import torch
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import random
from Bio.PDB import *
def get_hash(x):
return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase + ascii_lowercase)
from collections import OrderedDict, Counter
import gradio as gr
if not "pred" in dir():
from predict import Predictor
print("compile RoseTTAFold2")
model_params = "RF2_apr23.pt"
if torch.cuda.is_available():
pred = Predictor(model_params, torch.device("cuda:0"))
else:
print("WARNING: using CPU")
pred = Predictor(model_params, torch.device("cpu"))
def get_unique_sequences(seq_list):
unique_seqs = list(OrderedDict.fromkeys(seq_list))
return unique_seqs
def get_msa(seq, jobname, cov=50, id=90, max_msa=2048, mode="unpaired_paired"):
assert mode in ["unpaired", "paired", "unpaired_paired"]
seqs = [seq] if isinstance(seq, str) else seq
# collapse homooligomeric sequences
counts = Counter(seqs)
u_seqs = list(counts.keys())
u_nums = list(counts.values())
# expand homooligomeric sequences
first_seq = "/".join(sum([[x] * n for x, n in zip(u_seqs, u_nums)], []))
msa = [first_seq]
path = os.path.join(jobname, "msa")
os.makedirs(path, exist_ok=True)
if mode in ["paired", "unpaired_paired"] and len(u_seqs) > 1:
print("getting paired MSA")
out_paired = run_mmseqs2(u_seqs, f"{path}/", use_pairing=True)
headers, sequences = [], []
for a3m_lines in out_paired:
n = -1
for line in a3m_lines.split("\n"):
if len(line) > 0:
if line.startswith(">"):
n += 1
if len(headers) < (n + 1):
headers.append([])
sequences.append([])
headers[n].append(line)
else:
sequences[n].append(line)
# filter MSA
with open(f"{path}/paired_in.a3m", "w") as handle:
for n, sequence in enumerate(sequences):
handle.write(f">n{n}\n{''.join(sequence)}\n")
os.system(
f"hhfilter -i {path}/paired_in.a3m -id {id} -cov {cov} -o {path}/paired_out.a3m"
)
with open(f"{path}/paired_out.a3m", "r") as handle:
for line in handle:
if line.startswith(">"):
n = int(line[2:])
xs = sequences[n]
# expand homooligomeric sequences
xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)]
msa.append("/".join(xs))
if len(msa) < max_msa and (
mode in ["unpaired", "unpaired_paired"] or len(u_seqs) == 1
):
print("getting unpaired MSA")
out = run_mmseqs2(u_seqs, f"{path}/")
Ls = [len(seq) for seq in u_seqs]
sub_idx = []
sub_msa = []
sub_msa_num = 0
for n, a3m_lines in enumerate(out):
sub_msa.append([])
with open(f"{path}/in_{n}.a3m", "w") as handle:
handle.write(a3m_lines)
# filter
os.system(
f"hhfilter -i {path}/in_{n}.a3m -id {id} -cov {cov} -o {path}/out_{n}.a3m"
)
with open(f"{path}/out_{n}.a3m", "r") as handle:
for line in handle:
if not line.startswith(">"):
xs = ["-" * l for l in Ls]
xs[n] = line.rstrip()
# expand homooligomeric sequences
xs = ["/".join([x] * num) for x, num in zip(xs, u_nums)]
sub_msa[-1].append("/".join(xs))
sub_msa_num += 1
sub_idx.append(list(range(len(sub_msa[-1]))))
while len(msa) < max_msa and sub_msa_num > 0:
for n in range(len(sub_idx)):
if len(sub_idx[n]) > 0:
msa.append(sub_msa[n][sub_idx[n].pop(0)])
sub_msa_num -= 1
if len(msa) == max_msa:
break
with open(f"{jobname}/msa.a3m", "w") as handle:
for n, sequence in enumerate(msa):
handle.write(f">n{n}\n{sequence}\n")
from Bio.PDB.PDBExceptions import PDBConstructionWarning
import warnings
from Bio.PDB import *
import numpy as np
def add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname):
pdb_parser = PDBParser()
warnings.filterwarnings("ignore", category=PDBConstructionWarning)
structure = pdb_parser.get_structure(
"pdb", f"{jobname}/rf2_seed{best_seed}_00_pred.pdb"
)
io = MMCIFIO()
io.set_structure(structure)
io.save(f"{jobname}/rf2_seed{best_seed}_00_pred.cif")
plddt_cif = f"""#
loop_
_ma_qa_metric.id
_ma_qa_metric.mode
_ma_qa_metric.name
_ma_qa_metric.software_group_id
_ma_qa_metric.type
1 global pLDDT 1 pLDDT
2 local pLDDT 1 pLDDT
#
_ma_qa_metric_global.metric_id 1
_ma_qa_metric_global.metric_value {best_plddt:.3f}
_ma_qa_metric_global.model_id 1
_ma_qa_metric_global.ordinal_id 1
#
loop_
_ma_qa_metric_local.label_asym_id
_ma_qa_metric_local.label_comp_id
_ma_qa_metric_local.label_seq_id
_ma_qa_metric_local.metric_id
_ma_qa_metric_local.metric_value
_ma_qa_metric_local.model_id
_ma_qa_metric_local.ordinal_id"""
for chain in structure[0]:
for i, residue in enumerate(chain):
plddt_cif += f"\n{chain.id} {residue.resname} {residue.id[1]} 2 {best_plddts[i]*100:.2f} 1 {residue.id[1]}"
plddt_cif += "\n#"
with open(f"{jobname}/rf2_seed{best_seed}_00_pred.cif", "a") as f:
f.write(plddt_cif)
def predict(
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
mode="web",
):
if os.path.exists("/home/user/app"): # crude check if on spaces
if len(sequence) > 600:
raise gr.Error(
f"Your sequence is too long ({len(sequence)}). "
"Please use the full version of RoseTTAfold2 directly from GitHub."
)
random_seed = int(random_seed)
num_models = int(num_models)
max_msa = int(max_msa)
num_recycles = int(num_recycles)
order = int(order)
max_extra_msa = max_msa * 8
print("sequence", sequence)
sequence = re.sub("[^A-Z:]", "", sequence.replace("/", ":").upper())
sequence = re.sub(":+", ":", sequence)
sequence = re.sub("^[:]+", "", sequence)
sequence = re.sub("[:]+$", "", sequence)
print("sequence", sequence)
if sym in ["X", "C"]:
copies = int(order)
elif sym in ["D"]:
copies = int(order) * 2
else:
copies = {"T": 12, "O": 24, "I": 60}[sym]
order = ""
symm = sym + str(order)
sequences = sequence.replace(":", "/").split("/")
if collapse_identical:
u_sequences = get_unique_sequences(sequences)
else:
u_sequences = sequences
sequences = sum([u_sequences] * copies, [])
lengths = [len(s) for s in sequences]
# TODO
subcrop = 1000 if sum(lengths) > 1400 else -1
sequence = "/".join(sequences)
jobname = jobname + "_" + symm + "_" + get_hash(sequence)[:5]
print(f"jobname: {jobname}")
print(f"lengths: {lengths}")
print("final_sequence", u_sequences)
os.makedirs(jobname, exist_ok=True)
if msa_method == "mmseqs2":
get_msa(u_sequences, jobname, mode=pair_mode, max_msa=max_extra_msa)
elif msa_method == "single_sequence":
u_sequence = "/".join(u_sequences)
with open(f"{jobname}/msa.a3m", "w") as a3m:
a3m.write(f">{jobname}\n{u_sequence}\n")
# elif msa_method == "custom_a3m":
# print("upload custom a3m")
# # msa_dict = files.upload()
# lines = msa_dict[list(msa_dict.keys())[0]].decode().splitlines()
# a3m_lines = []
# for line in lines:
# line = line.replace("\x00", "")
# if len(line) > 0 and not line.startswith("#"):
# a3m_lines.append(line)
# with open(f"{jobname}/msa.a3m", "w") as a3m:
# a3m.write("\n".join(a3m_lines))
best_plddt = None
best_seed = None
for seed in range(int(random_seed), int(random_seed) + int(num_models)):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
npz = f"{jobname}/rf2_seed{seed}_00.npz"
mlm = 0.15 if use_mlm else 0
print("MLM", mlm, use_mlm)
pred.predict(
inputs=[f"{jobname}/msa.a3m"],
out_prefix=f"{jobname}/rf2_seed{seed}",
symm=symm,
ffdb=None, # TODO (templates),
n_recycles=num_recycles,
msa_mask=0.15 if use_mlm else 0,
msa_concat_mode=msa_concat_mode,
nseqs=max_msa,
nseqs_full=max_extra_msa,
subcrop=subcrop,
is_training=use_dropout,
)
plddt = np.load(npz)["lddt"].mean()
if best_plddt is None or plddt > best_plddt:
best_plddt = plddt
best_plddts = np.load(npz)["lddt"]
best_seed = seed
if mode == "web":
# Mol* only displays AlphaFold plDDT if they are in a cif.
pdb_parser = PDBParser()
mmcif_parser = MMCIFParser()
plddt_cif = add_plddt_to_cif(best_plddts, best_plddt, best_seed, jobname)
return f"{jobname}/rf2_seed{best_seed}_00_pred.cif"
else:
# for api just return a pdb file
return f"{jobname}/rf2_seed{best_seed}_00_pred.pdb"
def predict_api(
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
):
filename = predict(
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
mode="api",
)
with open(f"{filename}") as fp:
return fp.read()
def molecule(input_pdb, public_link):
print(input_pdb)
print(public_link + "/file=" + input_pdb)
link = public_link + "/file=" + input_pdb
x = (
"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0">
<title>PDBe Molstar - Helper functions</title>
<!-- Molstar CSS & JS -->
<link rel="stylesheet" type="text/css" href="https://www.ebi.ac.uk/pdbe/pdb-component-library/css/pdbe-molstar-light-3.1.0.css">
<script type="text/javascript" src="https://www.ebi.ac.uk/pdbe/pdb-component-library/js/pdbe-molstar-plugin-3.1.0.js"></script>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
.msp-plugin ::-webkit-scrollbar-thumb {
background-color: #474748 !important;
}
.viewerSection {
margin: 120px 0 0 0px;
}
#myViewer{
float:left;
width:100%;
height: 800px;
position:relative;
}
.btn{
font-family: "Open Sans", sans-serif;
display: inline-block;
outline: none;
cursor: pointer;
font-weight: 600;
border-radius: 3px;
padding: 12px 24px;
border: 0;
margin:0 10px;
line-height: 1.15;
font-size: 16px;
text-decoration: none;
}
.btn-orange{
background: #ff5000;
color: #fff;
}
.btn-gray{
color: #3a4149;
background: #e7ebee;
}
.btn:hover{
transition: all .1s ease;
box-shadow: 0 0 0 0 #fff, 0 0 0 3px #ddd;}
.text-center{
display: flex;
align-items: center;
justify-content: center;
padding: 20px 0;
}
.flex{
padding: 10px;
display: flex;
align-items: center;
justify-content: center;
width:fit-content;
}
.flex svg{
margin-right: 10px;
width:16px;
height:16px;
}
.flex a{
margin:0 10px;
}
</style>
</head>
<body>
<div class="text-center">
<a class="btn btn-orange flex" href=\""""
+ link
+ """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
<path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path>
</svg> <span>CIF File</span></a>
<a class="btn btn-gray flex" href=\""""
+ link.replace(".cif", ".pdb")
+ """\" target="_blank"> <svg fill="none" stroke="currentColor" stroke-width="1.5" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
<path stroke-linecap="round" stroke-linejoin="round" d="M19.5 13.5L12 21m0 0l-7.5-7.5M12 21V3"></path>
</svg> <span>PDB File</span></a>
</div>
<div class="viewerSection">
<!-- Molstar container -->
<div id="myViewer"></div>
</div>
<script>
//Create plugin instance
var viewerInstance = new PDBeMolstarPlugin();
//Set options (Checkout available options list in the documentation)
var options = {
customData: {
url: \""""
+ link
+ """\",
format: "cif"
},
alphafoldView: true,
bgColor: {r:255, g:255, b:255},
//hideCanvasControls: ["selection", "animation", "controlToggle", "controlInfo"]
}
//Get element from HTML/Template to place the viewer
var viewerContainer = document.getElementById("myViewer");
//Call render method to display the 3D view
viewerInstance.render(viewerContainer, options);
</script>
</body>
</html>"""
)
return f"""<iframe style="width: 100%; height: 1000px" name="result" allow="midi; geolocation; microphone; camera;
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
allow-scripts allow-same-origin allow-popups
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
def predict_web(
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
):
if os.path.exists("/home/user/app"):
public_link = "https://simonduerr-rosettafold2.hf.space"
else:
public_link = "http://localhost:7860"
filename = predict(
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
mode="web",
)
return molecule(filename, public_link)
with gr.Blocks() as rosettafold:
gr.Markdown("# RoseTTAFold2")
gr.Markdown(
"""If using please cite: [manuscript](https://www.biorxiv.org/content/10.1101/2023.05.24.542179v1)
<br> Heavily based on [RoseTTAFold2 ColabFold notebook](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/RoseTTAFold2.ipynb)"""
)
with gr.Accordion("How to use in PyMol", open=False):
gr.HTML(
"""<code>os.system('wget https://huggingface.co/spaces/simonduerr/rosettafold2/raw/main/rosettafold_pymol.py') <br>
run rosettafold_pymol.py <br>
rosettafold2 sequence, jobname, [sym, order, msa_concat_mode, msa_method, pair_mode, collapse_identical, num_recycles, use_mlm, use_dropout, max_msa, random_seed, num_models] <br>
color_plddt jobname</code>
"""
)
sequence = gr.Textbox(
label="sequence",
value="PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK",
)
jobname = gr.Textbox(label="jobname", value="test")
with gr.Accordion("Additional settings", open=False):
sym = gr.Textbox(label="sym", value="X")
order = gr.Slider(label="order", value=1, step=1, minimum=1, maximum=12)
msa_concat_mode = gr.Dropdown(
label="msa_concat_mode",
value="default",
choices=["diag", "repeat", "default"],
)
msa_method = gr.Dropdown(
label="msa_method",
value="single_sequence",
choices=[
"mmseqs2",
"single_sequence",
], # dont allow custom a3m for now , "custom_a3m"
)
pair_mode = gr.Dropdown(
label="pair_mode",
value="unpaired_paired",
choices=["unpaired_paired", "paired", "unpaired"],
)
num_recycles = gr.Dropdown(
label="num_recycles", value="6", choices=["0", "1", "3", "6", "12", "24"]
)
use_mlm = gr.Checkbox(label="use_mlm", value=False)
use_dropout = gr.Checkbox(label="use_dropout", value=False)
collapse_identical = gr.Checkbox(label="collapse_identical", value=False)
max_msa = gr.Dropdown(
choices=["16", "32", "64", "128", "256", "512"],
value="16",
label="max_msa",
)
random_seed = gr.Textbox(label="random_seed", value=0)
num_models = gr.Dropdown(
label="num_models", value="1", choices=["1", "2", "4", "8", "16", "32"]
)
btn = gr.Button("Run", visible=False)
btn_web = gr.Button("Run")
output_plain = gr.HTML()
output = gr.HTML()
btn.click(
fn=predict_api,
inputs=[
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
],
outputs=output_plain,
api_name="rosettafold2",
)
btn_web.click(
fn=predict_web,
inputs=[
sequence,
jobname,
sym,
order,
msa_concat_mode,
msa_method,
pair_mode,
collapse_identical,
num_recycles,
use_mlm,
use_dropout,
max_msa,
random_seed,
num_models,
],
outputs=output,
)
rosettafold.launch()