File size: 2,997 Bytes
716ee53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import sys
import traceback
import pickle
import os
import concurrent.futures
from tqdm import tqdm
from font_dataset.font import load_fonts
from font_dataset.layout import generate_font_image
from font_dataset.text import CorpusGeneratorManager
from font_dataset.background import background_image_generator
global_script_index = int(sys.argv[1])
global_script_index_total = int(sys.argv[2])
print(f"Mission {global_script_index} / {global_script_index_total}")
num_workers = 32
cjk_ratio = 3
train_cnt = 100
val_cnt = 10
test_cnt = 30
train_cnt_cjk = int(train_cnt * cjk_ratio)
val_cnt_cjk = int(val_cnt * cjk_ratio)
test_cnt_cjk = int(test_cnt * cjk_ratio)
dataset_path = "./dataset/font_img"
os.makedirs(dataset_path, exist_ok=True)
fonts = load_fonts()
corpus_manager = CorpusGeneratorManager()
images = background_image_generator()
def generate_dataset(dataset_type: str, cnt: int):
dataset_bath_dir = os.path.join(dataset_path, dataset_type)
os.makedirs(dataset_bath_dir, exist_ok=True)
def _generate_single(args):
while True:
try:
i, j, font = args
image_file_name = f"font_{i}_img_{j}.png"
label_file_name = f"font_{i}_img_{j}.bin"
image_file_path = os.path.join(dataset_bath_dir, image_file_name)
label_file_path = os.path.join(dataset_bath_dir, label_file_name)
# detect cache
if os.path.exists(image_file_path) and os.path.exists(label_file_path):
return
im = next(images)
im, label = generate_font_image(
im,
font,
corpus_manager,
)
im.save(image_file_path)
pickle.dump(label, open(label_file_path, "wb"))
return
except Exception as e:
traceback.print_exc()
continue
work_list = []
# divide len(fonts) into 64 parts and choose the third part for this script
for i in range(
(global_script_index - 1) * len(fonts) // global_script_index_total,
global_script_index * len(fonts) // global_script_index_total,
):
font = fonts[i]
if font.language == "CJK":
true_cnt = cnt * cjk_ratio
else:
true_cnt = cnt
for j in range(true_cnt):
work_list.append((i, j, font))
# with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
# _ = list(
# tqdm(
# executor.map(_generate_single, work_list),
# total=len(work_list),
# leave=True,
# desc=dataset_type,
# miniters=1,
# )
# )
for i in tqdm(range(len(work_list))):
_generate_single(work_list[i])
generate_dataset("train", train_cnt)
generate_dataset("val", val_cnt)
generate_dataset("test", test_cnt)
|