liuganghuggingface
commited on
Update graph_decoder/diffusion_model.py
Browse files
graph_decoder/diffusion_model.py
CHANGED
@@ -181,11 +181,11 @@ class GraphDiT(nn.Module):
|
|
181 |
def generate(
|
182 |
self,
|
183 |
properties,
|
184 |
-
device,
|
185 |
guide_scale=1.,
|
186 |
num_nodes=None,
|
187 |
number_chain_steps=50,
|
188 |
):
|
|
|
189 |
properties = [float('nan') if x is None else x for x in properties]
|
190 |
properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
|
191 |
batch_size = properties.size(0)
|
|
|
181 |
def generate(
|
182 |
self,
|
183 |
properties,
|
|
|
184 |
guide_scale=1.,
|
185 |
num_nodes=None,
|
186 |
number_chain_steps=50,
|
187 |
):
|
188 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
189 |
properties = [float('nan') if x is None else x for x in properties]
|
190 |
properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
|
191 |
batch_size = properties.size(0)
|