Spaces:
No application file
No application file
# TODO: Adapted from cli | |
import math | |
from typing import Callable, List, Optional | |
import numpy as np | |
from mmcm.utils.itertools_util import generate_sample_idxs | |
# copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py | |
def ordered_halving(val): | |
bin_str = f"{val:064b}" | |
bin_flip = bin_str[::-1] | |
as_int = int(bin_flip, 2) | |
return as_int / (1 << 64) | |
# TODO: closed_loop not work, to fix it | |
def uniform( | |
step: int = ..., | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
if num_frames <= context_size: | |
yield list(range(num_frames)) | |
return | |
context_stride = min( | |
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 | |
) | |
for context_step in 1 << np.arange(context_stride): | |
pad = int(round(num_frames * ordered_halving(step))) | |
for j in range( | |
int(ordered_halving(step) * context_step) + pad, | |
num_frames + pad + (0 if closed_loop else -context_overlap), | |
(context_size * context_step - context_overlap), | |
): | |
yield [ | |
e % num_frames | |
for e in range(j, j + context_size * context_step, context_step) | |
] | |
def uniform_v2( | |
step: int = ..., | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
return generate_sample_idxs( | |
total=num_frames, | |
window_size=context_size, | |
step=context_size - context_overlap, | |
sample_rate=1, | |
drop_last=False, | |
) | |
def get_context_scheduler(name: str) -> Callable: | |
if name == "uniform": | |
return uniform | |
elif name == "uniform_v2": | |
return uniform_v2 | |
else: | |
raise ValueError(f"Unknown context_overlap policy {name}") | |
def get_total_steps( | |
scheduler, | |
timesteps: List[int], | |
num_steps: Optional[int] = None, | |
num_frames: int = ..., | |
context_size: Optional[int] = None, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
closed_loop: bool = True, | |
): | |
return sum( | |
len( | |
list( | |
scheduler( | |
i, | |
num_steps, | |
num_frames, | |
context_size, | |
context_stride, | |
context_overlap, | |
) | |
) | |
) | |
for i in range(len(timesteps)) | |
) | |
def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]: | |
"""if len(contexts)>=2 and the max value the oenultimate list same as of the last list | |
Args: | |
List (_type_): _description_ | |
Returns: | |
List[List[int]]: _description_ | |
""" | |
if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]: | |
return contexts[:-1] | |
else: | |
return contexts | |
def prepare_global_context( | |
context_schedule: str, | |
num_inference_steps: int, | |
time_size: int, | |
context_frames: int, | |
context_stride: int, | |
context_overlap: int, | |
context_batch_size: int, | |
): | |
context_scheduler = get_context_scheduler(context_schedule) | |
context_queue = list( | |
context_scheduler( | |
step=0, | |
num_steps=num_inference_steps, | |
num_frames=time_size, | |
context_size=context_frames, | |
context_stride=context_stride, | |
context_overlap=context_overlap, | |
) | |
) | |
# 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉 | |
# remove the last context if max index of the last context is the same as the max index of the second last context | |
context_queue = drop_last_repeat_context(context_queue) | |
num_context_batches = math.ceil(len(context_queue) / context_batch_size) | |
global_context = [] | |
for i_tmp in range(num_context_batches): | |
global_context.append( | |
context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size] | |
) | |
return global_context | |