ProDiff / utils /multiprocess_utils.py
Rongjiehuang's picture
init
64e7f2f
raw
history blame
4.94 kB
import os
import traceback
from functools import partial
from tqdm import tqdm
def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None):
ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
while True:
args = args_queue.get()
if args == '<KILL>':
return
job_idx, map_func, arg = args
try:
map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func
if isinstance(arg, dict):
res = map_func_(**arg)
elif isinstance(arg, (list, tuple)):
res = map_func_(*arg)
else:
res = map_func_(arg)
results_queue.put((job_idx, res))
except:
traceback.print_exc()
results_queue.put((job_idx, None))
class MultiprocessManager:
def __init__(self, num_workers=None, init_ctx_func=None, multithread=False):
if multithread:
from multiprocessing.dummy import Queue, Process
else:
from multiprocessing import Queue, Process
if num_workers is None:
num_workers = int(os.getenv('N_PROC', os.cpu_count()))
self.num_workers = num_workers
self.results_queue = Queue(maxsize=-1)
self.args_queue = Queue(maxsize=-1)
self.workers = []
self.total_jobs = 0
for i in range(num_workers):
p = Process(target=chunked_worker,
args=(i, self.args_queue, self.results_queue, init_ctx_func),
daemon=True)
self.workers.append(p)
p.start()
def add_job(self, func, args):
self.args_queue.put((self.total_jobs, func, args))
self.total_jobs += 1
def get_results(self):
for w in range(self.num_workers):
self.args_queue.put("<KILL>")
self.n_finished = 0
while self.n_finished < self.total_jobs:
job_id, res = self.results_queue.get()
yield job_id, res
self.n_finished += 1
for w in self.workers:
w.join()
def __len__(self):
return self.total_jobs
def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None,
multithread=False, desc=None):
for i, res in tqdm(enumerate(
multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread)),
total=len(args), desc=desc):
yield i, res
def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False):
"""
Multiprocessing running chunked jobs.
Examples:
>>> for res in tqdm(multiprocess_run(job_func, args):
>>> print(res)
:param map_func:
:param args:
:param num_workers:
:param ordered:
:param init_ctx_func:
:param q_max_size:
:param multithread:
:return:
"""
if num_workers is None:
num_workers = int(os.getenv('N_PROC', os.cpu_count()))
manager = MultiprocessManager(num_workers, init_ctx_func, multithread)
for arg in args:
manager.add_job(map_func, arg)
if ordered:
n_jobs = len(args)
results = ['<WAIT>' for _ in range(n_jobs)]
i_now = 0
for job_i, res in manager.get_results():
results[job_i] = res
while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != '<WAIT>'):
yield results[i_now]
i_now += 1
else:
for res in manager.get_results():
yield res
def chunked_multiprocess_run(
map_func, args, num_workers=None, ordered=True,
init_ctx_func=None, q_max_size=1000, multithread=False):
if multithread:
from multiprocessing.dummy import Queue, Process
else:
from multiprocessing import Queue, Process
args = zip(range(len(args)), args)
args = list(args)
n_jobs = len(args)
if num_workers is None:
num_workers = int(os.getenv('N_PROC', os.cpu_count()))
results_queues = []
if ordered:
for i in range(num_workers):
results_queues.append(Queue(maxsize=q_max_size // num_workers))
else:
results_queue = Queue(maxsize=q_max_size)
for i in range(num_workers):
results_queues.append(results_queue)
workers = []
for i in range(num_workers):
args_worker = args[i::num_workers]
p = Process(target=chunked_worker, args=(
i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True)
workers.append(p)
p.start()
for n_finished in range(n_jobs):
results_queue = results_queues[n_finished % num_workers]
job_idx, res = results_queue.get()
assert job_idx == n_finished or not ordered, (job_idx, n_finished)
yield res
for w in workers:
w.join()