gyrojeff commited on
Commit
01d9a57
·
1 Parent(s): d47b024

feat: update v2 data augmentation

Browse files
Files changed (2) hide show
  1. detector/data.py +86 -20
  2. train.py +9 -2
detector/data.py CHANGED
@@ -17,19 +17,13 @@ from PIL import Image
17
 
18
 
19
  class RandomColorJitter(object):
20
- def __init__(
21
- self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05, preserve=0.2
22
- ):
23
  self.brightness = brightness
24
  self.contrast = contrast
25
  self.saturation = saturation
26
  self.hue = hue
27
- self.preserve = preserve
28
 
29
  def __call__(self, batch):
30
- if random.random() < self.preserve:
31
- return batch
32
-
33
  image, label = batch
34
  text_color = label[2:5].clone().view(3, 1, 1)
35
  stroke_color = label[7:10].clone().view(3, 1, 1)
@@ -60,14 +54,10 @@ class RandomColorJitter(object):
60
 
61
 
62
  class RandomCrop(object):
63
- def __init__(self, crop_factor: float = 0.1, preserve: float = 0.2):
64
  self.crop_factor = crop_factor
65
- self.preserve = preserve
66
 
67
  def __call__(self, batch):
68
- if random.random() < self.preserve:
69
- return batch
70
-
71
  image, label = batch
72
  width, height = image.size
73
 
@@ -89,15 +79,37 @@ class RandomCrop(object):
89
  return image, label
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  class FontDataset(Dataset):
93
  def __init__(
94
  self,
95
  path: str,
96
  config_path: str = "configs/font.yml",
97
  regression_use_tanh: bool = False,
98
- transforms: bool = False,
99
  crop_roi_bbox: bool = False,
100
  ):
 
 
 
 
 
 
 
 
 
101
  self.path = path
102
  self.fonts = load_font_with_exclusion(config_path)
103
  self.regression_use_tanh = regression_use_tanh
@@ -109,6 +121,9 @@ class FontDataset(Dataset):
109
  ]
110
  self.images.sort()
111
 
 
 
 
112
  def __len__(self):
113
  return len(self.images)
114
 
@@ -148,25 +163,71 @@ class FontDataset(Dataset):
148
  with open(label_path, "rb") as f:
149
  label: FontLabel = pickle.load(f)
150
 
151
- if self.crop_roi_bbox:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  left, top, width, height = label.bbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  image = TF.crop(image, top, left, height, width)
154
  label.image_width = width
155
  label.image_height = height
156
 
157
- # encode label
158
- label = self.fontlabel2tensor(label, label_path)
159
 
160
- # data augmentation
161
- if self.transforms:
162
  transform = transforms.Compose(
163
  [
164
- RandomColorJitter(),
165
- RandomCrop(),
 
166
  ]
167
  )
168
  image, label = transform((image, label))
169
 
 
 
 
 
 
 
 
 
 
 
 
170
  # resize and to tensor
171
  transform = transforms.Compose(
172
  [
@@ -176,6 +237,11 @@ class FontDataset(Dataset):
176
  )
177
  image = transform(image)
178
 
 
 
 
 
 
179
  # normalize label
180
  if self.regression_use_tanh:
181
  label[2:12] = label[2:12] * 2 - 1
 
17
 
18
 
19
  class RandomColorJitter(object):
20
+ def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.05):
 
 
21
  self.brightness = brightness
22
  self.contrast = contrast
23
  self.saturation = saturation
24
  self.hue = hue
 
25
 
26
  def __call__(self, batch):
 
 
 
27
  image, label = batch
28
  text_color = label[2:5].clone().view(3, 1, 1)
29
  stroke_color = label[7:10].clone().view(3, 1, 1)
 
54
 
55
 
56
  class RandomCrop(object):
57
+ def __init__(self, crop_factor: float = 0.1):
58
  self.crop_factor = crop_factor
 
59
 
60
  def __call__(self, batch):
 
 
 
61
  image, label = batch
62
  width, height = image.size
63
 
 
79
  return image, label
80
 
81
 
82
+ class RandomRotate(object):
83
+ def __init__(self, max_angle: int = 15):
84
+ self.max_angle = max_angle
85
+
86
+ def __call__(self, batch):
87
+ image, label = batch
88
+
89
+ angle = random.uniform(-self.max_angle, self.max_angle)
90
+ image = TF.rotate(image, angle)
91
+ label[11] = label[11] + angle / 180
92
+ return image, label
93
+
94
+
95
  class FontDataset(Dataset):
96
  def __init__(
97
  self,
98
  path: str,
99
  config_path: str = "configs/font.yml",
100
  regression_use_tanh: bool = False,
101
+ transforms: str = None,
102
  crop_roi_bbox: bool = False,
103
  ):
104
+ """Font dataset
105
+
106
+ Args:
107
+ path (str): path to the dataset
108
+ config_path (str, optional): path to font config file. Defaults to "configs/font.yml".
109
+ regression_use_tanh (bool, optional): whether use tanh as regression normalization. Defaults to False.
110
+ transforms (str, optional): choose from None, 'v1', 'v2'. Defaults to None.
111
+ crop_roi_bbox (bool, optional): whether to crop text roi bbox, must be true when transform='v2'. Defaults to False.
112
+ """
113
  self.path = path
