import torch from torch_geometric.data import Data from graph import prot_df_to_graph, mol_df_to_graph_for_qm def prot_graph_transform(item, atom_keys, label_key, edge_dist_cutoff): """Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset `. Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments. :param item: Dataset item to transform :type item: dict :param atom_keys: list of keys to transform, where each key contains a dataframe of atoms, defaults to ['atoms'] :type atom_keys: list, optional :param label_key: name of key containing labels, defaults to ['scores'] :type label_key: str, optional :return: Transformed Dataset item :rtype: dict """ for key in atom_keys: node_feats, edge_index, edge_feats, pos = prot_df_to_graph(item, item[key], edge_dist_cutoff) item[key] = Data(node_feats, edge_index, edge_feats, y=torch.FloatTensor(item[label_key]), pos=pos, ids=item["id"]) return item def mol_graph_transform_for_qm(item, atom_key, label_key, allowable_atoms, use_bonds, onehot_edges, edge_dist_cutoff): """Transform for converting dataframes to Pytorch Geometric graphs, to be applied when defining a :mod:`Dataset `. Operates on Dataset items, assumes that the item contains all keys specified in ``keys`` and ``labels`` arguments. :param item: Dataset item to transform :type item: dict :param atom_key: name of key containing molecule structure as a dataframe, defaults to 'atoms' :type atom_keys: list, optional :param label_key: name of key containing labels, defaults to 'scores' :type label_key: str, optional :param use_bonds: whether to use molecular bond information for edges instead of distance. Assumes bonds are stored under 'bonds' key, defaults to False :type use_bonds: bool, optional :return: Transformed Dataset item :rtype: dict """ bonds = item['bonds'] if use_bonds else None node_feats, edge_index, edge_feats, pos = mol_df_to_graph_for_qm(item[atom_key], bonds=bonds, onehot_edges=onehot_edges, allowable_atoms=allowable_atoms, edge_dist_cutoff=edge_dist_cutoff) item[atom_key] = Data(node_feats, edge_index, edge_feats, y=item[label_key], pos=pos) return item