|
import argparse |
|
import os |
|
import torch |
|
import pytorch_lightning as ptl |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
|
|
from detector.data import FontDataModule |
|
from detector.model import * |
|
from utils import get_current_tag |
|
|
|
|
|
torch.set_float32_matmul_precision("high") |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-d", |
|
"--devices", |
|
nargs="*", |
|
type=int, |
|
default=[0], |
|
help="GPU devices to use (default: [0])", |
|
) |
|
parser.add_argument( |
|
"-b", |
|
"--single-batch-size", |
|
type=int, |
|
default=64, |
|
help="Batch size of single device (default: 64)", |
|
) |
|
parser.add_argument( |
|
"-c", |
|
"--checkpoint", |
|
type=str, |
|
default=None, |
|
help="Trainer checkpoint path (default: None)", |
|
) |
|
parser.add_argument( |
|
"-m", |
|
"--model", |
|
type=str, |
|
default="resnet18", |
|
choices=["resnet18", "resnet34", "resnet50", "resnet101"], |
|
help="Model to use (default: resnet18)", |
|
) |
|
parser.add_argument( |
|
"-p", |
|
"--pretrained", |
|
action="store_true", |
|
help="Use pretrained model for ResNet (default: False)", |
|
) |
|
parser.add_argument( |
|
"-i", |
|
"--crop-roi-bbox", |
|
action="store_true", |
|
help="Crop ROI bounding box (default: False)", |
|
) |
|
parser.add_argument( |
|
"-a", |
|
"--augmentation", |
|
type=str, |
|
default=None, |
|
choices=["v1", "v2"], |
|
help="Augmentation strategy to use (default: None)", |
|
) |
|
parser.add_argument( |
|
"-l", |
|
"--lr", |
|
type=float, |
|
default=0.0001, |
|
help="Learning rate (default: 0.0001)", |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--datasets", |
|
nargs="*", |
|
type=str, |
|
default=["./dataset/font_img"], |
|
help="Datasets paths, seperated by space (default: ['./dataset/font_img'])", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
devices = args.devices |
|
single_batch_size = args.single_batch_size |
|
|
|
total_num_workers = os.cpu_count() |
|
single_device_num_workers = total_num_workers // len(devices) |
|
|
|
if os.name == "nt": |
|
single_device_num_workers = 0 |
|
|
|
lr = args.lr |
|
b1 = 0.9 |
|
b2 = 0.999 |
|
|
|
lambda_font = 2.0 |
|
lambda_direction = 0.5 |
|
lambda_regression = 1.0 |
|
|
|
regression_use_tanh = False |
|
|
|
num_warmup_epochs = 5 |
|
num_epochs = 100 |
|
|
|
log_every_n_steps = 100 |
|
|
|
num_device = len(devices) |
|
|
|
data_module = FontDataModule( |
|
train_paths=[os.path.join(path, "train") for path in args.datasets], |
|
val_paths=[os.path.join(path, "val") for path in args.datasets], |
|
test_paths=[os.path.join(path, "test") for path in args.datasets], |
|
batch_size=single_batch_size, |
|
num_workers=single_device_num_workers, |
|
pin_memory=True, |
|
train_shuffle=True, |
|
val_shuffle=False, |
|
test_shuffle=False, |
|
regression_use_tanh=regression_use_tanh, |
|
train_transforms=args.augmentation, |
|
crop_roi_bbox=args.crop_roi_bbox, |
|
) |
|
|
|
num_iters = data_module.get_train_num_iter(num_device) * num_epochs |
|
num_warmup_iter = data_module.get_train_num_iter(num_device) * num_warmup_epochs |
|
|
|
model_name = f"{get_current_tag()}" |
|
|
|
logger_unconditioned = TensorBoardLogger( |
|
save_dir=os.getcwd(), name="tensorboard", version=model_name |
|
) |
|
|
|
strategy = None if num_device == 1 else "ddp" |
|
|
|
trainer = ptl.Trainer( |
|
max_epochs=num_epochs, |
|
logger=logger_unconditioned, |
|
devices=devices, |
|
accelerator="gpu", |
|
enable_checkpointing=True, |
|
log_every_n_steps=log_every_n_steps, |
|
strategy=strategy, |
|
deterministic=True, |
|
) |
|
|
|
if args.model == "resnet18": |
|
model = ResNet18Regressor( |
|
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh |
|
) |
|
elif args.model == "resnet34": |
|
model = ResNet34Regressor( |
|
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh |
|
) |
|
elif args.model == "resnet50": |
|
model = ResNet50Regressor( |
|
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh |
|
) |
|
elif args.model == "resnet101": |
|
model = ResNet101Regressor( |
|
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh |
|
) |
|
else: |
|
raise NotImplementedError() |
|
|
|
if torch.__version__ >= "2.0": |
|
model = torch.compile(model) |
|
|
|
detector = FontDetector( |
|
model=model, |
|
lambda_font=lambda_font, |
|
lambda_direction=lambda_direction, |
|
lambda_regression=lambda_regression, |
|
lr=lr, |
|
betas=(b1, b2), |
|
num_warmup_iters=num_warmup_iter, |
|
num_iters=num_iters, |
|
num_epochs=num_epochs, |
|
) |
|
|
|
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint) |
|
trainer.test(detector, datamodule=data_module) |
|
|