gyrojeff commited on
Commit
716ee53
·
1 Parent(s): eafaf77

feat: add generation script

Browse files
Files changed (1) hide show
  1. font_ds_generate_script.py +104 -0
font_ds_generate_script.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import traceback
3
+ import pickle
4
+ import os
5
+ import concurrent.futures
6
+ from tqdm import tqdm
7
+ from font_dataset.font import load_fonts
8
+ from font_dataset.layout import generate_font_image
9
+ from font_dataset.text import CorpusGeneratorManager
10
+ from font_dataset.background import background_image_generator
11
+
12
+
13
+ global_script_index = int(sys.argv[1])
14
+ global_script_index_total = int(sys.argv[2])
15
+
16
+ print(f"Mission {global_script_index} / {global_script_index_total}")
17
+
18
+ num_workers = 32
19
+
20
+ cjk_ratio = 3
21
+
22
+ train_cnt = 100
23
+ val_cnt = 10
24
+ test_cnt = 30
25
+
26
+ train_cnt_cjk = int(train_cnt * cjk_ratio)
27
+ val_cnt_cjk = int(val_cnt * cjk_ratio)
28
+ test_cnt_cjk = int(test_cnt * cjk_ratio)
29
+
30
+ dataset_path = "./dataset/font_img"
31
+ os.makedirs(dataset_path, exist_ok=True)
32
+
33
+
34
+ fonts = load_fonts()
35
+ corpus_manager = CorpusGeneratorManager()
36
+ images = background_image_generator()
37
+
38
+
39
+ def generate_dataset(dataset_type: str, cnt: int):
40
+ dataset_bath_dir = os.path.join(dataset_path, dataset_type)
41
+ os.makedirs(dataset_bath_dir, exist_ok=True)
42
+
43
+ def _generate_single(args):
44
+ while True:
45
+ try:
46
+ i, j, font = args
47
+
48
+ image_file_name = f"font_{i}_img_{j}.png"
49
+ label_file_name = f"font_{i}_img_{j}.bin"
50
+
51
+ image_file_path = os.path.join(dataset_bath_dir, image_file_name)
52
+ label_file_path = os.path.join(dataset_bath_dir, label_file_name)
53
+
54
+ # detect cache
55
+ if os.path.exists(image_file_path) and os.path.exists(label_file_path):
56
+ return
57
+
58
+ im = next(images)
59
+ im, label = generate_font_image(
60
+ im,
61
+ font,
62
+ corpus_manager,
63
+ )
64
+
65
+ im.save(image_file_path)
66
+ pickle.dump(label, open(label_file_path, "wb"))
67
+ return
68
+ except Exception as e:
69
+ traceback.print_exc()
70
+ continue
71
+
72
+ work_list = []
73
+
74
+ # divide len(fonts) into 64 parts and choose the third part for this script
75
+ for i in range(
76
+ (global_script_index - 1) * len(fonts) // global_script_index_total,
77
+ global_script_index * len(fonts) // global_script_index_total,
78
+ ):
79
+ font = fonts[i]
80
+ if font.language == "CJK":
81
+ true_cnt = cnt * cjk_ratio
82
+ else:
83
+ true_cnt = cnt
84
+ for j in range(true_cnt):
85
+ work_list.append((i, j, font))
86
+
87
+ # with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
88
+ # _ = list(
89
+ # tqdm(
90
+ # executor.map(_generate_single, work_list),
91
+ # total=len(work_list),
92
+ # leave=True,
93
+ # desc=dataset_type,
94
+ # miniters=1,
95
+ # )
96
+ # )
97
+
98
+ for i in tqdm(range(len(work_list))):
99
+ _generate_single(work_list[i])
100
+
101
+
102
+ generate_dataset("train", train_cnt)
103
+ generate_dataset("val", val_cnt)
104
+ generate_dataset("test", test_cnt)