File size: 2,007 Bytes
ff82fe6
3163344
10f217b
3163344
 
 
 
 
 
 
 
00a4b21
10f217b
ff82fe6
00a4b21
ff82fe6
 
 
 
3163344
 
 
 
 
5a85fd3
3163344
 
 
5a85fd3
3163344
 
 
68dd12a
 
f7a9f34
3163344
 
 
 
 
 
 
 
 
 
 
 
 
68dd12a
3163344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c681b80
3163344
 
68dd12a
3163344
 
 
 
 
 
 
 
 
 
 
 
 
5c43f60
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import os
import torch
import pytorch_lightning as ptl
from pytorch_lightning.loggers import TensorBoardLogger

from detector.data import FontDataModule
from detector.model import FontDetector, ResNet18Regressor
from utils import get_current_tag


torch.set_float32_matmul_precision("high")

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])

args = parser.parse_args()

devices = args.devices

final_batch_size = 128
single_device_num_workers = 24


lr = 0.00005
b1 = 0.9
b2 = 0.999

lambda_font = 4.0
lambda_direction = 0.5
lambda_regression = 1.0

regression_use_tanh = True

num_warmup_epochs = 1
num_epochs = 100

log_every_n_steps = 100

num_device = len(devices)

data_module = FontDataModule(
    batch_size=final_batch_size // num_device,
    num_workers=single_device_num_workers,
    pin_memory=True,
    train_shuffle=True,
    val_shuffle=False,
    test_shuffle=False,
    regression_use_tanh=regression_use_tanh,
)

num_iters = data_module.get_train_num_iter(num_device) * num_epochs
num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs

model_name = f"{get_current_tag()}"

logger_unconditioned = TensorBoardLogger(
    save_dir=os.getcwd(), name="tensorboard", version=model_name
)

strategy = None if num_device == 1 else "ddp"

trainer = ptl.Trainer(
    max_epochs=num_epochs,
    logger=logger_unconditioned,
    devices=devices,
    accelerator="gpu",
    enable_checkpointing=True,
    log_every_n_steps=log_every_n_steps,
    strategy=strategy,
    deterministic=True,
)

model = ResNet18Regressor(regression_use_tanh=regression_use_tanh)

detector = FontDetector(
    model=model,
    lambda_font=lambda_font,
    lambda_direction=lambda_direction,
    lambda_regression=lambda_regression,
    lr=lr,
    betas=(b1, b2),
    num_warmup_iters=num_warmup_iter,
    num_iters=num_iters,
)

trainer.fit(detector, datamodule=data_module)
trainer.test(detector, datamodule=data_module)