Spaces:
Sleeping
Sleeping
igashov
commited on
Commit
•
d1da608
1
Parent(s):
b0ab0d5
fix size_nn
Browse files- app.py +1 -1
- src/linker_size_lightning.py +6 -3
app.py
CHANGED
@@ -72,7 +72,7 @@ print('Loaded diffusion model')
|
|
72 |
|
73 |
|
74 |
def sample_fn(_data):
|
75 |
-
output, _ = size_nn.forward(_data)
|
76 |
probabilities = torch.softmax(output, dim=1)
|
77 |
distribution = torch.distributions.Categorical(probs=probabilities)
|
78 |
samples = distribution.sample()
|
|
|
72 |
|
73 |
|
74 |
def sample_fn(_data):
|
75 |
+
output, _ = size_nn.forward(_data, return_loss=False)
|
76 |
probabilities = torch.softmax(output, dim=1)
|
77 |
distribution = torch.distributions.Categorical(probs=probabilities)
|
78 |
samples = distribution.sample()
|
src/linker_size_lightning.py
CHANGED
@@ -79,7 +79,7 @@ class SizeClassifier(pl.LightningModule):
|
|
79 |
def test_dataloader(self):
|
80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
81 |
|
82 |
-
def forward(self, data):
|
83 |
h = data['one_hot']
|
84 |
x = data['positions']
|
85 |
fragment_mask = data['fragment_mask']
|
@@ -103,8 +103,11 @@ class SizeClassifier(pl.LightningModule):
|
|
103 |
output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
|
104 |
output = output.view(bs, n_nodes, -1).mean(1)
|
105 |
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
108 |
|
109 |
return output, loss
|
110 |
|
|
|
79 |
def test_dataloader(self):
|
80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
81 |
|
82 |
+
def forward(self, data, return_loss=True):
|
83 |
h = data['one_hot']
|
84 |
x = data['positions']
|
85 |
fragment_mask = data['fragment_mask']
|
|
|
103 |
output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
|
104 |
output = output.view(bs, n_nodes, -1).mean(1)
|
105 |
|
106 |
+
if return_loss:
|
107 |
+
true = self.get_true_labels(linker_mask)
|
108 |
+
loss = cross_entropy(output, true, weight=self.loss_weights)
|
109 |
+
else:
|
110 |
+
loss = None
|
111 |
|
112 |
return output, loss
|
113 |
|