File size: 4,265 Bytes
964201e 68dd12a 964201e 68dd12a 964201e 3b25c65 964201e 68dd12a 964201e daa52ce 964201e 68dd12a 964201e 68dd12a 964201e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from font_dataset.fontlabel import FontLabel
from font_dataset.font import DSFont, load_font_with_exclusion
from . import config
import math
import os
import pickle
import torch
import torchvision.transforms as transforms
from typing import List, Dict, Tuple
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule
from PIL import Image
class FontDataset(Dataset):
def __init__(self, path: str, config_path: str = "configs/font.yml", regression_use_tanh: bool=False):
self.path = path
self.fonts = load_font_with_exclusion(config_path)
self.regression_use_tanh = regression_use_tanh
self.images = [
os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
]
self.images.sort()
def __len__(self):
return len(self.images)
def fontlabel2tensor(self, label: FontLabel, label_path) -> torch.Tensor:
out = torch.zeros(12, dtype=torch.float)
try:
out[0] = self.fonts[label.font.path]
except KeyError:
print(f"Unqualified font: {label.font.path}")
print(f"Label path: {label_path}")
raise KeyError
out[1] = 0 if label.text_direction == "ltr" else 1
# [0, 1]
out[2] = label.text_color[0] / 255.0
out[3] = label.text_color[1] / 255.0
out[4] = label.text_color[2] / 255.0
out[5] = label.text_size / label.image_width
out[6] = label.stroke_width / label.image_width
if label.stroke_color:
out[7] = label.stroke_color[0] / 255.0
out[8] = label.stroke_color[1] / 255.0
out[9] = label.stroke_color[2] / 255.0
else:
out[7:10] = out[2:5]
out[10] = label.line_spacing / label.image_width
out[11] = label.angle / 180.0 + 0.5
if self.regression_use_tanh:
out[2:12] = out[2:12] * 2 - 1
return out
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# Load image
image_path = self.images[index]
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
transforms.ToTensor(),
]
)
image = transform(image)
# Load label
label_path = image_path.replace(".jpg", ".bin")
with open(label_path, "rb") as f:
label: FontLabel = pickle.load(f)
# encode label
label = self.fontlabel2tensor(label, label_path)
return image, label
class FontDataModule(LightningDataModule):
def __init__(
self,
config_path: str = "configs/font.yml",
train_path: str = "./dataset/font_img/train",
val_path: str = "./dataset/font_img/val",
test_path: str = "./dataset/font_img/test",
train_shuffle: bool = True,
val_shuffle: bool = False,
test_shuffle: bool = False,
regression_use_tanh: bool = False,
**kwargs,
):
super().__init__()
self.dataloader_args = kwargs
self.train_shuffle = train_shuffle
self.val_shuffle = val_shuffle
self.test_shuffle = test_shuffle
self.train_dataset = FontDataset(train_path, config_path, regression_use_tanh)
self.val_dataset = FontDataset(val_path, config_path, regression_use_tanh)
self.test_dataset = FontDataset(test_path, config_path, regression_use_tanh)
def get_train_num_iter(self, num_device: int) -> int:
return math.ceil(
len(self.train_dataset) / (self.dataloader_args["batch_size"] * num_device)
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
shuffle=self.train_shuffle,
**self.dataloader_args,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
shuffle=self.val_shuffle,
**self.dataloader_args,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
shuffle=self.test_shuffle,
**self.dataloader_args,
)
|