TillCyrill
commited on
Commit
•
0da959e
1
Parent(s):
1d6ab1f
scripts from other repo
Browse files- 11GS.pdb +0 -0
- Dockerfile +9 -0
- MDmodel.py +42 -0
- app.py +150 -0
- best_weights_rep0.pt +3 -0
- graph.py +121 -0
- inference_for_md.hdf5 +0 -0
- transformMD.py +21 -0
- transforms.py +46 -0
11GS.pdb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Dockerfile
CHANGED
@@ -41,6 +41,15 @@ ENV AMBERHOME="/usr/bin/amber22/"
|
|
41 |
ENV PATH="$AMBERHOME/bin:$PATH"
|
42 |
ENV PYTHONPATH="$AMBERHOME/lib/python3.10/site-packages"
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
RUN adduser -u 5678 --disabled-password --gecos "" appuser && chown -R appuser .
|
45 |
USER appuser
|
46 |
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
41 |
ENV PATH="$AMBERHOME/bin:$PATH"
|
42 |
ENV PYTHONPATH="$AMBERHOME/lib/python3.10/site-packages"
|
43 |
|
44 |
+
|
45 |
+
RUN useradd -m -u 1000 user
|
46 |
+
USER user
|
47 |
+
ENV HOME=/home/user \
|
48 |
+
PATH=/home/user/.local/bin:$PATH
|
49 |
+
WORKDIR $HOME/app
|
50 |
+
COPY --chown=user . $HOME/app
|
51 |
+
|
52 |
+
|
53 |
RUN adduser -u 5678 --disabled-password --gecos "" appuser && chown -R appuser .
|
54 |
USER appuser
|
55 |
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
MDmodel.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch_geometric.nn import GCNConv
|
5 |
+
|
6 |
+
class GNN_MD(torch.nn.Module):
|
7 |
+
def __init__(self, num_features, hidden_dim):
|
8 |
+
super(GNN_MD, self).__init__()
|
9 |
+
self.conv1 = GCNConv(num_features, hidden_dim)
|
10 |
+
self.bn1 = nn.BatchNorm1d(hidden_dim)
|
11 |
+
self.conv2 = GCNConv(hidden_dim, hidden_dim*2)
|
12 |
+
self.bn2 = nn.BatchNorm1d(hidden_dim*2)
|
13 |
+
self.conv3 = GCNConv(hidden_dim*2, hidden_dim*4)
|
14 |
+
self.bn3 = nn.BatchNorm1d(hidden_dim*4)
|
15 |
+
self.conv4 = GCNConv(hidden_dim*4, hidden_dim*4)
|
16 |
+
self.bn4 = nn.BatchNorm1d(hidden_dim*4)
|
17 |
+
self.conv5 = GCNConv(hidden_dim*4, hidden_dim*8)
|
18 |
+
self.bn5 = nn.BatchNorm1d(hidden_dim*8)
|
19 |
+
self.fc1 = nn.Linear(hidden_dim*8, hidden_dim*4)
|
20 |
+
self.fc2 = nn.Linear(hidden_dim*4, 1)
|
21 |
+
|
22 |
+
|
23 |
+
def forward(self, data):
|
24 |
+
x = self.conv1(data.x, data.edge_index, data.edge_attr.view(-1))
|
25 |
+
x = F.relu(x)
|
26 |
+
x = self.bn1(x)
|
27 |
+
x = self.conv2(x, data.edge_index, data.edge_attr.view(-1))
|
28 |
+
x = F.relu(x)
|
29 |
+
x = self.bn2(x)
|
30 |
+
x = self.conv3(x, data.edge_index, data.edge_attr.view(-1))
|
31 |
+
x = F.relu(x)
|
32 |
+
x = self.bn3(x)
|
33 |
+
x = self.conv4(x, data.edge_index, data.edge_attr.view(-1))
|
34 |
+
x = self.bn4(x)
|
35 |
+
x = F.relu(x)
|
36 |
+
x = self.conv5(x, data.edge_index, data.edge_attr.view(-1))
|
37 |
+
x = self.bn5(x)
|
38 |
+
#x = global_add_pool(x, x.batch)
|
39 |
+
x = F.relu(x)
|
40 |
+
x = F.relu(self.fc1(x))
|
41 |
+
x = F.dropout(x, p=0.25)
|
42 |
+
return self.fc2(x).view(-1)
|
app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
import py3Dmol
|
4 |
+
|
5 |
+
from Bio.PDB import *
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from Bio.PDB import PDBParser
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
import os
|
12 |
+
from MDmodel import GNN_MD
|
13 |
+
import h5py
|
14 |
+
from transformMD import GNNTransformMD
|
15 |
+
|
16 |
+
# JavaScript functions
|
17 |
+
resid_hover = """function(atom,viewer) {{
|
18 |
+
if(!atom.label) {{
|
19 |
+
atom.label = viewer.addLabel('{0}:'+atom.atom+atom.serial,
|
20 |
+
{{position: atom, backgroundColor: 'mintcream', fontColor:'black'}});
|
21 |
+
}}
|
22 |
+
}}"""
|
23 |
+
hover_func = """
|
24 |
+
function(atom,viewer) {
|
25 |
+
if(!atom.label) {
|
26 |
+
atom.label = viewer.addLabel(atom.interaction,
|
27 |
+
{position: atom, backgroundColor: 'black', fontColor:'white'});
|
28 |
+
}
|
29 |
+
}"""
|
30 |
+
unhover_func = """
|
31 |
+
function(atom,viewer) {
|
32 |
+
if(atom.label) {
|
33 |
+
viewer.removeLabel(atom.label);
|
34 |
+
delete atom.label;
|
35 |
+
}
|
36 |
+
}"""
|
37 |
+
atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'}
|
38 |
+
|
39 |
+
model = GNN_MD(11, 64)
|
40 |
+
state_dict = torch.load(
|
41 |
+
"best_weights_rep0.pt",
|
42 |
+
map_location=torch.device("cpu"),
|
43 |
+
)["model_state_dict"]
|
44 |
+
model.load_state_dict(state_dict)
|
45 |
+
model = model.to('cpu')
|
46 |
+
model.eval()
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
def get_pdb(pdb_code="", filepath=""):
|
51 |
+
try:
|
52 |
+
return filepath.name
|
53 |
+
except AttributeError as e:
|
54 |
+
if pdb_code is None or pdb_code == "":
|
55 |
+
return None
|
56 |
+
else:
|
57 |
+
os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
|
58 |
+
return f"{pdb_code}.pdb"
|
59 |
+
|
60 |
+
|
61 |
+
def get_offset(pdb):
|
62 |
+
pdb_multiline = pdb.split("\n")
|
63 |
+
for line in pdb_multiline:
|
64 |
+
if line.startswith("ATOM"):
|
65 |
+
return int(line[22:27])
|
66 |
+
|
67 |
+
|
68 |
+
def predict(pdb_code, pdb_file):
|
69 |
+
#path_to_pdb = get_pdb(pdb_code=pdb_code, filepath=pdb_file)
|
70 |
+
|
71 |
+
#pdb = open(path_to_pdb, "r").read()
|
72 |
+
# switch to misato env if not running from container
|
73 |
+
mdh5_file = "inference_for_md.hdf5"
|
74 |
+
md_H5File = h5py.File(mdh5_file)
|
75 |
+
|
76 |
+
column_names = ["x", "y", "z", "element"]
|
77 |
+
atoms_protein = pd.DataFrame(columns = column_names)
|
78 |
+
cutoff = md_H5File["11GS"]["molecules_begin_atom_index"][:][-1] # cutoff defines protein atoms
|
79 |
+
|
80 |
+
atoms_protein["x"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 0]
|
81 |
+
atoms_protein["y"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 1]
|
82 |
+
atoms_protein["z"] = md_H5File["11GS"]["atoms_coordinates_ref"][:][:cutoff, 2]
|
83 |
+
|
84 |
+
atoms_protein["element"] = md_H5File["11GS"]["atoms_element"][:][:cutoff]
|
85 |
+
|
86 |
+
item = {}
|
87 |
+
item["scores"] = 0
|
88 |
+
item["id"] = "11GS"
|
89 |
+
item["atoms_protein"] = atoms_protein
|
90 |
+
|
91 |
+
transform = GNNTransformMD()
|
92 |
+
data_item = transform(item)
|
93 |
+
adaptability = model(data_item)
|
94 |
+
adaptability = adaptability.detach().numpy()
|
95 |
+
|
96 |
+
data = []
|
97 |
+
|
98 |
+
|
99 |
+
for i in range(adaptability.shape[0]):
|
100 |
+
data.append([i, atom_mapping(atoms_protein.iloc[i, atoms_protein.columns.get_loc("element")] - 1), atoms_protein.iloc[i, atoms_protein.columns.get_loc("x")],atoms_protein.iloc[i, atoms_protein.columns.get_loc("y")],atoms_protein.iloc[i, atoms_protein.columns.get_loc("z")],adaptability[i]])
|
101 |
+
|
102 |
+
topN = 100
|
103 |
+
topN_ind = np.argsort(adaptability)[::-1][:topN]
|
104 |
+
|
105 |
+
pdb = open(pdb_file.name, "r").read()
|
106 |
+
|
107 |
+
view = py3Dmol.view(width=600, height=400)
|
108 |
+
view.setBackgroundColor('black')
|
109 |
+
view.addModel(pdb, "pdb")
|
110 |
+
view.setStyle({'stick': {'colorscheme': {'prop': 'resi', 'C': 'turquoise'}}})
|
111 |
+
|
112 |
+
for i in range(topN):
|
113 |
+
view.addSphere({'center':{'x':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("x")], 'y':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("y")],'z':atoms_protein.iloc[topN_ind[i], atoms_protein.columns.get_loc("z")]},'radius':adaptability[topN_ind[i]]/1.5,'color':'orange','alpha':0.75})
|
114 |
+
|
115 |
+
view.zoomTo()
|
116 |
+
|
117 |
+
output = view._make_html().replace("'", '"')
|
118 |
+
|
119 |
+
x = f"""<!DOCTYPE html><html> {output} </html>""" # do not use ' in this input
|
120 |
+
return f"""<iframe style="width: 100%; height:420px" name="result" allow="midi; geolocation; microphone; camera;
|
121 |
+
display-capture; encrypted-media;" sandbox="allow-modals allow-forms
|
122 |
+
allow-scripts allow-same-origin allow-popups
|
123 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
124 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""", pd.DataFrame(data, columns=['index','element','x','y','z','Adaptability'])
|
125 |
+
|
126 |
+
|
127 |
+
callback = gr.CSVLogger()
|
128 |
+
|
129 |
+
with gr.Blocks() as demo:
|
130 |
+
gr.Markdown("# Protein Adaptability Prediction")
|
131 |
+
|
132 |
+
#text_input = gr.Textbox()
|
133 |
+
#text_output = gr.Textbox()
|
134 |
+
#text_button = gr.Button("Flip")
|
135 |
+
inp = gr.Textbox(placeholder="PDB Code or upload file below", label="Input structure")
|
136 |
+
pdb_file = gr.File(label="PDB File Upload")
|
137 |
+
#with gr.Row():
|
138 |
+
# helix = gr.ColorPicker(label="helix")
|
139 |
+
# sheet = gr.ColorPicker(label="sheet")
|
140 |
+
# loop = gr.ColorPicker(label="loop")
|
141 |
+
single_btn = gr.Button(label="Run")
|
142 |
+
with gr.Row():
|
143 |
+
html = gr.HTML()
|
144 |
+
with gr.Row():
|
145 |
+
dataframe = gr.Dataframe()
|
146 |
+
|
147 |
+
single_btn.click(fn=predict, inputs=[inp, pdb_file], outputs=[html, dataframe])
|
148 |
+
|
149 |
+
|
150 |
+
demo.launch(debug=True)
|
best_weights_rep0.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6b9499f2dc7b16eb3c669d2c6c26e11a53e21047264d3eb0cdda6bbc1d17f91
|
3 |
+
size 4517600
|
graph.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy.spatial as ss
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch_geometric.utils import to_undirected
|
6 |
+
from torch_sparse import coalesce
|
7 |
+
|
8 |
+
atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'}
|
9 |
+
residue_mapping = {0:'ALA', 1:'ARG', 2:'ASN', 3:'ASP', 4:'CYS', 5:'CYX', 6:'GLN', 7:'GLU', 8:'GLY', 9:'HIE', 10:'ILE', 11:'LEU', 12:'LYS', 13:'MET', 14:'PHE', 15:'PRO', 16:'SER', 17:'THR', 18:'TRP', 19:'TYR', 20:'VAL', 21:'UNK'}
|
10 |
+
|
11 |
+
ligand_atoms_mapping = {8: 0, 16: 1, 6: 2, 7: 3, 1: 4, 15: 5, 17: 6, 9: 7, 53: 8, 35: 9, 5: 10, 33: 11, 26: 12, 14: 13, 34: 14, 44: 15, 12: 16, 23: 17, 77: 18, 27: 19, 52: 20, 30: 21, 4: 22, 45: 23}
|
12 |
+
|
13 |
+
|
14 |
+
def prot_df_to_graph(item, df, edge_dist_cutoff, feat_col='element'):
|
15 |
+
r"""
|
16 |
+
Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom.
|
17 |
+
|
18 |
+
:param df: Protein structure in dataframe format.
|
19 |
+
:type df: pandas.DataFrame
|
20 |
+
:param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"``
|
21 |
+
:type node_col: str, optional
|
22 |
+
:param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features.
|
23 |
+
Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`).
|
24 |
+
:type allowable_feats: list, optional
|
25 |
+
:param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5.
|
26 |
+
:type edge_dist_cutoff: float, optional
|
27 |
+
|
28 |
+
:return: tuple containing
|
29 |
+
|
30 |
+
- node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``.
|
31 |
+
|
32 |
+
- edges (torch.LongTensor): Edges in COO format
|
33 |
+
|
34 |
+
- edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`.
|
35 |
+
|
36 |
+
- node_pos (torch.FloatTensor): x-y-z coordinates of each node
|
37 |
+
:rtype: Tuple
|
38 |
+
"""
|
39 |
+
|
40 |
+
allowable_feats = atom_mapping
|
41 |
+
|
42 |
+
try :
|
43 |
+
node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
|
44 |
+
kd_tree = ss.KDTree(node_pos)
|
45 |
+
edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
|
46 |
+
edges = torch.LongTensor(edge_tuples).t().contiguous()
|
47 |
+
edges = to_undirected(edges)
|
48 |
+
except:
|
49 |
+
print(f"Problem with PDB Id is {item['id']}")
|
50 |
+
|
51 |
+
node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e-1, allowable_feats) for e in df[feat_col]])
|
52 |
+
edge_weights = torch.FloatTensor(
|
53 |
+
[1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1)
|
54 |
+
|
55 |
+
|
56 |
+
return node_feats, edges, edge_weights, node_pos
|
57 |
+
|
58 |
+
|
59 |
+
def mol_df_to_graph_for_qm(df, bonds=None, allowable_atoms=None, edge_dist_cutoff=4.5, onehot_edges=True):
|
60 |
+
"""
|
61 |
+
Converts molecule in dataframe to a graph compatible with Pytorch-Geometric
|
62 |
+
:param df: Molecule structure in dataframe format
|
63 |
+
:type mol: pandas.DataFrame
|
64 |
+
:param bonds: Molecule structure in dataframe format
|
65 |
+
:type bonds: pandas.DataFrame
|
66 |
+
:param allowable_atoms: List containing allowable atom types
|
67 |
+
:type allowable_atoms: list[str], optional
|
68 |
+
:return: Tuple containing \n
|
69 |
+
- node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``.
|
70 |
+
- edge_index (torch.LongTensor): Edges from chemical bond graph in COO format.
|
71 |
+
- edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5.
|
72 |
+
- node_pos (torch.FloatTensor): x-y-z coordinates of each node.
|
73 |
+
"""
|
74 |
+
if allowable_atoms is None:
|
75 |
+
allowable_atoms = ligand_atoms_mapping
|
76 |
+
node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
|
77 |
+
|
78 |
+
if bonds is not None:
|
79 |
+
N = df.shape[0]
|
80 |
+
bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3}
|
81 |
+
bond_data = torch.FloatTensor(bonds)
|
82 |
+
edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0)
|
83 |
+
edge_index = edge_tuples.t().long().contiguous()
|
84 |
+
|
85 |
+
if onehot_edges:
|
86 |
+
bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist()))
|
87 |
+
edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float)
|
88 |
+
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
|
89 |
+
|
90 |
+
else:
|
91 |
+
edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0)
|
92 |
+
else:
|
93 |
+
kd_tree = ss.KDTree(node_pos)
|
94 |
+
edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
|
95 |
+
edge_index = torch.LongTensor(edge_tuples).t().contiguous()
|
96 |
+
edge_index = to_undirected(edge_index)
|
97 |
+
edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1)
|
98 |
+
edge_attr = edge_attr.unsqueeze(1)
|
99 |
+
|
100 |
+
node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices_qm(e, allowable_atoms) for e in df['element']])
|
101 |
+
|
102 |
+
return node_feats, edge_index, edge_attr, node_pos
|
103 |
+
|
104 |
+
|
105 |
+
def one_of_k_encoding_unk_indices(x, allowable_set):
|
106 |
+
"""Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
|
107 |
+
one_hot_encoding = [0] * len(allowable_set)
|
108 |
+
if x in allowable_set:
|
109 |
+
one_hot_encoding[x] = 1
|
110 |
+
else:
|
111 |
+
one_hot_encoding[-1] = 1
|
112 |
+
return one_hot_encoding
|
113 |
+
|
114 |
+
def one_of_k_encoding_unk_indices_qm(x, allowable_set):
|
115 |
+
"""Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
|
116 |
+
one_hot_encoding = [0] * (len(allowable_set)+1)
|
117 |
+
if x in allowable_set:
|
118 |
+
one_hot_encoding[allowable_set[x]] = 1
|
119 |
+
else:
|
120 |
+
one_hot_encoding[-1] = 1
|
121 |
+
return one_hot_encoding
|
inference_for_md.hdf5
ADDED
Binary file (61.9 kB). View file
|
|
transformMD.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transforms import prot_graph_transform
|
2 |
+
|
3 |
+
class GNNTransformMD(object):
|
4 |
+
"""
|
5 |
+
Transform the dict returned by the ProtDataset class to a pyTorch Geometric graph
|
6 |
+
"""
|
7 |
+
|
8 |
+
def __init__(self, edge_dist_cutoff=4.5):
|
9 |
+
"""
|
10 |
+
|
11 |
+
Args:
|
12 |
+
edge_dist_cutoff (float, optional): distence between the edges. Defaults to 4.5.
|
13 |
+
"""
|
14 |
+
self.edge_dist_cutoff = edge_dist_cutoff
|
15 |
+
|
16 |
+
def __call__(self, item):
|
17 |
+
item = prot_graph_transform(item, atom_keys=['atoms_protein'], label_key='scores', edge_dist_cutoff=self.edge_dist_cutoff)
|
18 |
+
return item['atoms_protein']
|
19 |
+
|
20 |
+
|
21 |
+
|
transforms.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_geometric.data import Data
|
3 |
+
from graph import prot_df_to_graph, mol_df_to_graph_for_qm
|
4 |
+
|
5 |
+
def prot_graph_transform(item, atom_keys, label_key, edge_dist_cutoff):
|
6 |
+
"""Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`.
|
7 |
+
Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments.
|
8 |
+
|
9 |
+
:param item: Dataset item to transform
|
10 |
+
:type item: dict
|
11 |
+
:param atom_keys: list of keys to transform, where each key contains a dataframe of atoms, defaults to ['atoms']
|
12 |
+
:type atom_keys: list, optional
|
13 |
+
:param label_key: name of key containing labels, defaults to ['scores']
|
14 |
+
:type label_key: str, optional
|
15 |
+
:return: Transformed Dataset item
|
16 |
+
:rtype: dict
|
17 |
+
"""
|
18 |
+
|
19 |
+
for key in atom_keys:
|
20 |
+
node_feats, edge_index, edge_feats, pos = prot_df_to_graph(item, item[key], edge_dist_cutoff)
|
21 |
+
item[key] = Data(node_feats, edge_index, edge_feats, y=torch.FloatTensor(item[label_key]), pos=pos, ids=item["id"])
|
22 |
+
|
23 |
+
return item
|
24 |
+
|
25 |
+
def mol_graph_transform_for_qm(item, atom_key, label_key, allowable_atoms, use_bonds, onehot_edges, edge_dist_cutoff):
|
26 |
+
"""Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset <atom3d.datasets.datasets>`.
|
27 |
+
Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments.
|
28 |
+
|
29 |
+
:param item: Dataset item to transform
|
30 |
+
:type item: dict
|
31 |
+
:param atom_key: name of key containing molecule structure as a dataframe, defaults to 'atoms'
|
32 |
+
:type atom_keys: list, optional
|
33 |
+
:param label_key: name of key containing labels, defaults to 'scores'
|
34 |
+
:type label_key: str, optional
|
35 |
+
:param use_bonds: whether to use molecular bond information for edges instead of distance. Assumes bonds are stored under 'bonds' key, defaults to False
|
36 |
+
:type use_bonds: bool, optional
|
37 |
+
:return: Transformed Dataset item
|
38 |
+
:rtype: dict
|
39 |
+
"""
|
40 |
+
|
41 |
+
bonds = item['bonds'] if use_bonds else None
|
42 |
+
|
43 |
+
node_feats, edge_index, edge_feats, pos = mol_df_to_graph_for_qm(item[atom_key], bonds=bonds, onehot_edges=onehot_edges, allowable_atoms=allowable_atoms, edge_dist_cutoff=edge_dist_cutoff)
|
44 |
+
item[atom_key] = Data(node_feats, edge_index, edge_feats, y=item[label_key], pos=pos)
|
45 |
+
|
46 |
+
return item
|