File size: 2,944 Bytes
95ba5bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import torch
import torch.nn as nn

from torch.distributions.categorical import Categorical
from src.egnn import GCL


class DistributionNodes:
    def __init__(self, histogram):

        self.n_nodes = []
        prob = []
        self.keys = {}
        for i, nodes in enumerate(histogram):
            self.n_nodes.append(nodes)
            self.keys[nodes] = i
            prob.append(histogram[nodes])
        self.n_nodes = torch.tensor(self.n_nodes)
        prob = np.array(prob)
        prob = prob/np.sum(prob)

        self.prob = torch.from_numpy(prob).float()

        entropy = torch.sum(self.prob * torch.log(self.prob + 1e-30))
        print("Entropy of n_nodes: H[N]", entropy.item())

        self.m = Categorical(torch.tensor(prob))

    def sample(self, n_samples=1):
        idx = self.m.sample((n_samples,))
        return self.n_nodes[idx]

    def log_prob(self, batch_n_nodes):
        assert len(batch_n_nodes.size()) == 1

        idcs = [self.keys[i.item()] for i in batch_n_nodes]
        idcs = torch.tensor(idcs).to(batch_n_nodes.device)

        log_p = torch.log(self.prob + 1e-30)

        log_p = log_p.to(batch_n_nodes.device)

        log_probs = log_p[idcs]

        return log_probs


class SizeGNN(nn.Module):
    def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_layers, normalization, device='cpu'):
        super(SizeGNN, self).__init__()
        self.hidden_nf = hidden_nf
        self.out_node_nf = out_node_nf
        self.device = device

        self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf)
        self.gcl1 = GCL(
            input_nf=self.hidden_nf,
            output_nf=self.hidden_nf,
            hidden_nf=self.hidden_nf,
            normalization_factor=1,
            aggregation_method='sum',
            edges_in_d=1,
            activation=nn.ReLU(),
            attention=False,
            normalization=normalization
        )

        layers = []
        for i in range(n_layers - 1):
            layer = GCL(
                input_nf=self.hidden_nf,
                output_nf=self.hidden_nf,
                hidden_nf=self.hidden_nf,
                normalization_factor=1,
                aggregation_method='sum',
                edges_in_d=1,
                activation=nn.ReLU(),
                attention=False,
                normalization=normalization
            )
            layers.append(layer)

        self.gcl_layers = nn.ModuleList(layers)
        self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf)
        self.to(self.device)

    def forward(self, h, edges, distances, node_mask, edge_mask):
        h = self.embedding_in(h)
        h, _ = self.gcl1(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)
        for gcl in self.gcl_layers:
            h, _ = gcl(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask)

        h = self.embedding_out(h)
        return h