gyrojeff commited on
Commit
ff82fe6
·
1 Parent(s): 68dd12a

feat: add device config into cli arg

Browse files
Files changed (1) hide show
  1. train.py +7 -1
train.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  import pytorch_lightning as ptl
@@ -10,7 +11,12 @@ from utils import get_current_tag
10
 
11
  torch.set_float32_matmul_precision('high')
12
 
13
- devices = [6, 7]
 
 
 
 
 
14
 
15
  final_batch_size = 128
16
  single_device_num_workers = 24
 
1
+ import argparse
2
  import os
3
  import torch
4
  import pytorch_lightning as ptl
 
11
 
12
  torch.set_float32_matmul_precision('high')
13
 
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('-d', '--devices', nargs='*', type=int, default=[0])
16
+
17
+ args = parser.parse_args()
18
+
19
+ devices = args.devices
20
 
21
  final_batch_size = 128
22
  single_device_num_workers = 24