modeling script
Browse files- .gitignore +15 -0
- modeling.py +66 -0
- ultra/__init__.py +0 -0
- ultra/base_nbfnet.py +336 -0
- ultra/datasets.py +1095 -0
- ultra/eval.py +153 -0
- ultra/layers.py +234 -0
- ultra/models.py +214 -0
- ultra/rspmm/rspmm.py +204 -0
- ultra/rspmm/source/operator.cuh +82 -0
- ultra/rspmm/source/rspmm.cpp +283 -0
- ultra/rspmm/source/rspmm.cu +386 -0
- ultra/rspmm/source/rspmm.h +108 -0
- ultra/rspmm/source/util.cuh +28 -0
- ultra/tasks.py +201 -0
- ultra/util.py +172 -0
.gitignore
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
output/
|
10 |
+
.vscode/
|
11 |
+
.DS_Store
|
12 |
+
datasets/
|
13 |
+
ckpts/
|
14 |
+
*.csv
|
15 |
+
*.txt
|
modeling.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
4 |
+
#sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
5 |
+
from ultra.models import Ultra
|
6 |
+
from ultra.datasets import WN18RR, CoDExSmall, FB15k237, FB15k237Inductive
|
7 |
+
from ultra.eval import test
|
8 |
+
|
9 |
+
|
10 |
+
class UltraConfig(PretrainedConfig):
|
11 |
+
|
12 |
+
model_type = "ultra"
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
relation_model_layers: int = 6,
|
17 |
+
relation_model_dim: int = 64,
|
18 |
+
entity_model_layers: int = 6,
|
19 |
+
entity_model_dim: int = 64,
|
20 |
+
**kwargs):
|
21 |
+
|
22 |
+
self.relation_model_cfg = dict(
|
23 |
+
input_dim=relation_model_dim,
|
24 |
+
hidden_dims=[relation_model_dim]*relation_model_layers,
|
25 |
+
message_func="distmult",
|
26 |
+
aggregate_func="sum",
|
27 |
+
short_cut=True,
|
28 |
+
layer_norm=True
|
29 |
+
)
|
30 |
+
|
31 |
+
self.entity_model_cfg = dict(
|
32 |
+
input_dim=entity_model_dim,
|
33 |
+
hidden_dims=[entity_model_dim]*entity_model_layers,
|
34 |
+
message_func="distmult",
|
35 |
+
aggregate_func="sum",
|
36 |
+
short_cut=True,
|
37 |
+
layer_norm=True
|
38 |
+
)
|
39 |
+
|
40 |
+
super().__init__(**kwargs)
|
41 |
+
|
42 |
+
class UltraLinkPrediction(PreTrainedModel):
|
43 |
+
|
44 |
+
config_class = UltraConfig
|
45 |
+
|
46 |
+
def __init__(self, config):
|
47 |
+
super().__init__(config)
|
48 |
+
|
49 |
+
self.model = Ultra(
|
50 |
+
rel_model_cfg=config.relation_model_cfg,
|
51 |
+
entity_model_cfg=config.entity_model_cfg,
|
52 |
+
)
|
53 |
+
|
54 |
+
def forward(self, data, batch):
|
55 |
+
# data: PyG data object
|
56 |
+
# batch shape: (bs, 1+num_negs, 3)
|
57 |
+
return self.model.forward(data, batch)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
|
62 |
+
model = UltraLinkPrediction.from_pretrained("mgalkin/ultra_50g")
|
63 |
+
dataset = CoDExSmall(root="./datasets/")
|
64 |
+
test(model, mode="test", dataset=dataset, gpus=None)
|
65 |
+
# mrr: 0.497697
|
66 |
+
# hits@10: 0.685175
|
ultra/__init__.py
ADDED
File without changes
|
ultra/base_nbfnet.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from collections.abc import Sequence
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn, autograd
|
6 |
+
|
7 |
+
from torch_scatter import scatter_add
|
8 |
+
from . import tasks, layers
|
9 |
+
|
10 |
+
|
11 |
+
class BaseNBFNet(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, input_dim, hidden_dims, num_relation, message_func="distmult", aggregate_func="sum",
|
14 |
+
short_cut=False, layer_norm=False, activation="relu", concat_hidden=False, num_mlp_layer=2,
|
15 |
+
dependent=False, remove_one_hop=False, num_beam=10, path_topk=10, **kwargs):
|
16 |
+
super(BaseNBFNet, self).__init__()
|
17 |
+
|
18 |
+
if not isinstance(hidden_dims, Sequence):
|
19 |
+
hidden_dims = [hidden_dims]
|
20 |
+
|
21 |
+
self.dims = [input_dim] + list(hidden_dims)
|
22 |
+
self.num_relation = num_relation
|
23 |
+
self.short_cut = short_cut # whether to use residual connections between GNN layers
|
24 |
+
self.concat_hidden = concat_hidden # whether to compute final states as a function of all layer outputs or last
|
25 |
+
self.remove_one_hop = remove_one_hop # whether to dynamically remove one-hop edges from edge_index
|
26 |
+
self.num_beam = num_beam
|
27 |
+
self.path_topk = path_topk
|
28 |
+
|
29 |
+
self.message_func = message_func
|
30 |
+
self.aggregate_func = aggregate_func
|
31 |
+
self.layer_norm = layer_norm
|
32 |
+
self.activation = activation
|
33 |
+
self.num_mlp_layers = num_mlp_layer
|
34 |
+
|
35 |
+
# self.layers = nn.ModuleList()
|
36 |
+
# for i in range(len(self.dims) - 1):
|
37 |
+
# self.layers.append(layers.GeneralizedRelationalConv(self.dims[i], self.dims[i + 1], num_relation,
|
38 |
+
# self.dims[0], message_func, aggregate_func, layer_norm,
|
39 |
+
# activation, dependent))
|
40 |
+
|
41 |
+
# feature_dim = (sum(hidden_dims) if concat_hidden else hidden_dims[-1]) + input_dim
|
42 |
+
|
43 |
+
# # additional relation embedding which serves as an initial 'query' for the NBFNet forward pass
|
44 |
+
# # each layer has its own learnable relations matrix, so we send the total number of relations, too
|
45 |
+
# self.query = nn.Embedding(num_relation, input_dim)
|
46 |
+
# self.mlp = nn.Sequential()
|
47 |
+
# mlp = []
|
48 |
+
# for i in range(num_mlp_layer - 1):
|
49 |
+
# mlp.append(nn.Linear(feature_dim, feature_dim))
|
50 |
+
# mlp.append(nn.ReLU())
|
51 |
+
# mlp.append(nn.Linear(feature_dim, 1))
|
52 |
+
# self.mlp = nn.Sequential(*mlp)
|
53 |
+
|
54 |
+
def remove_easy_edges(self, data, h_index, t_index, r_index=None):
|
55 |
+
# we remove training edges (we need to predict them at training time) from the edge index
|
56 |
+
# think of it as a dynamic edge dropout
|
57 |
+
h_index_ext = torch.cat([h_index, t_index], dim=-1)
|
58 |
+
t_index_ext = torch.cat([t_index, h_index], dim=-1)
|
59 |
+
r_index_ext = torch.cat([r_index, r_index + data.num_relations // 2], dim=-1)
|
60 |
+
if self.remove_one_hop:
|
61 |
+
# we remove all existing immediate edges between heads and tails in the batch
|
62 |
+
edge_index = data.edge_index
|
63 |
+
easy_edge = torch.stack([h_index_ext, t_index_ext]).flatten(1)
|
64 |
+
index = tasks.edge_match(edge_index, easy_edge)[0]
|
65 |
+
mask = ~index_to_mask(index, data.num_edges)
|
66 |
+
else:
|
67 |
+
# we remove existing immediate edges between heads and tails in the batch with the given relation
|
68 |
+
edge_index = torch.cat([data.edge_index, data.edge_type.unsqueeze(0)])
|
69 |
+
# note that here we add relation types r_index_ext to the matching query
|
70 |
+
easy_edge = torch.stack([h_index_ext, t_index_ext, r_index_ext]).flatten(1)
|
71 |
+
index = tasks.edge_match(edge_index, easy_edge)[0]
|
72 |
+
mask = ~index_to_mask(index, data.num_edges)
|
73 |
+
|
74 |
+
data = copy.copy(data)
|
75 |
+
data.edge_index = data.edge_index[:, mask]
|
76 |
+
data.edge_type = data.edge_type[mask]
|
77 |
+
return data
|
78 |
+
|
79 |
+
def negative_sample_to_tail(self, h_index, t_index, r_index, num_direct_rel):
|
80 |
+
# convert p(h | t, r) to p(t' | h', r')
|
81 |
+
# h' = t, r' = r^{-1}, t' = h
|
82 |
+
is_t_neg = (h_index == h_index[:, [0]]).all(dim=-1, keepdim=True)
|
83 |
+
new_h_index = torch.where(is_t_neg, h_index, t_index)
|
84 |
+
new_t_index = torch.where(is_t_neg, t_index, h_index)
|
85 |
+
new_r_index = torch.where(is_t_neg, r_index, r_index + num_direct_rel)
|
86 |
+
return new_h_index, new_t_index, new_r_index
|
87 |
+
|
88 |
+
def bellmanford(self, data, h_index, r_index, separate_grad=False):
|
89 |
+
batch_size = len(r_index)
|
90 |
+
|
91 |
+
# initialize queries (relation types of the given triples)
|
92 |
+
query = self.query(r_index)
|
93 |
+
index = h_index.unsqueeze(-1).expand_as(query)
|
94 |
+
|
95 |
+
# initial (boundary) condition - initialize all node states as zeros
|
96 |
+
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
|
97 |
+
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
|
98 |
+
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
|
99 |
+
size = (data.num_nodes, data.num_nodes)
|
100 |
+
edge_weight = torch.ones(data.num_edges, device=h_index.device)
|
101 |
+
|
102 |
+
hiddens = []
|
103 |
+
edge_weights = []
|
104 |
+
layer_input = boundary
|
105 |
+
|
106 |
+
for layer in self.layers:
|
107 |
+
if separate_grad:
|
108 |
+
edge_weight = edge_weight.clone().requires_grad_()
|
109 |
+
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
|
110 |
+
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
|
111 |
+
if self.short_cut and hidden.shape == layer_input.shape:
|
112 |
+
# residual connection here
|
113 |
+
hidden = hidden + layer_input
|
114 |
+
hiddens.append(hidden)
|
115 |
+
edge_weights.append(edge_weight)
|
116 |
+
layer_input = hidden
|
117 |
+
|
118 |
+
# original query (relation type) embeddings
|
119 |
+
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
|
120 |
+
if self.concat_hidden:
|
121 |
+
output = torch.cat(hiddens + [node_query], dim=-1)
|
122 |
+
else:
|
123 |
+
output = torch.cat([hiddens[-1], node_query], dim=-1)
|
124 |
+
|
125 |
+
return {
|
126 |
+
"node_feature": output,
|
127 |
+
"edge_weights": edge_weights,
|
128 |
+
}
|
129 |
+
|
130 |
+
def forward(self, data, batch):
|
131 |
+
h_index, t_index, r_index = batch.unbind(-1)
|
132 |
+
if self.training:
|
133 |
+
# Edge dropout in the training mode
|
134 |
+
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
|
135 |
+
# to make NBFNet iteration learn non-trivial paths
|
136 |
+
data = self.remove_easy_edges(data, h_index, t_index, r_index, data.num_relations // 2)
|
137 |
+
|
138 |
+
shape = h_index.shape
|
139 |
+
# turn all triples in a batch into a tail prediction mode
|
140 |
+
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
|
141 |
+
assert (h_index[:, [0]] == h_index).all()
|
142 |
+
assert (r_index[:, [0]] == r_index).all()
|
143 |
+
|
144 |
+
# message passing and updated node representations
|
145 |
+
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
|
146 |
+
feature = output["node_feature"]
|
147 |
+
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
|
148 |
+
# extract representations of tail entities from the updated node states
|
149 |
+
feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
|
150 |
+
|
151 |
+
# probability logit for each tail node in the batch
|
152 |
+
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
|
153 |
+
score = self.mlp(feature).squeeze(-1)
|
154 |
+
return score.view(shape)
|
155 |
+
|
156 |
+
def visualize(self, data, batch):
|
157 |
+
assert batch.shape == (1, 3)
|
158 |
+
h_index, t_index, r_index = batch.unbind(-1)
|
159 |
+
|
160 |
+
output = self.bellmanford(data, h_index, r_index, separate_grad=True)
|
161 |
+
feature = output["node_feature"]
|
162 |
+
edge_weights = output["edge_weights"]
|
163 |
+
|
164 |
+
index = t_index.unsqueeze(0).unsqueeze(-1).expand(-1, -1, feature.shape[-1])
|
165 |
+
feature = feature.gather(1, index).squeeze(0)
|
166 |
+
score = self.mlp(feature).squeeze(-1)
|
167 |
+
|
168 |
+
edge_grads = autograd.grad(score, edge_weights)
|
169 |
+
distances, back_edges = self.beam_search_distance(data, edge_grads, h_index, t_index, self.num_beam)
|
170 |
+
paths, weights = self.topk_average_length(distances, back_edges, t_index, self.path_topk)
|
171 |
+
|
172 |
+
return paths, weights
|
173 |
+
|
174 |
+
@torch.no_grad()
|
175 |
+
def beam_search_distance(self, data, edge_grads, h_index, t_index, num_beam=10):
|
176 |
+
# beam search the top-k distance from h to t (and to every other node)
|
177 |
+
num_nodes = data.num_nodes
|
178 |
+
input = torch.full((num_nodes, num_beam), float("-inf"), device=h_index.device)
|
179 |
+
input[h_index, 0] = 0
|
180 |
+
edge_mask = data.edge_index[0, :] != t_index
|
181 |
+
|
182 |
+
distances = []
|
183 |
+
back_edges = []
|
184 |
+
for edge_grad in edge_grads:
|
185 |
+
# we don't allow any path goes out of t once it arrives at t
|
186 |
+
node_in, node_out = data.edge_index[:, edge_mask]
|
187 |
+
relation = data.edge_type[edge_mask]
|
188 |
+
edge_grad = edge_grad[edge_mask]
|
189 |
+
|
190 |
+
message = input[node_in] + edge_grad.unsqueeze(-1) # (num_edges, num_beam)
|
191 |
+
# (num_edges, num_beam, 3)
|
192 |
+
msg_source = torch.stack([node_in, node_out, relation], dim=-1).unsqueeze(1).expand(-1, num_beam, -1)
|
193 |
+
|
194 |
+
# (num_edges, num_beam)
|
195 |
+
is_duplicate = torch.isclose(message.unsqueeze(-1), message.unsqueeze(-2)) & \
|
196 |
+
(msg_source.unsqueeze(-2) == msg_source.unsqueeze(-3)).all(dim=-1)
|
197 |
+
# pick the first occurrence as the ranking in the previous node's beam
|
198 |
+
# this makes deduplication easier later
|
199 |
+
# and store it in msg_source
|
200 |
+
is_duplicate = is_duplicate.float() - \
|
201 |
+
torch.arange(num_beam, dtype=torch.float, device=message.device) / (num_beam + 1)
|
202 |
+
prev_rank = is_duplicate.argmax(dim=-1, keepdim=True)
|
203 |
+
msg_source = torch.cat([msg_source, prev_rank], dim=-1) # (num_edges, num_beam, 4)
|
204 |
+
|
205 |
+
node_out, order = node_out.sort()
|
206 |
+
node_out_set = torch.unique(node_out)
|
207 |
+
# sort messages w.r.t. node_out
|
208 |
+
message = message[order].flatten() # (num_edges * num_beam)
|
209 |
+
msg_source = msg_source[order].flatten(0, -2) # (num_edges * num_beam, 4)
|
210 |
+
size = node_out.bincount(minlength=num_nodes)
|
211 |
+
msg2out = size_to_index(size[node_out_set] * num_beam)
|
212 |
+
# deduplicate messages that are from the same source and the same beam
|
213 |
+
is_duplicate = (msg_source[1:] == msg_source[:-1]).all(dim=-1)
|
214 |
+
is_duplicate = torch.cat([torch.zeros(1, dtype=torch.bool, device=message.device), is_duplicate])
|
215 |
+
message = message[~is_duplicate]
|
216 |
+
msg_source = msg_source[~is_duplicate]
|
217 |
+
msg2out = msg2out[~is_duplicate]
|
218 |
+
size = msg2out.bincount(minlength=len(node_out_set))
|
219 |
+
|
220 |
+
if not torch.isinf(message).all():
|
221 |
+
# take the topk messages from the neighborhood
|
222 |
+
# distance: (len(node_out_set) * num_beam)
|
223 |
+
distance, rel_index = scatter_topk(message, size, k=num_beam)
|
224 |
+
abs_index = rel_index + (size.cumsum(0) - size).unsqueeze(-1)
|
225 |
+
# store msg_source for backtracking
|
226 |
+
back_edge = msg_source[abs_index] # (len(node_out_set) * num_beam, 4)
|
227 |
+
distance = distance.view(len(node_out_set), num_beam)
|
228 |
+
back_edge = back_edge.view(len(node_out_set), num_beam, 4)
|
229 |
+
# scatter distance / back_edge back to all nodes
|
230 |
+
distance = scatter_add(distance, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam)
|
231 |
+
back_edge = scatter_add(back_edge, node_out_set, dim=0, dim_size=num_nodes) # (num_nodes, num_beam, 4)
|
232 |
+
else:
|
233 |
+
distance = torch.full((num_nodes, num_beam), float("-inf"), device=message.device)
|
234 |
+
back_edge = torch.zeros(num_nodes, num_beam, 4, dtype=torch.long, device=message.device)
|
235 |
+
|
236 |
+
distances.append(distance)
|
237 |
+
back_edges.append(back_edge)
|
238 |
+
input = distance
|
239 |
+
|
240 |
+
return distances, back_edges
|
241 |
+
|
242 |
+
def topk_average_length(self, distances, back_edges, t_index, k=10):
|
243 |
+
# backtrack distances and back_edges to generate the paths
|
244 |
+
paths = []
|
245 |
+
average_lengths = []
|
246 |
+
|
247 |
+
for i in range(len(distances)):
|
248 |
+
distance, order = distances[i][t_index].flatten(0, -1).sort(descending=True)
|
249 |
+
back_edge = back_edges[i][t_index].flatten(0, -2)[order]
|
250 |
+
for d, (h, t, r, prev_rank) in zip(distance[:k].tolist(), back_edge[:k].tolist()):
|
251 |
+
if d == float("-inf"):
|
252 |
+
break
|
253 |
+
path = [(h, t, r)]
|
254 |
+
for j in range(i - 1, -1, -1):
|
255 |
+
h, t, r, prev_rank = back_edges[j][h, prev_rank].tolist()
|
256 |
+
path.append((h, t, r))
|
257 |
+
paths.append(path[::-1])
|
258 |
+
average_lengths.append(d / len(path))
|
259 |
+
|
260 |
+
if paths:
|
261 |
+
average_lengths, paths = zip(*sorted(zip(average_lengths, paths), reverse=True)[:k])
|
262 |
+
|
263 |
+
return paths, average_lengths
|
264 |
+
|
265 |
+
|
266 |
+
def index_to_mask(index, size):
|
267 |
+
index = index.view(-1)
|
268 |
+
size = int(index.max()) + 1 if size is None else size
|
269 |
+
mask = index.new_zeros(size, dtype=torch.bool)
|
270 |
+
mask[index] = True
|
271 |
+
return mask
|
272 |
+
|
273 |
+
|
274 |
+
def size_to_index(size):
|
275 |
+
range = torch.arange(len(size), device=size.device)
|
276 |
+
index2sample = range.repeat_interleave(size)
|
277 |
+
return index2sample
|
278 |
+
|
279 |
+
|
280 |
+
def multi_slice_mask(starts, ends, length):
|
281 |
+
values = torch.cat([torch.ones_like(starts), -torch.ones_like(ends)])
|
282 |
+
slices = torch.cat([starts, ends])
|
283 |
+
mask = scatter_add(values, slices, dim=0, dim_size=length + 1)[:-1]
|
284 |
+
mask = mask.cumsum(0).bool()
|
285 |
+
return mask
|
286 |
+
|
287 |
+
|
288 |
+
def scatter_extend(data, size, input, input_size):
|
289 |
+
new_size = size + input_size
|
290 |
+
new_cum_size = new_size.cumsum(0)
|
291 |
+
new_data = torch.zeros(new_cum_size[-1], *data.shape[1:], dtype=data.dtype, device=data.device)
|
292 |
+
starts = new_cum_size - new_size
|
293 |
+
ends = starts + size
|
294 |
+
index = multi_slice_mask(starts, ends, new_cum_size[-1])
|
295 |
+
new_data[index] = data
|
296 |
+
new_data[~index] = input
|
297 |
+
return new_data, new_size
|
298 |
+
|
299 |
+
|
300 |
+
def scatter_topk(input, size, k, largest=True):
|
301 |
+
index2graph = size_to_index(size)
|
302 |
+
index2graph = index2graph.view([-1] + [1] * (input.ndim - 1))
|
303 |
+
|
304 |
+
mask = ~torch.isinf(input)
|
305 |
+
max = input[mask].max().item()
|
306 |
+
min = input[mask].min().item()
|
307 |
+
safe_input = input.clamp(2 * min - max, 2 * max - min)
|
308 |
+
offset = (max - min) * 4
|
309 |
+
if largest:
|
310 |
+
offset = -offset
|
311 |
+
input_ext = safe_input + offset * index2graph
|
312 |
+
index_ext = input_ext.argsort(dim=0, descending=largest)
|
313 |
+
num_actual = size.clamp(max=k)
|
314 |
+
num_padding = k - num_actual
|
315 |
+
starts = size.cumsum(0) - size
|
316 |
+
ends = starts + num_actual
|
317 |
+
mask = multi_slice_mask(starts, ends, len(index_ext)).nonzero().flatten()
|
318 |
+
|
319 |
+
if (num_padding > 0).any():
|
320 |
+
# special case: size < k, pad with the last valid index
|
321 |
+
padding = ends - 1
|
322 |
+
padding2graph = size_to_index(num_padding)
|
323 |
+
mask = scatter_extend(mask, num_actual, padding[padding2graph], num_padding)[0]
|
324 |
+
|
325 |
+
index = index_ext[mask] # (N * k, ...)
|
326 |
+
value = input.gather(0, index)
|
327 |
+
if isinstance(k, torch.Tensor) and k.shape == size.shape:
|
328 |
+
value = value.view(-1, *input.shape[1:])
|
329 |
+
index = index.view(-1, *input.shape[1:])
|
330 |
+
index = index - (size.cumsum(0) - size).repeat_interleave(k).view([-1] + [1] * (index.ndim - 1))
|
331 |
+
else:
|
332 |
+
value = value.view(-1, k, *input.shape[1:])
|
333 |
+
index = index.view(-1, k, *input.shape[1:])
|
334 |
+
index = index - (size.cumsum(0) - size).view([-1] + [1] * (index.ndim - 1))
|
335 |
+
|
336 |
+
return value, index
|
ultra/datasets.py
ADDED
@@ -0,0 +1,1095 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import csv
|
3 |
+
import shutil
|
4 |
+
import torch
|
5 |
+
from torch_geometric.data import Data, InMemoryDataset, download_url, extract_zip
|
6 |
+
from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
|
7 |
+
|
8 |
+
from ultra.tasks import build_relation_graph
|
9 |
+
|
10 |
+
|
11 |
+
class GrailInductiveDataset(InMemoryDataset):
|
12 |
+
|
13 |
+
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, merge_valid_test=True):
|
14 |
+
self.version = version
|
15 |
+
assert version in ["v1", "v2", "v3", "v4"]
|
16 |
+
|
17 |
+
# by default, most models on Grail datasets merge inductive valid and test splits as the final test split
|
18 |
+
# with this choice, the validation set is that of the transductive train (on the seen graph)
|
19 |
+
# by default it's turned on but you can experiment with turning this option off
|
20 |
+
# you'll need to delete the processed datasets then and re-run to cache a new dataset
|
21 |
+
self.merge_valid_test = merge_valid_test
|
22 |
+
super().__init__(root, transform, pre_transform)
|
23 |
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
24 |
+
|
25 |
+
@property
|
26 |
+
def num_relations(self):
|
27 |
+
return int(self.data.edge_type.max()) + 1
|
28 |
+
|
29 |
+
@property
|
30 |
+
def raw_dir(self):
|
31 |
+
return os.path.join(self.root, "grail", self.name, self.version, "raw")
|
32 |
+
|
33 |
+
@property
|
34 |
+
def processed_dir(self):
|
35 |
+
return os.path.join(self.root, "grail", self.name, self.version, "processed")
|
36 |
+
|
37 |
+
@property
|
38 |
+
def processed_file_names(self):
|
39 |
+
return "data.pt"
|
40 |
+
|
41 |
+
@property
|
42 |
+
def raw_file_names(self):
|
43 |
+
return [
|
44 |
+
"train_ind.txt", "valid_ind.txt", "test_ind.txt", "train.txt", "valid.txt"
|
45 |
+
]
|
46 |
+
|
47 |
+
def download(self):
|
48 |
+
for url, path in zip(self.urls, self.raw_paths):
|
49 |
+
download_path = download_url(url % self.version, self.raw_dir)
|
50 |
+
os.rename(download_path, path)
|
51 |
+
|
52 |
+
def process(self):
|
53 |
+
test_files = self.raw_paths[:3]
|
54 |
+
train_files = self.raw_paths[3:]
|
55 |
+
|
56 |
+
inv_train_entity_vocab = {}
|
57 |
+
inv_test_entity_vocab = {}
|
58 |
+
inv_relation_vocab = {}
|
59 |
+
triplets = []
|
60 |
+
num_samples = []
|
61 |
+
|
62 |
+
for txt_file in train_files:
|
63 |
+
with open(txt_file, "r") as fin:
|
64 |
+
num_sample = 0
|
65 |
+
for line in fin:
|
66 |
+
h_token, r_token, t_token = line.strip().split("\t")
|
67 |
+
if h_token not in inv_train_entity_vocab:
|
68 |
+
inv_train_entity_vocab[h_token] = len(inv_train_entity_vocab)
|
69 |
+
h = inv_train_entity_vocab[h_token]
|
70 |
+
if r_token not in inv_relation_vocab:
|
71 |
+
inv_relation_vocab[r_token] = len(inv_relation_vocab)
|
72 |
+
r = inv_relation_vocab[r_token]
|
73 |
+
if t_token not in inv_train_entity_vocab:
|
74 |
+
inv_train_entity_vocab[t_token] = len(inv_train_entity_vocab)
|
75 |
+
t = inv_train_entity_vocab[t_token]
|
76 |
+
triplets.append((h, t, r))
|
77 |
+
num_sample += 1
|
78 |
+
num_samples.append(num_sample)
|
79 |
+
|
80 |
+
for txt_file in test_files:
|
81 |
+
with open(txt_file, "r") as fin:
|
82 |
+
num_sample = 0
|
83 |
+
for line in fin:
|
84 |
+
h_token, r_token, t_token = line.strip().split("\t")
|
85 |
+
if h_token not in inv_test_entity_vocab:
|
86 |
+
inv_test_entity_vocab[h_token] = len(inv_test_entity_vocab)
|
87 |
+
h = inv_test_entity_vocab[h_token]
|
88 |
+
assert r_token in inv_relation_vocab
|
89 |
+
r = inv_relation_vocab[r_token]
|
90 |
+
if t_token not in inv_test_entity_vocab:
|
91 |
+
inv_test_entity_vocab[t_token] = len(inv_test_entity_vocab)
|
92 |
+
t = inv_test_entity_vocab[t_token]
|
93 |
+
triplets.append((h, t, r))
|
94 |
+
num_sample += 1
|
95 |
+
num_samples.append(num_sample)
|
96 |
+
triplets = torch.tensor(triplets)
|
97 |
+
|
98 |
+
edge_index = triplets[:, :2].t()
|
99 |
+
edge_type = triplets[:, 2]
|
100 |
+
num_relations = int(edge_type.max()) + 1
|
101 |
+
|
102 |
+
# creating fact graphs - those are graphs sent to a model, based on which we'll predict missing facts
|
103 |
+
# also, those fact graphs will be used for filtered evaluation
|
104 |
+
train_fact_slice = slice(None, sum(num_samples[:1]))
|
105 |
+
test_fact_slice = slice(sum(num_samples[:2]), sum(num_samples[:3]))
|
106 |
+
train_fact_index = edge_index[:, train_fact_slice]
|
107 |
+
train_fact_type = edge_type[train_fact_slice]
|
108 |
+
test_fact_index = edge_index[:, test_fact_slice]
|
109 |
+
test_fact_type = edge_type[test_fact_slice]
|
110 |
+
|
111 |
+
# add flipped triplets for the fact graphs
|
112 |
+
train_fact_index = torch.cat([train_fact_index, train_fact_index.flip(0)], dim=-1)
|
113 |
+
train_fact_type = torch.cat([train_fact_type, train_fact_type + num_relations])
|
114 |
+
test_fact_index = torch.cat([test_fact_index, test_fact_index.flip(0)], dim=-1)
|
115 |
+
test_fact_type = torch.cat([test_fact_type, test_fact_type + num_relations])
|
116 |
+
|
117 |
+
train_slice = slice(None, sum(num_samples[:1]))
|
118 |
+
valid_slice = slice(sum(num_samples[:1]), sum(num_samples[:2]))
|
119 |
+
# by default, SOTA models on Grail datasets merge inductive valid and test splits as the final test split
|
120 |
+
# with this choice, the validation set is that of the transductive train (on the seen graph)
|
121 |
+
# by default it's turned on but you can experiment with turning this option off
|
122 |
+
test_slice = slice(sum(num_samples[:3]), sum(num_samples)) if self.merge_valid_test else slice(sum(num_samples[:4]), sum(num_samples))
|
123 |
+
|
124 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
|
125 |
+
target_edge_index=edge_index[:, train_slice], target_edge_type=edge_type[train_slice], num_relations=num_relations*2)
|
126 |
+
valid_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=len(inv_train_entity_vocab),
|
127 |
+
target_edge_index=edge_index[:, valid_slice], target_edge_type=edge_type[valid_slice], num_relations=num_relations*2)
|
128 |
+
test_data = Data(edge_index=test_fact_index, edge_type=test_fact_type, num_nodes=len(inv_test_entity_vocab),
|
129 |
+
target_edge_index=edge_index[:, test_slice], target_edge_type=edge_type[test_slice], num_relations=num_relations*2)
|
130 |
+
|
131 |
+
if self.pre_transform is not None:
|
132 |
+
train_data = self.pre_transform(train_data)
|
133 |
+
valid_data = self.pre_transform(valid_data)
|
134 |
+
test_data = self.pre_transform(test_data)
|
135 |
+
|
136 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
137 |
+
|
138 |
+
def __repr__(self):
|
139 |
+
return "%s(%s)" % (self.name, self.version)
|
140 |
+
|
141 |
+
|
142 |
+
class FB15k237Inductive(GrailInductiveDataset):
|
143 |
+
|
144 |
+
urls = [
|
145 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/train.txt",
|
146 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/valid.txt",
|
147 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s_ind/test.txt",
|
148 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/train.txt",
|
149 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/fb237_%s/valid.txt"
|
150 |
+
]
|
151 |
+
|
152 |
+
name = "IndFB15k237"
|
153 |
+
|
154 |
+
def __init__(self, root, version):
|
155 |
+
super().__init__(root, version)
|
156 |
+
|
157 |
+
class WN18RRInductive(GrailInductiveDataset):
|
158 |
+
|
159 |
+
urls = [
|
160 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/train.txt",
|
161 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/valid.txt",
|
162 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s_ind/test.txt",
|
163 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/train.txt",
|
164 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/WN18RR_%s/valid.txt"
|
165 |
+
]
|
166 |
+
|
167 |
+
name = "IndWN18RR"
|
168 |
+
|
169 |
+
def __init__(self, root, version):
|
170 |
+
super().__init__(root, version)
|
171 |
+
|
172 |
+
class NELLInductive(GrailInductiveDataset):
|
173 |
+
urls = [
|
174 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/train.txt",
|
175 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/valid.txt",
|
176 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s_ind/test.txt",
|
177 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/train.txt",
|
178 |
+
"https://raw.githubusercontent.com/kkteru/grail/master/data/nell_%s/valid.txt"
|
179 |
+
]
|
180 |
+
name = "IndNELL"
|
181 |
+
|
182 |
+
def __init__(self, root, version):
|
183 |
+
super().__init__(root, version)
|
184 |
+
|
185 |
+
|
186 |
+
def FB15k237(root):
|
187 |
+
dataset = RelLinkPredDataset(name="FB15k-237", root=root+"/fb15k237/")
|
188 |
+
data = dataset.data
|
189 |
+
train_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
|
190 |
+
target_edge_index=data.train_edge_index, target_edge_type=data.train_edge_type,
|
191 |
+
num_relations=dataset.num_relations)
|
192 |
+
valid_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
|
193 |
+
target_edge_index=data.valid_edge_index, target_edge_type=data.valid_edge_type,
|
194 |
+
num_relations=dataset.num_relations)
|
195 |
+
test_data = Data(edge_index=data.edge_index, edge_type=data.edge_type, num_nodes=data.num_nodes,
|
196 |
+
target_edge_index=data.test_edge_index, target_edge_type=data.test_edge_type,
|
197 |
+
num_relations=dataset.num_relations)
|
198 |
+
|
199 |
+
# build relation graphs
|
200 |
+
train_data = build_relation_graph(train_data)
|
201 |
+
valid_data = build_relation_graph(valid_data)
|
202 |
+
test_data = build_relation_graph(test_data)
|
203 |
+
|
204 |
+
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
|
205 |
+
return dataset
|
206 |
+
|
207 |
+
def WN18RR(root):
|
208 |
+
dataset = WordNet18RR(root=root+"/wn18rr/")
|
209 |
+
# convert wn18rr into the same format as fb15k-237
|
210 |
+
data = dataset.data
|
211 |
+
num_nodes = int(data.edge_index.max()) + 1
|
212 |
+
num_relations = int(data.edge_type.max()) + 1
|
213 |
+
edge_index = data.edge_index[:, data.train_mask]
|
214 |
+
edge_type = data.edge_type[data.train_mask]
|
215 |
+
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)
|
216 |
+
edge_type = torch.cat([edge_type, edge_type + num_relations])
|
217 |
+
train_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
|
218 |
+
target_edge_index=data.edge_index[:, data.train_mask],
|
219 |
+
target_edge_type=data.edge_type[data.train_mask],
|
220 |
+
num_relations=num_relations*2)
|
221 |
+
valid_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
|
222 |
+
target_edge_index=data.edge_index[:, data.val_mask],
|
223 |
+
target_edge_type=data.edge_type[data.val_mask],
|
224 |
+
num_relations=num_relations*2)
|
225 |
+
test_data = Data(edge_index=edge_index, edge_type=edge_type, num_nodes=num_nodes,
|
226 |
+
target_edge_index=data.edge_index[:, data.test_mask],
|
227 |
+
target_edge_type=data.edge_type[data.test_mask],
|
228 |
+
num_relations=num_relations*2)
|
229 |
+
|
230 |
+
# build relation graphs
|
231 |
+
train_data = build_relation_graph(train_data)
|
232 |
+
valid_data = build_relation_graph(valid_data)
|
233 |
+
test_data = build_relation_graph(test_data)
|
234 |
+
|
235 |
+
dataset.data, dataset.slices = dataset.collate([train_data, valid_data, test_data])
|
236 |
+
dataset.num_relations = num_relations * 2
|
237 |
+
return dataset
|
238 |
+
|
239 |
+
|
240 |
+
class TransductiveDataset(InMemoryDataset):
|
241 |
+
|
242 |
+
delimiter = None
|
243 |
+
|
244 |
+
def __init__(self, root, transform=None, pre_transform=build_relation_graph, **kwargs):
|
245 |
+
|
246 |
+
super().__init__(root, transform, pre_transform)
|
247 |
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
248 |
+
|
249 |
+
@property
|
250 |
+
def raw_file_names(self):
|
251 |
+
return ["train.txt", "valid.txt", "test.txt"]
|
252 |
+
|
253 |
+
def download(self):
|
254 |
+
for url, path in zip(self.urls, self.raw_paths):
|
255 |
+
download_path = download_url(url, self.raw_dir)
|
256 |
+
os.rename(download_path, path)
|
257 |
+
|
258 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
|
259 |
+
|
260 |
+
triplets = []
|
261 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
262 |
+
|
263 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
264 |
+
for l in fin:
|
265 |
+
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
266 |
+
if u not in inv_entity_vocab:
|
267 |
+
inv_entity_vocab[u] = entity_cnt
|
268 |
+
entity_cnt += 1
|
269 |
+
if v not in inv_entity_vocab:
|
270 |
+
inv_entity_vocab[v] = entity_cnt
|
271 |
+
entity_cnt += 1
|
272 |
+
if r not in inv_rel_vocab:
|
273 |
+
inv_rel_vocab[r] = rel_cnt
|
274 |
+
rel_cnt += 1
|
275 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
276 |
+
|
277 |
+
triplets.append((u, v, r))
|
278 |
+
|
279 |
+
return {
|
280 |
+
"triplets": triplets,
|
281 |
+
"num_node": len(inv_entity_vocab), #entity_cnt,
|
282 |
+
"num_relation": rel_cnt,
|
283 |
+
"inv_entity_vocab": inv_entity_vocab,
|
284 |
+
"inv_rel_vocab": inv_rel_vocab
|
285 |
+
}
|
286 |
+
|
287 |
+
# default loading procedure: process train/valid/test files, create graphs from them
|
288 |
+
def process(self):
|
289 |
+
|
290 |
+
train_files = self.raw_paths[:3]
|
291 |
+
|
292 |
+
train_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
293 |
+
valid_results = self.load_file(train_files[1],
|
294 |
+
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
295 |
+
test_results = self.load_file(train_files[2],
|
296 |
+
train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
297 |
+
|
298 |
+
# in some datasets, there are several new nodes in the test set, eg 123,143 YAGO train adn 123,182 in YAGO test
|
299 |
+
# for consistency with other experimental results, we'll include those in the full vocab and num nodes
|
300 |
+
num_node = test_results["num_node"]
|
301 |
+
# the same for rels: in most cases train == test for transductive
|
302 |
+
# for AristoV4 train rels 1593, test 1604
|
303 |
+
num_relations = test_results["num_relation"]
|
304 |
+
|
305 |
+
train_triplets = train_results["triplets"]
|
306 |
+
valid_triplets = valid_results["triplets"]
|
307 |
+
test_triplets = test_results["triplets"]
|
308 |
+
|
309 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
|
310 |
+
train_target_etypes = torch.tensor([t[2] for t in train_triplets])
|
311 |
+
|
312 |
+
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
|
313 |
+
valid_etypes = torch.tensor([t[2] for t in valid_triplets])
|
314 |
+
|
315 |
+
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
|
316 |
+
test_etypes = torch.tensor([t[2] for t in test_triplets])
|
317 |
+
|
318 |
+
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
319 |
+
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
|
320 |
+
|
321 |
+
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
322 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
|
323 |
+
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
324 |
+
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
|
325 |
+
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
326 |
+
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
|
327 |
+
|
328 |
+
# build graphs of relations
|
329 |
+
if self.pre_transform is not None:
|
330 |
+
train_data = self.pre_transform(train_data)
|
331 |
+
valid_data = self.pre_transform(valid_data)
|
332 |
+
test_data = self.pre_transform(test_data)
|
333 |
+
|
334 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
335 |
+
|
336 |
+
def __repr__(self):
|
337 |
+
return "%s()" % (self.name)
|
338 |
+
|
339 |
+
@property
|
340 |
+
def num_relations(self):
|
341 |
+
return int(self.data.edge_type.max()) + 1
|
342 |
+
|
343 |
+
@property
|
344 |
+
def raw_dir(self):
|
345 |
+
return os.path.join(self.root, self.name, "raw")
|
346 |
+
|
347 |
+
@property
|
348 |
+
def processed_dir(self):
|
349 |
+
return os.path.join(self.root, self.name, "processed")
|
350 |
+
|
351 |
+
@property
|
352 |
+
def processed_file_names(self):
|
353 |
+
return "data.pt"
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
class CoDEx(TransductiveDataset):
|
358 |
+
|
359 |
+
name = "codex"
|
360 |
+
urls = [
|
361 |
+
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/train.txt",
|
362 |
+
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/valid.txt",
|
363 |
+
"https://raw.githubusercontent.com/tsafavi/codex/master/data/triples/%s/test.txt",
|
364 |
+
]
|
365 |
+
|
366 |
+
def download(self):
|
367 |
+
for url, path in zip(self.urls, self.raw_paths):
|
368 |
+
download_path = download_url(url % self.name, self.raw_dir)
|
369 |
+
os.rename(download_path, path)
|
370 |
+
|
371 |
+
|
372 |
+
class CoDExSmall(CoDEx):
|
373 |
+
"""
|
374 |
+
#node: 2034
|
375 |
+
#edge: 36543
|
376 |
+
#relation: 42
|
377 |
+
"""
|
378 |
+
url = "https://zenodo.org/record/4281094/files/codex-s.tar.gz"
|
379 |
+
md5 = "63cd8186fc2aeddc154e20cf4a10087e"
|
380 |
+
name = "codex-s"
|
381 |
+
|
382 |
+
def __init__(self, root):
|
383 |
+
super(CoDExSmall, self).__init__(root=root, size='s')
|
384 |
+
|
385 |
+
|
386 |
+
class CoDExMedium(CoDEx):
|
387 |
+
"""
|
388 |
+
#node: 17050
|
389 |
+
#edge: 206205
|
390 |
+
#relation: 51
|
391 |
+
"""
|
392 |
+
url = "https://zenodo.org/record/4281094/files/codex-m.tar.gz"
|
393 |
+
md5 = "43e561cfdca1c6ad9cc2f5b1ca4add76"
|
394 |
+
name = "codex-m"
|
395 |
+
def __init__(self, root):
|
396 |
+
super(CoDExMedium, self).__init__(root=root, size='m')
|
397 |
+
|
398 |
+
|
399 |
+
class CoDExLarge(CoDEx):
|
400 |
+
"""
|
401 |
+
#node: 77951
|
402 |
+
#edge: 612437
|
403 |
+
#relation: 69
|
404 |
+
"""
|
405 |
+
url = "https://zenodo.org/record/4281094/files/codex-l.tar.gz"
|
406 |
+
md5 = "9a10f4458c4bd2b16ef9b92b677e0d71"
|
407 |
+
name = "codex-l"
|
408 |
+
def __init__(self, root):
|
409 |
+
super(CoDExLarge, self).__init__(root=root, size='l')
|
410 |
+
|
411 |
+
|
412 |
+
class NELL995(TransductiveDataset):
|
413 |
+
|
414 |
+
# from the RED-GNN paper https://github.com/LARS-research/RED-GNN/tree/main/transductive/data/nell
|
415 |
+
# the OG dumps were found to have test set leakages
|
416 |
+
# training set is made out of facts+train files, so we sum up their samples to build one training graph
|
417 |
+
|
418 |
+
urls = [
|
419 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/facts.txt",
|
420 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/train.txt",
|
421 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/valid.txt",
|
422 |
+
"https://raw.githubusercontent.com/LARS-research/RED-GNN/main/transductive/data/nell/test.txt",
|
423 |
+
]
|
424 |
+
name = "nell995"
|
425 |
+
|
426 |
+
@property
|
427 |
+
def raw_file_names(self):
|
428 |
+
return ["facts.txt", "train.txt", "valid.txt", "test.txt"]
|
429 |
+
|
430 |
+
|
431 |
+
def process(self):
|
432 |
+
train_files = self.raw_paths[:4]
|
433 |
+
|
434 |
+
facts_results = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
435 |
+
train_results = self.load_file(train_files[1], facts_results["inv_entity_vocab"], facts_results["inv_rel_vocab"])
|
436 |
+
valid_results = self.load_file(train_files[2], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
437 |
+
test_results = self.load_file(train_files[3], train_results["inv_entity_vocab"], train_results["inv_rel_vocab"])
|
438 |
+
|
439 |
+
num_node = valid_results["num_node"]
|
440 |
+
num_relations = train_results["num_relation"]
|
441 |
+
|
442 |
+
train_triplets = facts_results["triplets"] + train_results["triplets"]
|
443 |
+
valid_triplets = valid_results["triplets"]
|
444 |
+
test_triplets = test_results["triplets"]
|
445 |
+
|
446 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_triplets], dtype=torch.long).t()
|
447 |
+
train_target_etypes = torch.tensor([t[2] for t in train_triplets])
|
448 |
+
|
449 |
+
valid_edges = torch.tensor([[t[0], t[1]] for t in valid_triplets], dtype=torch.long).t()
|
450 |
+
valid_etypes = torch.tensor([t[2] for t in valid_triplets])
|
451 |
+
|
452 |
+
test_edges = torch.tensor([[t[0], t[1]] for t in test_triplets], dtype=torch.long).t()
|
453 |
+
test_etypes = torch.tensor([t[2] for t in test_triplets])
|
454 |
+
|
455 |
+
train_edges = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
456 |
+
train_etypes = torch.cat([train_target_etypes, train_target_etypes+num_relations])
|
457 |
+
|
458 |
+
train_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
459 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_relations*2)
|
460 |
+
valid_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
461 |
+
target_edge_index=valid_edges, target_edge_type=valid_etypes, num_relations=num_relations*2)
|
462 |
+
test_data = Data(edge_index=train_edges, edge_type=train_etypes, num_nodes=num_node,
|
463 |
+
target_edge_index=test_edges, target_edge_type=test_etypes, num_relations=num_relations*2)
|
464 |
+
|
465 |
+
# build graphs of relations
|
466 |
+
if self.pre_transform is not None:
|
467 |
+
train_data = self.pre_transform(train_data)
|
468 |
+
valid_data = self.pre_transform(valid_data)
|
469 |
+
test_data = self.pre_transform(test_data)
|
470 |
+
|
471 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
472 |
+
|
473 |
+
|
474 |
+
class ConceptNet100k(TransductiveDataset):
|
475 |
+
|
476 |
+
urls = [
|
477 |
+
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/train",
|
478 |
+
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/valid",
|
479 |
+
"https://raw.githubusercontent.com/guojiapub/BiQUE/master/src_data/conceptnet-100k/test",
|
480 |
+
]
|
481 |
+
name = "cnet100k"
|
482 |
+
delimiter = "\t"
|
483 |
+
|
484 |
+
|
485 |
+
class DBpedia100k(TransductiveDataset):
|
486 |
+
urls = [
|
487 |
+
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_train.txt",
|
488 |
+
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_valid.txt",
|
489 |
+
"https://raw.githubusercontent.com/iieir-km/ComplEx-NNE_AER/master/datasets/DB100K/_test.txt",
|
490 |
+
]
|
491 |
+
name = "dbp100k"
|
492 |
+
|
493 |
+
|
494 |
+
class YAGO310(TransductiveDataset):
|
495 |
+
|
496 |
+
urls = [
|
497 |
+
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/train.txt",
|
498 |
+
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/valid.txt",
|
499 |
+
"https://raw.githubusercontent.com/DeepGraphLearning/KnowledgeGraphEmbedding/master/data/YAGO3-10/test.txt",
|
500 |
+
]
|
501 |
+
name = "yago310"
|
502 |
+
|
503 |
+
|
504 |
+
class Hetionet(TransductiveDataset):
|
505 |
+
|
506 |
+
urls = [
|
507 |
+
"https://www.dropbox.com/s/y47bt9oq57h6l5k/train.txt?dl=1",
|
508 |
+
"https://www.dropbox.com/s/a0pbrx9tz3dgsff/valid.txt?dl=1",
|
509 |
+
"https://www.dropbox.com/s/4dhrvg3fyq5tnu4/test.txt?dl=1",
|
510 |
+
]
|
511 |
+
name = "hetionet"
|
512 |
+
|
513 |
+
|
514 |
+
class AristoV4(TransductiveDataset):
|
515 |
+
|
516 |
+
url = "https://zenodo.org/record/5942560/files/aristo-v4.zip"
|
517 |
+
|
518 |
+
name = "aristov4"
|
519 |
+
delimiter = "\t"
|
520 |
+
|
521 |
+
def download(self):
|
522 |
+
download_path = download_url(self.url, self.raw_dir)
|
523 |
+
extract_zip(download_path, self.raw_dir)
|
524 |
+
os.unlink(download_path)
|
525 |
+
for oldname, newname in zip(['train', 'valid', 'test'], self.raw_paths):
|
526 |
+
os.rename(os.path.join(self.raw_dir, oldname), newname)
|
527 |
+
|
528 |
+
|
529 |
+
class SparserKG(TransductiveDataset):
|
530 |
+
|
531 |
+
# 5 datasets based on FB/NELL/WD, introduced in https://github.com/THU-KEG/DacKGR
|
532 |
+
# re-writing the loading function because dumps are in the format (h, t, r) while the standard is (h, r, t)
|
533 |
+
|
534 |
+
url = "https://raw.githubusercontent.com/THU-KEG/DacKGR/master/data.zip"
|
535 |
+
delimiter = "\t"
|
536 |
+
base_name = "SparseKG"
|
537 |
+
|
538 |
+
@property
|
539 |
+
def raw_dir(self):
|
540 |
+
return os.path.join(self.root, self.base_name, self.name, "raw")
|
541 |
+
|
542 |
+
@property
|
543 |
+
def processed_dir(self):
|
544 |
+
return os.path.join(self.root, self.base_name, self.name, "processed")
|
545 |
+
|
546 |
+
def download(self):
|
547 |
+
base_path = os.path.join(self.root, self.base_name)
|
548 |
+
download_path = download_url(self.url, base_path)
|
549 |
+
extract_zip(download_path, base_path)
|
550 |
+
for dsname in ['NELL23K', 'WD-singer', 'FB15K-237-10', 'FB15K-237-20', 'FB15K-237-50']:
|
551 |
+
for oldname, newname in zip(['train.triples', 'dev.triples', 'test.triples'], self.raw_file_names):
|
552 |
+
os.renames(os.path.join(base_path, "data", dsname, oldname), os.path.join(base_path, dsname, "raw", newname))
|
553 |
+
shutil.rmtree(os.path.join(base_path, "data"))
|
554 |
+
|
555 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
|
556 |
+
|
557 |
+
triplets = []
|
558 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
559 |
+
|
560 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
561 |
+
for l in fin:
|
562 |
+
u, v, r = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
563 |
+
if u not in inv_entity_vocab:
|
564 |
+
inv_entity_vocab[u] = entity_cnt
|
565 |
+
entity_cnt += 1
|
566 |
+
if v not in inv_entity_vocab:
|
567 |
+
inv_entity_vocab[v] = entity_cnt
|
568 |
+
entity_cnt += 1
|
569 |
+
if r not in inv_rel_vocab:
|
570 |
+
inv_rel_vocab[r] = rel_cnt
|
571 |
+
rel_cnt += 1
|
572 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
573 |
+
|
574 |
+
triplets.append((u, v, r))
|
575 |
+
|
576 |
+
return {
|
577 |
+
"triplets": triplets,
|
578 |
+
"num_node": len(inv_entity_vocab), #entity_cnt,
|
579 |
+
"num_relation": rel_cnt,
|
580 |
+
"inv_entity_vocab": inv_entity_vocab,
|
581 |
+
"inv_rel_vocab": inv_rel_vocab
|
582 |
+
}
|
583 |
+
|
584 |
+
class WDsinger(SparserKG):
|
585 |
+
name = "WD-singer"
|
586 |
+
|
587 |
+
class NELL23k(SparserKG):
|
588 |
+
name = "NELL23K"
|
589 |
+
|
590 |
+
class FB15k237_10(SparserKG):
|
591 |
+
name = "FB15K-237-10"
|
592 |
+
|
593 |
+
class FB15k237_20(SparserKG):
|
594 |
+
name = "FB15K-237-20"
|
595 |
+
|
596 |
+
class FB15k237_50(SparserKG):
|
597 |
+
name = "FB15K-237-50"
|
598 |
+
|
599 |
+
|
600 |
+
class InductiveDataset(InMemoryDataset):
|
601 |
+
|
602 |
+
delimiter = None
|
603 |
+
# some datasets (4 from Hamaguchi et al and Indigo) have validation set based off the train graph, not inference
|
604 |
+
valid_on_inf = True #
|
605 |
+
|
606 |
+
def __init__(self, root, version, transform=None, pre_transform=build_relation_graph, **kwargs):
|
607 |
+
|
608 |
+
self.version = str(version)
|
609 |
+
super().__init__(root, transform, pre_transform)
|
610 |
+
self.data, self.slices = torch.load(self.processed_paths[0])
|
611 |
+
|
612 |
+
def download(self):
|
613 |
+
for url, path in zip(self.urls, self.raw_paths):
|
614 |
+
download_path = download_url(url % self.version, self.raw_dir)
|
615 |
+
os.rename(download_path, path)
|
616 |
+
|
617 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}):
|
618 |
+
|
619 |
+
triplets = []
|
620 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
621 |
+
|
622 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
623 |
+
for l in fin:
|
624 |
+
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
625 |
+
if u not in inv_entity_vocab:
|
626 |
+
inv_entity_vocab[u] = entity_cnt
|
627 |
+
entity_cnt += 1
|
628 |
+
if v not in inv_entity_vocab:
|
629 |
+
inv_entity_vocab[v] = entity_cnt
|
630 |
+
entity_cnt += 1
|
631 |
+
if r not in inv_rel_vocab:
|
632 |
+
inv_rel_vocab[r] = rel_cnt
|
633 |
+
rel_cnt += 1
|
634 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
635 |
+
|
636 |
+
triplets.append((u, v, r))
|
637 |
+
|
638 |
+
return {
|
639 |
+
"triplets": triplets,
|
640 |
+
"num_node": len(inv_entity_vocab), #entity_cnt,
|
641 |
+
"num_relation": rel_cnt,
|
642 |
+
"inv_entity_vocab": inv_entity_vocab,
|
643 |
+
"inv_rel_vocab": inv_rel_vocab
|
644 |
+
}
|
645 |
+
|
646 |
+
def process(self):
|
647 |
+
|
648 |
+
train_files = self.raw_paths[:4]
|
649 |
+
|
650 |
+
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
651 |
+
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
|
652 |
+
valid_res = self.load_file(
|
653 |
+
train_files[2],
|
654 |
+
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
|
655 |
+
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
|
656 |
+
)
|
657 |
+
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
|
658 |
+
|
659 |
+
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
|
660 |
+
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
|
661 |
+
|
662 |
+
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
|
663 |
+
|
664 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
|
665 |
+
train_target_etypes = torch.tensor([t[2] for t in train_edges])
|
666 |
+
|
667 |
+
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
668 |
+
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
|
669 |
+
|
670 |
+
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
|
671 |
+
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
|
672 |
+
inf_etypes = torch.tensor([t[2] for t in inf_graph])
|
673 |
+
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
|
674 |
+
|
675 |
+
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
|
676 |
+
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
|
677 |
+
|
678 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
|
679 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
|
680 |
+
valid_data = Data(edge_index=inf_edges if self.valid_on_inf else train_fact_index,
|
681 |
+
edge_type=inf_etypes if self.valid_on_inf else train_fact_type,
|
682 |
+
num_nodes=inference_num_nodes if self.valid_on_inf else num_train_nodes,
|
683 |
+
target_edge_index=inf_valid_edges[:, :2].T,
|
684 |
+
target_edge_type=inf_valid_edges[:, 2],
|
685 |
+
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
|
686 |
+
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
|
687 |
+
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
|
688 |
+
|
689 |
+
if self.pre_transform is not None:
|
690 |
+
train_data = self.pre_transform(train_data)
|
691 |
+
valid_data = self.pre_transform(valid_data)
|
692 |
+
test_data = self.pre_transform(test_data)
|
693 |
+
|
694 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
695 |
+
|
696 |
+
@property
|
697 |
+
def num_relations(self):
|
698 |
+
return int(self.data.edge_type.max()) + 1
|
699 |
+
|
700 |
+
@property
|
701 |
+
def raw_dir(self):
|
702 |
+
return os.path.join(self.root, self.name, self.version, "raw")
|
703 |
+
|
704 |
+
@property
|
705 |
+
def processed_dir(self):
|
706 |
+
return os.path.join(self.root, self.name, self.version, "processed")
|
707 |
+
|
708 |
+
@property
|
709 |
+
def raw_file_names(self):
|
710 |
+
return [
|
711 |
+
"transductive_train.txt", "inference_graph.txt", "inf_valid.txt", "inf_test.txt"
|
712 |
+
]
|
713 |
+
|
714 |
+
@property
|
715 |
+
def processed_file_names(self):
|
716 |
+
return "data.pt"
|
717 |
+
|
718 |
+
def __repr__(self):
|
719 |
+
return "%s(%s)" % (self.name, self.version)
|
720 |
+
|
721 |
+
|
722 |
+
class IngramInductive(InductiveDataset):
|
723 |
+
|
724 |
+
@property
|
725 |
+
def raw_dir(self):
|
726 |
+
return os.path.join(self.root, "ingram", self.name, self.version, "raw")
|
727 |
+
|
728 |
+
@property
|
729 |
+
def processed_dir(self):
|
730 |
+
return os.path.join(self.root, "ingram", self.name, self.version, "processed")
|
731 |
+
|
732 |
+
|
733 |
+
class FBIngram(IngramInductive):
|
734 |
+
|
735 |
+
urls = [
|
736 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/train.txt",
|
737 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/msg.txt",
|
738 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/valid.txt",
|
739 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/FB-%s/test.txt",
|
740 |
+
]
|
741 |
+
name = "fb"
|
742 |
+
|
743 |
+
|
744 |
+
class WKIngram(IngramInductive):
|
745 |
+
|
746 |
+
urls = [
|
747 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/train.txt",
|
748 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/msg.txt",
|
749 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/valid.txt",
|
750 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/WK-%s/test.txt",
|
751 |
+
]
|
752 |
+
name = "wk"
|
753 |
+
|
754 |
+
class NLIngram(IngramInductive):
|
755 |
+
|
756 |
+
urls = [
|
757 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/train.txt",
|
758 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/msg.txt",
|
759 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/valid.txt",
|
760 |
+
"https://raw.githubusercontent.com/bdi-lab/InGram/master/data/NL-%s/test.txt",
|
761 |
+
]
|
762 |
+
name = "nl"
|
763 |
+
|
764 |
+
|
765 |
+
class ILPC2022(InductiveDataset):
|
766 |
+
|
767 |
+
urls = [
|
768 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/train.txt",
|
769 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference.txt",
|
770 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_validation.txt",
|
771 |
+
"https://raw.githubusercontent.com/pykeen/ilpc2022/master/data/%s/inference_test.txt",
|
772 |
+
]
|
773 |
+
|
774 |
+
name = "ilpc2022"
|
775 |
+
|
776 |
+
|
777 |
+
class HM(InductiveDataset):
|
778 |
+
# benchmarks from Hamaguchi et al and Indigo BM
|
779 |
+
|
780 |
+
urls = [
|
781 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/train.txt",
|
782 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-graph.txt",
|
783 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/train/valid.txt",
|
784 |
+
"https://raw.githubusercontent.com/shuwen-liu-ox/INDIGO/master/data/%s/test/test-fact.txt",
|
785 |
+
]
|
786 |
+
|
787 |
+
name = "hm"
|
788 |
+
versions = {
|
789 |
+
'1k': "Hamaguchi-BM_both-1000",
|
790 |
+
'3k': "Hamaguchi-BM_both-3000",
|
791 |
+
'5k': "Hamaguchi-BM_both-5000",
|
792 |
+
'indigo': "INDIGO-BM"
|
793 |
+
}
|
794 |
+
# in 4 HM graphs, the validation set is based off the training graph, so we'll adjust the dataset creation accordingly
|
795 |
+
valid_on_inf = False
|
796 |
+
|
797 |
+
def __init__(self, root, version, **kwargs):
|
798 |
+
version = self.versions[version]
|
799 |
+
super().__init__(root, version, **kwargs)
|
800 |
+
|
801 |
+
# HM datasets are a bit weird: validation set (based off the train graph) has a few hundred new nodes, so we need a custom processing
|
802 |
+
def process(self):
|
803 |
+
|
804 |
+
train_files = self.raw_paths[:4]
|
805 |
+
|
806 |
+
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
807 |
+
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
|
808 |
+
valid_res = self.load_file(
|
809 |
+
train_files[2],
|
810 |
+
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
|
811 |
+
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"]
|
812 |
+
)
|
813 |
+
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
|
814 |
+
|
815 |
+
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
|
816 |
+
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
|
817 |
+
|
818 |
+
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
|
819 |
+
|
820 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
|
821 |
+
train_target_etypes = torch.tensor([t[2] for t in train_edges])
|
822 |
+
|
823 |
+
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
824 |
+
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
|
825 |
+
|
826 |
+
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
|
827 |
+
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
|
828 |
+
inf_etypes = torch.tensor([t[2] for t in inf_graph])
|
829 |
+
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
|
830 |
+
|
831 |
+
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
|
832 |
+
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
|
833 |
+
|
834 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
|
835 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
|
836 |
+
valid_data = Data(edge_index=train_fact_index,
|
837 |
+
edge_type=train_fact_type,
|
838 |
+
num_nodes=valid_res["num_node"], # the only fix in this function
|
839 |
+
target_edge_index=inf_valid_edges[:, :2].T,
|
840 |
+
target_edge_type=inf_valid_edges[:, 2],
|
841 |
+
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
|
842 |
+
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
|
843 |
+
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
|
844 |
+
|
845 |
+
if self.pre_transform is not None:
|
846 |
+
train_data = self.pre_transform(train_data)
|
847 |
+
valid_data = self.pre_transform(valid_data)
|
848 |
+
test_data = self.pre_transform(test_data)
|
849 |
+
|
850 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
851 |
+
|
852 |
+
|
853 |
+
class MTDEAInductive(InductiveDataset):
|
854 |
+
|
855 |
+
valid_on_inf = False
|
856 |
+
url = "https://reltrans.s3.us-east-2.amazonaws.com/MTDEA_data.zip"
|
857 |
+
base_name = "mtdea"
|
858 |
+
|
859 |
+
def __init__(self, root, version, **kwargs):
|
860 |
+
|
861 |
+
assert version in self.versions, f"unknown version {version} for {self.name}, available: {self.versions}"
|
862 |
+
super().__init__(root, version, **kwargs)
|
863 |
+
|
864 |
+
@property
|
865 |
+
def raw_dir(self):
|
866 |
+
return os.path.join(self.root, self.base_name, self.name, self.version, "raw")
|
867 |
+
|
868 |
+
@property
|
869 |
+
def processed_dir(self):
|
870 |
+
return os.path.join(self.root, self.base_name, self.name, self.version, "processed")
|
871 |
+
|
872 |
+
@property
|
873 |
+
def raw_file_names(self):
|
874 |
+
return [
|
875 |
+
"transductive_train.txt", "inference_graph.txt", "transductive_valid.txt", "inf_test.txt"
|
876 |
+
]
|
877 |
+
|
878 |
+
def download(self):
|
879 |
+
base_path = os.path.join(self.root, self.base_name)
|
880 |
+
download_path = download_url(self.url, base_path)
|
881 |
+
extract_zip(download_path, base_path)
|
882 |
+
# unzip all datasets at once
|
883 |
+
for dsname in ['FBNELL', 'Metafam', 'WikiTopics-MT1', 'WikiTopics-MT2', 'WikiTopics-MT3', 'WikiTopics-MT4']:
|
884 |
+
cl = globals()[dsname.replace("-","")]
|
885 |
+
versions = cl.versions
|
886 |
+
for version in versions:
|
887 |
+
for oldname, newname in zip(['train.txt', 'observe.txt', 'valid.txt', 'test.txt'], self.raw_file_names):
|
888 |
+
foldername = cl.prefix % version + "-trans" if "transductive" in newname else cl.prefix % version + "-ind"
|
889 |
+
os.renames(
|
890 |
+
os.path.join(base_path, "MTDEA_datasets", dsname, foldername, oldname),
|
891 |
+
os.path.join(base_path, dsname, version, "raw", newname)
|
892 |
+
)
|
893 |
+
shutil.rmtree(os.path.join(base_path, "MTDEA_datasets"))
|
894 |
+
|
895 |
+
def load_file(self, triplet_file, inv_entity_vocab={}, inv_rel_vocab={}, limit_vocab=False):
|
896 |
+
|
897 |
+
triplets = []
|
898 |
+
entity_cnt, rel_cnt = len(inv_entity_vocab), len(inv_rel_vocab)
|
899 |
+
|
900 |
+
# limit_vocab is for dropping triples with unseen head/tail not seen in the main entity_vocab
|
901 |
+
# can be used for FBNELL and MT3:art, other datasets seem to be ok and share num_nodes/num_relations in the train/inference graph
|
902 |
+
with open(triplet_file, "r", encoding="utf-8") as fin:
|
903 |
+
for l in fin:
|
904 |
+
u, r, v = l.split() if self.delimiter is None else l.strip().split(self.delimiter)
|
905 |
+
if u not in inv_entity_vocab:
|
906 |
+
if limit_vocab:
|
907 |
+
continue
|
908 |
+
inv_entity_vocab[u] = entity_cnt
|
909 |
+
entity_cnt += 1
|
910 |
+
if v not in inv_entity_vocab:
|
911 |
+
if limit_vocab:
|
912 |
+
continue
|
913 |
+
inv_entity_vocab[v] = entity_cnt
|
914 |
+
entity_cnt += 1
|
915 |
+
if r not in inv_rel_vocab:
|
916 |
+
if limit_vocab:
|
917 |
+
continue
|
918 |
+
inv_rel_vocab[r] = rel_cnt
|
919 |
+
rel_cnt += 1
|
920 |
+
u, r, v = inv_entity_vocab[u], inv_rel_vocab[r], inv_entity_vocab[v]
|
921 |
+
|
922 |
+
triplets.append((u, v, r))
|
923 |
+
|
924 |
+
return {
|
925 |
+
"triplets": triplets,
|
926 |
+
"num_node": entity_cnt,
|
927 |
+
"num_relation": rel_cnt,
|
928 |
+
"inv_entity_vocab": inv_entity_vocab,
|
929 |
+
"inv_rel_vocab": inv_rel_vocab
|
930 |
+
}
|
931 |
+
|
932 |
+
# special processes for MTDEA datasets for one particular fix in the validation set loading
|
933 |
+
def process(self):
|
934 |
+
|
935 |
+
train_files = self.raw_paths[:4]
|
936 |
+
|
937 |
+
train_res = self.load_file(train_files[0], inv_entity_vocab={}, inv_rel_vocab={})
|
938 |
+
inference_res = self.load_file(train_files[1], inv_entity_vocab={}, inv_rel_vocab={})
|
939 |
+
valid_res = self.load_file(
|
940 |
+
train_files[2],
|
941 |
+
inference_res["inv_entity_vocab"] if self.valid_on_inf else train_res["inv_entity_vocab"],
|
942 |
+
inference_res["inv_rel_vocab"] if self.valid_on_inf else train_res["inv_rel_vocab"],
|
943 |
+
limit_vocab=True, # the 1st fix in this function compared to the superclass processor
|
944 |
+
)
|
945 |
+
test_res = self.load_file(train_files[3], inference_res["inv_entity_vocab"], inference_res["inv_rel_vocab"])
|
946 |
+
|
947 |
+
num_train_nodes, num_train_rels = train_res["num_node"], train_res["num_relation"]
|
948 |
+
inference_num_nodes, inference_num_rels = test_res["num_node"], test_res["num_relation"]
|
949 |
+
|
950 |
+
train_edges, inf_graph, inf_valid_edges, inf_test_edges = train_res["triplets"], inference_res["triplets"], valid_res["triplets"], test_res["triplets"]
|
951 |
+
|
952 |
+
train_target_edges = torch.tensor([[t[0], t[1]] for t in train_edges], dtype=torch.long).t()
|
953 |
+
train_target_etypes = torch.tensor([t[2] for t in train_edges])
|
954 |
+
|
955 |
+
train_fact_index = torch.cat([train_target_edges, train_target_edges.flip(0)], dim=1)
|
956 |
+
train_fact_type = torch.cat([train_target_etypes, train_target_etypes + num_train_rels])
|
957 |
+
|
958 |
+
inf_edges = torch.tensor([[t[0], t[1]] for t in inf_graph], dtype=torch.long).t()
|
959 |
+
inf_edges = torch.cat([inf_edges, inf_edges.flip(0)], dim=1)
|
960 |
+
inf_etypes = torch.tensor([t[2] for t in inf_graph])
|
961 |
+
inf_etypes = torch.cat([inf_etypes, inf_etypes + inference_num_rels])
|
962 |
+
|
963 |
+
inf_valid_edges = torch.tensor(inf_valid_edges, dtype=torch.long)
|
964 |
+
inf_test_edges = torch.tensor(inf_test_edges, dtype=torch.long)
|
965 |
+
|
966 |
+
train_data = Data(edge_index=train_fact_index, edge_type=train_fact_type, num_nodes=num_train_nodes,
|
967 |
+
target_edge_index=train_target_edges, target_edge_type=train_target_etypes, num_relations=num_train_rels*2)
|
968 |
+
valid_data = Data(edge_index=train_fact_index,
|
969 |
+
edge_type=train_fact_type,
|
970 |
+
num_nodes=valid_res["num_node"], # the 2nd fix in this function
|
971 |
+
target_edge_index=inf_valid_edges[:, :2].T,
|
972 |
+
target_edge_type=inf_valid_edges[:, 2],
|
973 |
+
num_relations=inference_num_rels*2 if self.valid_on_inf else num_train_rels*2)
|
974 |
+
test_data = Data(edge_index=inf_edges, edge_type=inf_etypes, num_nodes=inference_num_nodes,
|
975 |
+
target_edge_index=inf_test_edges[:, :2].T, target_edge_type=inf_test_edges[:, 2], num_relations=inference_num_rels*2)
|
976 |
+
|
977 |
+
if self.pre_transform is not None:
|
978 |
+
train_data = self.pre_transform(train_data)
|
979 |
+
valid_data = self.pre_transform(valid_data)
|
980 |
+
test_data = self.pre_transform(test_data)
|
981 |
+
|
982 |
+
torch.save((self.collate([train_data, valid_data, test_data])), self.processed_paths[0])
|
983 |
+
|
984 |
+
|
985 |
+
class FBNELL(MTDEAInductive):
|
986 |
+
|
987 |
+
name = "FBNELL"
|
988 |
+
prefix = "%s"
|
989 |
+
versions = ["FBNELL_v1"]
|
990 |
+
|
991 |
+
def __init__(self, **kwargs):
|
992 |
+
kwargs.pop("version")
|
993 |
+
kwargs['version'] = self.versions[0]
|
994 |
+
super(FBNELL, self).__init__(**kwargs)
|
995 |
+
|
996 |
+
|
997 |
+
class Metafam(MTDEAInductive):
|
998 |
+
|
999 |
+
name = "Metafam"
|
1000 |
+
prefix = "%s"
|
1001 |
+
versions = ["Metafam"]
|
1002 |
+
|
1003 |
+
def __init__(self, **kwargs):
|
1004 |
+
kwargs.pop("version")
|
1005 |
+
kwargs['version'] = self.versions[0]
|
1006 |
+
super(Metafam, self).__init__(**kwargs)
|
1007 |
+
|
1008 |
+
|
1009 |
+
class WikiTopicsMT1(MTDEAInductive):
|
1010 |
+
|
1011 |
+
name = "WikiTopics-MT1"
|
1012 |
+
prefix = "wikidata_%sv1"
|
1013 |
+
versions = ['mt', 'health', 'tax']
|
1014 |
+
|
1015 |
+
def __init__(self, **kwargs):
|
1016 |
+
assert kwargs['version'] in self.versions, f"unknown version {kwargs['version']}, available: {self.versions}"
|
1017 |
+
super(WikiTopicsMT1, self).__init__(**kwargs)
|
1018 |
+
|
1019 |
+
|
1020 |
+
class WikiTopicsMT2(MTDEAInductive):
|
1021 |
+
|
1022 |
+
name = "WikiTopics-MT2"
|
1023 |
+
prefix = "wikidata_%sv1"
|
1024 |
+
versions = ['mt2', 'org', 'sci']
|
1025 |
+
|
1026 |
+
def __init__(self, **kwargs):
|
1027 |
+
super(WikiTopicsMT2, self).__init__(**kwargs)
|
1028 |
+
|
1029 |
+
|
1030 |
+
class WikiTopicsMT3(MTDEAInductive):
|
1031 |
+
|
1032 |
+
name = "WikiTopics-MT3"
|
1033 |
+
prefix = "wikidata_%sv2"
|
1034 |
+
versions = ['mt3', 'art', 'infra']
|
1035 |
+
|
1036 |
+
def __init__(self, **kwargs):
|
1037 |
+
super(WikiTopicsMT3, self).__init__(**kwargs)
|
1038 |
+
|
1039 |
+
|
1040 |
+
class WikiTopicsMT4(MTDEAInductive):
|
1041 |
+
|
1042 |
+
name = "WikiTopics-MT4"
|
1043 |
+
prefix = "wikidata_%sv2"
|
1044 |
+
versions = ['mt4', 'sci', 'health']
|
1045 |
+
|
1046 |
+
def __init__(self, **kwargs):
|
1047 |
+
super(WikiTopicsMT4, self).__init__(**kwargs)
|
1048 |
+
|
1049 |
+
|
1050 |
+
# a joint dataset for pre-training ULTRA on several graphs
|
1051 |
+
class JointDataset(InMemoryDataset):
|
1052 |
+
|
1053 |
+
datasets_map = {
|
1054 |
+
'FB15k237': FB15k237,
|
1055 |
+
'WN18RR': WN18RR,
|
1056 |
+
'CoDExSmall': CoDExSmall,
|
1057 |
+
'CoDExMedium': CoDExMedium,
|
1058 |
+
'CoDExLarge': CoDExLarge,
|
1059 |
+
'NELL995': NELL995,
|
1060 |
+
'ConceptNet100k': ConceptNet100k,
|
1061 |
+
'DBpedia100k': DBpedia100k,
|
1062 |
+
'YAGO310': YAGO310,
|
1063 |
+
'AristoV4': AristoV4,
|
1064 |
+
}
|
1065 |
+
|
1066 |
+
def __init__(self, root, graphs, transform=None, pre_transform=None):
|
1067 |
+
|
1068 |
+
|
1069 |
+
self.graphs = [self.datasets_map[ds](root=root) for ds in graphs]
|
1070 |
+
self.num_graphs = len(graphs)
|
1071 |
+
super().__init__(root, transform, pre_transform)
|
1072 |
+
self.data = torch.load(self.processed_paths[0])
|
1073 |
+
|
1074 |
+
@property
|
1075 |
+
def raw_dir(self):
|
1076 |
+
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "raw")
|
1077 |
+
|
1078 |
+
@property
|
1079 |
+
def processed_dir(self):
|
1080 |
+
return os.path.join(self.root, "joint", f'{self.num_graphs}g', "processed")
|
1081 |
+
|
1082 |
+
@property
|
1083 |
+
def processed_file_names(self):
|
1084 |
+
return "data.pt"
|
1085 |
+
|
1086 |
+
def process(self):
|
1087 |
+
|
1088 |
+
train_data = [g[0] for g in self.graphs]
|
1089 |
+
valid_data = [g[1] for g in self.graphs]
|
1090 |
+
test_data = [g[2] for g in self.graphs]
|
1091 |
+
# filter_data = [
|
1092 |
+
# Data(edge_index=g.data.target_edge_index, edge_type=g.data.target_edge_type, num_nodes=g[0].num_nodes) for g in self.graphs
|
1093 |
+
# ]
|
1094 |
+
|
1095 |
+
torch.save((train_data, valid_data, test_data), self.processed_paths[0])
|
ultra/eval.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import distributed as dist
|
5 |
+
from torch.utils import data as torch_data
|
6 |
+
from torch_geometric.data import Data
|
7 |
+
|
8 |
+
from ultra import tasks, util
|
9 |
+
|
10 |
+
|
11 |
+
TRANSDUCTIVE = ("WordNet18RR", "RelLinkPredDataset", "CoDExSmall", "CoDExMedium", "CoDExLarge",
|
12 |
+
"YAGO310", "NELL995", "ConceptNet100k", "DBpedia100k", "Hetionet", "AristoV4",
|
13 |
+
"WDsinger", "NELL23k", "FB15k237_10", "FB15k237_20", "FB15k237_50")
|
14 |
+
|
15 |
+
|
16 |
+
def get_filtered_data(dataset, mode):
|
17 |
+
train_data, valid_data, test_data = dataset[0], dataset[1], dataset[2]
|
18 |
+
ds_name = dataset.__class__.__name__
|
19 |
+
|
20 |
+
if ds_name in TRANSDUCTIVE:
|
21 |
+
filtered_data = Data(edge_index=dataset._data.target_edge_index, edge_type=dataset._data.target_edge_type, num_nodes=dataset[0].num_nodes)
|
22 |
+
else:
|
23 |
+
if "ILPC" in ds_name or "Ingram" in ds_name:
|
24 |
+
full_inference_edges = torch.cat([valid_data.edge_index, valid_data.target_edge_index, test_data.target_edge_index], dim=1)
|
25 |
+
full_inference_etypes = torch.cat([valid_data.edge_type, valid_data.target_edge_type, test_data.target_edge_type])
|
26 |
+
filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
|
27 |
+
else:
|
28 |
+
# test filtering graph: inference edges + test edges
|
29 |
+
full_inference_edges = torch.cat([test_data.edge_index, test_data.target_edge_index], dim=1)
|
30 |
+
full_inference_etypes = torch.cat([test_data.edge_type, test_data.target_edge_type])
|
31 |
+
if mode == "test":
|
32 |
+
filtered_data = Data(edge_index=full_inference_edges, edge_type=full_inference_etypes, num_nodes=test_data.num_nodes)
|
33 |
+
else:
|
34 |
+
# validation filtering graph: train edges + validation edges
|
35 |
+
filtered_data = Data(
|
36 |
+
edge_index=torch.cat([train_data.edge_index, valid_data.target_edge_index], dim=1),
|
37 |
+
edge_type=torch.cat([train_data.edge_type, valid_data.target_edge_type])
|
38 |
+
)
|
39 |
+
|
40 |
+
return filtered_data
|
41 |
+
|
42 |
+
|
43 |
+
@torch.no_grad()
|
44 |
+
def test(model, mode, dataset, batch_size=32, eval_metrics=["mrr", "hits@10"], gpus=None, return_metrics=False):
|
45 |
+
logger = util.get_root_logger()
|
46 |
+
test_data = dataset[1] if mode == "valid" else dataset[2]
|
47 |
+
filtered_data = get_filtered_data(dataset, mode)
|
48 |
+
|
49 |
+
device = util.get_devices(gpus)
|
50 |
+
world_size = util.get_world_size()
|
51 |
+
rank = util.get_rank()
|
52 |
+
|
53 |
+
test_triplets = torch.cat([test_data.target_edge_index, test_data.target_edge_type.unsqueeze(0)]).t()
|
54 |
+
sampler = torch_data.DistributedSampler(test_triplets, world_size, rank)
|
55 |
+
test_loader = torch_data.DataLoader(test_triplets, batch_size, sampler=sampler)
|
56 |
+
|
57 |
+
model.eval()
|
58 |
+
rankings = []
|
59 |
+
num_negatives = []
|
60 |
+
tail_rankings, num_tail_negs = [], [] # for explicit tail-only evaluation needed for 5 datasets
|
61 |
+
for batch in test_loader:
|
62 |
+
t_batch, h_batch = tasks.all_negative(test_data, batch)
|
63 |
+
t_pred = model(test_data, t_batch)
|
64 |
+
h_pred = model(test_data, h_batch)
|
65 |
+
|
66 |
+
if filtered_data is None:
|
67 |
+
t_mask, h_mask = tasks.strict_negative_mask(test_data, batch)
|
68 |
+
else:
|
69 |
+
t_mask, h_mask = tasks.strict_negative_mask(filtered_data, batch)
|
70 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
71 |
+
t_ranking = tasks.compute_ranking(t_pred, pos_t_index, t_mask)
|
72 |
+
h_ranking = tasks.compute_ranking(h_pred, pos_h_index, h_mask)
|
73 |
+
num_t_negative = t_mask.sum(dim=-1)
|
74 |
+
num_h_negative = h_mask.sum(dim=-1)
|
75 |
+
|
76 |
+
rankings += [t_ranking, h_ranking]
|
77 |
+
num_negatives += [num_t_negative, num_h_negative]
|
78 |
+
|
79 |
+
tail_rankings += [t_ranking]
|
80 |
+
num_tail_negs += [num_t_negative]
|
81 |
+
|
82 |
+
ranking = torch.cat(rankings)
|
83 |
+
num_negative = torch.cat(num_negatives)
|
84 |
+
all_size = torch.zeros(world_size, dtype=torch.long, device=device)
|
85 |
+
all_size[rank] = len(ranking)
|
86 |
+
|
87 |
+
# ugly repetitive code for tail-only ranks processing
|
88 |
+
tail_ranking = torch.cat(tail_rankings)
|
89 |
+
num_tail_neg = torch.cat(num_tail_negs)
|
90 |
+
all_size_t = torch.zeros(world_size, dtype=torch.long, device=device)
|
91 |
+
all_size_t[rank] = len(tail_ranking)
|
92 |
+
if world_size > 1:
|
93 |
+
dist.all_reduce(all_size, op=dist.ReduceOp.SUM)
|
94 |
+
dist.all_reduce(all_size_t, op=dist.ReduceOp.SUM)
|
95 |
+
|
96 |
+
# obtaining all ranks
|
97 |
+
cum_size = all_size.cumsum(0)
|
98 |
+
all_ranking = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
|
99 |
+
all_ranking[cum_size[rank] - all_size[rank]: cum_size[rank]] = ranking
|
100 |
+
all_num_negative = torch.zeros(all_size.sum(), dtype=torch.long, device=device)
|
101 |
+
all_num_negative[cum_size[rank] - all_size[rank]: cum_size[rank]] = num_negative
|
102 |
+
|
103 |
+
# the same for tails-only ranks
|
104 |
+
cum_size_t = all_size_t.cumsum(0)
|
105 |
+
all_ranking_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
|
106 |
+
all_ranking_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = tail_ranking
|
107 |
+
all_num_negative_t = torch.zeros(all_size_t.sum(), dtype=torch.long, device=device)
|
108 |
+
all_num_negative_t[cum_size_t[rank] - all_size_t[rank]: cum_size_t[rank]] = num_tail_neg
|
109 |
+
if world_size > 1:
|
110 |
+
dist.all_reduce(all_ranking, op=dist.ReduceOp.SUM)
|
111 |
+
dist.all_reduce(all_num_negative, op=dist.ReduceOp.SUM)
|
112 |
+
dist.all_reduce(all_ranking_t, op=dist.ReduceOp.SUM)
|
113 |
+
dist.all_reduce(all_num_negative_t, op=dist.ReduceOp.SUM)
|
114 |
+
|
115 |
+
metrics = {}
|
116 |
+
if rank == 0:
|
117 |
+
for metric in eval_metrics:
|
118 |
+
if "-tail" in metric:
|
119 |
+
_metric_name, direction = metric.split("-")
|
120 |
+
if direction != "tail":
|
121 |
+
raise ValueError("Only tail metric is supported in this mode")
|
122 |
+
_ranking = all_ranking_t
|
123 |
+
_num_neg = all_num_negative_t
|
124 |
+
else:
|
125 |
+
_ranking = all_ranking
|
126 |
+
_num_neg = all_num_negative
|
127 |
+
_metric_name = metric
|
128 |
+
|
129 |
+
if _metric_name == "mr":
|
130 |
+
score = _ranking.float().mean()
|
131 |
+
elif _metric_name == "mrr":
|
132 |
+
score = (1 / _ranking.float()).mean()
|
133 |
+
elif _metric_name.startswith("hits@"):
|
134 |
+
values = _metric_name[5:].split("_")
|
135 |
+
threshold = int(values[0])
|
136 |
+
if len(values) > 1:
|
137 |
+
num_sample = int(values[1])
|
138 |
+
# unbiased estimation
|
139 |
+
fp_rate = (_ranking - 1).float() / _num_neg
|
140 |
+
score = 0
|
141 |
+
for i in range(threshold):
|
142 |
+
# choose i false positive from num_sample - 1 negatives
|
143 |
+
num_comb = math.factorial(num_sample - 1) / \
|
144 |
+
math.factorial(i) / math.factorial(num_sample - i - 1)
|
145 |
+
score += num_comb * (fp_rate ** i) * ((1 - fp_rate) ** (num_sample - i - 1))
|
146 |
+
score = score.mean()
|
147 |
+
else:
|
148 |
+
score = (_ranking <= threshold).float().mean()
|
149 |
+
logger.warning("%s: %g" % (metric, score))
|
150 |
+
metrics[metric] = score
|
151 |
+
mrr = (1 / all_ranking.float()).mean()
|
152 |
+
|
153 |
+
return mrr if not return_metrics else metrics
|
ultra/layers.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch_scatter import scatter
|
5 |
+
|
6 |
+
from torch_geometric.nn.conv import MessagePassing
|
7 |
+
from torch_geometric.utils import degree
|
8 |
+
from typing import Tuple
|
9 |
+
|
10 |
+
|
11 |
+
class GeneralizedRelationalConv(MessagePassing):
|
12 |
+
|
13 |
+
eps = 1e-6
|
14 |
+
|
15 |
+
message2mul = {
|
16 |
+
"transe": "add",
|
17 |
+
"distmult": "mul",
|
18 |
+
}
|
19 |
+
|
20 |
+
# TODO for compile() - doesn't work currently
|
21 |
+
# propagate_type = {"edge_index": torch.LongTensor, "size": Tuple[int, int]}
|
22 |
+
|
23 |
+
def __init__(self, input_dim, output_dim, num_relation, query_input_dim, message_func="distmult",
|
24 |
+
aggregate_func="pna", layer_norm=False, activation="relu", dependent=False, project_relations=False):
|
25 |
+
super(GeneralizedRelationalConv, self).__init__()
|
26 |
+
self.input_dim = input_dim
|
27 |
+
self.output_dim = output_dim
|
28 |
+
self.num_relation = num_relation
|
29 |
+
self.query_input_dim = query_input_dim
|
30 |
+
self.message_func = message_func
|
31 |
+
self.aggregate_func = aggregate_func
|
32 |
+
self.dependent = dependent
|
33 |
+
self.project_relations = project_relations
|
34 |
+
|
35 |
+
if layer_norm:
|
36 |
+
self.layer_norm = nn.LayerNorm(output_dim)
|
37 |
+
else:
|
38 |
+
self.layer_norm = None
|
39 |
+
if isinstance(activation, str):
|
40 |
+
self.activation = getattr(F, activation)
|
41 |
+
else:
|
42 |
+
self.activation = activation
|
43 |
+
|
44 |
+
if self.aggregate_func == "pna":
|
45 |
+
self.linear = nn.Linear(input_dim * 13, output_dim)
|
46 |
+
else:
|
47 |
+
self.linear = nn.Linear(input_dim * 2, output_dim)
|
48 |
+
|
49 |
+
if dependent:
|
50 |
+
# obtain relation embeddings as a projection of the query relation
|
51 |
+
self.relation_linear = nn.Linear(query_input_dim, num_relation * input_dim)
|
52 |
+
else:
|
53 |
+
if not self.project_relations:
|
54 |
+
# relation embeddings as an independent embedding matrix per each layer
|
55 |
+
self.relation = nn.Embedding(num_relation, input_dim)
|
56 |
+
else:
|
57 |
+
# will be initialized after the pass over relation graph
|
58 |
+
self.relation = None
|
59 |
+
self.relation_projection = nn.Sequential(
|
60 |
+
nn.Linear(input_dim, input_dim),
|
61 |
+
nn.ReLU(),
|
62 |
+
nn.Linear(input_dim, input_dim)
|
63 |
+
)
|
64 |
+
|
65 |
+
|
66 |
+
def forward(self, input, query, boundary, edge_index, edge_type, size, edge_weight=None):
|
67 |
+
batch_size = len(query)
|
68 |
+
|
69 |
+
if self.dependent:
|
70 |
+
# layer-specific relation features as a projection of input "query" (relation) embeddings
|
71 |
+
relation = self.relation_linear(query).view(batch_size, self.num_relation, self.input_dim)
|
72 |
+
else:
|
73 |
+
if not self.project_relations:
|
74 |
+
# layer-specific relation features as a special embedding matrix unique to each layer
|
75 |
+
relation = self.relation.weight.expand(batch_size, -1, -1)
|
76 |
+
else:
|
77 |
+
# NEW and only change:
|
78 |
+
# projecting relation features to unique features for this layer, then resizing for the current batch
|
79 |
+
relation = self.relation_projection(self.relation)
|
80 |
+
if edge_weight is None:
|
81 |
+
edge_weight = torch.ones(len(edge_type), device=input.device)
|
82 |
+
|
83 |
+
# note that we send the initial boundary condition (node states at layer0) to the message passing
|
84 |
+
# correspond to Eq.6 on p5 in https://arxiv.org/pdf/2106.06935.pdf
|
85 |
+
output = self.propagate(input=input, relation=relation, boundary=boundary, edge_index=edge_index,
|
86 |
+
edge_type=edge_type, size=size, edge_weight=edge_weight)
|
87 |
+
return output
|
88 |
+
|
89 |
+
def propagate(self, edge_index, size=None, **kwargs):
|
90 |
+
if kwargs["edge_weight"].requires_grad or self.message_func == "rotate":
|
91 |
+
# the rspmm cuda kernel only works for TransE and DistMult message functions
|
92 |
+
# otherwise we invoke separate message & aggregate functions
|
93 |
+
return super(GeneralizedRelationalConv, self).propagate(edge_index, size, **kwargs)
|
94 |
+
|
95 |
+
for hook in self._propagate_forward_pre_hooks.values():
|
96 |
+
res = hook(self, (edge_index, size, kwargs))
|
97 |
+
if res is not None:
|
98 |
+
edge_index, size, kwargs = res
|
99 |
+
|
100 |
+
# in newer PyG,
|
101 |
+
# __check_input__ -> _check_input()
|
102 |
+
# __collect__ -> _collect()
|
103 |
+
# __fused_user_args__ -> _fuser_user_args
|
104 |
+
size = self._check_input(edge_index, size)
|
105 |
+
coll_dict = self._collect(self._fused_user_args, edge_index, size, kwargs)
|
106 |
+
|
107 |
+
msg_aggr_kwargs = self.inspector.distribute("message_and_aggregate", coll_dict)
|
108 |
+
for hook in self._message_and_aggregate_forward_pre_hooks.values():
|
109 |
+
res = hook(self, (edge_index, msg_aggr_kwargs))
|
110 |
+
if res is not None:
|
111 |
+
edge_index, msg_aggr_kwargs = res
|
112 |
+
out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)
|
113 |
+
for hook in self._message_and_aggregate_forward_hooks.values():
|
114 |
+
res = hook(self, (edge_index, msg_aggr_kwargs), out)
|
115 |
+
if res is not None:
|
116 |
+
out = res
|
117 |
+
|
118 |
+
update_kwargs = self.inspector.distribute("update", coll_dict)
|
119 |
+
out = self.update(out, **update_kwargs)
|
120 |
+
|
121 |
+
for hook in self._propagate_forward_hooks.values():
|
122 |
+
res = hook(self, (edge_index, size, kwargs), out)
|
123 |
+
if res is not None:
|
124 |
+
out = res
|
125 |
+
|
126 |
+
return out
|
127 |
+
|
128 |
+
def message(self, input_j, relation, boundary, edge_type):
|
129 |
+
relation_j = relation.index_select(self.node_dim, edge_type)
|
130 |
+
|
131 |
+
if self.message_func == "transe":
|
132 |
+
message = input_j + relation_j
|
133 |
+
elif self.message_func == "distmult":
|
134 |
+
message = input_j * relation_j
|
135 |
+
elif self.message_func == "rotate":
|
136 |
+
x_j_re, x_j_im = input_j.chunk(2, dim=-1)
|
137 |
+
r_j_re, r_j_im = relation_j.chunk(2, dim=-1)
|
138 |
+
message_re = x_j_re * r_j_re - x_j_im * r_j_im
|
139 |
+
message_im = x_j_re * r_j_im + x_j_im * r_j_re
|
140 |
+
message = torch.cat([message_re, message_im], dim=-1)
|
141 |
+
else:
|
142 |
+
raise ValueError("Unknown message function `%s`" % self.message_func)
|
143 |
+
|
144 |
+
# augment messages with the boundary condition
|
145 |
+
message = torch.cat([message, boundary], dim=self.node_dim) # (num_edges + num_nodes, batch_size, input_dim)
|
146 |
+
|
147 |
+
return message
|
148 |
+
|
149 |
+
def aggregate(self, input, edge_weight, index, dim_size):
|
150 |
+
# augment aggregation index with self-loops for the boundary condition
|
151 |
+
index = torch.cat([index, torch.arange(dim_size, device=input.device)]) # (num_edges + num_nodes,)
|
152 |
+
edge_weight = torch.cat([edge_weight, torch.ones(dim_size, device=input.device)])
|
153 |
+
shape = [1] * input.ndim
|
154 |
+
shape[self.node_dim] = -1
|
155 |
+
edge_weight = edge_weight.view(shape)
|
156 |
+
|
157 |
+
if self.aggregate_func == "pna":
|
158 |
+
mean = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="mean")
|
159 |
+
sq_mean = scatter(input ** 2 * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="mean")
|
160 |
+
max = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="max")
|
161 |
+
min = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size, reduce="min")
|
162 |
+
std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()
|
163 |
+
features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
|
164 |
+
features = features.flatten(-2)
|
165 |
+
degree_out = degree(index, dim_size).unsqueeze(0).unsqueeze(-1)
|
166 |
+
scale = degree_out.log()
|
167 |
+
scale = scale / scale.mean()
|
168 |
+
scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1)
|
169 |
+
output = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2)
|
170 |
+
else:
|
171 |
+
output = scatter(input * edge_weight, index, dim=self.node_dim, dim_size=dim_size,
|
172 |
+
reduce=self.aggregate_func)
|
173 |
+
|
174 |
+
return output
|
175 |
+
|
176 |
+
def message_and_aggregate(self, edge_index, input, relation, boundary, edge_type, edge_weight, index, dim_size):
|
177 |
+
# fused computation of message and aggregate steps with the custom rspmm cuda kernel
|
178 |
+
# speed up computation by several times
|
179 |
+
# reduce memory complexity from O(|E|d) to O(|V|d), so we can apply it to larger graphs
|
180 |
+
from ultra.rspmm.rspmm import generalized_rspmm
|
181 |
+
|
182 |
+
batch_size, num_node = input.shape[:2]
|
183 |
+
input = input.transpose(0, 1).flatten(1)
|
184 |
+
relation = relation.transpose(0, 1).flatten(1)
|
185 |
+
boundary = boundary.transpose(0, 1).flatten(1)
|
186 |
+
degree_out = degree(index, dim_size).unsqueeze(-1) + 1
|
187 |
+
|
188 |
+
if self.message_func in self.message2mul:
|
189 |
+
mul = self.message2mul[self.message_func]
|
190 |
+
else:
|
191 |
+
raise ValueError("Unknown message function `%s`" % self.message_func)
|
192 |
+
if self.aggregate_func == "sum":
|
193 |
+
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
|
194 |
+
update = update + boundary
|
195 |
+
elif self.aggregate_func == "mean":
|
196 |
+
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
|
197 |
+
update = (update + boundary) / degree_out
|
198 |
+
elif self.aggregate_func == "max":
|
199 |
+
update = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
|
200 |
+
update = torch.max(update, boundary)
|
201 |
+
elif self.aggregate_func == "pna":
|
202 |
+
# we use PNA with 4 aggregators (mean / max / min / std)
|
203 |
+
# and 3 scalars (identity / log degree / reciprocal of log degree)
|
204 |
+
sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul=mul)
|
205 |
+
sq_sum = generalized_rspmm(edge_index, edge_type, edge_weight, relation ** 2, input ** 2, sum="add",
|
206 |
+
mul=mul)
|
207 |
+
max = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="max", mul=mul)
|
208 |
+
min = generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="min", mul=mul)
|
209 |
+
mean = (sum + boundary) / degree_out
|
210 |
+
sq_mean = (sq_sum + boundary ** 2) / degree_out
|
211 |
+
max = torch.max(max, boundary)
|
212 |
+
min = torch.min(min, boundary) # (node, batch_size * input_dim)
|
213 |
+
std = (sq_mean - mean ** 2).clamp(min=self.eps).sqrt()
|
214 |
+
features = torch.cat([mean.unsqueeze(-1), max.unsqueeze(-1), min.unsqueeze(-1), std.unsqueeze(-1)], dim=-1)
|
215 |
+
features = features.flatten(-2) # (node, batch_size * input_dim * 4)
|
216 |
+
scale = degree_out.log()
|
217 |
+
scale = scale / scale.mean()
|
218 |
+
scales = torch.cat([torch.ones_like(scale), scale, 1 / scale.clamp(min=1e-2)], dim=-1) # (node, 3)
|
219 |
+
update = (features.unsqueeze(-1) * scales.unsqueeze(-2)).flatten(-2) # (node, batch_size * input_dim * 4 * 3)
|
220 |
+
else:
|
221 |
+
raise ValueError("Unknown aggregation function `%s`" % self.aggregate_func)
|
222 |
+
|
223 |
+
update = update.view(num_node, batch_size, -1).transpose(0, 1)
|
224 |
+
return update
|
225 |
+
|
226 |
+
def update(self, update, input):
|
227 |
+
# node update as a function of old states (input) and this layer output (update)
|
228 |
+
output = self.linear(torch.cat([input, update], dim=-1))
|
229 |
+
if self.layer_norm:
|
230 |
+
output = self.layer_norm(output)
|
231 |
+
if self.activation:
|
232 |
+
output = self.activation(output)
|
233 |
+
return output
|
234 |
+
|
ultra/models.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from . import tasks, layers
|
5 |
+
from ultra.base_nbfnet import BaseNBFNet
|
6 |
+
|
7 |
+
class Ultra(nn.Module):
|
8 |
+
|
9 |
+
def __init__(self, rel_model_cfg, entity_model_cfg):
|
10 |
+
# kept that because super Ultra sounds cool
|
11 |
+
super(Ultra, self).__init__()
|
12 |
+
|
13 |
+
self.relation_model = RelNBFNet(**rel_model_cfg)
|
14 |
+
self.entity_model = EntityNBFNet(**entity_model_cfg)
|
15 |
+
|
16 |
+
|
17 |
+
def forward(self, data, batch):
|
18 |
+
|
19 |
+
# batch shape: (bs, 1+num_negs, 3)
|
20 |
+
# relations are the same all positive and negative triples, so we can extract only one from the first triple among 1+nug_negs
|
21 |
+
query_rels = batch[:, 0, 2]
|
22 |
+
relation_representations = self.relation_model(data.relation_graph, query=query_rels)
|
23 |
+
score = self.entity_model(data, relation_representations, batch)
|
24 |
+
|
25 |
+
return score
|
26 |
+
|
27 |
+
|
28 |
+
# NBFNet to work on the graph of relations with 4 fundamental interactions
|
29 |
+
# Doesn't have the final projection MLP from hidden dim -> 1, returns all node representations
|
30 |
+
# of shape [bs, num_rel, hidden]
|
31 |
+
class RelNBFNet(BaseNBFNet):
|
32 |
+
|
33 |
+
def __init__(self, input_dim, hidden_dims, num_relation=4, **kwargs):
|
34 |
+
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
|
35 |
+
|
36 |
+
self.layers = nn.ModuleList()
|
37 |
+
for i in range(len(self.dims) - 1):
|
38 |
+
self.layers.append(
|
39 |
+
layers.GeneralizedRelationalConv(
|
40 |
+
self.dims[i], self.dims[i + 1], num_relation,
|
41 |
+
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
|
42 |
+
self.activation, dependent=False)
|
43 |
+
)
|
44 |
+
|
45 |
+
if self.concat_hidden:
|
46 |
+
feature_dim = sum(hidden_dims) + input_dim
|
47 |
+
self.mlp = nn.Sequential(
|
48 |
+
nn.Linear(feature_dim, feature_dim),
|
49 |
+
nn.ReLU(),
|
50 |
+
nn.Linear(feature_dim, input_dim)
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def bellmanford(self, data, h_index, separate_grad=False):
|
55 |
+
batch_size = len(h_index)
|
56 |
+
|
57 |
+
# initialize initial nodes (relations of interest in the batcj) with all ones
|
58 |
+
query = torch.ones(h_index.shape[0], self.dims[0], device=h_index.device, dtype=torch.float)
|
59 |
+
index = h_index.unsqueeze(-1).expand_as(query)
|
60 |
+
|
61 |
+
# initial (boundary) condition - initialize all node states as zeros
|
62 |
+
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
|
63 |
+
#boundary = torch.zeros(data.num_nodes, *query.shape, device=h_index.device)
|
64 |
+
# Indicator function: by the scatter operation we put ones as init features of source (index) nodes
|
65 |
+
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
|
66 |
+
size = (data.num_nodes, data.num_nodes)
|
67 |
+
edge_weight = torch.ones(data.num_edges, device=h_index.device)
|
68 |
+
|
69 |
+
hiddens = []
|
70 |
+
edge_weights = []
|
71 |
+
layer_input = boundary
|
72 |
+
|
73 |
+
for layer in self.layers:
|
74 |
+
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
|
75 |
+
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
|
76 |
+
if self.short_cut and hidden.shape == layer_input.shape:
|
77 |
+
# residual connection here
|
78 |
+
hidden = hidden + layer_input
|
79 |
+
hiddens.append(hidden)
|
80 |
+
edge_weights.append(edge_weight)
|
81 |
+
layer_input = hidden
|
82 |
+
|
83 |
+
# original query (relation type) embeddings
|
84 |
+
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
|
85 |
+
if self.concat_hidden:
|
86 |
+
output = torch.cat(hiddens + [node_query], dim=-1)
|
87 |
+
output = self.mlp(output)
|
88 |
+
else:
|
89 |
+
output = hiddens[-1]
|
90 |
+
|
91 |
+
return {
|
92 |
+
"node_feature": output,
|
93 |
+
"edge_weights": edge_weights,
|
94 |
+
}
|
95 |
+
|
96 |
+
def forward(self, rel_graph, query):
|
97 |
+
|
98 |
+
# message passing and updated node representations (that are in fact relations)
|
99 |
+
output = self.bellmanford(rel_graph, h_index=query)["node_feature"] # (batch_size, num_nodes, hidden_dim)
|
100 |
+
|
101 |
+
return output
|
102 |
+
|
103 |
+
|
104 |
+
class EntityNBFNet(BaseNBFNet):
|
105 |
+
|
106 |
+
def __init__(self, input_dim, hidden_dims, num_relation=1, **kwargs):
|
107 |
+
|
108 |
+
# dummy num_relation = 1 as we won't use it in the NBFNet layer
|
109 |
+
super().__init__(input_dim, hidden_dims, num_relation, **kwargs)
|
110 |
+
|
111 |
+
self.layers = nn.ModuleList()
|
112 |
+
for i in range(len(self.dims) - 1):
|
113 |
+
self.layers.append(
|
114 |
+
layers.GeneralizedRelationalConv(
|
115 |
+
self.dims[i], self.dims[i + 1], num_relation,
|
116 |
+
self.dims[0], self.message_func, self.aggregate_func, self.layer_norm,
|
117 |
+
self.activation, dependent=False, project_relations=True)
|
118 |
+
)
|
119 |
+
|
120 |
+
feature_dim = (sum(hidden_dims) if self.concat_hidden else hidden_dims[-1]) + input_dim
|
121 |
+
self.mlp = nn.Sequential()
|
122 |
+
mlp = []
|
123 |
+
for i in range(self.num_mlp_layers - 1):
|
124 |
+
mlp.append(nn.Linear(feature_dim, feature_dim))
|
125 |
+
mlp.append(nn.ReLU())
|
126 |
+
mlp.append(nn.Linear(feature_dim, 1))
|
127 |
+
self.mlp = nn.Sequential(*mlp)
|
128 |
+
|
129 |
+
|
130 |
+
def bellmanford(self, data, h_index, r_index, separate_grad=False):
|
131 |
+
batch_size = len(r_index)
|
132 |
+
|
133 |
+
# initialize queries (relation types of the given triples)
|
134 |
+
query = self.query[torch.arange(batch_size, device=r_index.device), r_index]
|
135 |
+
index = h_index.unsqueeze(-1).expand_as(query)
|
136 |
+
|
137 |
+
# initial (boundary) condition - initialize all node states as zeros
|
138 |
+
boundary = torch.zeros(batch_size, data.num_nodes, self.dims[0], device=h_index.device)
|
139 |
+
# by the scatter operation we put query (relation) embeddings as init features of source (index) nodes
|
140 |
+
boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))
|
141 |
+
|
142 |
+
size = (data.num_nodes, data.num_nodes)
|
143 |
+
edge_weight = torch.ones(data.num_edges, device=h_index.device)
|
144 |
+
|
145 |
+
hiddens = []
|
146 |
+
edge_weights = []
|
147 |
+
layer_input = boundary
|
148 |
+
|
149 |
+
for layer in self.layers:
|
150 |
+
|
151 |
+
# for visualization
|
152 |
+
if separate_grad:
|
153 |
+
edge_weight = edge_weight.clone().requires_grad_()
|
154 |
+
|
155 |
+
# Bellman-Ford iteration, we send the original boundary condition in addition to the updated node states
|
156 |
+
hidden = layer(layer_input, query, boundary, data.edge_index, data.edge_type, size, edge_weight)
|
157 |
+
if self.short_cut and hidden.shape == layer_input.shape:
|
158 |
+
# residual connection here
|
159 |
+
hidden = hidden + layer_input
|
160 |
+
hiddens.append(hidden)
|
161 |
+
edge_weights.append(edge_weight)
|
162 |
+
layer_input = hidden
|
163 |
+
|
164 |
+
# original query (relation type) embeddings
|
165 |
+
node_query = query.unsqueeze(1).expand(-1, data.num_nodes, -1) # (batch_size, num_nodes, input_dim)
|
166 |
+
if self.concat_hidden:
|
167 |
+
output = torch.cat(hiddens + [node_query], dim=-1)
|
168 |
+
else:
|
169 |
+
output = torch.cat([hiddens[-1], node_query], dim=-1)
|
170 |
+
|
171 |
+
return {
|
172 |
+
"node_feature": output,
|
173 |
+
"edge_weights": edge_weights,
|
174 |
+
}
|
175 |
+
|
176 |
+
def forward(self, data, relation_representations, batch):
|
177 |
+
h_index, t_index, r_index = batch.unbind(-1)
|
178 |
+
|
179 |
+
# initial query representations are those from the relation graph
|
180 |
+
self.query = relation_representations
|
181 |
+
|
182 |
+
# initialize relations in each NBFNet layer (with uinque projection internally)
|
183 |
+
for layer in self.layers:
|
184 |
+
layer.relation = relation_representations
|
185 |
+
|
186 |
+
if self.training:
|
187 |
+
# Edge dropout in the training mode
|
188 |
+
# here we want to remove immediate edges (head, relation, tail) from the edge_index and edge_types
|
189 |
+
# to make NBFNet iteration learn non-trivial paths
|
190 |
+
data = self.remove_easy_edges(data, h_index, t_index, r_index)
|
191 |
+
|
192 |
+
shape = h_index.shape
|
193 |
+
# turn all triples in a batch into a tail prediction mode
|
194 |
+
h_index, t_index, r_index = self.negative_sample_to_tail(h_index, t_index, r_index, num_direct_rel=data.num_relations // 2)
|
195 |
+
assert (h_index[:, [0]] == h_index).all()
|
196 |
+
assert (r_index[:, [0]] == r_index).all()
|
197 |
+
|
198 |
+
# message passing and updated node representations
|
199 |
+
output = self.bellmanford(data, h_index[:, 0], r_index[:, 0]) # (num_nodes, batch_size, feature_dim)
|
200 |
+
feature = output["node_feature"]
|
201 |
+
index = t_index.unsqueeze(-1).expand(-1, -1, feature.shape[-1])
|
202 |
+
# extract representations of tail entities from the updated node states
|
203 |
+
feature = feature.gather(1, index) # (batch_size, num_negative + 1, feature_dim)
|
204 |
+
|
205 |
+
# probability logit for each tail node in the batch
|
206 |
+
# (batch_size, num_negative + 1, dim) -> (batch_size, num_negative + 1)
|
207 |
+
score = self.mlp(feature).squeeze(-1)
|
208 |
+
return score.view(shape)
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
|
ultra/rspmm/rspmm.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import torch.backends.openmp
|
5 |
+
from torch import autograd
|
6 |
+
from torch.utils import cpp_extension
|
7 |
+
|
8 |
+
module = sys.modules[__name__]
|
9 |
+
|
10 |
+
|
11 |
+
class RSPMMAddMulFunction(autograd.Function):
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
15 |
+
node_in, node_out = edge_index
|
16 |
+
key = node_in * (node_out.max() + 1) + node_out
|
17 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
18 |
+
|
19 |
+
if input.device.type == "cuda":
|
20 |
+
forward = rspmm.rspmm_add_mul_forward_cuda
|
21 |
+
else:
|
22 |
+
forward = rspmm.rspmm_add_mul_forward_cpu
|
23 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
24 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
25 |
+
return output
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def backward(ctx, output_grad):
|
29 |
+
if output_grad.device.type == "cuda":
|
30 |
+
backward = rspmm.rspmm_add_mul_backward_cuda
|
31 |
+
else:
|
32 |
+
backward = rspmm.rspmm_add_mul_backward_cpu
|
33 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
34 |
+
return None, None, weight_grad, relation_grad, input_grad
|
35 |
+
|
36 |
+
|
37 |
+
class RSPMMMinMulFunction(autograd.Function):
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
41 |
+
node_in, node_out = edge_index
|
42 |
+
key = node_in * (node_out.max() + 1) + node_out
|
43 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
44 |
+
|
45 |
+
if input.device.type == "cuda":
|
46 |
+
forward = rspmm.rspmm_min_mul_forward_cuda
|
47 |
+
else:
|
48 |
+
forward = rspmm.rspmm_min_mul_forward_cpu
|
49 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
50 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
51 |
+
return output
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def backward(ctx, output_grad):
|
55 |
+
if output_grad.device.type == "cuda":
|
56 |
+
backward = rspmm.rspmm_min_mul_backward_cuda
|
57 |
+
else:
|
58 |
+
backward = rspmm.rspmm_min_mul_backward_cpu
|
59 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
60 |
+
return None, None, weight_grad, relation_grad, input_grad
|
61 |
+
|
62 |
+
|
63 |
+
class RSPMMMaxMulFunction(autograd.Function):
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
67 |
+
node_in, node_out = edge_index
|
68 |
+
key = node_in * (node_out.max() + 1) + node_out
|
69 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
70 |
+
|
71 |
+
if input.device.type == "cuda":
|
72 |
+
forward = rspmm.rspmm_max_mul_forward_cuda
|
73 |
+
else:
|
74 |
+
forward = rspmm.rspmm_max_mul_forward_cpu
|
75 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
76 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
77 |
+
return output
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def backward(ctx, output_grad):
|
81 |
+
if output_grad.device.type == "cuda":
|
82 |
+
backward = rspmm.rspmm_max_mul_backward_cuda
|
83 |
+
else:
|
84 |
+
backward = rspmm.rspmm_max_mul_backward_cpu
|
85 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
86 |
+
return None, None, weight_grad, relation_grad, input_grad
|
87 |
+
|
88 |
+
|
89 |
+
class RSPMMAddAddFunction(autograd.Function):
|
90 |
+
|
91 |
+
@staticmethod
|
92 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
93 |
+
node_in, node_out = edge_index
|
94 |
+
key = node_in * (node_out.max() + 1) + node_out
|
95 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
96 |
+
|
97 |
+
if input.device.type == "cuda":
|
98 |
+
forward = rspmm.rspmm_add_add_forward_cuda
|
99 |
+
else:
|
100 |
+
forward = rspmm.rspmm_add_add_forward_cpu
|
101 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
102 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
103 |
+
return output
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def backward(ctx, output_grad):
|
107 |
+
if output_grad.device.type == "cuda":
|
108 |
+
backward = rspmm.rspmm_add_add_backward_cuda
|
109 |
+
else:
|
110 |
+
backward = rspmm.rspmm_add_add_backward_cpu
|
111 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
112 |
+
return None, None, weight_grad, relation_grad, input_grad
|
113 |
+
|
114 |
+
|
115 |
+
class RSPMMMinAddFunction(autograd.Function):
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
119 |
+
node_in, node_out = edge_index
|
120 |
+
key = node_in * (node_out.max() + 1) + node_out
|
121 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
122 |
+
|
123 |
+
if input.device.type == "cuda":
|
124 |
+
forward = rspmm.rspmm_min_add_forward_cuda
|
125 |
+
else:
|
126 |
+
forward = rspmm.rspmm_min_add_forward_cpu
|
127 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
128 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
129 |
+
return output
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def backward(ctx, output_grad):
|
133 |
+
if output_grad.device.type == "cuda":
|
134 |
+
backward = rspmm.rspmm_min_add_backward_cuda
|
135 |
+
else:
|
136 |
+
backward = rspmm.rspmm_min_add_backward_cpu
|
137 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
138 |
+
return None, None, weight_grad, relation_grad, input_grad
|
139 |
+
|
140 |
+
|
141 |
+
class RSPMMMaxAddFunction(autograd.Function):
|
142 |
+
|
143 |
+
@staticmethod
|
144 |
+
def forward(ctx, edge_index, edge_type, edge_weight, relation, input):
|
145 |
+
node_in, node_out = edge_index
|
146 |
+
key = node_in * (node_out.max() + 1) + node_out
|
147 |
+
assert (key.diff() >= 0).all(), "Expect sorted `edge_index`"
|
148 |
+
|
149 |
+
if input.device.type == "cuda":
|
150 |
+
forward = rspmm.rspmm_max_add_forward_cuda
|
151 |
+
else:
|
152 |
+
forward = rspmm.rspmm_max_add_forward_cpu
|
153 |
+
output = forward(edge_index, edge_type, edge_weight, relation, input)
|
154 |
+
ctx.save_for_backward(edge_index, edge_type, edge_weight, relation, input, output)
|
155 |
+
return output
|
156 |
+
|
157 |
+
@staticmethod
|
158 |
+
def backward(ctx, output_grad):
|
159 |
+
if output_grad.device.type == "cuda":
|
160 |
+
backward = rspmm.rspmm_max_add_backward_cuda
|
161 |
+
else:
|
162 |
+
backward = rspmm.rspmm_max_add_backward_cpu
|
163 |
+
weight_grad, relation_grad, input_grad = backward(*ctx.saved_tensors, output_grad)
|
164 |
+
return None, None, weight_grad, relation_grad, input_grad
|
165 |
+
|
166 |
+
|
167 |
+
def generalized_rspmm(edge_index, edge_type, edge_weight, relation, input, sum="add", mul="mul"):
|
168 |
+
name = "RSPMM%s%sFunction" % (sum.capitalize(), mul.capitalize())
|
169 |
+
if not hasattr(module, name):
|
170 |
+
raise ValueError("No generalized rspmm implementation found for summation `%s` and multiplication `%s`"
|
171 |
+
% (sum, mul))
|
172 |
+
Function = getattr(module, name)
|
173 |
+
|
174 |
+
node_in, node_out = edge_index
|
175 |
+
key = node_in * (node_out.max() + 1) + node_out
|
176 |
+
order = key.argsort()
|
177 |
+
|
178 |
+
return Function.apply(edge_index[:, order], edge_type[order], edge_weight[order], relation, input)
|
179 |
+
|
180 |
+
|
181 |
+
def load_extension(name, sources, extra_cflags=None, extra_cuda_cflags=None, **kwargs):
|
182 |
+
if extra_cflags is None:
|
183 |
+
extra_cflags = ["-Ofast"]
|
184 |
+
if torch.backends.openmp.is_available():
|
185 |
+
extra_cflags += ["-fopenmp", "-DAT_PARALLEL_OPENMP"]
|
186 |
+
else:
|
187 |
+
extra_cflags.append("-DAT_PARALLEL_NATIVE")
|
188 |
+
if extra_cuda_cflags is None:
|
189 |
+
if torch.cuda.is_available():
|
190 |
+
extra_cuda_cflags = ["-O3"]
|
191 |
+
extra_cflags.append("-DCUDA_OP")
|
192 |
+
else:
|
193 |
+
new_sources = []
|
194 |
+
for source in sources:
|
195 |
+
if not cpp_extension._is_cuda_file(source):
|
196 |
+
new_sources.append(source)
|
197 |
+
sources = new_sources
|
198 |
+
|
199 |
+
return cpp_extension.load(name, sources, extra_cflags, extra_cuda_cflags, **kwargs)
|
200 |
+
|
201 |
+
|
202 |
+
print("Load rspmm extension. This may take a while...")
|
203 |
+
path = os.path.join(os.path.dirname(__file__), "source")
|
204 |
+
rspmm = load_extension("rspmm", [os.path.join(path, "rspmm.cpp"), os.path.join(path, "rspmm.cu")])
|
ultra/rspmm/source/operator.cuh
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <limits>
|
4 |
+
|
5 |
+
#ifdef __CUDA_ARCH__
|
6 |
+
#define HOST_DEVICE __host__ __device__
|
7 |
+
#else
|
8 |
+
#define HOST_DEVICE
|
9 |
+
#endif
|
10 |
+
|
11 |
+
namespace at {
|
12 |
+
|
13 |
+
template <class scalar_t>
|
14 |
+
struct BinaryAdd {
|
15 |
+
HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
|
16 |
+
return x + y;
|
17 |
+
}
|
18 |
+
|
19 |
+
HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
|
20 |
+
return 1;
|
21 |
+
}
|
22 |
+
|
23 |
+
HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
|
24 |
+
return 1;
|
25 |
+
}
|
26 |
+
};
|
27 |
+
|
28 |
+
template <class scalar_t>
|
29 |
+
struct BinaryMul {
|
30 |
+
HOST_DEVICE static scalar_t forward(scalar_t x, scalar_t y) {
|
31 |
+
return x * y;
|
32 |
+
}
|
33 |
+
|
34 |
+
HOST_DEVICE static scalar_t backward_lhs(scalar_t x, scalar_t y) {
|
35 |
+
return y;
|
36 |
+
}
|
37 |
+
|
38 |
+
HOST_DEVICE static scalar_t backward_rhs(scalar_t x, scalar_t y) {
|
39 |
+
return x;
|
40 |
+
}
|
41 |
+
};
|
42 |
+
|
43 |
+
template <class scalar_t>
|
44 |
+
struct NaryAdd {
|
45 |
+
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
|
46 |
+
return result + x;
|
47 |
+
}
|
48 |
+
|
49 |
+
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
|
50 |
+
return 1;
|
51 |
+
}
|
52 |
+
|
53 |
+
static constexpr scalar_t zero = 0;
|
54 |
+
};
|
55 |
+
|
56 |
+
template <class scalar_t>
|
57 |
+
struct NaryMin {
|
58 |
+
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
|
59 |
+
return result < x ? result : x;
|
60 |
+
}
|
61 |
+
|
62 |
+
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
|
63 |
+
return result == x ? 1 : 0;
|
64 |
+
}
|
65 |
+
|
66 |
+
static constexpr scalar_t zero = std::numeric_limits<scalar_t>::max();
|
67 |
+
};
|
68 |
+
|
69 |
+
template <class scalar_t>
|
70 |
+
struct NaryMax {
|
71 |
+
HOST_DEVICE static scalar_t forward(scalar_t result, scalar_t x) {
|
72 |
+
return result > x ? result : x;
|
73 |
+
}
|
74 |
+
|
75 |
+
HOST_DEVICE static scalar_t backward(scalar_t result, scalar_t x) {
|
76 |
+
return result == x ? 1 : 0;
|
77 |
+
}
|
78 |
+
|
79 |
+
static constexpr scalar_t zero = std::numeric_limits<scalar_t>::lowest();
|
80 |
+
};
|
81 |
+
|
82 |
+
} // namespace at
|
ultra/rspmm/source/rspmm.cpp
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <mutex>
|
2 |
+
|
3 |
+
#include <ATen/Parallel.h>
|
4 |
+
|
5 |
+
#include "operator.cuh"
|
6 |
+
#include "rspmm.h"
|
7 |
+
|
8 |
+
namespace at {
|
9 |
+
|
10 |
+
// In PyTorch 1.4.0, parallel_for depends on some functions from at::internal in ATen/Parallel.h
|
11 |
+
// which are not explicitly included
|
12 |
+
// This is fixed in some new PyTorch release
|
13 |
+
using namespace at::internal;
|
14 |
+
|
15 |
+
void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
16 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg) {
|
17 |
+
checkDim(c, edge_index_arg, 2);
|
18 |
+
checkDim(c, edge_type_arg, 1);
|
19 |
+
checkDim(c, edge_weight_arg, 1);
|
20 |
+
checkDim(c, relation_arg, 2);
|
21 |
+
checkDim(c, input_arg, 2);
|
22 |
+
checkSameType(c, edge_index_arg, edge_type_arg);
|
23 |
+
checkAllSameType(c, {edge_weight_arg, relation_arg, input_arg});
|
24 |
+
checkSize(c, edge_index_arg, 0, 2);
|
25 |
+
checkSize(c, edge_type_arg, {edge_index_arg->size(1)});
|
26 |
+
checkSize(c, edge_weight_arg, {edge_index_arg->size(1)});
|
27 |
+
checkSize(c, relation_arg, 1, input_arg->size(1));
|
28 |
+
}
|
29 |
+
|
30 |
+
void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
31 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
|
32 |
+
const TensorArg &output_arg, const TensorArg &output_grad_arg) {
|
33 |
+
rspmm_forward_check(c, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
|
34 |
+
checkDim(c, output_arg, 2);
|
35 |
+
checkSameSize(c, output_arg, output_grad_arg);
|
36 |
+
checkAllSameType(c, {input_arg, output_arg, output_grad_arg});
|
37 |
+
checkSize(c, output_arg, 1, input_arg->size(1));
|
38 |
+
}
|
39 |
+
|
40 |
+
Tensor ind2ptr(const Tensor &index, int size) {
|
41 |
+
// scatter_add is super slow for int64, due to non-hardware atomic operations
|
42 |
+
// use int32 instead
|
43 |
+
Tensor num_per_index = at::zeros({size}, index.options().dtype(at::ScalarType::Int));
|
44 |
+
num_per_index.scatter_add_(0, index, at::ones(index.sizes(), num_per_index.options()));
|
45 |
+
num_per_index = num_per_index.toType(at::ScalarType::Long);
|
46 |
+
Tensor pointer = num_per_index.cumsum(0) - num_per_index;
|
47 |
+
return pointer;
|
48 |
+
}
|
49 |
+
|
50 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
51 |
+
void rspmm_forward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
52 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
53 |
+
scalar_t *output,
|
54 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
55 |
+
parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) {
|
56 |
+
for (int64_t row = row_start; row < row_end; row++) {
|
57 |
+
for (int64_t d = 0; d < dim; d++)
|
58 |
+
output[row * dim + d] = NaryOp::zero;
|
59 |
+
|
60 |
+
int64_t ptr_start = row_ptr[row];
|
61 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
62 |
+
for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) {
|
63 |
+
int64_t col = col_ind[ptr];
|
64 |
+
int64_t layer = layer_ind[ptr];
|
65 |
+
scalar_t w = weight[ptr];
|
66 |
+
for (int64_t d = 0; d < dim; d++) {
|
67 |
+
scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]);
|
68 |
+
scalar_t y = w * x;
|
69 |
+
scalar_t &out = output[row * dim + d];
|
70 |
+
out = NaryOp::forward(out, y);
|
71 |
+
}
|
72 |
+
}
|
73 |
+
}
|
74 |
+
});
|
75 |
+
}
|
76 |
+
|
77 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
78 |
+
void rspmm_backward_out_cpu(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
79 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
80 |
+
const scalar_t *output, const scalar_t *output_grad,
|
81 |
+
scalar_t *weight_grad, scalar_t *relation_grad, scalar_t *input_grad,
|
82 |
+
int64_t num_row, int64_t nnz, int64_t dim,
|
83 |
+
std::vector<std::mutex> &relation_mutex, std::vector<std::mutex> &input_mutex) {
|
84 |
+
parallel_for(0, num_row, 0, [&](int64_t row_start, int64_t row_end) {
|
85 |
+
for (int64_t row = row_start; row < row_end; row++) {
|
86 |
+
int64_t ptr_start = row_ptr[row];
|
87 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
88 |
+
for (int64_t ptr = ptr_start; ptr < ptr_end; ptr++) {
|
89 |
+
int64_t col = col_ind[ptr];
|
90 |
+
int64_t layer = layer_ind[ptr];
|
91 |
+
scalar_t w = weight[ptr];
|
92 |
+
scalar_t w_grad = 0;
|
93 |
+
for (int64_t d = 0; d < dim; d++) {
|
94 |
+
scalar_t rel = relation[layer * dim + d];
|
95 |
+
scalar_t in = input[col * dim + d];
|
96 |
+
scalar_t out = output[row * dim + d];
|
97 |
+
scalar_t out_grad = output_grad[row * dim + d];
|
98 |
+
scalar_t x = BinaryOp::forward(rel, in);
|
99 |
+
scalar_t y = w * x;
|
100 |
+
scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
|
101 |
+
scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
|
102 |
+
scalar_t dout_dy = NaryOp::backward(out, y);
|
103 |
+
scalar_t dy_dw = x;
|
104 |
+
scalar_t dy_dx = w;
|
105 |
+
w_grad += out_grad * dout_dy * dy_dw;
|
106 |
+
{
|
107 |
+
std::lock_guard<std::mutex> lock(relation_mutex[layer * dim + d]);
|
108 |
+
relation_grad[layer * dim + d] += out_grad * dout_dy * dy_dx * dx_drel;
|
109 |
+
}
|
110 |
+
{
|
111 |
+
std::lock_guard<std::mutex> lock(input_mutex[col * dim + d]);
|
112 |
+
input_grad[col * dim + d] += out_grad * dout_dy * dy_dx * dx_din;
|
113 |
+
}
|
114 |
+
}
|
115 |
+
weight_grad[ptr] = w_grad;
|
116 |
+
}
|
117 |
+
}
|
118 |
+
});
|
119 |
+
}
|
120 |
+
|
121 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
122 |
+
Tensor rspmm_forward_cpu(const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
123 |
+
const Tensor &relation_, const Tensor &input_) {
|
124 |
+
constexpr const char *fn_name = "rspmm_forward_cpu";
|
125 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
126 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
127 |
+
input_arg(input_, "input", 5);
|
128 |
+
|
129 |
+
rspmm_forward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
|
130 |
+
checkDeviceType(fn_name, {edge_index_, edge_type_, edge_weight_, relation_, input_}, kCPU);
|
131 |
+
|
132 |
+
const Tensor edge_index = edge_index_.contiguous();
|
133 |
+
const Tensor edge_type = edge_type_.contiguous();
|
134 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
135 |
+
const Tensor relation = relation_.contiguous();
|
136 |
+
const Tensor input = input_.contiguous();
|
137 |
+
|
138 |
+
int64_t nnz = edge_index.size(0);
|
139 |
+
int64_t num_row = input.size(0);
|
140 |
+
int64_t dim = input.size(1);
|
141 |
+
Tensor output = at::empty({num_row, dim}, input.options());
|
142 |
+
|
143 |
+
Tensor row_ind = edge_index.select(0, 0);
|
144 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
145 |
+
Tensor col_ind = edge_index.select(0, 1);
|
146 |
+
Tensor layer_ind = edge_type;
|
147 |
+
|
148 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cpu", [&] {
|
149 |
+
rspmm_forward_out_cpu<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>(
|
150 |
+
row_ptr.data_ptr<int64_t>(),
|
151 |
+
col_ind.data_ptr<int64_t>(),
|
152 |
+
layer_ind.data_ptr<int64_t>(),
|
153 |
+
edge_weight.data_ptr<scalar_t>(),
|
154 |
+
relation.data_ptr<scalar_t>(),
|
155 |
+
input.data_ptr<scalar_t>(),
|
156 |
+
output.data_ptr<scalar_t>(),
|
157 |
+
num_row, nnz, dim
|
158 |
+
);
|
159 |
+
});
|
160 |
+
|
161 |
+
return output;
|
162 |
+
}
|
163 |
+
|
164 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
165 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cpu(
|
166 |
+
const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
167 |
+
const Tensor &relation_, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) {
|
168 |
+
constexpr const char *fn_name = "rspmm_backward_cpu";
|
169 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
170 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
171 |
+
input_arg(input_, "input", 5), output_arg(output_, "output", 6),
|
172 |
+
output_grad_arg(output_grad_, "output_grad", 7);
|
173 |
+
|
174 |
+
rspmm_backward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg,
|
175 |
+
output_arg, output_grad_arg);
|
176 |
+
checkDeviceType(fn_name, {edge_index_, edge_type_, edge_weight_, relation_, input_, output_, output_grad_}, kCPU);
|
177 |
+
|
178 |
+
const Tensor edge_index = edge_index_.contiguous();
|
179 |
+
const Tensor edge_type = edge_type_.contiguous();
|
180 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
181 |
+
const Tensor relation = relation_.contiguous();
|
182 |
+
const Tensor input = input_.contiguous();
|
183 |
+
const Tensor output = output_.contiguous();
|
184 |
+
const Tensor output_grad = output_grad_.contiguous();
|
185 |
+
|
186 |
+
int64_t nnz = edge_index.size(0);
|
187 |
+
int64_t num_row = input.size(0);
|
188 |
+
int64_t dim = input.size(1);
|
189 |
+
Tensor weight_grad = at::zeros_like(edge_weight);
|
190 |
+
Tensor relation_grad = at::zeros_like(relation);
|
191 |
+
Tensor input_grad = at::zeros_like(input);
|
192 |
+
|
193 |
+
Tensor row_ind = edge_index.select(0, 0);
|
194 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
195 |
+
Tensor col_ind = edge_index.select(0, 1);
|
196 |
+
Tensor layer_ind = edge_type;
|
197 |
+
std::vector<std::mutex> relation_mutex(relation.numel());
|
198 |
+
std::vector<std::mutex> input_mutex(input.numel());
|
199 |
+
|
200 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cpu", [&] {
|
201 |
+
rspmm_backward_out_cpu<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>(
|
202 |
+
row_ptr.data_ptr<int64_t>(),
|
203 |
+
col_ind.data_ptr<int64_t>(),
|
204 |
+
layer_ind.data_ptr<int64_t>(),
|
205 |
+
edge_weight.data_ptr<scalar_t>(),
|
206 |
+
relation.data_ptr<scalar_t>(),
|
207 |
+
input.data_ptr<scalar_t>(),
|
208 |
+
output.data_ptr<scalar_t>(),
|
209 |
+
output_grad.data_ptr<scalar_t>(),
|
210 |
+
weight_grad.data_ptr<scalar_t>(),
|
211 |
+
relation_grad.data_ptr<scalar_t>(),
|
212 |
+
input_grad.data_ptr<scalar_t>(),
|
213 |
+
num_row, nnz, dim,
|
214 |
+
relation_mutex, input_mutex
|
215 |
+
);
|
216 |
+
});
|
217 |
+
|
218 |
+
return std::make_tuple(weight_grad, relation_grad, input_grad);
|
219 |
+
}
|
220 |
+
|
221 |
+
#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
222 |
+
Tensor rspmm_##ADD##_##MUL##_forward_cpu( \
|
223 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
224 |
+
const Tensor &relation, const Tensor &input) { \
|
225 |
+
return rspmm_forward_cpu<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input); \
|
226 |
+
}
|
227 |
+
|
228 |
+
#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
229 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_##ADD##_##MUL##_backward_cpu( \
|
230 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
231 |
+
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \
|
232 |
+
return rspmm_backward_cpu<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input, \
|
233 |
+
output, output_grad); \
|
234 |
+
}
|
235 |
+
|
236 |
+
DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
237 |
+
DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
238 |
+
|
239 |
+
DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
240 |
+
DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
241 |
+
|
242 |
+
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
243 |
+
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
244 |
+
|
245 |
+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
246 |
+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
247 |
+
|
248 |
+
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
249 |
+
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
250 |
+
|
251 |
+
DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
252 |
+
DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
253 |
+
|
254 |
+
} // namespace at
|
255 |
+
|
256 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
257 |
+
m.def("rspmm_add_mul_forward_cpu", &at::rspmm_add_mul_forward_cpu);
|
258 |
+
m.def("rspmm_add_mul_backward_cpu", &at::rspmm_add_mul_backward_cpu);
|
259 |
+
m.def("rspmm_min_mul_forward_cpu", &at::rspmm_min_mul_forward_cpu);
|
260 |
+
m.def("rspmm_min_mul_backward_cpu", &at::rspmm_min_mul_backward_cpu);
|
261 |
+
m.def("rspmm_max_mul_forward_cpu", &at::rspmm_max_mul_forward_cpu);
|
262 |
+
m.def("rspmm_max_mul_backward_cpu", &at::rspmm_max_mul_backward_cpu);
|
263 |
+
m.def("rspmm_add_add_forward_cpu", &at::rspmm_add_add_forward_cpu);
|
264 |
+
m.def("rspmm_add_add_backward_cpu", &at::rspmm_add_add_backward_cpu);
|
265 |
+
m.def("rspmm_min_add_forward_cpu", &at::rspmm_min_add_forward_cpu);
|
266 |
+
m.def("rspmm_min_add_backward_cpu", &at::rspmm_min_add_backward_cpu);
|
267 |
+
m.def("rspmm_max_add_forward_cpu", &at::rspmm_max_add_forward_cpu);
|
268 |
+
m.def("rspmm_max_add_backward_cpu", &at::rspmm_max_add_backward_cpu);
|
269 |
+
#ifdef CUDA_OP
|
270 |
+
m.def("rspmm_add_mul_forward_cuda", &at::rspmm_add_mul_forward_cuda);
|
271 |
+
m.def("rspmm_add_mul_backward_cuda", &at::rspmm_add_mul_backward_cuda);
|
272 |
+
m.def("rspmm_min_mul_forward_cuda", &at::rspmm_min_mul_forward_cuda);
|
273 |
+
m.def("rspmm_min_mul_backward_cuda", &at::rspmm_min_mul_backward_cuda);
|
274 |
+
m.def("rspmm_max_mul_forward_cuda", &at::rspmm_max_mul_forward_cuda);
|
275 |
+
m.def("rspmm_max_mul_backward_cuda", &at::rspmm_max_mul_backward_cuda);
|
276 |
+
m.def("rspmm_add_add_forward_cuda", &at::rspmm_add_add_forward_cuda);
|
277 |
+
m.def("rspmm_add_add_backward_cuda", &at::rspmm_add_add_backward_cuda);
|
278 |
+
m.def("rspmm_min_add_forward_cuda", &at::rspmm_min_add_forward_cuda);
|
279 |
+
m.def("rspmm_min_add_backward_cuda", &at::rspmm_min_add_backward_cuda);
|
280 |
+
m.def("rspmm_max_add_forward_cuda", &at::rspmm_max_add_forward_cuda);
|
281 |
+
m.def("rspmm_max_add_backward_cuda", &at::rspmm_max_add_backward_cuda);
|
282 |
+
#endif
|
283 |
+
}
|
ultra/rspmm/source/rspmm.cu
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <THC/THCAtomics.cuh>
|
3 |
+
|
4 |
+
#include "util.cuh"
|
5 |
+
#include "operator.cuh"
|
6 |
+
#include "rspmm.h"
|
7 |
+
|
8 |
+
namespace at {
|
9 |
+
|
10 |
+
// Memory & time efficient implementation of generalized spmm
|
11 |
+
// Much of the code is inspired by GE-SpMM
|
12 |
+
// https://github.com/hgyhungry/ge-spmm
|
13 |
+
|
14 |
+
namespace {
|
15 |
+
|
16 |
+
const int kCoarseningFactor = 2;
|
17 |
+
const int kThreadPerBlock = 256;
|
18 |
+
|
19 |
+
} // namespace anonymous
|
20 |
+
|
21 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
22 |
+
__global__
|
23 |
+
void rspmm_forward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
24 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
25 |
+
scalar_t *output,
|
26 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
27 |
+
// for best optimization, the following code is compiled with constant warpSize
|
28 |
+
assert(blockDim.x == warpSize);
|
29 |
+
|
30 |
+
extern __shared__ int64_t buffer[];
|
31 |
+
int64_t *col_ind_buf = buffer;
|
32 |
+
int64_t *layer_ind_buf = buffer + blockDim.y * warpSize;
|
33 |
+
scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
|
34 |
+
col_ind_buf += threadIdx.y * warpSize;
|
35 |
+
layer_ind_buf += threadIdx.y * warpSize;
|
36 |
+
weight_buf += threadIdx.y * warpSize;
|
37 |
+
|
38 |
+
int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
|
39 |
+
if (row >= num_row)
|
40 |
+
return;
|
41 |
+
int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
|
42 |
+
int64_t ptr_start = row_ptr[row];
|
43 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
44 |
+
scalar_t out[kCoarseningFactor];
|
45 |
+
#pragma unroll
|
46 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++)
|
47 |
+
out[i] = NaryOp::zero;
|
48 |
+
|
49 |
+
for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
|
50 |
+
int64_t ptr = block_ptr + threadIdx.x;
|
51 |
+
if (ptr < ptr_end) {
|
52 |
+
col_ind_buf[threadIdx.x] = col_ind[ptr];
|
53 |
+
layer_ind_buf[threadIdx.x] = layer_ind[ptr];
|
54 |
+
weight_buf[threadIdx.x] = weight[ptr];
|
55 |
+
}
|
56 |
+
__syncwarp();
|
57 |
+
|
58 |
+
int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
|
59 |
+
for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
|
60 |
+
int64_t col = col_ind_buf[offset_ptr];
|
61 |
+
int64_t layer = layer_ind_buf[offset_ptr];
|
62 |
+
scalar_t w = weight_buf[offset_ptr];
|
63 |
+
#pragma unroll
|
64 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
65 |
+
int64_t d = d_start + i * warpSize;
|
66 |
+
if (d >= dim)
|
67 |
+
break;
|
68 |
+
scalar_t x = BinaryOp::forward(relation[layer * dim + d], input[col * dim + d]);
|
69 |
+
scalar_t y = w * x;
|
70 |
+
out[i] = NaryOp::forward(out[i], y);
|
71 |
+
}
|
72 |
+
}
|
73 |
+
__syncwarp();
|
74 |
+
}
|
75 |
+
|
76 |
+
#pragma unroll
|
77 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
78 |
+
int64_t d = d_start + i * warpSize;
|
79 |
+
if (d >= dim)
|
80 |
+
break;
|
81 |
+
output[row * dim + d] = out[i];
|
82 |
+
}
|
83 |
+
}
|
84 |
+
|
85 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
86 |
+
__global__
|
87 |
+
void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
88 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
89 |
+
const scalar_t *output, const scalar_t *output_grad,
|
90 |
+
scalar_t *weight_grad, scalar_t *relation_grad, scalar_t *input_grad,
|
91 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
92 |
+
// for best optimization, the following code is compiled with constant warpSize
|
93 |
+
assert(blockDim.x == warpSize);
|
94 |
+
|
95 |
+
extern __shared__ int64_t buffer[];
|
96 |
+
int64_t *col_ind_buf = buffer;
|
97 |
+
int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize;
|
98 |
+
scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
|
99 |
+
col_ind_buf += threadIdx.y * warpSize;
|
100 |
+
layer_ind_buf += threadIdx.y * warpSize;
|
101 |
+
weight_buf += threadIdx.y * warpSize;
|
102 |
+
|
103 |
+
int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
|
104 |
+
if (row >= num_row)
|
105 |
+
return;
|
106 |
+
int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
|
107 |
+
int64_t ptr_start = row_ptr[row];
|
108 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
109 |
+
|
110 |
+
for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
|
111 |
+
int64_t ptr = block_ptr + threadIdx.x;
|
112 |
+
if (ptr < ptr_end) {
|
113 |
+
col_ind_buf[threadIdx.x] = col_ind[ptr];
|
114 |
+
layer_ind_buf[threadIdx.x] = layer_ind[ptr];
|
115 |
+
weight_buf[threadIdx.x] = weight[ptr];
|
116 |
+
}
|
117 |
+
__syncwarp();
|
118 |
+
|
119 |
+
int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
|
120 |
+
for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
|
121 |
+
int64_t col = col_ind_buf[offset_ptr];
|
122 |
+
int64_t layer = layer_ind_buf[offset_ptr];
|
123 |
+
scalar_t w = weight_buf[offset_ptr];
|
124 |
+
scalar_t w_grad = 0;
|
125 |
+
#pragma unroll
|
126 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
127 |
+
int64_t d = d_start + i * warpSize;
|
128 |
+
if (d >= dim)
|
129 |
+
break;
|
130 |
+
scalar_t rel = relation[layer * dim + d];
|
131 |
+
scalar_t in = input[col * dim + d];
|
132 |
+
scalar_t out = output[row * dim + d];
|
133 |
+
scalar_t out_grad = output_grad[row * dim + d];
|
134 |
+
scalar_t x = BinaryOp::forward(rel, in);
|
135 |
+
scalar_t y = w * x;
|
136 |
+
scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
|
137 |
+
scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
|
138 |
+
scalar_t dout_dy = NaryOp::backward(out, y);
|
139 |
+
scalar_t dy_dw = x;
|
140 |
+
scalar_t dy_dx = w;
|
141 |
+
w_grad += out_grad * dout_dy * dy_dw;
|
142 |
+
atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel);
|
143 |
+
atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din);
|
144 |
+
}
|
145 |
+
w_grad = warp_reduce(w_grad);
|
146 |
+
if (threadIdx.x == 0)
|
147 |
+
atomicAdd(&weight_grad[block_ptr + offset_ptr], w_grad);
|
148 |
+
}
|
149 |
+
__syncwarp();
|
150 |
+
}
|
151 |
+
}
|
152 |
+
|
153 |
+
// only relation & input require gradients
|
154 |
+
template <class scalar_t, class NaryOp, class BinaryOp>
|
155 |
+
__global__
|
156 |
+
void rspmm_backward_out_cuda(const int64_t *row_ptr, const int64_t *col_ind, const int64_t *layer_ind,
|
157 |
+
const scalar_t *weight, const scalar_t *relation, const scalar_t *input,
|
158 |
+
const scalar_t *output, const scalar_t *output_grad,
|
159 |
+
scalar_t *relation_grad, scalar_t *input_grad,
|
160 |
+
int64_t num_row, int64_t nnz, int64_t dim) {
|
161 |
+
// for best optimization, the following code is compiled with constant warpSize
|
162 |
+
assert(blockDim.x == warpSize);
|
163 |
+
|
164 |
+
extern __shared__ int64_t buffer[];
|
165 |
+
int64_t *col_ind_buf = buffer;
|
166 |
+
int64_t *layer_ind_buf = col_ind_buf + blockDim.y * warpSize;
|
167 |
+
scalar_t *weight_buf = reinterpret_cast<scalar_t *>(layer_ind_buf + blockDim.y * warpSize);
|
168 |
+
col_ind_buf += threadIdx.y * warpSize;
|
169 |
+
layer_ind_buf += threadIdx.y * warpSize;
|
170 |
+
weight_buf += threadIdx.y * warpSize;
|
171 |
+
|
172 |
+
int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
|
173 |
+
if (row >= num_row)
|
174 |
+
return;
|
175 |
+
int64_t d_start = blockIdx.y * warpSize * kCoarseningFactor + threadIdx.x;
|
176 |
+
int64_t ptr_start = row_ptr[row];
|
177 |
+
int64_t ptr_end = row + 1 < num_row ? row_ptr[row + 1] : nnz;
|
178 |
+
|
179 |
+
for (int64_t block_ptr = ptr_start; block_ptr < ptr_end; block_ptr += warpSize) {
|
180 |
+
int64_t ptr = block_ptr + threadIdx.x;
|
181 |
+
if (ptr < ptr_end) {
|
182 |
+
col_ind_buf[threadIdx.x] = col_ind[ptr];
|
183 |
+
layer_ind_buf[threadIdx.x] = layer_ind[ptr];
|
184 |
+
weight_buf[threadIdx.x] = weight[ptr];
|
185 |
+
}
|
186 |
+
__syncwarp();
|
187 |
+
|
188 |
+
int64_t max_offset = warpSize < ptr_end - block_ptr ? warpSize : ptr_end - block_ptr;
|
189 |
+
for (int64_t offset_ptr = 0; offset_ptr < max_offset; offset_ptr++) {
|
190 |
+
int64_t col = col_ind_buf[offset_ptr];
|
191 |
+
int64_t layer = layer_ind_buf[offset_ptr];
|
192 |
+
scalar_t w = weight_buf[offset_ptr];
|
193 |
+
#pragma unroll
|
194 |
+
for (int64_t i = 0; i < kCoarseningFactor; i++) {
|
195 |
+
int64_t d = d_start + i * warpSize;
|
196 |
+
if (d >= dim)
|
197 |
+
break;
|
198 |
+
scalar_t rel = relation[layer * dim + d];
|
199 |
+
scalar_t in = input[col * dim + d];
|
200 |
+
scalar_t out = output[row * dim + d];
|
201 |
+
scalar_t out_grad = output_grad[row * dim + d];
|
202 |
+
scalar_t x = BinaryOp::forward(rel, in);
|
203 |
+
scalar_t y = w * x;
|
204 |
+
scalar_t dx_drel = BinaryOp::backward_lhs(rel, in);
|
205 |
+
scalar_t dx_din = BinaryOp::backward_rhs(rel, in);
|
206 |
+
scalar_t dout_dy = NaryOp::backward(out, y);
|
207 |
+
scalar_t dy_dx = w;
|
208 |
+
atomicAdd(&relation_grad[layer * dim + d], out_grad * dout_dy * dy_dx * dx_drel);
|
209 |
+
atomicAdd(&input_grad[col * dim + d], out_grad * dout_dy * dy_dx * dx_din);
|
210 |
+
}
|
211 |
+
}
|
212 |
+
__syncwarp();
|
213 |
+
}
|
214 |
+
}
|
215 |
+
|
216 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
217 |
+
Tensor rspmm_forward_cuda(const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
218 |
+
const Tensor &relation_, const Tensor &input_) {
|
219 |
+
constexpr const char *fn_name = "rspmm_forward_cuda";
|
220 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
221 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
222 |
+
input_arg(input_, "input", 5);
|
223 |
+
|
224 |
+
rspmm_forward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg);
|
225 |
+
checkAllSameGPU(fn_name, {edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg});
|
226 |
+
|
227 |
+
const Tensor edge_index = edge_index_.contiguous();
|
228 |
+
const Tensor edge_type = edge_type_.contiguous();
|
229 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
230 |
+
const Tensor relation = relation_.contiguous();
|
231 |
+
const Tensor input = input_.contiguous();
|
232 |
+
|
233 |
+
int64_t nnz = edge_index.size(0);
|
234 |
+
int64_t num_row = input.size(0);
|
235 |
+
int64_t dim = input.size(1);
|
236 |
+
Tensor output = at::empty({num_row, dim}, input.options());
|
237 |
+
|
238 |
+
Tensor row_ind = edge_index.select(0, 0);
|
239 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
240 |
+
Tensor col_ind = edge_index.select(0, 1);
|
241 |
+
Tensor layer_ind = edge_type;
|
242 |
+
|
243 |
+
cudaSetDevice(input.get_device());
|
244 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
245 |
+
|
246 |
+
const int dim_per_block = 32; // warpSize
|
247 |
+
const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor);
|
248 |
+
const int row_per_block = kThreadPerBlock / dim_per_block;
|
249 |
+
const int num_row_block = (num_row + row_per_block - 1) / row_per_block;
|
250 |
+
|
251 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_forward_cuda", [&] {
|
252 |
+
const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
|
253 |
+
rspmm_forward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
|
254 |
+
<<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
|
255 |
+
row_ptr.data_ptr<int64_t>(),
|
256 |
+
col_ind.data_ptr<int64_t>(),
|
257 |
+
layer_ind.data_ptr<int64_t>(),
|
258 |
+
edge_weight.data_ptr<scalar_t>(),
|
259 |
+
relation.data_ptr<scalar_t>(),
|
260 |
+
input.data_ptr<scalar_t>(),
|
261 |
+
output.data_ptr<scalar_t>(),
|
262 |
+
num_row, nnz, dim
|
263 |
+
);
|
264 |
+
});
|
265 |
+
|
266 |
+
return output;
|
267 |
+
}
|
268 |
+
|
269 |
+
template <template<class> class NaryOp, template<class> class BinaryOp>
|
270 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_backward_cuda(
|
271 |
+
const Tensor &edge_index_, const Tensor &edge_type_, const Tensor &edge_weight_,
|
272 |
+
const Tensor &relation_, const Tensor &input_, const Tensor &output_, const Tensor &output_grad_) {
|
273 |
+
constexpr const char *fn_name = "rspmm_backward_cuda";
|
274 |
+
TensorArg edge_index_arg(edge_index_, "edge_index", 1), edge_type_arg(edge_type_, "edge_type", 2),
|
275 |
+
edge_weight_arg(edge_weight_, "edge_weight", 3), relation_arg(relation_, "relation", 4),
|
276 |
+
input_arg(input_, "input", 5), output_arg(output_, "output", 6),
|
277 |
+
output_grad_arg(output_grad_, "output_grad", 7);
|
278 |
+
|
279 |
+
rspmm_backward_check(fn_name, edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg,
|
280 |
+
output_arg, output_grad_arg);
|
281 |
+
checkAllSameGPU(fn_name, {edge_index_arg, edge_type_arg, edge_weight_arg, relation_arg, input_arg, output_arg,
|
282 |
+
output_grad_arg});
|
283 |
+
|
284 |
+
const Tensor edge_index = edge_index_.contiguous();
|
285 |
+
const Tensor edge_type = edge_type_.contiguous();
|
286 |
+
const Tensor edge_weight = edge_weight_.contiguous();
|
287 |
+
const Tensor relation = relation_.contiguous();
|
288 |
+
const Tensor input = input_.contiguous();
|
289 |
+
const Tensor output = output_.contiguous();
|
290 |
+
const Tensor output_grad = output_grad_.contiguous();
|
291 |
+
|
292 |
+
int64_t nnz = edge_index.size(0);
|
293 |
+
int64_t num_row = input.size(0);
|
294 |
+
int64_t dim = input.size(1);
|
295 |
+
Tensor weight_grad = at::zeros_like(edge_weight);
|
296 |
+
Tensor relation_grad = at::zeros_like(relation);
|
297 |
+
Tensor input_grad = at::zeros_like(input);
|
298 |
+
|
299 |
+
Tensor row_ind = edge_index.select(0, 0);
|
300 |
+
Tensor row_ptr = ind2ptr(row_ind, num_row);
|
301 |
+
Tensor col_ind = edge_index.select(0, 1);
|
302 |
+
Tensor layer_ind = edge_type;
|
303 |
+
|
304 |
+
cudaSetDevice(input.get_device());
|
305 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
306 |
+
|
307 |
+
const int dim_per_block = 32; // warpSize
|
308 |
+
const int num_dim_block = (dim + dim_per_block * kCoarseningFactor - 1) / (dim_per_block * kCoarseningFactor);
|
309 |
+
const int row_per_block = kThreadPerBlock / dim_per_block;
|
310 |
+
const int num_row_block = (num_row + row_per_block - 1) / row_per_block;
|
311 |
+
|
312 |
+
if (edge_weight.requires_grad())
|
313 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] {
|
314 |
+
const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
|
315 |
+
rspmm_backward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
|
316 |
+
<<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
|
317 |
+
row_ptr.data_ptr<int64_t>(),
|
318 |
+
col_ind.data_ptr<int64_t>(),
|
319 |
+
layer_ind.data_ptr<int64_t>(),
|
320 |
+
edge_weight.data_ptr<scalar_t>(),
|
321 |
+
relation.data_ptr<scalar_t>(),
|
322 |
+
input.data_ptr<scalar_t>(),
|
323 |
+
output.data_ptr<scalar_t>(),
|
324 |
+
output_grad.data_ptr<scalar_t>(),
|
325 |
+
weight_grad.data_ptr<scalar_t>(),
|
326 |
+
relation_grad.data_ptr<scalar_t>(),
|
327 |
+
input_grad.data_ptr<scalar_t>(),
|
328 |
+
num_row, nnz, dim
|
329 |
+
);
|
330 |
+
});
|
331 |
+
else
|
332 |
+
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rspmm_backward_cuda", [&] {
|
333 |
+
const int memory_size = kThreadPerBlock * (sizeof(int64_t) * 2 + sizeof(scalar_t));
|
334 |
+
rspmm_backward_out_cuda<scalar_t, NaryOp<scalar_t>, BinaryOp<scalar_t>>
|
335 |
+
<<<dim3(num_row_block, num_dim_block), dim3(dim_per_block, row_per_block), memory_size, stream>>>(
|
336 |
+
row_ptr.data_ptr<int64_t>(),
|
337 |
+
col_ind.data_ptr<int64_t>(),
|
338 |
+
layer_ind.data_ptr<int64_t>(),
|
339 |
+
edge_weight.data_ptr<scalar_t>(),
|
340 |
+
relation.data_ptr<scalar_t>(),
|
341 |
+
input.data_ptr<scalar_t>(),
|
342 |
+
output.data_ptr<scalar_t>(),
|
343 |
+
output_grad.data_ptr<scalar_t>(),
|
344 |
+
relation_grad.data_ptr<scalar_t>(),
|
345 |
+
input_grad.data_ptr<scalar_t>(),
|
346 |
+
num_row, nnz, dim
|
347 |
+
);
|
348 |
+
});
|
349 |
+
|
350 |
+
return std::make_tuple(weight_grad, relation_grad, input_grad);
|
351 |
+
}
|
352 |
+
|
353 |
+
#define DECLARE_FORWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
354 |
+
Tensor rspmm_##ADD##_##MUL##_forward_cuda( \
|
355 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
356 |
+
const Tensor &relation, const Tensor &input) { \
|
357 |
+
return rspmm_forward_cuda<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input); \
|
358 |
+
}
|
359 |
+
|
360 |
+
#define DECLARE_BACKWARD_IMPL(ADD, MUL, NARYOP, BINARYOP) \
|
361 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_##ADD##_##MUL##_backward_cuda( \
|
362 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, \
|
363 |
+
const Tensor &relation, const Tensor &input, const Tensor &output, const Tensor &output_grad) { \
|
364 |
+
return rspmm_backward_cuda<NARYOP, BINARYOP>(edge_index, edge_type, edge_weight, relation, input, \
|
365 |
+
output, output_grad); \
|
366 |
+
}
|
367 |
+
|
368 |
+
DECLARE_FORWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
369 |
+
DECLARE_BACKWARD_IMPL(add, mul, NaryAdd, BinaryMul)
|
370 |
+
|
371 |
+
DECLARE_FORWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
372 |
+
DECLARE_BACKWARD_IMPL(min, mul, NaryMin, BinaryMul)
|
373 |
+
|
374 |
+
DECLARE_FORWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
375 |
+
DECLARE_BACKWARD_IMPL(max, mul, NaryMax, BinaryMul)
|
376 |
+
|
377 |
+
DECLARE_FORWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
378 |
+
DECLARE_BACKWARD_IMPL(add, add, NaryAdd, BinaryAdd)
|
379 |
+
|
380 |
+
DECLARE_FORWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
381 |
+
DECLARE_BACKWARD_IMPL(min, add, NaryMin, BinaryAdd)
|
382 |
+
|
383 |
+
DECLARE_FORWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
384 |
+
DECLARE_BACKWARD_IMPL(max, add, NaryMax, BinaryAdd)
|
385 |
+
|
386 |
+
} // namespace at
|
ultra/rspmm/source/rspmm.h
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <tuple>
|
4 |
+
|
5 |
+
#include <torch/extension.h>
|
6 |
+
//#include <ATen/SparseTensorUtils.h>
|
7 |
+
#include <ATen/native/SparseTensorUtils.h>
|
8 |
+
|
9 |
+
namespace at {
|
10 |
+
|
11 |
+
using namespace at::sparse;
|
12 |
+
|
13 |
+
void rspmm_forward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
14 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg);
|
15 |
+
|
16 |
+
void rspmm_backward_check(CheckedFrom c, const TensorArg &edge_index_arg, const TensorArg &edge_type_arg,
|
17 |
+
const TensorArg &edge_weight_arg, const TensorArg &relation_arg, const TensorArg &input_arg,
|
18 |
+
const TensorArg &output_arg, const TensorArg &output_grad_arg);
|
19 |
+
|
20 |
+
Tensor ind2ptr(const Tensor &index, int size);
|
21 |
+
|
22 |
+
Tensor rspmm_add_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
23 |
+
const Tensor &relation, const Tensor &input);
|
24 |
+
|
25 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cpu(
|
26 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
27 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
28 |
+
|
29 |
+
Tensor rspmm_min_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
30 |
+
const Tensor &relation, const Tensor &input);
|
31 |
+
|
32 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cpu(
|
33 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
34 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
35 |
+
|
36 |
+
Tensor rspmm_max_mul_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
37 |
+
const Tensor &relation, const Tensor &input);
|
38 |
+
|
39 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cpu(
|
40 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
41 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
42 |
+
|
43 |
+
Tensor rspmm_add_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
44 |
+
const Tensor &relation, const Tensor &input);
|
45 |
+
|
46 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cpu(
|
47 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
48 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
49 |
+
|
50 |
+
Tensor rspmm_min_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
51 |
+
const Tensor &relation, const Tensor &input);
|
52 |
+
|
53 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cpu(
|
54 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
55 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
56 |
+
|
57 |
+
Tensor rspmm_max_add_forward_cpu(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
58 |
+
const Tensor &relation, const Tensor &input);
|
59 |
+
|
60 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cpu(
|
61 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
62 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
63 |
+
|
64 |
+
#ifdef CUDA_OP
|
65 |
+
Tensor rspmm_add_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
66 |
+
const Tensor &relation, const Tensor &input);
|
67 |
+
|
68 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_mul_backward_cuda(
|
69 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
70 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
71 |
+
|
72 |
+
Tensor rspmm_min_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
73 |
+
const Tensor &relation, const Tensor &input);
|
74 |
+
|
75 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_mul_backward_cuda(
|
76 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
77 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
78 |
+
|
79 |
+
Tensor rspmm_max_mul_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
80 |
+
const Tensor &relation, const Tensor &input);
|
81 |
+
|
82 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_mul_backward_cuda(
|
83 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
84 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
85 |
+
|
86 |
+
Tensor rspmm_add_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
87 |
+
const Tensor &relation, const Tensor &input);
|
88 |
+
|
89 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_add_add_backward_cuda(
|
90 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
91 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
92 |
+
|
93 |
+
Tensor rspmm_min_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
94 |
+
const Tensor &relation, const Tensor &input);
|
95 |
+
|
96 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_min_add_backward_cuda(
|
97 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
98 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
99 |
+
|
100 |
+
Tensor rspmm_max_add_forward_cuda(const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight,
|
101 |
+
const Tensor &relation, const Tensor &input);
|
102 |
+
|
103 |
+
std::tuple<Tensor, Tensor, Tensor> rspmm_max_add_backward_cuda(
|
104 |
+
const Tensor &edge_index, const Tensor &edge_type, const Tensor &edge_weight, const Tensor &relation,
|
105 |
+
const Tensor &input, const Tensor &output, const Tensor &output_grad);
|
106 |
+
#endif
|
107 |
+
|
108 |
+
} // namespace at
|
ultra/rspmm/source/util.cuh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
namespace at {
|
4 |
+
|
5 |
+
const unsigned kFullMask = 0xFFFFFFFF;
|
6 |
+
|
7 |
+
template <class scalar_t>
|
8 |
+
__device__ scalar_t warp_reduce(scalar_t value) {
|
9 |
+
#pragma unroll
|
10 |
+
for (int delta = 1; delta < warpSize; delta *= 2)
|
11 |
+
#if __CUDACC_VER_MAJOR__ >= 9
|
12 |
+
value += __shfl_down_sync(kFullMask, value, delta);
|
13 |
+
#else
|
14 |
+
value += __shfl_down(value, delta);
|
15 |
+
#endif
|
16 |
+
return value;
|
17 |
+
}
|
18 |
+
|
19 |
+
template<class scalar_t>
|
20 |
+
__device__ scalar_t warp_broadcast(scalar_t value, int lane_id) {
|
21 |
+
#if __CUDACC_VER_MAJOR__ >= 9
|
22 |
+
return __shfl_sync(kFullMask, value, lane_id);
|
23 |
+
#else
|
24 |
+
return __shfl(value, lane_id);
|
25 |
+
#endif
|
26 |
+
}
|
27 |
+
|
28 |
+
} // namespace at
|
ultra/tasks.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce
|
2 |
+
from torch_scatter import scatter_add
|
3 |
+
from torch_geometric.data import Data
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def edge_match(edge_index, query_index):
|
8 |
+
# O((n + q)logn) time
|
9 |
+
# O(n) memory
|
10 |
+
# edge_index: big underlying graph
|
11 |
+
# query_index: edges to match
|
12 |
+
|
13 |
+
# preparing unique hashing of edges, base: (max_node, max_relation) + 1
|
14 |
+
base = edge_index.max(dim=1)[0] + 1
|
15 |
+
# we will map edges to long ints, so we need to make sure the maximum product is less than MAX_LONG_INT
|
16 |
+
# idea: max number of edges = num_nodes * num_relations
|
17 |
+
# e.g. for a graph of 10 nodes / 5 relations, edge IDs 0...9 mean all possible outgoing edge types from node 0
|
18 |
+
# given a tuple (h, r), we will search for all other existing edges starting from head h
|
19 |
+
assert reduce(int.__mul__, base.tolist()) < torch.iinfo(torch.long).max
|
20 |
+
scale = base.cumprod(0)
|
21 |
+
scale = scale[-1] // scale
|
22 |
+
|
23 |
+
# hash both the original edge index and the query index to unique integers
|
24 |
+
edge_hash = (edge_index * scale.unsqueeze(-1)).sum(dim=0)
|
25 |
+
edge_hash, order = edge_hash.sort()
|
26 |
+
query_hash = (query_index * scale.unsqueeze(-1)).sum(dim=0)
|
27 |
+
|
28 |
+
# matched ranges: [start[i], end[i])
|
29 |
+
start = torch.bucketize(query_hash, edge_hash)
|
30 |
+
end = torch.bucketize(query_hash, edge_hash, right=True)
|
31 |
+
# num_match shows how many edges satisfy the (h, r) pattern for each query in the batch
|
32 |
+
num_match = end - start
|
33 |
+
|
34 |
+
# generate the corresponding ranges
|
35 |
+
offset = num_match.cumsum(0) - num_match
|
36 |
+
range = torch.arange(num_match.sum(), device=edge_index.device)
|
37 |
+
range = range + (start - offset).repeat_interleave(num_match)
|
38 |
+
|
39 |
+
return order[range], num_match
|
40 |
+
|
41 |
+
|
42 |
+
def negative_sampling(data, batch, num_negative, strict=True):
|
43 |
+
batch_size = len(batch)
|
44 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
45 |
+
|
46 |
+
# strict negative sampling vs random negative sampling
|
47 |
+
if strict:
|
48 |
+
t_mask, h_mask = strict_negative_mask(data, batch)
|
49 |
+
t_mask = t_mask[:batch_size // 2]
|
50 |
+
neg_t_candidate = t_mask.nonzero()[:, 1]
|
51 |
+
num_t_candidate = t_mask.sum(dim=-1)
|
52 |
+
# draw samples for negative tails
|
53 |
+
rand = torch.rand(len(t_mask), num_negative, device=batch.device)
|
54 |
+
index = (rand * num_t_candidate.unsqueeze(-1)).long()
|
55 |
+
index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1)
|
56 |
+
neg_t_index = neg_t_candidate[index]
|
57 |
+
|
58 |
+
h_mask = h_mask[batch_size // 2:]
|
59 |
+
neg_h_candidate = h_mask.nonzero()[:, 1]
|
60 |
+
num_h_candidate = h_mask.sum(dim=-1)
|
61 |
+
# draw samples for negative heads
|
62 |
+
rand = torch.rand(len(h_mask), num_negative, device=batch.device)
|
63 |
+
index = (rand * num_h_candidate.unsqueeze(-1)).long()
|
64 |
+
index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1)
|
65 |
+
neg_h_index = neg_h_candidate[index]
|
66 |
+
else:
|
67 |
+
neg_index = torch.randint(data.num_nodes, (batch_size, num_negative), device=batch.device)
|
68 |
+
neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:]
|
69 |
+
|
70 |
+
h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1)
|
71 |
+
t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1)
|
72 |
+
r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1)
|
73 |
+
t_index[:batch_size // 2, 1:] = neg_t_index
|
74 |
+
h_index[batch_size // 2:, 1:] = neg_h_index
|
75 |
+
|
76 |
+
return torch.stack([h_index, t_index, r_index], dim=-1)
|
77 |
+
|
78 |
+
|
79 |
+
def all_negative(data, batch):
|
80 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
81 |
+
r_index = pos_r_index.unsqueeze(-1).expand(-1, data.num_nodes)
|
82 |
+
# generate all negative tails for this batch
|
83 |
+
all_index = torch.arange(data.num_nodes, device=batch.device)
|
84 |
+
h_index, t_index = torch.meshgrid(pos_h_index, all_index, indexing="ij") # indexing "xy" would return transposed
|
85 |
+
t_batch = torch.stack([h_index, t_index, r_index], dim=-1)
|
86 |
+
# generate all negative heads for this batch
|
87 |
+
all_index = torch.arange(data.num_nodes, device=batch.device)
|
88 |
+
t_index, h_index = torch.meshgrid(pos_t_index, all_index, indexing="ij")
|
89 |
+
h_batch = torch.stack([h_index, t_index, r_index], dim=-1)
|
90 |
+
|
91 |
+
return t_batch, h_batch
|
92 |
+
|
93 |
+
|
94 |
+
def strict_negative_mask(data, batch):
|
95 |
+
# this function makes sure that for a given (h, r) batch we will NOT sample true tails as random negatives
|
96 |
+
# similarly, for a given (t, r) we will NOT sample existing true heads as random negatives
|
97 |
+
|
98 |
+
pos_h_index, pos_t_index, pos_r_index = batch.t()
|
99 |
+
|
100 |
+
# part I: sample hard negative tails
|
101 |
+
# edge index of all (head, relation) edges from the underlying graph
|
102 |
+
edge_index = torch.stack([data.edge_index[0], data.edge_type])
|
103 |
+
# edge index of current batch (head, relation) for which we will sample negatives
|
104 |
+
query_index = torch.stack([pos_h_index, pos_r_index])
|
105 |
+
# search for all true tails for the given (h, r) batch
|
106 |
+
edge_id, num_t_truth = edge_match(edge_index, query_index)
|
107 |
+
# build an index from the found edges
|
108 |
+
t_truth_index = data.edge_index[1, edge_id]
|
109 |
+
sample_id = torch.arange(len(num_t_truth), device=batch.device).repeat_interleave(num_t_truth)
|
110 |
+
t_mask = torch.ones(len(num_t_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
|
111 |
+
# assign 0s to the mask with the found true tails
|
112 |
+
t_mask[sample_id, t_truth_index] = 0
|
113 |
+
t_mask.scatter_(1, pos_t_index.unsqueeze(-1), 0)
|
114 |
+
|
115 |
+
# part II: sample hard negative heads
|
116 |
+
# edge_index[1] denotes tails, so the edge index becomes (t, r)
|
117 |
+
edge_index = torch.stack([data.edge_index[1], data.edge_type])
|
118 |
+
# edge index of current batch (tail, relation) for which we will sample heads
|
119 |
+
query_index = torch.stack([pos_t_index, pos_r_index])
|
120 |
+
# search for all true heads for the given (t, r) batch
|
121 |
+
edge_id, num_h_truth = edge_match(edge_index, query_index)
|
122 |
+
# build an index from the found edges
|
123 |
+
h_truth_index = data.edge_index[0, edge_id]
|
124 |
+
sample_id = torch.arange(len(num_h_truth), device=batch.device).repeat_interleave(num_h_truth)
|
125 |
+
h_mask = torch.ones(len(num_h_truth), data.num_nodes, dtype=torch.bool, device=batch.device)
|
126 |
+
# assign 0s to the mask with the found true heads
|
127 |
+
h_mask[sample_id, h_truth_index] = 0
|
128 |
+
h_mask.scatter_(1, pos_h_index.unsqueeze(-1), 0)
|
129 |
+
|
130 |
+
return t_mask, h_mask
|
131 |
+
|
132 |
+
|
133 |
+
def compute_ranking(pred, target, mask=None):
|
134 |
+
pos_pred = pred.gather(-1, target.unsqueeze(-1))
|
135 |
+
if mask is not None:
|
136 |
+
# filtered ranking
|
137 |
+
ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
|
138 |
+
else:
|
139 |
+
# unfiltered ranking
|
140 |
+
ranking = torch.sum(pos_pred <= pred, dim=-1) + 1
|
141 |
+
return ranking
|
142 |
+
|
143 |
+
|
144 |
+
def build_relation_graph(graph):
|
145 |
+
|
146 |
+
# expect the graph is already with inverse edges
|
147 |
+
|
148 |
+
edge_index, edge_type = graph.edge_index, graph.edge_type
|
149 |
+
num_nodes, num_rels = graph.num_nodes, graph.num_relations
|
150 |
+
device = edge_index.device
|
151 |
+
|
152 |
+
Eh = torch.vstack([edge_index[0], edge_type]).T.unique(dim=0) # (num_edges, 2)
|
153 |
+
Dh = scatter_add(torch.ones_like(Eh[:, 1]), Eh[:, 0])
|
154 |
+
|
155 |
+
EhT = torch.sparse_coo_tensor(
|
156 |
+
torch.flip(Eh, dims=[1]).T,
|
157 |
+
torch.ones(Eh.shape[0], device=device) / Dh[Eh[:, 0]],
|
158 |
+
(num_rels, num_nodes)
|
159 |
+
)
|
160 |
+
Eh = torch.sparse_coo_tensor(
|
161 |
+
Eh.T,
|
162 |
+
torch.ones(Eh.shape[0], device=device),
|
163 |
+
(num_nodes, num_rels)
|
164 |
+
)
|
165 |
+
Et = torch.vstack([edge_index[1], edge_type]).T.unique(dim=0) # (num_edges, 2)
|
166 |
+
|
167 |
+
Dt = scatter_add(torch.ones_like(Et[:, 1]), Et[:, 0])
|
168 |
+
assert not (Dt[Et[:, 0]] == 0).any()
|
169 |
+
|
170 |
+
EtT = torch.sparse_coo_tensor(
|
171 |
+
torch.flip(Et, dims=[1]).T,
|
172 |
+
torch.ones(Et.shape[0], device=device) / Dt[Et[:, 0]],
|
173 |
+
(num_rels, num_nodes)
|
174 |
+
)
|
175 |
+
Et = torch.sparse_coo_tensor(
|
176 |
+
Et.T,
|
177 |
+
torch.ones(Et.shape[0], device=device),
|
178 |
+
(num_nodes, num_rels)
|
179 |
+
)
|
180 |
+
|
181 |
+
Ahh = torch.sparse.mm(EhT, Eh).coalesce()
|
182 |
+
Att = torch.sparse.mm(EtT, Et).coalesce()
|
183 |
+
Aht = torch.sparse.mm(EhT, Et).coalesce()
|
184 |
+
Ath = torch.sparse.mm(EtT, Eh).coalesce()
|
185 |
+
|
186 |
+
hh_edges = torch.cat([Ahh.indices().T, torch.zeros(Ahh.indices().T.shape[0], 1, dtype=torch.long).fill_(0)], dim=1) # head to head
|
187 |
+
tt_edges = torch.cat([Att.indices().T, torch.zeros(Att.indices().T.shape[0], 1, dtype=torch.long).fill_(1)], dim=1) # tail to tail
|
188 |
+
ht_edges = torch.cat([Aht.indices().T, torch.zeros(Aht.indices().T.shape[0], 1, dtype=torch.long).fill_(2)], dim=1) # head to tail
|
189 |
+
th_edges = torch.cat([Ath.indices().T, torch.zeros(Ath.indices().T.shape[0], 1, dtype=torch.long).fill_(3)], dim=1) # tail to head
|
190 |
+
|
191 |
+
rel_graph = Data(
|
192 |
+
edge_index=torch.cat([hh_edges[:, [0, 1]].T, tt_edges[:, [0, 1]].T, ht_edges[:, [0, 1]].T, th_edges[:, [0, 1]].T], dim=1),
|
193 |
+
edge_type=torch.cat([hh_edges[:, 2], tt_edges[:, 2], ht_edges[:, 2], th_edges[:, 2]], dim=0),
|
194 |
+
num_nodes=num_rels,
|
195 |
+
num_relations=4
|
196 |
+
)
|
197 |
+
|
198 |
+
graph.relation_graph = rel_graph
|
199 |
+
return graph
|
200 |
+
|
201 |
+
|
ultra/util.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import ast
|
4 |
+
import copy
|
5 |
+
import time
|
6 |
+
import logging
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import yaml
|
10 |
+
import jinja2
|
11 |
+
from jinja2 import meta
|
12 |
+
import easydict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import distributed as dist
|
16 |
+
from torch_geometric.data import Data
|
17 |
+
from torch_geometric.datasets import RelLinkPredDataset, WordNet18RR
|
18 |
+
|
19 |
+
from ultra import models, datasets
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__file__)
|
23 |
+
|
24 |
+
|
25 |
+
def detect_variables(cfg_file):
|
26 |
+
with open(cfg_file, "r") as fin:
|
27 |
+
raw = fin.read()
|
28 |
+
env = jinja2.Environment()
|
29 |
+
tree = env.parse(raw)
|
30 |
+
vars = meta.find_undeclared_variables(tree)
|
31 |
+
return vars
|
32 |
+
|
33 |
+
|
34 |
+
def load_config(cfg_file, context=None):
|
35 |
+
with open(cfg_file, "r") as fin:
|
36 |
+
raw = fin.read()
|
37 |
+
template = jinja2.Template(raw)
|
38 |
+
instance = template.render(context)
|
39 |
+
cfg = yaml.safe_load(instance)
|
40 |
+
cfg = easydict.EasyDict(cfg)
|
41 |
+
return cfg
|
42 |
+
|
43 |
+
|
44 |
+
def literal_eval(string):
|
45 |
+
try:
|
46 |
+
return ast.literal_eval(string)
|
47 |
+
except (ValueError, SyntaxError):
|
48 |
+
return string
|
49 |
+
|
50 |
+
|
51 |
+
def parse_args():
|
52 |
+
parser = argparse.ArgumentParser()
|
53 |
+
parser.add_argument("-c", "--config", help="yaml configuration file", required=True)
|
54 |
+
parser.add_argument("-s", "--seed", help="random seed for PyTorch", type=int, default=1024)
|
55 |
+
|
56 |
+
args, unparsed = parser.parse_known_args()
|
57 |
+
# get dynamic arguments defined in the config file
|
58 |
+
vars = detect_variables(args.config)
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
for var in vars:
|
61 |
+
parser.add_argument("--%s" % var, required=True)
|
62 |
+
vars = parser.parse_known_args(unparsed)[0]
|
63 |
+
vars = {k: literal_eval(v) for k, v in vars._get_kwargs()}
|
64 |
+
|
65 |
+
return args, vars
|
66 |
+
|
67 |
+
|
68 |
+
def get_root_logger(file=True):
|
69 |
+
format = "%(asctime)-10s %(message)s"
|
70 |
+
datefmt = "%H:%M:%S"
|
71 |
+
logging.basicConfig(format=format, datefmt=datefmt)
|
72 |
+
logger = logging.getLogger("")
|
73 |
+
logger.setLevel(logging.INFO)
|
74 |
+
|
75 |
+
if file:
|
76 |
+
handler = logging.FileHandler("log.txt")
|
77 |
+
format = logging.Formatter(format, datefmt)
|
78 |
+
handler.setFormatter(format)
|
79 |
+
logger.addHandler(handler)
|
80 |
+
|
81 |
+
return logger
|
82 |
+
|
83 |
+
|
84 |
+
def get_rank():
|
85 |
+
if dist.is_initialized():
|
86 |
+
return dist.get_rank()
|
87 |
+
if "RANK" in os.environ:
|
88 |
+
return int(os.environ["RANK"])
|
89 |
+
return 0
|
90 |
+
|
91 |
+
|
92 |
+
def get_world_size():
|
93 |
+
if dist.is_initialized():
|
94 |
+
return dist.get_world_size()
|
95 |
+
if "WORLD_SIZE" in os.environ:
|
96 |
+
return int(os.environ["WORLD_SIZE"])
|
97 |
+
return 1
|
98 |
+
|
99 |
+
|
100 |
+
def synchronize():
|
101 |
+
if get_world_size() > 1:
|
102 |
+
dist.barrier()
|
103 |
+
|
104 |
+
|
105 |
+
def get_device(cfg):
|
106 |
+
if cfg.train.gpus:
|
107 |
+
device = torch.device(cfg.train.gpus[get_rank()])
|
108 |
+
else:
|
109 |
+
device = torch.device("cpu")
|
110 |
+
return device
|
111 |
+
|
112 |
+
def get_devices(gpus):
|
113 |
+
if gpus is not None:
|
114 |
+
device = torch.device(gpus[get_rank()])
|
115 |
+
else:
|
116 |
+
device = torch.device("cpu")
|
117 |
+
return device
|
118 |
+
|
119 |
+
|
120 |
+
def create_working_directory(cfg):
|
121 |
+
file_name = "working_dir.tmp"
|
122 |
+
world_size = get_world_size()
|
123 |
+
if cfg.train.gpus is not None and len(cfg.train.gpus) != world_size:
|
124 |
+
error_msg = "World size is %d but found %d GPUs in the argument"
|
125 |
+
if world_size == 1:
|
126 |
+
error_msg += ". Did you launch with `python -m torch.distributed.launch`?"
|
127 |
+
raise ValueError(error_msg % (world_size, len(cfg.train.gpus)))
|
128 |
+
if world_size > 1 and not dist.is_initialized():
|
129 |
+
dist.init_process_group("nccl", init_method="env://")
|
130 |
+
|
131 |
+
working_dir = os.path.join(os.path.expanduser(cfg.output_dir),
|
132 |
+
cfg.model["class"], cfg.dataset["class"], time.strftime("%Y-%m-%d-%H-%M-%S"))
|
133 |
+
|
134 |
+
# synchronize working directory
|
135 |
+
if get_rank() == 0:
|
136 |
+
with open(file_name, "w") as fout:
|
137 |
+
fout.write(working_dir)
|
138 |
+
os.makedirs(working_dir)
|
139 |
+
synchronize()
|
140 |
+
if get_rank() != 0:
|
141 |
+
with open(file_name, "r") as fin:
|
142 |
+
working_dir = fin.read()
|
143 |
+
synchronize()
|
144 |
+
if get_rank() == 0:
|
145 |
+
os.remove(file_name)
|
146 |
+
|
147 |
+
os.chdir(working_dir)
|
148 |
+
return working_dir
|
149 |
+
|
150 |
+
|
151 |
+
def build_dataset(cfg):
|
152 |
+
data_config = copy.deepcopy(cfg.dataset)
|
153 |
+
cls = data_config.pop("class")
|
154 |
+
|
155 |
+
ds_cls = getattr(datasets, cls)
|
156 |
+
dataset = ds_cls(**data_config)
|
157 |
+
|
158 |
+
if get_rank() == 0:
|
159 |
+
logger.warning("%s dataset" % (cls if "version" not in cfg.dataset else f'{cls}({cfg.dataset.version})'))
|
160 |
+
if cls != "JointDataset":
|
161 |
+
logger.warning("#train: %d, #valid: %d, #test: %d" %
|
162 |
+
(dataset[0].target_edge_index.shape[1], dataset[1].target_edge_index.shape[1],
|
163 |
+
dataset[2].target_edge_index.shape[1]))
|
164 |
+
else:
|
165 |
+
logger.warning("#train: %d, #valid: %d, #test: %d" %
|
166 |
+
(sum(d.target_edge_index.shape[1] for d in dataset._data[0]),
|
167 |
+
sum(d.target_edge_index.shape[1] for d in dataset._data[1]),
|
168 |
+
sum(d.target_edge_index.shape[1] for d in dataset._data[2]),
|
169 |
+
))
|
170 |
+
|
171 |
+
return dataset
|
172 |
+
|