Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision.transforms import ToTensor | |
def get_mgrid(sidelen, vmin=-1, vmax=1): | |
if type(vmin) is not list: | |
vmin = [vmin for _ in range(len(sidelen))] | |
if type(vmax) is not list: | |
vmax = [vmax for _ in range(len(sidelen))] | |
tensors = tuple([torch.linspace(vmin[i], vmax[i], steps=sidelen[i]) for i in range(len(sidelen))]) | |
mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1) | |
mgrid = mgrid.reshape(-1, len(sidelen)) | |
return mgrid | |
def apply_homography(x, h): | |
h = torch.cat([h, torch.ones_like(h[:, [0]])], -1) | |
h = h.view(-1, 3, 3) | |
x = torch.cat([x, torch.ones_like(x[:, 0]).unsqueeze(-1)], -1).unsqueeze(-1) | |
o = torch.bmm(h, x).squeeze(-1) | |
o = o[:, :-1] / o[:, [-1]] | |
return o | |
def jacobian(y, x): | |
B, N = y.shape | |
jacobian = list() | |
for i in range(N): | |
v = torch.zeros_like(y) | |
v[:, i] = 1. | |
dy_i_dx = torch.autograd.grad(y, | |
x, | |
grad_outputs=v, | |
retain_graph=True, | |
create_graph=True)[0] # shape [B, N] | |
jacobian.append(dy_i_dx) | |
jacobian = torch.stack(jacobian, dim=1).requires_grad_() | |
return jacobian | |
def overlap_mix(img1, img2, img_order, overlap_num): | |
w1 = np.linspace(0, 1, overlap_num)[::-1] | |
w2 = 1 - w1 | |
return w1[img_order] * img1 + w2[img_order] * img2 | |
class VideoFitting(Dataset): | |
def __init__(self, path, transform=None): | |
super().__init__() | |
self.path = path | |
if transform is None: | |
self.transform = ToTensor() | |
else: | |
self.transform = transform | |
self.video = self.get_video_tensor() | |
self.num_frames, _, self.H, self.W = self.video.size() | |
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
shuffle = torch.randperm(len(self.pixels)) | |
self.pixels = self.pixels[shuffle] | |
self.coords = self.coords[shuffle] | |
def get_video_tensor(self): | |
frames = sorted(os.listdir(self.path)) | |
video = [] | |
for i in range(len(frames)): | |
img = Image.open(os.path.join(self.path, frames[i])) | |
img = self.transform(img) | |
video.append(img) | |
return torch.stack(video, 0) | |
def __len__(self): | |
return 1 | |
def __getitem__(self, idx): | |
if idx > 0: raise IndexError | |
return self.coords, self.pixels | |
class TestVideoFitting(Dataset): | |
def __init__(self, path, transform=None): | |
super().__init__() | |
self.path = path | |
if transform is None: | |
self.transform = ToTensor() | |
else: | |
self.transform = transform | |
self.video = self.get_video_tensor() | |
self.num_frames, _, self.H, self.W = self.video.size() | |
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
def get_video_tensor(self): | |
frames = sorted(os.listdir(self.path)) | |
video = [] | |
for i in range(len(frames)): | |
img = Image.open(os.path.join(self.path, frames[i])) | |
img = self.transform(img) | |
video.append(img) | |
return torch.stack(video, 0) | |
def __len__(self): | |
return 1 | |
def __getitem__(self, idx): | |
if idx > 0: raise IndexError | |
return self.coords, self.pixels | |
class GroupVideoFitting(Dataset): | |
def __init__(self, path, mask_path, transform=None, mask_transform=None): | |
super().__init__() | |
self.path = path | |
self.mask_path = mask_path | |
if transform is None: | |
self.transform = ToTensor() | |
else: | |
self.transform = transform | |
if mask_transform is None: | |
self.mask_transform = ToTensor() | |
else: | |
self.mask_transform = mask_transform | |
self.video = self.get_video_tensor() | |
self.mask = self.get_mask_tensor() | |
self.num_frames, _, self.H, self.W = self.video.size() | |
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1) | |
self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
shuffle = torch.randperm(len(self.pixels)) | |
self.pixels = self.pixels[shuffle] | |
self.coords = self.coords[shuffle] | |
self.mask_pixels = self.mask_pixels[shuffle] | |
def get_video_tensor(self): | |
frames = sorted(os.listdir(self.path)) | |
video = [] | |
for i in range(len(frames)): | |
img = Image.open(os.path.join(self.path, frames[i])) | |
img = self.transform(img) | |
video.append(img) | |
return torch.stack(video, 0) | |
def get_mask_tensor(self): | |
masks = sorted(os.listdir(self.mask_path)) | |
all_mask = [] | |
for i in range(len(masks)): | |
mask = Image.open(os.path.join(self.mask_path, masks[i])) | |
mask = self.mask_transform(mask) | |
all_mask.append(mask) | |
return torch.stack(all_mask, 0) | |
def __len__(self): | |
return 1 | |
def __getitem__(self, idx): | |
if idx > 0: raise IndexError | |
return self.coords, self.pixels, self.mask_pixels | |
class TestGroupVideoFitting(Dataset): | |
def __init__(self, path, mask_path, back_mask_path, transform=None, mask_transform=None): | |
super().__init__() | |
self.path = path | |
self.mask_path = mask_path | |
self.back_mask_path = back_mask_path | |
if transform is None: | |
self.transform = ToTensor() | |
else: | |
self.transform = transform | |
if mask_transform is None: | |
self.mask_transform = ToTensor() | |
else: | |
self.mask_transform = mask_transform | |
self.video = self.get_video_tensor() | |
self.mask = self.get_mask_tensor() | |
self.back_mask = self.get_back_mask_tensor() | |
self.num_frames, _, self.H, self.W = self.video.size() | |
self.pixels = self.video.permute(2, 3, 0, 1).contiguous().view(-1, 3) | |
self.mask_pixels = self.mask.permute(2, 3, 0, 1).contiguous().view(-1, 1) | |
self.back_mask_pixels = self.back_mask.permute(2, 3, 0, 1).contiguous().view(-1, 1) | |
self.coords = get_mgrid([self.H, self.W, self.num_frames]) | |
def get_video_tensor(self): | |
frames = sorted(os.listdir(self.path)) | |
video = [] | |
for i in range(len(frames)): | |
img = Image.open(os.path.join(self.path, frames[i])) | |
img = self.transform(img) | |
video.append(img) | |
return torch.stack(video, 0) | |
def get_mask_tensor(self): | |
masks = sorted(os.listdir(self.mask_path)) | |
all_mask = [] | |
for i in range(len(masks)): | |
mask = Image.open(os.path.join(self.mask_path, masks[i])) | |
mask = self.mask_transform(mask) | |
all_mask.append(mask) | |
return torch.stack(all_mask, 0) | |
def get_back_mask_tensor(self): | |
masks = sorted(os.listdir(self.back_mask_path)) | |
all_mask = [] | |
for i in range(len(masks)): | |
mask = Image.open(os.path.join(self.back_mask_path, masks[i])) | |
mask = self.mask_transform(mask) | |
all_mask.append(mask) | |
return torch.stack(all_mask, 0) | |
def __len__(self): | |
return 1 | |
def __getitem__(self, idx): | |
if idx > 0: raise IndexError | |
return self.coords, self.pixels, self.mask_pixels, self.back_mask_pixels |