114
  self.fonts = load_font_with_exclusion(config_path)
115
  self.regression_use_tanh = regression_use_tanh
 
121
  ]
122
  self.images.sort()
123
 
124
+ if transforms == "v2":
125
+ assert crop_roi_bbox, "crop_roi_bbox must be true when transform='v2'"
126
+
127
  def __len__(self):
128
  return len(self.images)
129
 
 
163
  with open(label_path, "rb") as f:
164
  label: FontLabel = pickle.load(f)
165
 
166
+ if (self.transforms == "v1") or (self.transforms is None):
167
+ if self.crop_roi_bbox:
168
+ left, top, width, height = label.bbox
169
+ image = TF.crop(image, top, left, height, width)
170
+ label.image_width = width
171
+ label.image_height = height
172
+
173
+ # encode label
174
+ label = self.fontlabel2tensor(label, label_path)
175
+
176
+ # data augmentation
177
+ if self.transforms is not None:
178
+ transform = transforms.Compose(
179
+ [
180
+ transforms.RandomApply(RandomColorJitter(), p=0.8),
181
+ transforms.RandomApply(RandomCrop(), p=0.8),
182
+ ]
183
+ )
184
+ image, label = transform((image, label))
185
+ elif self.transforms == "v2":
186
+ # crop from 30% to 130% of bbox
187
  left, top, width, height = label.bbox
188
+
189
+ right = left + width
190
+ bottom = top + height
191
+
192
+ width_delta = width * 0.07
193
+ height_delta = height * 0.07
194
+
195
+ left = max(0, int(left - width_delta))
196
+ top = max(0, int(top - height_delta))
197
+
198
+ right = min(image.width, int(right + width_delta))
199
+ bottom = min(image.height, int(bottom + height_delta))
200
+
201
+ width = right - left
202
+ height = bottom - top
203
+
204
  image = TF.crop(image, top, left, height, width)
205
  label.image_width = width
206
  label.image_height = height
207
 
208
+ # encode label
209
+ label = self.fontlabel2tensor(label, label_path)
210
 
 
 
211
  transform = transforms.Compose(
212
  [
213
+ transforms.RandomApply(RandomColorJitter(), p=0.8),
214
+ RandomCrop(crop_factor=0.54),
215
+ transforms.RandomApply(RandomRotate(), p=0.8),
216
  ]
217
  )
218
  image, label = transform((image, label))
219
 
220
+ transform = transforms.Compose(
221
+ [
222
+ transforms.RandomApply(
223
+ transforms.GaussianBlur(random.randint(2, 5), sigma=(0.1, 5.0)),
224
+ p=0.8,
225
+ ),
226
+ ]
227
+ )
228
+
229
+ image = transform(image)
230
+
231
  # resize and to tensor
232
  transform = transforms.Compose(
233
  [
 
237
  )
238
  image = transform(image)
239
 
240
+ if self.transforms == "v2":
241
+ # noise
242
+ if random.random() < 0.9:
243
+ image = image + torch.randn_like(image) * random.random() * 0.05
244
+
245
  # normalize label
246
  if self.regression_use_tanh:
247
  label[2:12] = label[2:12] * 2 - 1
train.py CHANGED
@@ -54,6 +54,14 @@ parser.add_argument(
54
  action="store_true",
55
  help="Crop ROI bounding box (default: False)",
56
  )
 
 
 
 
 
 
 
 
57
 
58
  args = parser.parse_args()
59
 
@@ -73,7 +81,6 @@ lambda_direction = 0.5
73
  lambda_regression = 1.0
74
 
75
  regression_use_tanh = False
76
- augmentation = True
77
 
78
  num_warmup_epochs = 5
79
  num_epochs = 100
@@ -90,7 +97,7 @@ data_module = FontDataModule(
90
  val_shuffle=False,
91
  test_shuffle=False,
92
  regression_use_tanh=regression_use_tanh,
93
- train_transforms=augmentation,
94
  crop_roi_bbox=args.crop_roi_bbox,
95
  )
96
 
 
54
  action="store_true",
55
  help="Crop ROI bounding box (default: False)",
56
  )
57
+ parser.add_argument(
58
+ "-a",
59
+ "--augmentation",
60
+ type=str,
61
+ default=None,
62
+ choices=["v1", "v2"],
63
+ help="Augmentation strategy to use (default: None)",
64
+ )
65
 
66
  args = parser.parse_args()
67
 
 
81
  lambda_regression = 1.0
82
 
83
  regression_use_tanh = False
 
84
 
85
  num_warmup_epochs = 5
86
  num_epochs = 100
 
97
  val_shuffle=False,
98
  test_shuffle=False,
99
  regression_use_tanh=regression_use_tanh,
100
+ train_transforms=args.augmentation,
101
  crop_roi_bbox=args.crop_roi_bbox,
102
  )
103