gyrojeff commited on
Commit
86f4fb9
·
1 Parent(s): 8526729

feat: add preserve aspect ratio & fix typo

Browse files
Files changed (2) hide show
  1. detector/data.py +128 -49
  2. train.py +7 -0
detector/data.py CHANGED
@@ -8,7 +8,7 @@ import os
8
  import random
9
  import pickle
10
  import torch
11
- import torchvision.transforms as transforms
12
  import torchvision.transforms.functional as TF
13
  from typing import List, Dict, Tuple
14
  from torch.utils.data import Dataset, DataLoader, ConcatDataset
@@ -106,6 +106,50 @@ class RandomRotate(object):
106
  return image, label
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  class FontDataset(Dataset):
110
  def __init__(
111
  self,
@@ -114,6 +158,7 @@ class FontDataset(Dataset):
114
  regression_use_tanh: bool = False,
115
  transforms: str = None,
116
  crop_roi_bbox: bool = False,
 
117
  ):
118
  """Font dataset
119
 
@@ -121,8 +166,9 @@ class FontDataset(Dataset):
121
  path (str): path to the dataset
122
  config_path (str, optional): path to font config file. Defaults to "configs/font.yml".
123
  regression_use_tanh (bool, optional): whether use tanh as regression normalization. Defaults to False.
124
- transforms (str, optional): choose from None, 'v1', 'v2'. Defaults to None.
125
- crop_roi_bbox (bool, optional): whether to crop text roi bbox, must be true when transform='v2'. Defaults to False.
 
126
  """
127
  self.path = path
128
  self.fonts = load_font_with_exclusion(config_path)
@@ -135,8 +181,72 @@ class FontDataset(Dataset):
135
  ]
136
  self.images.sort()
137
 
138
- if transforms == "v2":
139
- assert crop_roi_bbox, "crop_roi_bbox must be true when transform='v2'"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def __len__(self):
142
  return len(self.images)
@@ -177,26 +287,14 @@ class FontDataset(Dataset):
177
  with open(label_path, "rb") as f:
178
  label: FontLabel = pickle.load(f)
179
 
 
180
  if (self.transforms == "v1") or (self.transforms is None):
181
  if self.crop_roi_bbox:
182
  left, top, width, height = label.bbox
183
  image = TF.crop(image, top, left, height, width)
184
  label.image_width = width
185
  label.image_height = height
186
-
187
- # encode label
188
- label = self.fontlabel2tensor(label, label_path)
189
-
190
- # data augmentation
191
- if self.transforms is not None:
192
- transform = transforms.Compose(
193
- [
194
- RandomColorJitter(preserve=0.2),
195
- RandomCrop(preserve=0.2),
196
- ]
197
- )
198
- image, label = transform((image, label))
199
- elif self.transforms == "v2":
200
  # crop from 30% to 130% of bbox
201
  left, top, width, height = label.bbox
202
 
@@ -219,37 +317,14 @@ class FontDataset(Dataset):
219
  label.image_width = width
220
  label.image_height = height
221
 
222
- # encode label
223
- label = self.fontlabel2tensor(label, label_path)
224
-
225
- transform = transforms.Compose(
226
- [
227
- RandomColorJitter(preserve=0.2),
228
- RandomCrop(crop_factor=0.54, preserve=0),
229
- RandomRotate(preserve=0.2),
230
- ]
231
- )
232
- image, label = transform((image, label))
233
-
234
- transform = transforms.GaussianBlur(
235
- random.randint(1, 3) * 2 - 1, sigma=(0.1, 5.0)
236
- )
237
-
238
- image = transform(image)
239
-
240
- # resize and to tensor
241
- transform = transforms.Compose(
242
- [
243
- transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
244
- transforms.ToTensor(),
245
- ]
246
- )
247
- image = transform(image)
248
 
249
- if self.transforms == "v2":
250
- # noise
251
- if random.random() < 0.9:
252
- image = image + torch.randn_like(image) * random.random() * 0.05
 
253
 
254
  # normalize label
255
  if self.regression_use_tanh:
@@ -272,6 +347,7 @@ class FontDataModule(LightningDataModule):
272
  val_transforms: bool = None,
273
  test_transforms: bool = None,
274
  crop_roi_bbox: bool = False,
 
275
  regression_use_tanh: bool = False,
276
  **kwargs,
277
  ):
@@ -288,6 +364,7 @@ class FontDataModule(LightningDataModule):
288
  regression_use_tanh,
289
  train_transforms,
290
  crop_roi_bbox,
 
291
  )
292
  for train_path in train_paths
293
  ]
@@ -300,6 +377,7 @@ class FontDataModule(LightningDataModule):
300
  regression_use_tanh,
