liuganghuggingface commited on
Commit
7f99386
·
verified ·
1 Parent(s): 7c67898

Update graph_decoder/transformer.py

Browse files
Files changed (1) hide show
  1. graph_decoder/transformer.py +1 -33
graph_decoder/transformer.py CHANGED
@@ -2,39 +2,7 @@ import torch
2
  import torch.nn as nn
3
  from .layers import Attention, MLP
4
  from .conditions import TimestepEmbedder, ConditionEmbedder
5
- # from .diffusion_utils import PlaceHolder
6
-
7
- #### graph utils
8
- class PlaceHolder:
9
- def __init__(self, X, E, y):
10
- self.X = X
11
- self.E = E
12
- self.y = y
13
-
14
- def type_as(self, x: torch.Tensor, categorical: bool = False):
15
- """Changes the device and dtype of X, E, y."""
16
- self.X = self.X.type_as(x)
17
- self.E = self.E.type_as(x)
18
- if categorical:
19
- self.y = self.y.type_as(x)
20
- return self
21
-
22
- def mask(self, node_mask, collapse=False):
23
- x_mask = node_mask.unsqueeze(-1) # bs, n, 1
24
- e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1
25
- e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1
26
-
27
- if collapse:
28
- self.X = torch.argmax(self.X, dim=-1)
29
- self.E = torch.argmax(self.E, dim=-1)
30
-
31
- self.X[node_mask == 0] = -1
32
- self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1
33
- else:
34
- self.X = self.X * x_mask
35
- self.E = self.E * e_mask1 * e_mask2
36
- assert torch.allclose(self.E, torch.transpose(self.E, 1, 2))
37
- return self
38
 
39
  def modulate(x, shift, scale):
40
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
 
2
  import torch.nn as nn
3
  from .layers import Attention, MLP
4
  from .conditions import TimestepEmbedder, ConditionEmbedder
5
+ from .diffusion_utils import PlaceHolder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def modulate(x, shift, scale):
8
  return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)