Spaces:
vztu
/
Runtime error

COVER / cover /datasets /basic_datasets.py
nanushio
+ [MAJOR] [ROOT] [CREATE] 1. fork repo from COVER github
feb2918
import os.path as osp
import random
import cv2
import decord
import numpy as np
import skvideo.io
import torch
import torchvision
from decord import VideoReader, cpu, gpu
from tqdm import tqdm
random.seed(42)
decord.bridge.set_bridge("torch")
def get_spatial_fragments(
video,
fragments_h=7,
fragments_w=7,
fsize_h=32,
fsize_w=32,
aligned=32,
nfrags=1,
random=False,
fallback_type="upsample",
):
size_h = fragments_h * fsize_h
size_w = fragments_w * fsize_w
## situation for images
if video.shape[1] == 1:
aligned = 1
dur_t, res_h, res_w = video.shape[-3:]
ratio = min(res_h / size_h, res_w / size_w)
if fallback_type == "upsample" and ratio < 1:
ovideo = video
video = torch.nn.functional.interpolate(
video / 255.0, scale_factor=1 / ratio, mode="bilinear"
)
video = (video * 255.0).type_as(ovideo)
assert dur_t % aligned == 0, "Please provide match vclip and align index"
size = size_h, size_w
## make sure that sampling will not run out of the picture
hgrids = torch.LongTensor(
[min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)]
)
wgrids = torch.LongTensor(
[min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)]
)
hlength, wlength = res_h // fragments_h, res_w // fragments_w
if random:
print("This part is deprecated. Please remind that.")
if res_h > fsize_h:
rnd_h = torch.randint(
res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
if res_w > fsize_w:
rnd_w = torch.randint(
res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
else:
if hlength > fsize_h:
rnd_h = torch.randint(
hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
if wlength > fsize_w:
rnd_w = torch.randint(
wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned)
)
else:
rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int()
target_video = torch.zeros(video.shape[:-2] + size).to(video.device)
# target_videos = []
for i, hs in enumerate(hgrids):
for j, ws in enumerate(wgrids):
for t in range(dur_t // aligned):
t_s, t_e = t * aligned, (t + 1) * aligned
h_s, h_e = i * fsize_h, (i + 1) * fsize_h
w_s, w_e = j * fsize_w, (j + 1) * fsize_w
if random:
h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h
w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w
else:
h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h
w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w
target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[
:, t_s:t_e, h_so:h_eo, w_so:w_eo
]
# target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo])
# target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6)
# target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments
return target_video
class FragmentSampleFrames:
def __init__(self, fsize_t, fragments_t, frame_interval=1, num_clips=1):
self.fragments_t = fragments_t
self.fsize_t = fsize_t
self.size_t = fragments_t * fsize_t
self.frame_interval = frame_interval
self.num_clips = num_clips
def get_frame_indices(self, num_frames):
tgrids = np.array(
[num_frames // self.fragments_t * i for i in range(self.fragments_t)],
dtype=np.int32,
)
tlength = num_frames // self.fragments_t
if tlength > self.fsize_t * self.frame_interval:
rnd_t = np.random.randint(
0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids)
)
else:
rnd_t = np.zeros(len(tgrids), dtype=np.int32)
ranges_t = (
np.arange(self.fsize_t)[None, :] * self.frame_interval
+ rnd_t[:, None]
+ tgrids[:, None]
)
return np.concatenate(ranges_t)
def __call__(self, total_frames, train=False, start_index=0):
frame_inds = []
for i in range(self.num_clips):
frame_inds += [self.get_frame_indices(total_frames)]
frame_inds = np.concatenate(frame_inds)
frame_inds = np.mod(frame_inds + start_index, total_frames)
return frame_inds
class SampleFrames:
def __init__(self, clip_len, frame_interval=1, num_clips=1):
self.clip_len = clip_len
self.frame_interval = frame_interval
self.num_clips = num_clips
def _get_train_clips(self, num_frames):
"""Get clip offsets in train mode.
It will calculate the average interval for selected frames,
and randomly shift them within offsets between [0, avg_interval].
If the total number of frames is smaller than clips num or origin
frames length, it will return all zero indices.
Args:
num_frames (int): Total number of frame in the video.
Returns:
np.ndarray: Sampled frame indices in train mode.
"""
ori_clip_len = self.clip_len * self.frame_interval
avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips
if avg_interval > 0:
base_offsets = np.arange(self.num_clips) * avg_interval
clip_offsets = base_offsets + np.random.randint(
avg_interval, size=self.num_clips
)
elif num_frames > max(self.num_clips, ori_clip_len):
clip_offsets = np.sort(
np.random.randint(num_frames - ori_clip_len + 1, size=self.num_clips)
)
elif avg_interval == 0:
ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips
clip_offsets = np.around(np.arange(self.num_clips) * ratio)
else:
clip_offsets = np.zeros((self.num_clips,), dtype=np.int)
return clip_offsets
def _get_test_clips(self, num_frames, start_index=0):
"""Get clip offsets in test mode.
Calculate the average interval for selected frames, and shift them
fixedly by avg_interval/2.
Args:
num_frames (int): Total number of frame in the video.
Returns:
np.ndarray: Sampled frame indices in test mode.
"""
ori_clip_len = self.clip_len * self.frame_interval
avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips)
if num_frames > ori_clip_len - 1:
base_offsets = np.arange(self.num_clips) * avg_interval
clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32)
else:
clip_offsets = np.zeros((self.num_clips,), dtype=np.int32)
return clip_offsets
def __call__(self, total_frames, train=False, start_index=0):
"""Perform the SampleFrames loading.
Args:
results (dict): The resulting dict to be modified and passed
to the next transform in pipeline.
"""
if train:
clip_offsets = self._get_train_clips(total_frames)
else:
clip_offsets = self._get_test_clips(total_frames)
frame_inds = (
clip_offsets[:, None]
+ np.arange(self.clip_len)[None, :] * self.frame_interval
)
frame_inds = np.concatenate(frame_inds)
frame_inds = frame_inds.reshape((-1, self.clip_len))
frame_inds = np.mod(frame_inds, total_frames)
frame_inds = np.concatenate(frame_inds) + start_index
return frame_inds.astype(np.int32)
class FastVQAPlusPlusDataset(torch.utils.data.Dataset):
def __init__(
self,
ann_file,
data_prefix,
frame_interval=2,
aligned=32,
fragments=(8, 8, 8),
fsize=(4, 32, 32),
num_clips=1,
nfrags=1,
cache_in_memory=False,
phase="test",
fallback_type="oversample",
):
"""
Fragments.
args:
fragments: G_f as in the paper.
fsize: S_f as in the paper.
nfrags: number of samples (spatially) as in the paper.
num_clips: number of samples (temporally) as in the paper.
"""
self.ann_file = ann_file
self.data_prefix = data_prefix
self.frame_interval = frame_interval
self.num_clips = num_clips
self.fragments = fragments
self.fsize = fsize
self.nfrags = nfrags
self.clip_len = fragments[0] * fsize[0]
self.aligned = aligned
self.fallback_type = fallback_type
self.sampler = FragmentSampleFrames(
fsize[0], fragments[0], frame_interval, num_clips
)
self.video_infos = []
self.phase = phase
self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
self.std = torch.FloatTensor([58.395, 57.12, 57.375])
if isinstance(self.ann_file, list):
self.video_infos = self.ann_file
else:
with open(self.ann_file, "r") as fin:
for line in fin:
line_split = line.strip().split(",")
filename, _, _, label = line_split
label = float(label)
filename = osp.join(self.data_prefix, filename)
self.video_infos.append(dict(filename=filename, label=label))
if cache_in_memory:
self.cache = {}
for i in tqdm(range(len(self)), desc="Caching fragments"):
self.cache[i] = self.__getitem__(i, tocache=True)
else:
self.cache = None
def __getitem__(
self, index, tocache=False, need_original_frames=False,
):
if tocache or self.cache is None:
fx, fy = self.fragments[1:]
fsx, fsy = self.fsize[1:]
video_info = self.video_infos[index]
filename = video_info["filename"]
label = video_info["label"]
if filename.endswith(".yuv"):
video = skvideo.io.vread(
filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
)
frame_inds = self.sampler(video.shape[0], self.phase == "train")
imgs = [torch.from_numpy(video[idx]) for idx in frame_inds]
else:
vreader = VideoReader(filename)
frame_inds = self.sampler(len(vreader), self.phase == "train")
frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)}
imgs = [frame_dict[idx] for idx in frame_inds]
img_shape = imgs[0].shape
video = torch.stack(imgs, 0)
video = video.permute(3, 0, 1, 2)
if self.nfrags == 1:
vfrag = get_spatial_fragments(
video,
fx,
fy,
fsx,
fsy,
aligned=self.aligned,
fallback_type=self.fallback_type,
)
else:
vfrag = get_spatial_fragments(
video,
fx,
fy,
fsx,
fsy,
aligned=self.aligned,
fallback_type=self.fallback_type,
)
for i in range(1, self.nfrags):
vfrag = torch.cat(
(
vfrag,
get_spatial_fragments(
video,
fragments,
fx,
fy,
fsx,
fsy,
aligned=self.aligned,
fallback_type=self.fallback_type,
),
),
1,
)
if tocache:
return (vfrag, frame_inds, label, img_shape)
else:
vfrag, frame_inds, label, img_shape = self.cache[index]
vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2)
data = {
"video": vfrag.reshape(
(-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:]
).transpose(
0, 1
), # B, V, T, C, H, W
"frame_inds": frame_inds,
"gt_label": label,
"original_shape": img_shape,
}
if need_original_frames:
data["original_video"] = video.reshape(
(-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:]
).transpose(0, 1)
return data
def __len__(self):
return len(self.video_infos)
class FragmentVideoDataset(torch.utils.data.Dataset):
def __init__(
self,
ann_file,
data_prefix,
clip_len=32,
frame_interval=2,
num_clips=4,
aligned=32,
fragments=7,
fsize=32,
nfrags=1,
cache_in_memory=False,
phase="test",
):
"""
Fragments.
args:
fragments: G_f as in the paper.
fsize: S_f as in the paper.
nfrags: number of samples as in the paper.
"""
self.ann_file = ann_file
self.data_prefix = data_prefix
self.clip_len = clip_len
self.frame_interval = frame_interval
self.num_clips = num_clips
self.fragments = fragments
self.fsize = fsize
self.nfrags = nfrags
self.aligned = aligned
self.sampler = SampleFrames(clip_len, frame_interval, num_clips)
self.video_infos = []
self.phase = phase
self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
self.std = torch.FloatTensor([58.395, 57.12, 57.375])
if isinstance(self.ann_file, list):
self.video_infos = self.ann_file
else:
with open(self.ann_file, "r") as fin:
for line in fin:
line_split = line.strip().split(",")
filename, _, _, label = line_split
label = float(label)
filename = osp.join(self.data_prefix, filename)
self.video_infos.append(dict(filename=filename, label=label))
if cache_in_memory:
self.cache = {}
for i in tqdm(range(len(self)), desc="Caching fragments"):
self.cache[i] = self.__getitem__(i, tocache=True)
else:
self.cache = None
def __getitem__(
self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False,
):
if tocache or self.cache is None:
if fragments == -1:
fragments = self.fragments
if fsize == -1:
fsize = self.fsize
video_info = self.video_infos[index]
filename = video_info["filename"]
label = video_info["label"]
if filename.endswith(".yuv"):
video = skvideo.io.vread(
filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"}
)
frame_inds = self.sampler(video.shape[0], self.phase == "train")
imgs = [torch.from_numpy(video[idx]) for idx in frame_inds]
else:
vreader = VideoReader(filename)
frame_inds = self.sampler(len(vreader), self.phase == "train")
frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)}
imgs = [frame_dict[idx] for idx in frame_inds]
img_shape = imgs[0].shape
video = torch.stack(imgs, 0)
video = video.permute(3, 0, 1, 2)
if self.nfrags == 1:
vfrag = get_spatial_fragments(
video, fragments, fragments, fsize, fsize, aligned=self.aligned
)
else:
vfrag = get_spatial_fragments(
video, fragments, fragments, fsize, fsize, aligned=self.aligned
)
for i in range(1, self.nfrags):
vfrag = torch.cat(
(
vfrag,
get_spatial_fragments(
video,
fragments,
fragments,
fsize,
fsize,
aligned=self.aligned,
),
),
1,
)
if tocache:
return (vfrag, frame_inds, label, img_shape)
else:
vfrag, frame_inds, label, img_shape = self.cache[index]
vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2)
data = {
"video": vfrag.reshape(
(-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:]
).transpose(
0, 1
), # B, V, T, C, H, W
"frame_inds": frame_inds,
"gt_label": label,
"original_shape": img_shape,
}
if need_original_frames:
data["original_video"] = video.reshape(
(-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:]
).transpose(0, 1)
return data
def __len__(self):
return len(self.video_infos)
class ResizedVideoDataset(torch.utils.data.Dataset):
def __init__(
self,
ann_file,
data_prefix,
clip_len=32,
frame_interval=2,
num_clips=4,
aligned=32,
size=224,
cache_in_memory=False,
phase="test",
):
"""
Using resizing.
"""
self.ann_file = ann_file
self.data_prefix = data_prefix
self.clip_len = clip_len
self.frame_interval = frame_interval
self.num_clips = num_clips
self.size = size
self.aligned = aligned
self.sampler = SampleFrames(clip_len, frame_interval, num_clips)
self.video_infos = []
self.phase = phase
self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
self.std = torch.FloatTensor([58.395, 57.12, 57.375])
if isinstance(self.ann_file, list):
self.video_infos = self.ann_file
else:
with open(self.ann_file, "r") as fin:
for line in fin:
line_split = line.strip().split(",")
filename, _, _, label = line_split
label = float(label)
filename = osp.join(self.data_prefix, filename)
self.video_infos.append(dict(filename=filename, label=label))
if cache_in_memory:
self.cache = {}
for i in tqdm(range(len(self)), desc="Caching resized videos"):
self.cache[i] = self.__getitem__(i, tocache=True)
else:
self.cache = None
def __getitem__(self, index, tocache=False, need_original_frames=False):
if tocache or self.cache is None:
video_info = self.video_infos[index]
filename = video_info["filename"]
label = video_info["label"]
vreader = VideoReader(filename)
frame_inds = self.sampler(len(vreader), self.phase == "train")
frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)}
imgs = [frame_dict[idx] for idx in frame_inds]
img_shape = imgs[0].shape
video = torch.stack(imgs, 0)
video = video.permute(3, 0, 1, 2)
video = torch.nn.functional.interpolate(video, size=(self.size, self.size))
if tocache:
return (vfrag, frame_inds, label, img_shape)
else:
vfrag, frame_inds, label, img_shape = self.cache[index]
vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2)
data = {
"video": vfrag.reshape(
(-1, self.num_clips, self.clip_len) + vfrag.shape[2:]
).transpose(
0, 1
), # B, V, T, C, H, W
"frame_inds": frame_inds,
"gt_label": label,
"original_shape": img_shape,
}
if need_original_frames:
data["original_video"] = video.reshape(
(-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:]
).transpose(0, 1)
return data
def __len__(self):
return len(self.video_infos)
class CroppedVideoDataset(FragmentVideoDataset):
def __init__(
self,
ann_file,
data_prefix,
clip_len=32,
frame_interval=2,
num_clips=4,
aligned=32,
size=224,
ncrops=1,
cache_in_memory=False,
phase="test",
):
"""
Regard Cropping as a special case for Fragments in Grid 1*1.
"""
super().__init__(
ann_file,
data_prefix,
clip_len=clip_len,
frame_interval=frame_interval,
num_clips=num_clips,
aligned=aligned,
fragments=1,
fsize=224,
nfrags=ncrops,
cache_in_memory=cache_in_memory,
phase=phase,
)
class FragmentImageDataset(torch.utils.data.Dataset):
def __init__(
self,
ann_file,
data_prefix,
fragments=7,
fsize=32,
nfrags=1,
cache_in_memory=False,
phase="test",
):
self.ann_file = ann_file
self.data_prefix = data_prefix
self.fragments = fragments
self.fsize = fsize
self.nfrags = nfrags
self.image_infos = []
self.phase = phase
self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
self.std = torch.FloatTensor([58.395, 57.12, 57.375])
if isinstance(self.ann_file, list):
self.image_infos = self.ann_file
else:
with open(self.ann_file, "r") as fin:
for line in fin:
line_split = line.strip().split(",")
filename, _, _, label = line_split
label = float(label)
filename = osp.join(self.data_prefix, filename)
self.image_infos.append(dict(filename=filename, label=label))
if cache_in_memory:
self.cache = {}
for i in tqdm(range(len(self)), desc="Caching fragments"):
self.cache[i] = self.__getitem__(i, tocache=True)
else:
self.cache = None
def __getitem__(
self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False
):
if tocache or self.cache is None:
if fragments == -1:
fragments = self.fragments
if fsize == -1:
fsize = self.fsize
image_info = self.image_infos[index]
filename = image_info["filename"]
label = image_info["label"]
try:
img = torchvision.io.read_image(filename)
except:
img = cv2.imread(filename)
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1)
img_shape = img.shape[1:]
image = img.unsqueeze(1)
if self.nfrags == 1:
ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize)
else:
ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize)
for i in range(1, self.nfrags):
ifrag = torch.cat(
(
ifrag,
get_spatial_fragments(
image, fragments, fragments, fsize, fsize
),
),
1,
)
if tocache:
return (ifrag, label, img_shape)
else:
ifrag, label, img_shape = self.cache[index]
if self.nfrags == 1:
ifrag = (
((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std)
.squeeze(0)
.permute(2, 0, 1)
)
else:
### During testing, one image as a batch
ifrag = (
((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std)
.squeeze(0)
.permute(0, 3, 1, 2)
)
data = {
"image": ifrag,
"gt_label": label,
"original_shape": img_shape,
"name": filename,
}
if need_original_frames:
data["original_image"] = image.squeeze(1)
return data
def __len__(self):
return len(self.image_infos)
class ResizedImageDataset(torch.utils.data.Dataset):
def __init__(
self, ann_file, data_prefix, size=224, cache_in_memory=False, phase="test",
):
self.ann_file = ann_file
self.data_prefix = data_prefix
self.size = size
self.image_infos = []
self.phase = phase
self.mean = torch.FloatTensor([123.675, 116.28, 103.53])
self.std = torch.FloatTensor([58.395, 57.12, 57.375])
if isinstance(self.ann_file, list):
self.image_infos = self.ann_file
else:
with open(self.ann_file, "r") as fin:
for line in fin:
line_split = line.strip().split(",")
filename, _, _, label = line_split
label = float(label)
filename = osp.join(self.data_prefix, filename)
self.image_infos.append(dict(filename=filename, label=label))
if cache_in_memory:
self.cache = {}
for i in tqdm(range(len(self)), desc="Caching fragments"):
self.cache[i] = self.__getitem__(i, tocache=True)
else:
self.cache = None
def __getitem__(
self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False
):
if tocache or self.cache is None:
if fragments == -1:
fragments = self.fragments
if fsize == -1:
fsize = self.fsize
image_info = self.image_infos[index]
filename = image_info["filename"]
label = image_info["label"]
img = torchvision.io.read_image(filename)
img_shape = img.shape[1:]
image = img.unsqueeze(1)
if self.nfrags == 1:
ifrag = get_spatial_fragments(image, fragments, fsize)
else:
ifrag = get_spatial_fragments(image, fragments, fsize)
for i in range(1, self.nfrags):
ifrag = torch.cat(
(ifrag, get_spatial_fragments(image, fragments, fsize)), 1
)
if tocache:
return (ifrag, label, img_shape)
else:
ifrag, label, img_shape = self.cache[index]
ifrag = (
((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std)
.squeeze(0)
.permute(2, 0, 1)
)
data = {
"image": ifrag,
"gt_label": label,
"original_shape": img_shape,
}
if need_original_frames:
data["original_image"] = image.squeeze(1)
return data
def __len__(self):
return len(self.image_infos)
class CroppedImageDataset(FragmentImageDataset):
def __init__(
self,
ann_file,
data_prefix,
size=224,
ncrops=1,
cache_in_memory=False,
phase="test",
):
"""
Regard Cropping as a special case for Fragments in Grid 1*1.
"""
super().__init__(
ann_file,
data_prefix,
fragments=1,
fsize=224,
nfrags=ncrops,
cache_in_memory=cache_in_memory,
phase=phase,
)