301
  val_transforms,
302
  crop_roi_bbox,
 
303
  )
304
  for val_path in val_paths
305
  ]
@@ -312,6 +390,7 @@ class FontDataModule(LightningDataModule):
312
  regression_use_tanh,
313
  test_transforms,
314
  crop_roi_bbox,
 
315
  )
316
  for test_path in test_paths
317
  ]
 
8
  import random
9
  import pickle
10
  import torch
11
+ import torchvision
12
  import torchvision.transforms.functional as TF
13
  from typing import List, Dict, Tuple
14
  from torch.utils.data import Dataset, DataLoader, ConcatDataset
 
106
  return image, label
107
 
108
 
109
+ class RandomNoise(object):
110
+ def __init__(self, max_noise: float = 0.05, preserve: float = 0.1):
111
+ self.max_noise = max_noise
112
+ self.preserve = preserve
113
+
114
+ def __call__(self, image):
115
+ if random.random() < self.preserve:
116
+ return image
117
+ return image + torch.randn_like(image) * random.random() * self.max_noise
118
+
119
+
120
+ class RandomDownSample(object):
121
+ def __init__(self, max_ratio: float = 2, preserve: float = 0.5):
122
+ self.max_ratio = max_ratio
123
+ self.preserve = preserve
124
+
125
+ def __call__(self, image):
126
+ if random.random() < self.preserve:
127
+ return image
128
+ ratio = random.uniform(1, self.max_ratio)
129
+ return TF.resize(
130
+ image, (int(image.size[1] / ratio), int(image.size[0] / ratio))
131
+ )
132
+
133
+
134
+ class RandomCropPreserveAspectRatio(object):
135
+ def __call__(self, batch):
136
+ image, label = batch
137
+ width, height = image.size
138
+
139
+ if width == height:
140
+ return batch
141
+
142
+ if width > height:
143
+ x = random.randint(0, width - height)
144
+ image = TF.crop(image, 0, x, height, height)
145
+ label[[5, 6, 10]] = label[[5, 6, 10]] / height * width
146
+ else:
147
+ y = random.randint(0, height - width)
148
+ image = TF.crop(image, y, 0, width, width)
149
+ label[[5, 6, 10]] = label[[5, 6, 10]] / width * height
150
+ return image, label
151
+
152
+
153
  class FontDataset(Dataset):
154
  def __init__(
155
  self,
 
158
  regression_use_tanh: bool = False,
159
  transforms: str = None,
160
  crop_roi_bbox: bool = False,
161
+ preserve_aspect_ratio_by_random_crop: bool = False,
162
  ):
163
  """Font dataset
164
 
 
166
  path (str): path to the dataset
167
  config_path (str, optional): path to font config file. Defaults to "configs/font.yml".
168
  regression_use_tanh (bool, optional): whether use tanh as regression normalization. Defaults to False.
169
+ transforms (str, optional): choose from None, 'v1', 'v2', 'v3'. Defaults to None.
170
+ crop_roi_bbox (bool, optional): whether to crop text roi bbox, must be true when transform='v2' or 'v3'. Defaults to False.
171
+ preserve_aspect_ratio_by_random_crop (bool, optional): whether to preserve aspect ratio by random cropping maximum squares. Defaults to False.
172
  """
173
  self.path = path
174
  self.fonts = load_font_with_exclusion(config_path)
 
181
  ]
182
  self.images.sort()
183
 
