Spaces:
Build error
Build error
import pandas as pd | |
import os | |
import torch | |
from PIL import Image | |
from torch.utils.data import Dataset | |
import clip | |
from torch.utils.data import DataLoader | |
import torchvision.transforms as tf | |
import torchvision.transforms.functional as TF | |
try: | |
from torchvision.transforms import InterpolationMode | |
BICUBIC = InterpolationMode.BICUBIC | |
except ImportError: | |
BICUBIC = Image.BICUBIC | |
class ExtractFeaturesDataset(Dataset): | |
def __init__(self, | |
annotations, | |
img_path, | |
image_transforms=None, | |
question_transforms=None, | |
tta=False): | |
self.img_path = img_path | |
self.image_transforms = image_transforms | |
self.question_transforms = question_transforms | |
self.img_ids = annotations["image_id"].values | |
self.split = annotations["split"].values | |
self.questions = annotations["question"].values | |
self.tta = tta | |
def __getitem__(self, index): | |
image_id = self.img_ids[index] | |
split = self.split[index] | |
# image input | |
with open(os.path.join(self.img_path, split, image_id), "rb") as f: | |
img = Image.open(f) | |
if self.tta: | |
image_augmentations = [] | |
for transform in self.image_transforms: | |
image_augmentations.append(transform(img)) | |
img = torch.stack(image_augmentations, dim=0) | |
else: | |
img = self.image_transforms(img) | |
question = self.questions[index] | |
if self.question_transforms: | |
question = self.question_transforms(question) | |
# question input | |
question = clip.tokenize(question, truncate=True) | |
question = question.squeeze() | |
return img, question, image_id | |
def __len__(self): | |
return len(self.img_ids) | |
def _convert_image_to_rgb(image): | |
return image.convert("RGB") | |
def Sharpen(sharpness_factor=1.0): | |
def wrapper(x): | |
return TF.adjust_sharpness(x, sharpness_factor) | |
return wrapper | |
def Rotate(angle=0.0): | |
def wrapper(x): | |
return TF.rotate(x, angle) | |
return wrapper | |
def transform_crop(n_px): | |
return tf.Compose([ | |
tf.Resize(n_px, interpolation=BICUBIC), | |
tf.CenterCrop(n_px), | |
_convert_image_to_rgb, | |
tf.ToTensor(), | |
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
def transform_crop_rotate(n_px, rotation_angle=0.0): | |
return tf.Compose([ | |
Rotate(angle=rotation_angle), | |
tf.Resize(n_px, interpolation=BICUBIC), | |
tf.CenterCrop(n_px), | |
_convert_image_to_rgb, | |
tf.ToTensor(), | |
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
def transform_resize(n_px): | |
return tf.Compose([ | |
tf.Resize((n_px, n_px), interpolation=BICUBIC), | |
_convert_image_to_rgb, | |
tf.ToTensor(), | |
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
def transform_resize_rotate(n_px, rotation_angle=0.0): | |
return tf.Compose([ | |
Rotate(angle=rotation_angle), | |
tf.Resize((n_px, n_px), interpolation=BICUBIC), | |
_convert_image_to_rgb, | |
tf.ToTensor(), | |
tf.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
]) | |
def get_tta_preprocess(img_size): | |
img_preprocess = [ | |
transform_crop(img_size), | |
transform_crop_rotate(img_size, rotation_angle=90.0), | |
transform_crop_rotate(img_size, rotation_angle=270.0), | |
transform_resize(img_size), | |
transform_resize_rotate(img_size, rotation_angle=90.0), | |
transform_resize_rotate(img_size, rotation_angle=270.0), | |
] | |
return img_preprocess | |
def question_preprocess(question, debug=False): | |
question = question.replace("?", ".") | |
if question[-1] == " ": | |
question = question[:-1] | |
if question[-1] != ".": | |
question = question + "." | |
if debug: | |
print("Question:", question) | |
return question | |
def get_dataloader_extraction(config): | |
if config.use_question_preprocess: | |
print("Using custom preprocessing: Question") | |
question_transforms = question_preprocess | |
else: | |
question_transforms = None | |
if config.tta: | |
("Using augmentation transforms:") | |
img_preprocess = get_tta_preprocess(config.img_size) | |
else: | |
("Using original CLIP transforms:") | |
img_preprocess = transform_crop(config.img_size) | |
train_data = pd.read_csv(config.train_annotations_path) | |
train_dataset = ExtractFeaturesDataset(annotations = train_data, | |
img_path=config.img_path, | |
image_transforms=img_preprocess, | |
question_transforms=question_transforms, | |
tta=config.tta) | |
train_loader = DataLoader(dataset=train_dataset, | |
batch_size=config.batch_size, | |
shuffle=False, | |
num_workers=config.num_workers) | |
test_data = pd.read_csv(config.test_annotations_path) | |
test_dataset = ExtractFeaturesDataset(annotations = test_data, | |
img_path=config.img_path, | |
image_transforms=img_preprocess, | |
question_transforms=question_transforms, | |
tta=config.tta) | |
test_loader = ExtractFeaturesDataset(dataset=test_dataset, | |
batch_size=config.batch_size, | |
shuffle=False, | |
num_workers=config.num_workers) | |
return train_loader, test_loader | |
def get_dataloader_inference(config): | |
if config.use_question_preprocess: | |
print("Using custom preprocessing: Question") | |
question_transforms = question_preprocess | |
else: | |
question_transforms = None | |
if config.tta: | |
("Using augmentation transforms:") | |
img_preprocess = transform_resize(config.img_size) | |
else: | |
("Using original CLIP transforms:") | |
img_preprocess = transform_crop(config.img_size) | |
train_data = pd.read_csv(config.train_annotations_path) | |
train_dataset = ExtractFeaturesDataset(annotations = train_data, | |
img_path=config.img_path, | |
image_transforms=img_preprocess, | |
question_transforms=question_transforms, | |
tta=config.tta) | |
train_loader = DataLoader(dataset=train_dataset, | |
batch_size=config.batch_size, | |
shuffle=False, | |
num_workers=config.num_workers) | |
test_data = pd.read_csv(config.test_annotations_path) | |
test_dataset = ExtractFeaturesDataset(annotations = test_data, | |
img_path=config.img_path, | |
image_transforms=img_preprocess, | |
question_transforms=question_transforms, | |
tta=config.tta) | |
test_loader = ExtractFeaturesDataset(dataset=test_dataset, | |
batch_size=config.batch_size, | |
shuffle=False, | |
num_workers=config.num_workers) | |
return train_loader, test_loader | |