import pandas as pd import os import torch from PIL import Image from torch.utils.data import Dataset import clip from torch.utils.data import DataLoader import torchvision.transforms as tf import torchvision.transforms.functional as TF try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC class ExtractFeaturesDataset(Dataset): def __init__(self, annotations, img_path, image_transforms=None, question_transforms=None, tta=False): self.img_path = img_path self.image_transforms = image_transforms self.question_transforms = question_transforms self.img_ids = annotations["image_id"].values self.split = annotations["split"].values self.questions = annotations["question"].values self.tta = tta def __getitem__(self, index): image_id = self.img_ids[index] split = self.split[index] # image input with open(os.path.join(self.img_path, split, image_id), "rb") as f: img = Image.open(f) if self.tta: image_augmentations = [] for transform in self.image_transforms: image_augmentations.append(transform(img)) img = torch.stack(image_augmentations, dim=0) else: img = self.image_transforms(img) question = self.questions[index] if self.question_transforms: question = self.question_transforms(question) # question input question = clip.tokenize(question, truncate=True) question = question.squeeze() return img, question, image_id def __len__(self): return len(self.img_ids) def _convert_image_to_rgb(image): return image.convert("RGB") def Sharpen(sharpness_factor=1.0): def wrapper(x): return TF.adjust_sharpness(x, sharpness_factor) return wrapper def Rotate(angle=0.0): def wrapper(x): return TF.rotate(x, angle) return wrapper def transform_crop(n_px): return tf.Compose([ tf.Resize(n_px, interpolation=BICUBIC), tf.CenterCrop(n_px), _convert_image_to_rgb, tf.ToTensor(), tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def transform_crop_rotate(n_px, rotation_angle=0.0): return tf.Compose([ Rotate(angle=rotation_angle), tf.Resize(n_px, interpolation=BICUBIC), tf.CenterCrop(n_px), _convert_image_to_rgb, tf.ToTensor(), tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def transform_resize(n_px): return tf.Compose([ tf.Resize((n_px, n_px), interpolation=BICUBIC), _convert_image_to_rgb, tf.ToTensor(), tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def transform_resize_rotate(n_px, rotation_angle=0.0): return tf.Compose([ Rotate(angle=rotation_angle), tf.Resize((n_px, n_px), interpolation=BICUBIC), _convert_image_to_rgb, tf.ToTensor(), tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def get_tta_preprocess(img_size): img_preprocess = [ transform_crop(img_size), transform_crop_rotate(img_size, rotation_angle=90.0), transform_crop_rotate(img_size, rotation_angle=270.0), transform_resize(img_size), transform_resize_rotate(img_size, rotation_angle=90.0), transform_resize_rotate(img_size, rotation_angle=270.0), ] return img_preprocess def question_preprocess(question, debug=False): question = question.replace("?", ".") if question[-1] == " ": question = question[:-1] if question[-1] != ".": question = question + "." if debug: print("Question:", question) return question def get_dataloader_extraction(config): if config.use_question_preprocess: print("Using custom preprocessing: Question") question_transforms = question_preprocess else: question_transforms = None if config.tta: ("Using augmentation transforms:") img_preprocess = get_tta_preprocess(config.img_size) else: ("Using original CLIP transforms:") img_preprocess = transform_crop(config.img_size) train_data = pd.read_csv(config.train_annotations_path) train_dataset = ExtractFeaturesDataset(annotations = train_data, img_path=config.img_path, image_transforms=img_preprocess, question_transforms=question_transforms, tta=config.tta) train_loader = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) test_data = pd.read_csv(config.test_annotations_path) test_dataset = ExtractFeaturesDataset(annotations = test_data, img_path=config.img_path, image_transforms=img_preprocess, question_transforms=question_transforms, tta=config.tta) test_loader = ExtractFeaturesDataset(dataset=test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) return train_loader, test_loader def get_dataloader_inference(config): if config.use_question_preprocess: print("Using custom preprocessing: Question") question_transforms = question_preprocess else: question_transforms = None if config.tta: ("Using augmentation transforms:") img_preprocess = transform_resize(config.img_size) else: ("Using original CLIP transforms:") img_preprocess = transform_crop(config.img_size) train_data = pd.read_csv(config.train_annotations_path) train_dataset = ExtractFeaturesDataset(annotations = train_data, img_path=config.img_path, image_transforms=img_preprocess, question_transforms=question_transforms, tta=config.tta) train_loader = DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) test_data = pd.read_csv(config.test_annotations_path) test_dataset = ExtractFeaturesDataset(annotations = test_data, img_path=config.img_path, image_transforms=img_preprocess, question_transforms=question_transforms, tta=config.tta) test_loader = ExtractFeaturesDataset(dataset=test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) return train_loader, test_loader