Spaces:
Runtime error
Runtime error
import json | |
import cv2 | |
import numpy as np | |
import os | |
from torch.utils.data import Dataset | |
import pycocotools.mask as maskUtils | |
from torchvision import transforms | |
import utils.transforms as custom_transforms | |
from PIL import Image | |
class SAMDataset(Dataset): | |
def __init__(self, data_path='../data/files', txt_path='../data/data_85616.txt'): | |
self.data = [] | |
with open(txt_path, 'rt') as f: | |
for line in f: | |
self.data.append(eval(line)) | |
self.data_path = data_path | |
randomresizedcrop = custom_transforms.RandomResizedCrop( | |
512, | |
scale=(0.9, 1), | |
) | |
self.transform = custom_transforms.Compose([ | |
randomresizedcrop, | |
custom_transforms.RandomHorizontalFlip(p=0.5), | |
custom_transforms.ToTensor(), | |
transforms.Normalize(mean=0.5, std=0.5) | |
]) | |
def __len__(self): | |
return len(self.data) | |
def load_rle_annotations_from_json(self, json_file_path, return_pil=True): | |
with open(json_file_path, 'r', encoding='utf-8') as f: | |
anno_data = json.load(f) | |
annotations = anno_data['annotations'] | |
height = int(anno_data['image']['height']) | |
width = int(anno_data['image']['width']) | |
map = np.zeros((height,width), dtype=np.uint16) | |
for i in range(len(annotations)): | |
ann = annotations[i] | |
mask = maskUtils.decode(ann['segmentation']) | |
map[mask != 0] = i + 1 | |
if return_pil: | |
res = np.zeros((map.shape[0], map.shape[1], 3)) | |
res[:, :, 0] = map % 256 | |
res[:, :, 1] = map // 256 | |
res = Image.fromarray(res.astype(np.uint8)) | |
return res | |
return map | |
def __getitem__(self, idx): | |
item = self.data[idx] | |
source_filename = item['source'] | |
target_filename = item['target'] | |
prompt = item['prompt'] | |
source = self.load_rle_annotations_from_json(os.path.join(self.data_path, source_filename)) | |
target = Image.open(os.path.join(self.data_path, target_filename)) | |
target, source = self.transform(target, source) | |
print(source.max(), source.min()) | |
target = target.permute(1,2,0) | |
return dict(jpg=target, txt=prompt, hint=source) | |