Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import numpy as np | |
import math | |
import torch | |
import torch.utils.data as data | |
import sys | |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) | |
from configs.anipose_data_info import COMPLETE_DATA_INFO | |
from stacked_hourglass.utils.imutils import load_image, im_to_torch | |
from stacked_hourglass.utils.transforms import crop, color_normalize | |
from stacked_hourglass.utils.pilutil import imresize | |
from stacked_hourglass.utils.imutils import im_to_torch | |
from configs.data_info import COMPLETE_DATA_INFO_24 | |
class ImgCrops(data.Dataset): | |
DATA_INFO = COMPLETE_DATA_INFO_24 | |
ACC_JOINTS = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 16] | |
def __init__(self, image_list, bbox_list=None, inp_res=256, dataset_mode='keyp_only'): | |
# the list contains the images directly, not only their paths | |
self.image_list = image_list | |
self.bbox_list = bbox_list | |
self.inp_res = inp_res | |
self.test_name_list = [] | |
for ind in np.arange(0, len(self.image_list)): | |
self.test_name_list.append(str(ind)) | |
print('len(dataset): ' + str(self.__len__())) | |
def __getitem__(self, index): | |
'''img_name = self.test_name_list[index] | |
# load image | |
img_path = os.path.join(self.folder_imgs, img_name) | |
img = load_image(img_path) # CxHxW''' | |
# load image | |
'''img_hwc = self.image_list[index] | |
img = np.rollaxis(img_hwc, 2, 0) ''' | |
img = im_to_torch(self.image_list[index]) | |
# import pdb; pdb.set_trace() | |
# try loading bounding box | |
if self.bbox_list is not None: | |
bbox = self.bbox_list[index] | |
bbox_xywh = [bbox[0][0], bbox[0][1], bbox[1][0]-bbox[0][0], bbox[1][1]-bbox[0][1]] | |
bbox_c = [bbox_xywh[0]+0.5*bbox_xywh[2], bbox_xywh[1]+0.5*bbox_xywh[3]] | |
bbox_max = max(bbox_xywh[2], bbox_xywh[3]) | |
bbox_diag = math.sqrt(bbox_xywh[2]**2 + bbox_xywh[3]**2) | |
bbox_s = bbox_max / 200. * 256. / 200. # maximum side of the bbox will be 200 | |
c = torch.Tensor(bbox_c) | |
s = bbox_s | |
img_prep = crop(img, c, s, [self.inp_res, self.inp_res], rot=0) | |
else: | |
# prepare image (cropping and color) | |
img_max = max(img.shape[1], img.shape[2]) | |
img_padded = torch.zeros((img.shape[0], img_max, img_max)) | |
if img_max == img.shape[2]: | |
start = (img_max-img.shape[1])//2 | |
img_padded[:, start:start+img.shape[1], :] = img | |
else: | |
start = (img_max-img.shape[2])//2 | |
img_padded[:, :, start:start+img.shape[2]] = img | |
img = img_padded | |
img_prep = im_to_torch(imresize(img, [self.inp_res, self.inp_res], interp='bilinear')) | |
inp = color_normalize(img_prep, self.DATA_INFO.rgb_mean, self.DATA_INFO.rgb_stddev) | |
# add the following fields to make it compatible with stanext, most of them are fake | |
target_dict = {'index': index, 'center' : -2, 'scale' : -2, | |
'breed_index': -2, 'sim_breed_index': -2, | |
'ind_dataset': 1} | |
target_dict['pts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) | |
target_dict['tpts'] = np.zeros((self.DATA_INFO.n_keyp, 3)) | |
target_dict['target_weight'] = np.zeros((self.DATA_INFO.n_keyp, 1)) | |
target_dict['silh'] = np.zeros((self.inp_res, self.inp_res)) | |
return inp, target_dict | |
def __len__(self): | |
return len(self.image_list) | |