|
from transforms import prot_graph_transform |
|
|
|
class GNNTransformMD(object): |
|
""" |
|
Transform the dict returned by the ProtDataset class to a pyTorch Geometric graph |
|
""" |
|
|
|
def __init__(self, edge_dist_cutoff=4.5): |
|
""" |
|
|
|
Args: |
|
edge_dist_cutoff (float, optional): distence between the edges. Defaults to 4.5. |
|
""" |
|
self.edge_dist_cutoff = edge_dist_cutoff |
|
|
|
def __call__(self, item): |
|
item = prot_graph_transform(item, atom_keys=['atoms_protein'], label_key='scores', edge_dist_cutoff=self.edge_dist_cutoff) |
|
return item['atoms_protein'] |
|
|
|
|
|
|