Spaces:
Sleeping
Sleeping
import asyncio | |
import math | |
from collections import deque | |
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
from glob import glob | |
from pathlib import Path | |
import av | |
import numpy as np | |
from PIL import Image | |
from torch.utils.data import Dataset, default_collate | |
def get_default_video_reader( | |
data_path, | |
): | |
with av.open(str(data_path)) as container: | |
for frame in container.decode(video=0): | |
yield frame.to_ndarray( | |
format="rgb" if data_path.suffix == ".mp4" else "rgba" | |
) | |
accepted_format = set([".webp", ".png", ".jpg"]) | |
def read_image(path): | |
return np.array(Image.open(path).convert("RGBA")) | |
class ImageDataset(Dataset): | |
def __init__(self, path, num_skip_frames=0): | |
paths = sorted( | |
[it for it in glob(f"{path}/*") if Path(it).suffix in accepted_format] | |
) | |
self.paths = paths[num_skip_frames:] + paths[:num_skip_frames] | |
def __getitem__(self, idx): | |
return read_image(self.paths[idx]) | |
def __len__(self): | |
return len(self.paths) | |
class ProcessPoolIterator: | |
def __init__(self, dataset, preload=8, num_workers=2): | |
self.pool = ProcessPoolExecutor(num_workers) | |
self.dataset = dataset | |
self.queue = deque() | |
self.preload = preload | |
def __iter__(self): | |
for i in range(min(self.preload, len(self.dataset))): | |
self.queue.append(self.pool.submit(self.dataset.__getitem__, i)) | |
for i in range(self.preload, len(self.dataset)): | |
self.queue.append(self.pool.submit(self.dataset.__getitem__, i)) | |
yield self.queue.popleft().result() | |
while len(self.queue): | |
yield self.queue.popleft().result() | |
def __len__(self): | |
return len(self.dataset) | |
class ProcessPoolBatchIterator: | |
def __init__(self, dataset, batch_size, num_workers=4, drop_last=False): | |
self.iterator = ProcessPoolIterator( | |
dataset=dataset, preload=batch_size, num_workers=num_workers | |
) | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
def __iter__(self): | |
iterator = iter(self.iterator) | |
while True: | |
ret = [] | |
try: | |
for i in range(self.batch_size): | |
ret.append(next(iterator)) | |
yield default_collate(ret) | |
except StopIteration as e: | |
if not self.drop_last and ret: | |
yield default_collate(ret) | |
break | |
def __len__(self): | |
return ( | |
math.floor(len(self.iterator) / self.batch_size) | |
if self.drop_last | |
else math.ceil(len(self.iterator) / self.batch_size) | |
) | |
class AsyncProcessPoolIterator: | |
def __init__(self, dataset, preload=8, num_workers=4): | |
self.pool = ProcessPoolExecutor(num_workers) | |
self.dataset = dataset | |
self.queue = deque() | |
self.preload = preload | |
async def __aiter__(self): | |
loop = asyncio.get_running_loop() | |
for i in range(min(self.preload, len(self.dataset))): | |
self.queue.append( | |
loop.run_in_executor(self.pool, self.dataset.__getitem__, i) | |
) | |
for i in range(self.preload, len(self.dataset)): | |
self.queue.append( | |
loop.run_in_executor(self.pool, self.dataset.__getitem__, i) | |
) | |
yield await self.queue.popleft() | |
while len(self.queue): | |
yield await self.queue.popleft() | |
def __len__(self): | |
return len(self.dataset) | |
class AsyncProcessPoolBatchIterator: | |
def __init__(self, dataset, batch_size, num_workers=4, drop_last=False): | |
self.iterator = AsyncProcessPoolIterator( | |
dataset=dataset, preload=batch_size, num_workers=num_workers | |
) | |
self.batch_size = batch_size | |
self.drop_last = drop_last | |
async def __aiter__(self): | |
iterator = aiter(self.iterator) | |
while True: | |
ret = [] | |
try: | |
for _ in range(self.batch_size): | |
ret.append(await anext(iterator)) | |
yield default_collate(ret) | |
except StopAsyncIteration as e: | |
if not self.drop_last and ret: | |
yield default_collate(ret) | |
break | |
def __len__(self): | |
return ( | |
math.floor(len(self.iterator) / self.batch_size) | |
if self.drop_last | |
else math.ceil(len(self.iterator) / self.batch_size) | |
) | |
def get_image_folder_process_reader( | |
data_path, | |
num_skip_frames=0, | |
num_workers=4, | |
preload=16, | |
): | |
dataset = ImageDataset(path=data_path, num_skip_frames=num_skip_frames) | |
dataloader = ProcessPoolIterator( | |
dataset=dataset, | |
num_workers=num_workers, | |
preload=preload, | |
) | |
return dataloader | |
def get_image_folder_async_process_reader( | |
data_path, | |
num_skip_frames=0, | |
num_workers=4, | |
preload=16, | |
): | |
dataset = ImageDataset(path=data_path, num_skip_frames=num_skip_frames) | |
dataloader = AsyncProcessPoolIterator( | |
dataset=dataset, | |
num_workers=num_workers, | |
preload=preload, | |
) | |
return dataloader | |