|
import numpy as np |
|
import scipy.spatial as ss |
|
import torch |
|
import torch.nn.functional as F |
|
from torch_geometric.utils import to_undirected |
|
from torch_sparse import coalesce |
|
|
|
atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'} |
|
residue_mapping = {0:'ALA', 1:'ARG', 2:'ASN', 3:'ASP', 4:'CYS', 5:'CYX', 6:'GLN', 7:'GLU', 8:'GLY', 9:'HIE', 10:'ILE', 11:'LEU', 12:'LYS', 13:'MET', 14:'PHE', 15:'PRO', 16:'SER', 17:'THR', 18:'TRP', 19:'TYR', 20:'VAL', 21:'UNK'} |
|
|
|
ligand_atoms_mapping = {8: 0, 16: 1, 6: 2, 7: 3, 1: 4, 15: 5, 17: 6, 9: 7, 53: 8, 35: 9, 5: 10, 33: 11, 26: 12, 14: 13, 34: 14, 44: 15, 12: 16, 23: 17, 77: 18, 27: 19, 52: 20, 30: 21, 4: 22, 45: 23} |
|
|
|
|
|
def prot_df_to_graph(item, df, edge_dist_cutoff, feat_col='element'): |
|
r""" |
|
Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom. |
|
|
|
:param df: Protein structure in dataframe format. |
|
:type df: pandas.DataFrame |
|
:param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"`` |
|
:type node_col: str, optional |
|
:param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features. |
|
Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`). |
|
:type allowable_feats: list, optional |
|
:param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5. |
|
:type edge_dist_cutoff: float, optional |
|
|
|
:return: tuple containing |
|
|
|
- node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``. |
|
|
|
- edges (torch.LongTensor): Edges in COO format |
|
|
|
- edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`. |
|
|
|
- node_pos (torch.FloatTensor): x-y-z coordinates of each node |
|
:rtype: Tuple |
|
""" |
|
|
|
allowable_feats = atom_mapping |
|
|
|
try : |
|
node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy()) |
|
kd_tree = ss.KDTree(node_pos) |
|
edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff)) |
|
edges = torch.LongTensor(edge_tuples).t().contiguous() |
|
edges = to_undirected(edges) |
|
except: |
|
print(f"Problem with PDB Id is {item['id']}") |
|
|
|
node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e-1, allowable_feats) for e in df[feat_col]]) |
|
edge_weights = torch.FloatTensor( |
|
[1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1) |
|
|
|
|
|
return node_feats, edges, edge_weights, node_pos |
|
|
|
|
|
def mol_df_to_graph_for_qm(df, bonds=None, allowable_atoms=None, edge_dist_cutoff=4.5, onehot_edges=True): |
|
""" |
|
Converts molecule in dataframe to a graph compatible with Pytorch-Geometric |
|
:param df: Molecule structure in dataframe format |
|
:type mol: pandas.DataFrame |
|
:param bonds: Molecule structure in dataframe format |
|
:type bonds: pandas.DataFrame |
|
:param allowable_atoms: List containing allowable atom types |
|
:type allowable_atoms: list[str], optional |
|
:return: Tuple containing \n |
|
- node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``. |
|
- edge_index (torch.LongTensor): Edges from chemical bond graph in COO format. |
|
- edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5. |
|
- node_pos (torch.FloatTensor): x-y-z coordinates of each node. |
|
""" |
|
if allowable_atoms is None: |
|
allowable_atoms = ligand_atoms_mapping |
|
node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy()) |
|
|
|
if bonds is not None: |
|
N = df.shape[0] |
|
bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3} |
|
bond_data = torch.FloatTensor(bonds) |
|
edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0) |
|
edge_index = edge_tuples.t().long().contiguous() |
|
|
|
if onehot_edges: |
|
bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) |
|
edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float) |
|
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N) |
|
|
|
else: |
|
edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0) |
|
else: |
|
kd_tree = ss.KDTree(node_pos) |
|
edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff)) |
|
edge_index = torch.LongTensor(edge_tuples).t().contiguous() |
|
edge_index = to_undirected(edge_index) |
|
edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1) |
|
edge_attr = edge_attr.unsqueeze(1) |
|
|
|
node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices_qm(e, allowable_atoms) for e in df['element']]) |
|
|
|
return node_feats, edge_index, edge_attr, node_pos |
|
|
|
|
|
def one_of_k_encoding_unk_indices(x, allowable_set): |
|
"""Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element.""" |
|
one_hot_encoding = [0] * len(allowable_set) |
|
if x in allowable_set: |
|
one_hot_encoding[x] = 1 |
|
else: |
|
one_hot_encoding[-1] = 1 |
|
return one_hot_encoding |
|
|
|
def one_of_k_encoding_unk_indices_qm(x, allowable_set): |
|
"""Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element.""" |
|
one_hot_encoding = [0] * (len(allowable_set)+1) |
|
if x in allowable_set: |
|
one_hot_encoding[allowable_set[x]] = 1 |
|
else: |
|
one_hot_encoding[-1] = 1 |
|
return one_hot_encoding |