fix: transform label when horizontal flip
Browse files- detector/data.py +16 -1
detector/data.py
CHANGED
@@ -152,6 +152,21 @@ class RandomCropPreserveAspectRatio(object):
|
|
152 |
return image, label
|
153 |
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
class FontDataset(Dataset):
|
156 |
def __init__(
|
157 |
self,
|
@@ -222,6 +237,7 @@ class FontDataset(Dataset):
|
|
222 |
RandomColorJitter(preserve=0.2),
|
223 |
RandomCrop(crop_factor=0.54, preserve=0),
|
224 |
RandomRotate(preserve=0.2),
|
|
|
225 |
]
|
226 |
image_transforms = [
|
227 |
torchvision.transforms.GaussianBlur(
|
@@ -231,7 +247,6 @@ class FontDataset(Dataset):
|
|
231 |
torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
|
232 |
torchvision.transforms.ToTensor(),
|
233 |
RandomNoise(max_noise=0.05, preserve=0.1),
|
234 |
-
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
235 |
]
|
236 |
else:
|
237 |
raise ValueError(f"Unknown transform: {transforms}")
|
|
|
152 |
return image, label
|
153 |
|
154 |
|
155 |
+
class RandomHorizontalFlip(object):
|
156 |
+
def __init__(self, preserve: float = 0.5):
|
157 |
+
self.preserve = preserve
|
158 |
+
|
159 |
+
def __call__(self, batch):
|
160 |
+
if random.random() < self.preserve:
|
161 |
+
return batch
|
162 |
+
|
163 |
+
image, label = batch
|
164 |
+
image = TF.hflip(image)
|
165 |
+
label[11] = 1 - label[11]
|
166 |
+
|
167 |
+
return image, label
|
168 |
+
|
169 |
+
|
170 |
class FontDataset(Dataset):
|
171 |
def __init__(
|
172 |
self,
|
|
|
237 |
RandomColorJitter(preserve=0.2),
|
238 |
RandomCrop(crop_factor=0.54, preserve=0),
|
239 |
RandomRotate(preserve=0.2),
|
240 |
+
RandomHorizontalFlip(preserve=0.5),
|
241 |
]
|
242 |
image_transforms = [
|
243 |
torchvision.transforms.GaussianBlur(
|
|
|
247 |
torchvision.transforms.Resize((config.INPUT_SIZE, config.INPUT_SIZE)),
|
248 |
torchvision.transforms.ToTensor(),
|
249 |
RandomNoise(max_noise=0.05, preserve=0.1),
|
|
|
250 |
]
|
251 |
else:
|
252 |
raise ValueError(f"Unknown transform: {transforms}")
|