liuganghuggingface commited on
Commit
d7a9f21
·
verified ·
1 Parent(s): 50e54f4

Update graph_decoder/diffusion_model.py

Browse files
Files changed (1) hide show
  1. graph_decoder/diffusion_model.py +1 -1
graph_decoder/diffusion_model.py CHANGED
@@ -181,11 +181,11 @@ class GraphDiT(nn.Module):
181
  def generate(
182
  self,
183
  properties,
184
- device,
185
  guide_scale=1.,
186
  num_nodes=None,
187
  number_chain_steps=50,
188
  ):
 
189
  properties = [float('nan') if x is None else x for x in properties]
190
  properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
191
  batch_size = properties.size(0)
 
181
  def generate(
182
  self,
183
  properties,
 
184
  guide_scale=1.,
185
  num_nodes=None,
186
  number_chain_steps=50,
187
  ):
188
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189
  properties = [float('nan') if x is None else x for x in properties]
190
  properties = torch.tensor(properties, dtype=torch.float).reshape(1, -1).to(device)
191
  batch_size = properties.size(0)