gyrojeff commited on
Commit
8063aac
·
1 Parent(s): 3a87836

fix: transform label when horizontal flip

Browse files
Files changed (1) hide show
  1. detector/data.py +16 -1
detector/data.py CHANGED
@@ -152,6 +152,21 @@ class RandomCropPreserveAspectRatio(object):
152
  return image, label
153
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  class FontDataset(Dataset):
156
  def __init__(
157
  self,
@@ -222,6 +237,7 @@ class FontDataset(Dataset):
222
  RandomColorJitter(preserve=0.2),
223
  RandomCrop(crop_factor=0.54, preserve=0),
224
  RandomRotate(preserve=0.2),
 
225
  ]
226
  image_transforms = [
227
  torchvision.transforms.GaussianBlur(
@@ -231,7 +247,6 @@ class FontDataset(Dataset):
231
  torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
232
  torchvision.transforms.ToTensor(),
233
  RandomNoise(max_noise=0.05, preserve=0.1),
234
- torchvision.transforms.RandomHorizontalFlip(p=0.5),
235
  ]
236
  else:
237
  raise ValueError(f"Unknown transform: {transforms}")
 
152
  return image, label
153
 
154
 
155
+ class RandomHorizontalFlip(object):
156
+ def __init__(self, preserve: float = 0.5):
157
+ self.preserve = preserve
158
+
159
+ def __call__(self, batch):
160
+ if random.random() < self.preserve:
161
+ return batch
162
+
163
+ image, label = batch
164
+ image = TF.hflip(image)
165
+ label[11] = 1 - label[11]
166
+
167
+ return image, label
168
+
169
+
170
  class FontDataset(Dataset):
171
  def __init__(
172
  self,
 
237
  RandomColorJitter(preserve=0.2),
238
  RandomCrop(crop_factor=0.54, preserve=0),
239
  RandomRotate(preserve=0.2),
240
+ RandomHorizontalFlip(preserve=0.5),
241
  ]
242
  image_transforms = [
243
  torchvision.transforms.GaussianBlur(
 
247
  torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
248
  torchvision.transforms.ToTensor(),
249
  RandomNoise(max_noise=0.05, preserve=0.1),
 
250
  ]
251
  else:
252
  raise ValueError(f"Unknown transform: {transforms}")