feat: add device config into cli arg
Browse files
train.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import pytorch_lightning as ptl
|
@@ -10,7 +11,12 @@ from utils import get_current_tag
|
|
10 |
|
11 |
torch.set_float32_matmul_precision('high')
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
final_batch_size = 128
|
16 |
single_device_num_workers = 24
|
|
|
1 |
+
import argparse
|
2 |
import os
|
3 |
import torch
|
4 |
import pytorch_lightning as ptl
|
|
|
11 |
|
12 |
torch.set_float32_matmul_precision('high')
|
13 |
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('-d', '--devices', nargs='*', type=int, default=[0])
|
16 |
+
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
devices = args.devices
|
20 |
|
21 |
final_batch_size = 128
|
22 |
single_device_num_workers = 24
|