igashov commited on
Commit
f97fe8b
1 Parent(s): 49021fb

update device

Browse files
Files changed (1) hide show
  1. src/linker_size_lightning.py +2 -2
src/linker_size_lightning.py CHANGED
@@ -38,8 +38,8 @@ class SizeClassifier(pl.LightningModule):
38
  self.linker_id2size = linker_id2size
39
  self.batch_size = batch_size
40
  self.lr = lr
41
- self.torch_device = torch_device
42
- self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=torch_device)
43
  self.gnn = SizeGNN(
44
  in_node_nf=in_node_nf,
45
  hidden_nf=hidden_nf,
 
38
  self.linker_id2size = linker_id2size
39
  self.batch_size = batch_size
40
  self.lr = lr
41
+ self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ self.loss_weights = None if loss_weights is None else torch.tensor(loss_weights, device=self.torch_device)
43
  self.gnn = SizeGNN(
44
  in_node_nf=in_node_nf,
45
  hidden_nf=hidden_nf,