184
+ if transforms == "v2" or transforms == "v3":
185
+ assert (
186
+ crop_roi_bbox
187
+ ), "crop_roi_bbox must be true when transform='v2' or 'v3'"
188
+
189
+ if transforms is None:
190
+ label_image_transforms = []
191
+ image_transforms = [
192
+ torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
193
+ torchvision.transforms.ToTensor(),
194
+ ]
195
+ elif transforms == "v1":
196
+ label_image_transforms = [
197
+ RandomColorJitter(preserve=0.2),
198
+ RandomCrop(preserve=0.2),
199
+ ]
200
+ image_transforms = [
201
+ torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
202
+ torchvision.transforms.ToTensor(),
203
+ ]
204
+ elif transforms == "v2":
205
+ label_image_transforms = [
206
+ RandomColorJitter(preserve=0.2),
207
+ RandomCrop(crop_factor=0.54, preserve=0),
208
+ RandomRotate(preserve=0.2),
209
+ ]
210
+ image_transforms = [
211
+ torchvision.transforms.GaussianBlur(
212
+ random.randint(1, 3) * 2 - 1, sigma=(0.1, 5.0)
213
+ ),
214
+ torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
215
+ torchvision.transforms.ToTensor(),
216
+ RandomNoise(max_noise=0.05, preserve=0.1),
217
+ ]
218
+ elif transforms == "v3":
219
+ label_image_transforms = [
220
+ RandomColorJitter(preserve=0.2),
221
+ RandomCrop(crop_factor=0.54, preserve=0),
222
+ RandomRotate(preserve=0.2),
223
+ ]
224
+ image_transforms = [
225
+ RandomDownSample(max_ratio=2, preserve=0.5),
226
+ torchvision.transforms.GaussianBlur(
227
+ random.randint(1, 3) * 2 - 1, sigma=(0.1, 5.0)
228
+ ),
229
+ torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
230
+ torchvision.transforms.ToTensor(),
231
+ RandomNoise(max_noise=0.05, preserve=0.1),
232
+ torchvision.transforms.RandomHorizontalFlip(p=0.5),
233
+ ]
234
+ else:
235
+ raise ValueError(f"Unknown transform: {transforms}")
236
+
237
+ if preserve_aspect_ratio_by_random_crop:
238
+ label_image_transforms.append(RandomCropPreserveAspectRatio())
239
+
240
+ if len(label_image_transforms) == 0:
241
+ self.transform_label_image = None
242
+ else:
243
+ self.transform_label_image = torchvision.transforms.Compose(
244
+ label_image_transforms
245
+ )
246
+ if len(image_transforms) == 0:
247
+ self.transform_image = None
248
+ else:
249
+ self.transform_image = torchvision.transforms.Compose(image_transforms)
250
 
251
  def __len__(self):
252
  return len(self.images)
 
287
  with open(label_path, "rb") as f:
288
  label: FontLabel = pickle.load(f)
289
 
290
+ # preparation
291
  if (self.transforms == "v1") or (self.transforms is None):
292
  if self.crop_roi_bbox:
293
  left, top, width, height = label.bbox
294
  image = TF.crop(image, top, left, height, width)
295
  label.image_width = width
296
  label.image_height = height
297
+ elif self.transforms == "v2" or self.transforms == "v3":
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  # crop from 30% to 130% of bbox
299
  left, top, width, height = label.bbox
300
 
 
317
  label.image_width = width
318
  label.image_height = height
319
 
320
+ # encode label
321
+ label = self.fontlabel2tensor(label, label_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ # transform
324
+ if self.transform_label_image is not None:
325
+ image, label = self.transform_label_image((image, label))
326
+ if self.transform_image is not None:
327
+ image = self.transform_image(image)
328
 
329
  # normalize label
330
  if self.regression_use_tanh:
 
347
  val_transforms: bool = None,
348
  test_transforms: bool = None,
349
  crop_roi_bbox: bool = False,
350
+ preserve_aspect_ratio_by_random_crop: bool = False,
351
  regression_use_tanh: bool = False,
352
  **kwargs,
353
  ):
 
364
  regression_use_tanh,
365
  train_transforms,
366
  crop_roi_bbox,
367
+ preserve_aspect_ratio_by_random_crop,
368
  )
369
  for train_path in train_paths
370
  ]
 
377
  regression_use_tanh,
378
  val_transforms,
379
  crop_roi_bbox,
380
+ preserve_aspect_ratio_by_random_crop,
381
  )
382
  for val_path in val_paths
383
  ]
 
390
  regression_use_tanh,
391
  test_transforms,
392
  crop_roi_bbox,
393
+ preserve_aspect_ratio_by_random_crop,
394
  )
395
  for test_path in test_paths
396
  ]
train.py CHANGED
@@ -103,6 +103,12 @@ parser.add_argument(
103
  default="high",
104
  help="Tensor core precision (default: high)",
105
  )
 
 
 
 
 
 
106
 
107
  args = parser.parse_args()
108
 
@@ -149,6 +155,7 @@ data_module = FontDataModule(
149
  regression_use_tanh=regression_use_tanh,
150
  train_transforms=args.augmentation,
151
  crop_roi_bbox=args.crop_roi_bbox,
 
152
  )
153
 
154
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
 
103
  default="high",
104
  help="Tensor core precision (default: high)",
105
  )
106
+ parser.add_argument(
107
+ "-r",
108
+ "--preserve-aspect-ratio",
109
+ action="store_true",
110
+ help="Preserve aspect ratio (default: False)",
111
+ )
112
 
113
  args = parser.parse_args()
114
 
 
155
  regression_use_tanh=regression_use_tanh,
156
  train_transforms=args.augmentation,
157
  crop_roi_bbox=args.crop_roi_bbox,
158
+ preserve_aspect_ratio_by_random_crop=args.preserve_aspect_ratio,
159
  )
160
 
161
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs