from transforms import prot_graph_transform class GNNTransformMD(object): """ Transform the dict returned by the ProtDataset class to a pyTorch Geometric graph """ def __init__(self, edge_dist_cutoff=4.5): """ Args: edge_dist_cutoff (float, optional): distence between the edges. Defaults to 4.5. """ self.edge_dist_cutoff = edge_dist_cutoff def __call__(self, item): item = prot_graph_transform(item, atom_keys=['atoms_protein'], label_key='scores', edge_dist_cutoff=self.edge_dist_cutoff) return item['atoms_protein']