File size: 2,296 Bytes
ff82fe6
3163344
10f217b
3163344
 
 
 
8e068be
3163344
 
 
00a4b21
10f217b
ff82fe6
00a4b21
fd9442f
8d9c0ef
ff82fe6
 
 
 
fd9442f
3163344
fd9442f
 
3163344
 
8364103
3163344
 
 
8364103
3163344
 
 
8d2e833
a976004
68dd12a
8364103
3163344
 
 
 
 
 
 
fd9442f
3163344
 
 
 
 
68dd12a
a976004
3163344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c681b80
3163344
 
8e068be
3163344
 
 
 
 
 
 
 
 
 
912d566
3163344
 
8d9c0ef
5c43f60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
import torch
import pytorch_lightning as ptl
from pytorch_lightning.loggers import TensorBoardLogger

from detector.data import FontDataModule
from detector.model import *
from utils import get_current_tag


torch.set_float32_matmul_precision("high")

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
parser.add_argument("-b", "--single-batch-size", type=int, default=64)
parser.add_argument("-c", "--checkpoint", type=str, default=None)

args = parser.parse_args()

devices = args.devices
single_batch_size = args.single_batch_size

total_num_workers = os.cpu_count()
single_device_num_workers = total_num_workers // len(devices)


lr = 0.0001
b1 = 0.9
b2 = 0.999

lambda_font = 2.0
lambda_direction = 0.5
lambda_regression = 1.0

regression_use_tanh = False
augmentation = True

num_warmup_epochs = 5
num_epochs = 100

log_every_n_steps = 100

num_device = len(devices)

data_module = FontDataModule(
    batch_size=single_batch_size,
    num_workers=single_device_num_workers,
    pin_memory=True,
    train_shuffle=True,
    val_shuffle=False,
    test_shuffle=False,
    regression_use_tanh=regression_use_tanh,
    train_transforms=augmentation,
)

num_iters = data_module.get_train_num_iter(num_device) * num_epochs
num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs

model_name = f"{get_current_tag()}"

logger_unconditioned = TensorBoardLogger(
    save_dir=os.getcwd(), name="tensorboard", version=model_name
)

strategy = None if num_device == 1 else "ddp"

trainer = ptl.Trainer(
    max_epochs=num_epochs,
    logger=logger_unconditioned,
    devices=devices,
    accelerator="gpu",
    enable_checkpointing=True,
    log_every_n_steps=log_every_n_steps,
    strategy=strategy,
    deterministic=True,
)

model = ResNet34Regressor(regression_use_tanh=regression_use_tanh)

detector = FontDetector(
    model=model,
    lambda_font=lambda_font,
    lambda_direction=lambda_direction,
    lambda_regression=lambda_regression,
    lr=lr,
    betas=(b1, b2),
    num_warmup_iters=num_warmup_iter,
    num_iters=num_iters,
    num_epochs=num_epochs,
)

trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
trainer.test(detector, datamodule=data_module)