igashov commited on
Commit
d5b42eb
1 Parent(s): f97fe8b

update device

Browse files
Files changed (2) hide show
  1. src/lightning.py +2 -2
  2. src/linker_size_lightning.py +1 -1
src/lightning.py CHANGED
@@ -55,7 +55,7 @@ class DDPM(pl.LightningModule):
55
  self.val_data_prefix = val_data_prefix
56
  self.batch_size = batch_size
57
  self.lr = lr
58
- self.torch_device = torch_device
59
  self.include_charges = include_charges
60
  self.test_epochs = test_epochs
61
  self.n_stability_samples = n_stability_samples
@@ -81,7 +81,7 @@ class DDPM(pl.LightningModule):
81
  in_node_nf=in_node_nf,
82
  n_dims=n_dims,
83
  context_node_nf=context_node_nf,
84
- device=torch_device,
85
  hidden_nf=hidden_nf,
86
  activation=activation,
87
  n_layers=n_layers,
 
55
  self.val_data_prefix = val_data_prefix
56
  self.batch_size = batch_size
57
  self.lr = lr
58
+ self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
  self.include_charges = include_charges
60
  self.test_epochs = test_epochs
61
  self.n_stability_samples = n_stability_samples
 
81
  in_node_nf=in_node_nf,
82
  n_dims=n_dims,
83
  context_node_nf=context_node_nf,
84
+ device=self.torch_device,
85
  hidden_nf=hidden_nf,
86
  activation=activation,
87
  n_layers=n_layers,
src/linker_size_lightning.py CHANGED
@@ -45,7 +45,7 @@ class SizeClassifier(pl.LightningModule):
45
  hidden_nf=hidden_nf,
46
  out_node_nf=out_node_nf,
47
  n_layers=n_layers,
48
- device=torch_device,
49
  normalization=normalization,
50
  )
51
 
 
45
  hidden_nf=hidden_nf,
46
  out_node_nf=out_node_nf,
47
  n_layers=n_layers,
48
+ device=self.torch_device,
49
  normalization=normalization,
50
  )
51