Spaces:
No application file
No application file
File size: 4,285 Bytes
6755a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# TODO: Adapted from cli
import math
from typing import Callable, List, Optional
import numpy as np
from mmcm.utils.itertools_util import generate_sample_idxs
# copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py
def ordered_halving(val):
bin_str = f"{val:064b}"
bin_flip = bin_str[::-1]
as_int = int(bin_flip, 2)
return as_int / (1 << 64)
# TODO: closed_loop not work, to fix it
def uniform(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
if num_frames <= context_size:
yield list(range(num_frames))
return
context_stride = min(
context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
)
for context_step in 1 << np.arange(context_stride):
pad = int(round(num_frames * ordered_halving(step)))
for j in range(
int(ordered_halving(step) * context_step) + pad,
num_frames + pad + (0 if closed_loop else -context_overlap),
(context_size * context_step - context_overlap),
):
yield [
e % num_frames
for e in range(j, j + context_size * context_step, context_step)
]
def uniform_v2(
step: int = ...,
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
return generate_sample_idxs(
total=num_frames,
window_size=context_size,
step=context_size - context_overlap,
sample_rate=1,
drop_last=False,
)
def get_context_scheduler(name: str) -> Callable:
if name == "uniform":
return uniform
elif name == "uniform_v2":
return uniform_v2
else:
raise ValueError(f"Unknown context_overlap policy {name}")
def get_total_steps(
scheduler,
timesteps: List[int],
num_steps: Optional[int] = None,
num_frames: int = ...,
context_size: Optional[int] = None,
context_stride: int = 3,
context_overlap: int = 4,
closed_loop: bool = True,
):
return sum(
len(
list(
scheduler(
i,
num_steps,
num_frames,
context_size,
context_stride,
context_overlap,
)
)
)
for i in range(len(timesteps))
)
def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]:
"""if len(contexts)>=2 and the max value the oenultimate list same as of the last list
Args:
List (_type_): _description_
Returns:
List[List[int]]: _description_
"""
if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]:
return contexts[:-1]
else:
return contexts
def prepare_global_context(
context_schedule: str,
num_inference_steps: int,
time_size: int,
context_frames: int,
context_stride: int,
context_overlap: int,
context_batch_size: int,
):
context_scheduler = get_context_scheduler(context_schedule)
context_queue = list(
context_scheduler(
step=0,
num_steps=num_inference_steps,
num_frames=time_size,
context_size=context_frames,
context_stride=context_stride,
context_overlap=context_overlap,
)
)
# 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉
# remove the last context if max index of the last context is the same as the max index of the second last context
context_queue = drop_last_repeat_context(context_queue)
num_context_batches = math.ceil(len(context_queue) / context_batch_size)
global_context = []
for i_tmp in range(num_context_batches):
global_context.append(
context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size]
)
return global_context
|