VizWiz-CLIP-VQA / dataloader /extract_features_dataloader.py
Skyy93's picture
Add all files
a4fb052
import pandas as pd
import os
import torch
from PIL import Image
from torch.utils.data import Dataset
import clip
from torch.utils.data import DataLoader
import torchvision.transforms as tf
import torchvision.transforms.functional as TF
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
class ExtractFeaturesDataset(Dataset):
def __init__(self,
annotations,
img_path,
image_transforms=None,
question_transforms=None,
tta=False):
self.img_path = img_path
self.image_transforms = image_transforms
self.question_transforms = question_transforms
self.img_ids = annotations["image_id"].values
self.split = annotations["split"].values
self.questions = annotations["question"].values
self.tta = tta
def __getitem__(self, index):
image_id = self.img_ids[index]
split = self.split[index]
# image input
with open(os.path.join(self.img_path, split, image_id), "rb") as f:
img = Image.open(f)
if self.tta:
image_augmentations = []
for transform in self.image_transforms:
image_augmentations.append(transform(img))
img = torch.stack(image_augmentations, dim=0)
else:
img = self.image_transforms(img)
question = self.questions[index]
if self.question_transforms:
question = self.question_transforms(question)
# question input
question = clip.tokenize(question, truncate=True)
question = question.squeeze()
return img, question, image_id
def __len__(self):
return len(self.img_ids)
def _convert_image_to_rgb(image):
return image.convert("RGB")
def Sharpen(sharpness_factor=1.0):
def wrapper(x):
return TF.adjust_sharpness(x, sharpness_factor)
return wrapper
def Rotate(angle=0.0):
def wrapper(x):
return TF.rotate(x, angle)
return wrapper
def transform_crop(n_px):
return tf.Compose([
tf.Resize(n_px, interpolation=BICUBIC),
tf.CenterCrop(n_px),
_convert_image_to_rgb,
tf.ToTensor(),
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def transform_crop_rotate(n_px, rotation_angle=0.0):
return tf.Compose([
Rotate(angle=rotation_angle),
tf.Resize(n_px, interpolation=BICUBIC),
tf.CenterCrop(n_px),
_convert_image_to_rgb,
tf.ToTensor(),
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def transform_resize(n_px):
return tf.Compose([
tf.Resize((n_px, n_px), interpolation=BICUBIC),
_convert_image_to_rgb,
tf.ToTensor(),
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def transform_resize_rotate(n_px, rotation_angle=0.0):
return tf.Compose([
Rotate(angle=rotation_angle),
tf.Resize((n_px, n_px), interpolation=BICUBIC),
_convert_image_to_rgb,
tf.ToTensor(),
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def get_tta_preprocess(img_size):
img_preprocess = [
transform_crop(img_size),
transform_crop_rotate(img_size, rotation_angle=90.0),
transform_crop_rotate(img_size, rotation_angle=270.0),
transform_resize(img_size),
transform_resize_rotate(img_size, rotation_angle=90.0),
transform_resize_rotate(img_size, rotation_angle=270.0),
]
return img_preprocess
def question_preprocess(question, debug=False):
question = question.replace("?", ".")
if question[-1] == " ":
question = question[:-1]
if question[-1] != ".":
question = question + "."
if debug:
print("Question:", question)
return question
def get_dataloader_extraction(config):
if config.use_question_preprocess:
print("Using custom preprocessing: Question")
question_transforms = question_preprocess
else:
question_transforms = None
if config.tta:
("Using augmentation transforms:")
img_preprocess = get_tta_preprocess(config.img_size)
else:
("Using original CLIP transforms:")
img_preprocess = transform_crop(config.img_size)
train_data = pd.read_csv(config.train_annotations_path)
train_dataset = ExtractFeaturesDataset(annotations = train_data,
img_path=config.img_path,
image_transforms=img_preprocess,
question_transforms=question_transforms,
tta=config.tta)
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
test_data = pd.read_csv(config.test_annotations_path)
test_dataset = ExtractFeaturesDataset(annotations = test_data,
img_path=config.img_path,
image_transforms=img_preprocess,
question_transforms=question_transforms,
tta=config.tta)
test_loader = ExtractFeaturesDataset(dataset=test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
return train_loader, test_loader
def get_dataloader_inference(config):
if config.use_question_preprocess:
print("Using custom preprocessing: Question")
question_transforms = question_preprocess
else:
question_transforms = None
if config.tta:
("Using augmentation transforms:")
img_preprocess = transform_resize(config.img_size)
else:
("Using original CLIP transforms:")
img_preprocess = transform_crop(config.img_size)
train_data = pd.read_csv(config.train_annotations_path)
train_dataset = ExtractFeaturesDataset(annotations = train_data,
img_path=config.img_path,
image_transforms=img_preprocess,
question_transforms=question_transforms,
tta=config.tta)
train_loader = DataLoader(dataset=train_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
test_data = pd.read_csv(config.test_annotations_path)
test_dataset = ExtractFeaturesDataset(annotations = test_data,
img_path=config.img_path,
image_transforms=img_preprocess,
question_transforms=question_transforms,
tta=config.tta)
test_loader = ExtractFeaturesDataset(dataset=test_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
return train_loader, test_loader