gyrojeff commited on
Commit
b06784f
·
1 Parent(s): eb2d25d

fix: data augmentation

Browse files
Files changed (1) hide show
  1. detector/data.py +27 -18
detector/data.py CHANGED
@@ -17,13 +17,19 @@ from PIL import Image
17
 
18
 
19
  class RandomColorJitter(object):
20
- def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05):
 
 
21
  self.brightness = brightness
22
  self.contrast = contrast
23
  self.saturation = saturation
24
  self.hue = hue
 
25
 
26
  def __call__(self, batch):
 
 
 
27
  image, label = batch
28
  text_color = label[2:5].clone().view(3, 1, 1)
29
  stroke_color = label[7:10].clone().view(3, 1, 1)
@@ -54,10 +60,14 @@ class RandomColorJitter(object):
54
 
55
 
56
  class RandomCrop(object):
57
- def __init__(self, crop_factor: float = 0.1):
58
  self.crop_factor = crop_factor
 
59
 
60
  def __call__(self, batch):
 
 
 
61
  image, label = batch
62
  width, height = image.size
63
 
@@ -80,10 +90,14 @@ class RandomCrop(object):
80
 
81
 
82
  class RandomRotate(object):
83
- def __init__(self, max_angle: int = 15):
84
  self.max_angle = max_angle
 
85
 
86
  def __call__(self, batch):
 
 
 
87
  image, label = batch
88
 
89
  angle = random.uniform(-self.max_angle, self.max_angle)
@@ -177,8 +191,8 @@ class FontDataset(Dataset):
177
  if self.transforms is not None:
178
  transform = transforms.Compose(
179
  [
180
- transforms.RandomApply(RandomColorJitter(), p=0.8),
181
- transforms.RandomApply(RandomCrop(), p=0.8),
182
  ]
183
  )
184
  image, label = transform((image, label))
@@ -210,20 +224,15 @@ class FontDataset(Dataset):
210
 
211
  transform = transforms.Compose(
212
  [
213
- transforms.RandomApply(RandomColorJitter(), p=0.8),
214
- RandomCrop(crop_factor=0.54),
215
- transforms.RandomApply(RandomRotate(), p=0.8),
216
  ]
217
  )
218
  image, label = transform((image, label))
219
 
220
- transform = transforms.Compose(
221
- [
222
- transforms.RandomApply(
223
- transforms.GaussianBlur(random.randint(2, 5), sigma=(0.1, 5.0)),
224
- p=0.8,
225
- ),
226
- ]
227
  )
228
 
229
  image = transform(image)
@@ -259,9 +268,9 @@ class FontDataModule(LightningDataModule):
259
  train_shuffle: bool = True,
260
  val_shuffle: bool = False,
261
  test_shuffle: bool = False,
262
- train_transforms: bool = False,
263
- val_transforms: bool = False,
264
- test_transforms: bool = False,
265
  crop_roi_bbox: bool = False,
266
  regression_use_tanh: bool = False,
267
  **kwargs,
 
17
 
18
 
19
  class RandomColorJitter(object):
20
+ def __init__(
21
+ self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05, preserve=0.2
22
+ ):
23
  self.brightness = brightness
24
  self.contrast = contrast
25
  self.saturation = saturation
26
  self.hue = hue
27
+ self.preserve = preserve
28
 
29
  def __call__(self, batch):
30
+ if random.random() < self.preserve:
31
+ return batch
32
+
33
  image, label = batch
34
  text_color = label[2:5].clone().view(3, 1, 1)
35
  stroke_color = label[7:10].clone().view(3, 1, 1)
 
60
 
61
 
62
  class RandomCrop(object):
63
+ def __init__(self, crop_factor: float = 0.1, preserve: float = 0.2):
64
  self.crop_factor = crop_factor
65
+ self.preserve = preserve
66
 
67
  def __call__(self, batch):
68
+ if random.random() < self.preserve:
69
+ return batch
70
+
71
  image, label = batch
72
  width, height = image.size
73
 
 
90
 
91
 
92
  class RandomRotate(object):
93
+ def __init__(self, max_angle: int = 15, preserve: float = 0.2):
94
  self.max_angle = max_angle
95
+ self.preserve = preserve
96
 
97
  def __call__(self, batch):
98
+ if random.random() < self.preserve:
99
+ return batch
100
+
101
  image, label = batch
102
 
103
  angle = random.uniform(-self.max_angle, self.max_angle)
 
191
  if self.transforms is not None:
192
  transform = transforms.Compose(
193
  [
194
+ RandomColorJitter(preserve=0.2),
195
+ RandomCrop(preserve=0.2),
196
  ]
197
  )
198
  image, label = transform((image, label))
 
224
 
225
  transform = transforms.Compose(
226
  [
227
+ RandomColorJitter(preserve=0.2),
228
+ RandomCrop(crop_factor=0.54, preserve=0),
229
+ RandomRotate(preserve=0.2),
230
  ]
231
  )
232
  image, label = transform((image, label))
233
 
234
+ transform = transforms.GaussianBlur(
235
+ random.randint(1, 3) * 2 - 1, sigma=(0.1, 5.0)
 
 
 
 
 
236
  )
237
 
238
  image = transform(image)
 
268
  train_shuffle: bool = True,
269
  val_shuffle: bool = False,
270
  test_shuffle: bool = False,
271
+ train_transforms: bool = None,
272
+ val_transforms: bool = None,
273
+ test_transforms: bool = None,
274
  crop_roi_bbox: bool = False,
275
  regression_use_tanh: bool = False,
276
  **kwargs,