|
import os.path as osp |
|
import random |
|
|
|
import cv2 |
|
import decord |
|
import numpy as np |
|
import skvideo.io |
|
import torch |
|
import torchvision |
|
from decord import VideoReader, cpu, gpu |
|
from tqdm import tqdm |
|
|
|
random.seed(42) |
|
|
|
decord.bridge.set_bridge("torch") |
|
|
|
|
|
def get_spatial_fragments( |
|
video, |
|
fragments_h=7, |
|
fragments_w=7, |
|
fsize_h=32, |
|
fsize_w=32, |
|
aligned=32, |
|
nfrags=1, |
|
random=False, |
|
fallback_type="upsample", |
|
): |
|
size_h = fragments_h * fsize_h |
|
size_w = fragments_w * fsize_w |
|
|
|
|
|
if video.shape[1] == 1: |
|
aligned = 1 |
|
|
|
dur_t, res_h, res_w = video.shape[-3:] |
|
ratio = min(res_h / size_h, res_w / size_w) |
|
if fallback_type == "upsample" and ratio < 1: |
|
|
|
ovideo = video |
|
video = torch.nn.functional.interpolate( |
|
video / 255.0, scale_factor=1 / ratio, mode="bilinear" |
|
) |
|
video = (video * 255.0).type_as(ovideo) |
|
|
|
assert dur_t % aligned == 0, "Please provide match vclip and align index" |
|
size = size_h, size_w |
|
|
|
|
|
hgrids = torch.LongTensor( |
|
[min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] |
|
) |
|
wgrids = torch.LongTensor( |
|
[min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] |
|
) |
|
hlength, wlength = res_h // fragments_h, res_w // fragments_w |
|
|
|
if random: |
|
print("This part is deprecated. Please remind that.") |
|
if res_h > fsize_h: |
|
rnd_h = torch.randint( |
|
res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) |
|
) |
|
else: |
|
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() |
|
if res_w > fsize_w: |
|
rnd_w = torch.randint( |
|
res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) |
|
) |
|
else: |
|
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() |
|
else: |
|
if hlength > fsize_h: |
|
rnd_h = torch.randint( |
|
hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) |
|
) |
|
else: |
|
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() |
|
if wlength > fsize_w: |
|
rnd_w = torch.randint( |
|
wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) |
|
) |
|
else: |
|
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() |
|
|
|
target_video = torch.zeros(video.shape[:-2] + size).to(video.device) |
|
|
|
|
|
for i, hs in enumerate(hgrids): |
|
for j, ws in enumerate(wgrids): |
|
for t in range(dur_t // aligned): |
|
t_s, t_e = t * aligned, (t + 1) * aligned |
|
h_s, h_e = i * fsize_h, (i + 1) * fsize_h |
|
w_s, w_e = j * fsize_w, (j + 1) * fsize_w |
|
if random: |
|
h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h |
|
w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w |
|
else: |
|
h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h |
|
w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w |
|
target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ |
|
:, t_s:t_e, h_so:h_eo, w_so:w_eo |
|
] |
|
|
|
|
|
|
|
return target_video |
|
|
|
|
|
class FragmentSampleFrames: |
|
def __init__(self, fsize_t, fragments_t, frame_interval=1, num_clips=1): |
|
|
|
self.fragments_t = fragments_t |
|
self.fsize_t = fsize_t |
|
self.size_t = fragments_t * fsize_t |
|
self.frame_interval = frame_interval |
|
self.num_clips = num_clips |
|
|
|
def get_frame_indices(self, num_frames): |
|
|
|
tgrids = np.array( |
|
[num_frames // self.fragments_t * i for i in range(self.fragments_t)], |
|
dtype=np.int32, |
|
) |
|
tlength = num_frames // self.fragments_t |
|
|
|
if tlength > self.fsize_t * self.frame_interval: |
|
rnd_t = np.random.randint( |
|
0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids) |
|
) |
|
else: |
|
rnd_t = np.zeros(len(tgrids), dtype=np.int32) |
|
|
|
ranges_t = ( |
|
np.arange(self.fsize_t)[None, :] * self.frame_interval |
|
+ rnd_t[:, None] |
|
+ tgrids[:, None] |
|
) |
|
return np.concatenate(ranges_t) |
|
|
|
def __call__(self, total_frames, train=False, start_index=0): |
|
frame_inds = [] |
|
for i in range(self.num_clips): |
|
frame_inds += [self.get_frame_indices(total_frames)] |
|
frame_inds = np.concatenate(frame_inds) |
|
frame_inds = np.mod(frame_inds + start_index, total_frames) |
|
return frame_inds |
|
|
|
|
|
class SampleFrames: |
|
def __init__(self, clip_len, frame_interval=1, num_clips=1): |
|
|
|
self.clip_len = clip_len |
|
self.frame_interval = frame_interval |
|
self.num_clips = num_clips |
|
|
|
def _get_train_clips(self, num_frames): |
|
"""Get clip offsets in train mode. |
|
|
|
It will calculate the average interval for selected frames, |
|
and randomly shift them within offsets between [0, avg_interval]. |
|
If the total number of frames is smaller than clips num or origin |
|
frames length, it will return all zero indices. |
|
|
|
Args: |
|
num_frames (int): Total number of frame in the video. |
|
|
|
Returns: |
|
np.ndarray: Sampled frame indices in train mode. |
|
""" |
|
ori_clip_len = self.clip_len * self.frame_interval |
|
avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips |
|
|
|
if avg_interval > 0: |
|
base_offsets = np.arange(self.num_clips) * avg_interval |
|
clip_offsets = base_offsets + np.random.randint( |
|
avg_interval, size=self.num_clips |
|
) |
|
elif num_frames > max(self.num_clips, ori_clip_len): |
|
clip_offsets = np.sort( |
|
np.random.randint(num_frames - ori_clip_len + 1, size=self.num_clips) |
|
) |
|
elif avg_interval == 0: |
|
ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips |
|
clip_offsets = np.around(np.arange(self.num_clips) * ratio) |
|
else: |
|
clip_offsets = np.zeros((self.num_clips,), dtype=np.int) |
|
return clip_offsets |
|
|
|
def _get_test_clips(self, num_frames, start_index=0): |
|
"""Get clip offsets in test mode. |
|
|
|
Calculate the average interval for selected frames, and shift them |
|
fixedly by avg_interval/2. |
|
|
|
Args: |
|
num_frames (int): Total number of frame in the video. |
|
|
|
Returns: |
|
np.ndarray: Sampled frame indices in test mode. |
|
""" |
|
ori_clip_len = self.clip_len * self.frame_interval |
|
avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips) |
|
if num_frames > ori_clip_len - 1: |
|
base_offsets = np.arange(self.num_clips) * avg_interval |
|
clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32) |
|
else: |
|
clip_offsets = np.zeros((self.num_clips,), dtype=np.int32) |
|
return clip_offsets |
|
|
|
def __call__(self, total_frames, train=False, start_index=0): |
|
"""Perform the SampleFrames loading. |
|
|
|
Args: |
|
results (dict): The resulting dict to be modified and passed |
|
to the next transform in pipeline. |
|
""" |
|
if train: |
|
clip_offsets = self._get_train_clips(total_frames) |
|
else: |
|
clip_offsets = self._get_test_clips(total_frames) |
|
frame_inds = ( |
|
clip_offsets[:, None] |
|
+ np.arange(self.clip_len)[None, :] * self.frame_interval |
|
) |
|
frame_inds = np.concatenate(frame_inds) |
|
|
|
frame_inds = frame_inds.reshape((-1, self.clip_len)) |
|
frame_inds = np.mod(frame_inds, total_frames) |
|
frame_inds = np.concatenate(frame_inds) + start_index |
|
return frame_inds.astype(np.int32) |
|
|
|
|
|
class FastVQAPlusPlusDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
ann_file, |
|
data_prefix, |
|
frame_interval=2, |
|
aligned=32, |
|
fragments=(8, 8, 8), |
|
fsize=(4, 32, 32), |
|
num_clips=1, |
|
nfrags=1, |
|
cache_in_memory=False, |
|
phase="test", |
|
fallback_type="oversample", |
|
): |
|
""" |
|
Fragments. |
|
args: |
|
fragments: G_f as in the paper. |
|
fsize: S_f as in the paper. |
|
nfrags: number of samples (spatially) as in the paper. |
|
num_clips: number of samples (temporally) as in the paper. |
|
""" |
|
self.ann_file = ann_file |
|
self.data_prefix = data_prefix |
|
self.frame_interval = frame_interval |
|
self.num_clips = num_clips |
|
self.fragments = fragments |
|
self.fsize = fsize |
|
self.nfrags = nfrags |
|
self.clip_len = fragments[0] * fsize[0] |
|
self.aligned = aligned |
|
self.fallback_type = fallback_type |
|
self.sampler = FragmentSampleFrames( |
|
fsize[0], fragments[0], frame_interval, num_clips |
|
) |
|
self.video_infos = [] |
|
self.phase = phase |
|
self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) |
|
self.std = torch.FloatTensor([58.395, 57.12, 57.375]) |
|
if isinstance(self.ann_file, list): |
|
self.video_infos = self.ann_file |
|
else: |
|
with open(self.ann_file, "r") as fin: |
|
for line in fin: |
|
line_split = line.strip().split(",") |
|
filename, _, _, label = line_split |
|
label = float(label) |
|
filename = osp.join(self.data_prefix, filename) |
|
self.video_infos.append(dict(filename=filename, label=label)) |
|
if cache_in_memory: |
|
self.cache = {} |
|
for i in tqdm(range(len(self)), desc="Caching fragments"): |
|
self.cache[i] = self.__getitem__(i, tocache=True) |
|
else: |
|
self.cache = None |
|
|
|
def __getitem__( |
|
self, index, tocache=False, need_original_frames=False, |
|
): |
|
if tocache or self.cache is None: |
|
fx, fy = self.fragments[1:] |
|
fsx, fsy = self.fsize[1:] |
|
video_info = self.video_infos[index] |
|
filename = video_info["filename"] |
|
label = video_info["label"] |
|
if filename.endswith(".yuv"): |
|
video = skvideo.io.vread( |
|
filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} |
|
) |
|
frame_inds = self.sampler(video.shape[0], self.phase == "train") |
|
imgs = [torch.from_numpy(video[idx]) for idx in frame_inds] |
|
else: |
|
vreader = VideoReader(filename) |
|
frame_inds = self.sampler(len(vreader), self.phase == "train") |
|
frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} |
|
imgs = [frame_dict[idx] for idx in frame_inds] |
|
img_shape = imgs[0].shape |
|
video = torch.stack(imgs, 0) |
|
video = video.permute(3, 0, 1, 2) |
|
if self.nfrags == 1: |
|
vfrag = get_spatial_fragments( |
|
video, |
|
fx, |
|
fy, |
|
fsx, |
|
fsy, |
|
aligned=self.aligned, |
|
fallback_type=self.fallback_type, |
|
) |
|
else: |
|
vfrag = get_spatial_fragments( |
|
video, |
|
fx, |
|
fy, |
|
fsx, |
|
fsy, |
|
aligned=self.aligned, |
|
fallback_type=self.fallback_type, |
|
) |
|
for i in range(1, self.nfrags): |
|
vfrag = torch.cat( |
|
( |
|
vfrag, |
|
get_spatial_fragments( |
|
video, |
|
fragments, |
|
fx, |
|
fy, |
|
fsx, |
|
fsy, |
|
aligned=self.aligned, |
|
fallback_type=self.fallback_type, |
|
), |
|
), |
|
1, |
|
) |
|
if tocache: |
|
return (vfrag, frame_inds, label, img_shape) |
|
else: |
|
vfrag, frame_inds, label, img_shape = self.cache[index] |
|
vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) |
|
data = { |
|
"video": vfrag.reshape( |
|
(-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:] |
|
).transpose( |
|
0, 1 |
|
), |
|
"frame_inds": frame_inds, |
|
"gt_label": label, |
|
"original_shape": img_shape, |
|
} |
|
if need_original_frames: |
|
data["original_video"] = video.reshape( |
|
(-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] |
|
).transpose(0, 1) |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.video_infos) |
|
|
|
|
|
class FragmentVideoDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
ann_file, |
|
data_prefix, |
|
clip_len=32, |
|
frame_interval=2, |
|
num_clips=4, |
|
aligned=32, |
|
fragments=7, |
|
fsize=32, |
|
nfrags=1, |
|
cache_in_memory=False, |
|
phase="test", |
|
): |
|
""" |
|
Fragments. |
|
args: |
|
fragments: G_f as in the paper. |
|
fsize: S_f as in the paper. |
|
nfrags: number of samples as in the paper. |
|
""" |
|
self.ann_file = ann_file |
|
self.data_prefix = data_prefix |
|
self.clip_len = clip_len |
|
self.frame_interval = frame_interval |
|
self.num_clips = num_clips |
|
self.fragments = fragments |
|
self.fsize = fsize |
|
self.nfrags = nfrags |
|
self.aligned = aligned |
|
self.sampler = SampleFrames(clip_len, frame_interval, num_clips) |
|
self.video_infos = [] |
|
self.phase = phase |
|
self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) |
|
self.std = torch.FloatTensor([58.395, 57.12, 57.375]) |
|
if isinstance(self.ann_file, list): |
|
self.video_infos = self.ann_file |
|
else: |
|
with open(self.ann_file, "r") as fin: |
|
for line in fin: |
|
line_split = line.strip().split(",") |
|
filename, _, _, label = line_split |
|
label = float(label) |
|
filename = osp.join(self.data_prefix, filename) |
|
self.video_infos.append(dict(filename=filename, label=label)) |
|
if cache_in_memory: |
|
self.cache = {} |
|
for i in tqdm(range(len(self)), desc="Caching fragments"): |
|
self.cache[i] = self.__getitem__(i, tocache=True) |
|
else: |
|
self.cache = None |
|
|
|
def __getitem__( |
|
self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False, |
|
): |
|
if tocache or self.cache is None: |
|
if fragments == -1: |
|
fragments = self.fragments |
|
if fsize == -1: |
|
fsize = self.fsize |
|
video_info = self.video_infos[index] |
|
filename = video_info["filename"] |
|
label = video_info["label"] |
|
if filename.endswith(".yuv"): |
|
video = skvideo.io.vread( |
|
filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} |
|
) |
|
frame_inds = self.sampler(video.shape[0], self.phase == "train") |
|
imgs = [torch.from_numpy(video[idx]) for idx in frame_inds] |
|
else: |
|
vreader = VideoReader(filename) |
|
frame_inds = self.sampler(len(vreader), self.phase == "train") |
|
frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} |
|
imgs = [frame_dict[idx] for idx in frame_inds] |
|
img_shape = imgs[0].shape |
|
video = torch.stack(imgs, 0) |
|
video = video.permute(3, 0, 1, 2) |
|
if self.nfrags == 1: |
|
vfrag = get_spatial_fragments( |
|
video, fragments, fragments, fsize, fsize, aligned=self.aligned |
|
) |
|
else: |
|
vfrag = get_spatial_fragments( |
|
video, fragments, fragments, fsize, fsize, aligned=self.aligned |
|
) |
|
for i in range(1, self.nfrags): |
|
vfrag = torch.cat( |
|
( |
|
vfrag, |
|
get_spatial_fragments( |
|
video, |
|
fragments, |
|
fragments, |
|
fsize, |
|
fsize, |
|
aligned=self.aligned, |
|
), |
|
), |
|
1, |
|
) |
|
if tocache: |
|
return (vfrag, frame_inds, label, img_shape) |
|
else: |
|
vfrag, frame_inds, label, img_shape = self.cache[index] |
|
vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) |
|
data = { |
|
"video": vfrag.reshape( |
|
(-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:] |
|
).transpose( |
|
0, 1 |
|
), |
|
"frame_inds": frame_inds, |
|
"gt_label": label, |
|
"original_shape": img_shape, |
|
} |
|
if need_original_frames: |
|
data["original_video"] = video.reshape( |
|
(-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] |
|
).transpose(0, 1) |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.video_infos) |
|
|
|
|
|
class ResizedVideoDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
ann_file, |
|
data_prefix, |
|
clip_len=32, |
|
frame_interval=2, |
|
num_clips=4, |
|
aligned=32, |
|
size=224, |
|
cache_in_memory=False, |
|
phase="test", |
|
): |
|
""" |
|
Using resizing. |
|
""" |
|
self.ann_file = ann_file |
|
self.data_prefix = data_prefix |
|
self.clip_len = clip_len |
|
self.frame_interval = frame_interval |
|
self.num_clips = num_clips |
|
self.size = size |
|
self.aligned = aligned |
|
self.sampler = SampleFrames(clip_len, frame_interval, num_clips) |
|
self.video_infos = [] |
|
self.phase = phase |
|
self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) |
|
self.std = torch.FloatTensor([58.395, 57.12, 57.375]) |
|
if isinstance(self.ann_file, list): |
|
self.video_infos = self.ann_file |
|
else: |
|
with open(self.ann_file, "r") as fin: |
|
for line in fin: |
|
line_split = line.strip().split(",") |
|
filename, _, _, label = line_split |
|
label = float(label) |
|
filename = osp.join(self.data_prefix, filename) |
|
self.video_infos.append(dict(filename=filename, label=label)) |
|
if cache_in_memory: |
|
self.cache = {} |
|
for i in tqdm(range(len(self)), desc="Caching resized videos"): |
|
self.cache[i] = self.__getitem__(i, tocache=True) |
|
else: |
|
self.cache = None |
|
|
|
def __getitem__(self, index, tocache=False, need_original_frames=False): |
|
if tocache or self.cache is None: |
|
video_info = self.video_infos[index] |
|
filename = video_info["filename"] |
|
label = video_info["label"] |
|
vreader = VideoReader(filename) |
|
frame_inds = self.sampler(len(vreader), self.phase == "train") |
|
frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} |
|
imgs = [frame_dict[idx] for idx in frame_inds] |
|
img_shape = imgs[0].shape |
|
video = torch.stack(imgs, 0) |
|
video = video.permute(3, 0, 1, 2) |
|
video = torch.nn.functional.interpolate(video, size=(self.size, self.size)) |
|
if tocache: |
|
return (vfrag, frame_inds, label, img_shape) |
|
else: |
|
vfrag, frame_inds, label, img_shape = self.cache[index] |
|
vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) |
|
data = { |
|
"video": vfrag.reshape( |
|
(-1, self.num_clips, self.clip_len) + vfrag.shape[2:] |
|
).transpose( |
|
0, 1 |
|
), |
|
"frame_inds": frame_inds, |
|
"gt_label": label, |
|
"original_shape": img_shape, |
|
} |
|
if need_original_frames: |
|
data["original_video"] = video.reshape( |
|
(-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] |
|
).transpose(0, 1) |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.video_infos) |
|
|
|
|
|
class CroppedVideoDataset(FragmentVideoDataset): |
|
def __init__( |
|
self, |
|
ann_file, |
|
data_prefix, |
|
clip_len=32, |
|
frame_interval=2, |
|
num_clips=4, |
|
aligned=32, |
|
size=224, |
|
ncrops=1, |
|
cache_in_memory=False, |
|
phase="test", |
|
): |
|
|
|
""" |
|
Regard Cropping as a special case for Fragments in Grid 1*1. |
|
""" |
|
super().__init__( |
|
ann_file, |
|
data_prefix, |
|
clip_len=clip_len, |
|
frame_interval=frame_interval, |
|
num_clips=num_clips, |
|
aligned=aligned, |
|
fragments=1, |
|
fsize=224, |
|
nfrags=ncrops, |
|
cache_in_memory=cache_in_memory, |
|
phase=phase, |
|
) |
|
|
|
|
|
class FragmentImageDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, |
|
ann_file, |
|
data_prefix, |
|
fragments=7, |
|
fsize=32, |
|
nfrags=1, |
|
cache_in_memory=False, |
|
phase="test", |
|
): |
|
self.ann_file = ann_file |
|
self.data_prefix = data_prefix |
|
self.fragments = fragments |
|
self.fsize = fsize |
|
self.nfrags = nfrags |
|
self.image_infos = [] |
|
self.phase = phase |
|
self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) |
|
self.std = torch.FloatTensor([58.395, 57.12, 57.375]) |
|
if isinstance(self.ann_file, list): |
|
self.image_infos = self.ann_file |
|
else: |
|
with open(self.ann_file, "r") as fin: |
|
for line in fin: |
|
line_split = line.strip().split(",") |
|
filename, _, _, label = line_split |
|
label = float(label) |
|
filename = osp.join(self.data_prefix, filename) |
|
self.image_infos.append(dict(filename=filename, label=label)) |
|
if cache_in_memory: |
|
self.cache = {} |
|
for i in tqdm(range(len(self)), desc="Caching fragments"): |
|
self.cache[i] = self.__getitem__(i, tocache=True) |
|
else: |
|
self.cache = None |
|
|
|
def __getitem__( |
|
self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False |
|
): |
|
if tocache or self.cache is None: |
|
if fragments == -1: |
|
fragments = self.fragments |
|
if fsize == -1: |
|
fsize = self.fsize |
|
image_info = self.image_infos[index] |
|
filename = image_info["filename"] |
|
label = image_info["label"] |
|
try: |
|
img = torchvision.io.read_image(filename) |
|
except: |
|
img = cv2.imread(filename) |
|
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1) |
|
img_shape = img.shape[1:] |
|
image = img.unsqueeze(1) |
|
if self.nfrags == 1: |
|
ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize) |
|
else: |
|
ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize) |
|
for i in range(1, self.nfrags): |
|
ifrag = torch.cat( |
|
( |
|
ifrag, |
|
get_spatial_fragments( |
|
image, fragments, fragments, fsize, fsize |
|
), |
|
), |
|
1, |
|
) |
|
if tocache: |
|
return (ifrag, label, img_shape) |
|
else: |
|
ifrag, label, img_shape = self.cache[index] |
|
if self.nfrags == 1: |
|
ifrag = ( |
|
((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) |
|
.squeeze(0) |
|
.permute(2, 0, 1) |
|
) |
|
else: |
|
|
|
ifrag = ( |
|
((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) |
|
.squeeze(0) |
|
.permute(0, 3, 1, 2) |
|
) |
|
data = { |
|
"image": ifrag, |
|
"gt_label": label, |
|
"original_shape": img_shape, |
|
"name": filename, |
|
} |
|
if need_original_frames: |
|
data["original_image"] = image.squeeze(1) |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.image_infos) |
|
|
|
|
|
class ResizedImageDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, ann_file, data_prefix, size=224, cache_in_memory=False, phase="test", |
|
): |
|
self.ann_file = ann_file |
|
self.data_prefix = data_prefix |
|
self.size = size |
|
self.image_infos = [] |
|
self.phase = phase |
|
self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) |
|
self.std = torch.FloatTensor([58.395, 57.12, 57.375]) |
|
if isinstance(self.ann_file, list): |
|
self.image_infos = self.ann_file |
|
else: |
|
with open(self.ann_file, "r") as fin: |
|
for line in fin: |
|
line_split = line.strip().split(",") |
|
filename, _, _, label = line_split |
|
label = float(label) |
|
filename = osp.join(self.data_prefix, filename) |
|
self.image_infos.append(dict(filename=filename, label=label)) |
|
if cache_in_memory: |
|
self.cache = {} |
|
for i in tqdm(range(len(self)), desc="Caching fragments"): |
|
self.cache[i] = self.__getitem__(i, tocache=True) |
|
else: |
|
self.cache = None |
|
|
|
def __getitem__( |
|
self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False |
|
): |
|
if tocache or self.cache is None: |
|
if fragments == -1: |
|
fragments = self.fragments |
|
if fsize == -1: |
|
fsize = self.fsize |
|
image_info = self.image_infos[index] |
|
filename = image_info["filename"] |
|
label = image_info["label"] |
|
img = torchvision.io.read_image(filename) |
|
img_shape = img.shape[1:] |
|
image = img.unsqueeze(1) |
|
if self.nfrags == 1: |
|
ifrag = get_spatial_fragments(image, fragments, fsize) |
|
else: |
|
ifrag = get_spatial_fragments(image, fragments, fsize) |
|
for i in range(1, self.nfrags): |
|
ifrag = torch.cat( |
|
(ifrag, get_spatial_fragments(image, fragments, fsize)), 1 |
|
) |
|
if tocache: |
|
return (ifrag, label, img_shape) |
|
else: |
|
ifrag, label, img_shape = self.cache[index] |
|
ifrag = ( |
|
((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) |
|
.squeeze(0) |
|
.permute(2, 0, 1) |
|
) |
|
data = { |
|
"image": ifrag, |
|
"gt_label": label, |
|
"original_shape": img_shape, |
|
} |
|
if need_original_frames: |
|
data["original_image"] = image.squeeze(1) |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.image_infos) |
|
|
|
|
|
class CroppedImageDataset(FragmentImageDataset): |
|
def __init__( |
|
self, |
|
ann_file, |
|
data_prefix, |
|
size=224, |
|
ncrops=1, |
|
cache_in_memory=False, |
|
phase="test", |
|
): |
|
|
|
""" |
|
Regard Cropping as a special case for Fragments in Grid 1*1. |
|
""" |
|
super().__init__( |
|
ann_file, |
|
data_prefix, |
|
fragments=1, |
|
fsize=224, |
|
nfrags=ncrops, |
|
cache_in_memory=cache_in_memory, |
|
phase=phase, |
|
) |
|
|