Spaces:
No application file
No application file
from typing import List, Dict, Literal, Union, Tuple | |
import os | |
import string | |
import logging | |
import torch | |
import numpy as np | |
from einops import rearrange, repeat | |
logger = logging.getLogger(__name__) | |
def generate_tasks_of_dir( | |
path: str, | |
output_dir: str, | |
exts: Tuple[str], | |
same_dir_name: bool = False, | |
**kwargs, | |
) -> List[Dict]: | |
"""covert video directory into tasks | |
Args: | |
path (str): _description_ | |
output_dir (str): _description_ | |
exts (Tuple[str]): _description_ | |
same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False. | |
whether keep the same parent dir name as the source video | |
Returns: | |
List[Dict]: _description_ | |
""" | |
tasks = [] | |
for rootdir, dirs, files in os.walk(path): | |
for basename in files: | |
if basename.lower().endswith(exts): | |
video_path = os.path.join(rootdir, basename) | |
filename, ext = basename.split(".") | |
rootdir_name = os.path.basename(rootdir) | |
if same_dir_name: | |
save_path = os.path.join( | |
output_dir, rootdir_name, f"{filename}.h5py" | |
) | |
save_dir = os.path.join(output_dir, rootdir_name) | |
else: | |
save_path = os.path.join(output_dir, f"{filename}.h5py") | |
save_dir = output_dir | |
task = { | |
"video_path": video_path, | |
"output_path": save_path, | |
"output_dir": save_dir, | |
"filename": filename, | |
"ext": ext, | |
} | |
task.update(kwargs) | |
tasks.append(task) | |
return tasks | |
def sample_by_idx( | |
T: int, | |
n_sample: int, | |
sample_rate: int, | |
sample_start_idx: int = None, | |
change_sample_rate: bool = False, | |
seed: int = None, | |
whether_random: bool = True, | |
n_independent: int = 0, | |
) -> List[int]: | |
"""given a int to represent candidate list, sample n_sample with sample_rate from the candidate list | |
Args: | |
T (int): _description_ | |
n_sample (int): 目标采样数目. sample number | |
sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number | |
sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0. | |
change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False. | |
whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False. | |
Raises: | |
ValueError: T / sample_rate should be larger than n_sample | |
Returns: | |
List[int]: 采样的索引位置. sampled index position | |
""" | |
if T < n_sample: | |
raise ValueError(f"T({T}) < n_sample({n_sample})") | |
else: | |
if T / sample_rate < n_sample: | |
if not change_sample_rate: | |
raise ValueError( | |
f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})" | |
) | |
else: | |
while T / sample_rate < n_sample: | |
sample_rate -= 1 | |
logger.error( | |
f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}" | |
) | |
if sample_rate == 0: | |
raise ValueError("T / sample_rate < n_sample") | |
if sample_start_idx is None: | |
if whether_random: | |
sample_start_idx_candidates = np.arange(T - n_sample * sample_rate) | |
if seed is not None: | |
np.random.seed(seed) | |
sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0] | |
else: | |
sample_start_idx = 0 | |
sample_end_idx = sample_start_idx + sample_rate * n_sample | |
sample = list(range(sample_start_idx, sample_end_idx, sample_rate)) | |
if n_independent == 0: | |
n_independent_sample = None | |
else: | |
left_candidate = np.array( | |
list(range(0, sample_start_idx)) + list(range(sample_end_idx, T)) | |
) | |
if len(left_candidate) >= n_independent: | |
# 使用两端的剩余空间采样, use the left space to sample | |
n_independent_sample = np.random.choice(left_candidate, n_independent) | |
else: | |
# 当两端没有剩余采样空间时,使用任意不是sample中的帧 | |
# if no enough space to sample, use any frame not in sample | |
left_candidate = np.array(list(set(range(T) - set(sample)))) | |
n_independent_sample = np.random.choice(left_candidate, n_independent) | |
return sample, sample_rate, n_independent_sample | |
def sample_tensor_by_idx( | |
tensor: Union[torch.Tensor, np.ndarray], | |
n_sample: int, | |
sample_rate: int, | |
sample_start_idx: int = 0, | |
change_sample_rate: bool = False, | |
seed: int = None, | |
dim: int = 0, | |
return_type: Literal["numpy", "torch"] = "torch", | |
whether_random: bool = True, | |
n_independent: int = 0, | |
) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: | |
"""sample sub_tensor | |
Args: | |
tensor (Union[torch.Tensor, np.ndarray]): _description_ | |
n_sample (int): _description_ | |
sample_rate (int): _description_ | |
sample_start_idx (int, optional): _description_. Defaults to 0. | |
change_sample_rate (bool, optional): _description_. Defaults to False. | |
seed (int, optional): _description_. Defaults to None. | |
dim (int, optional): _description_. Defaults to 0. | |
return_type (Literal["numpy", "torch"], optional): _description_. Defaults to "torch". | |
whether_random (bool, optional): _description_. Defaults to True. | |
n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0. | |
n_independent sample number that is independent of n_sample | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor | |
""" | |
if isinstance(tensor, np.ndarray): | |
tensor = torch.from_numpy(tensor) | |
T = tensor.shape[dim] | |
sample_idx, sample_rate, independent_sample_idx = sample_by_idx( | |
T, | |
n_sample, | |
sample_rate, | |
sample_start_idx, | |
change_sample_rate, | |
seed, | |
whether_random=whether_random, | |
n_independent=n_independent, | |
) | |
sample_idx = torch.LongTensor(sample_idx) | |
sample = torch.index_select(tensor, dim, sample_idx) | |
if independent_sample_idx is not None: | |
independent_sample_idx = torch.LongTensor(independent_sample_idx) | |
independent_sample = torch.index_select(tensor, dim, independent_sample_idx) | |
else: | |
independent_sample = None | |
independent_sample_idx = None | |
if return_type == "numpy": | |
sample = sample.cpu().numpy() | |
return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx | |
def concat_two_tensor( | |
data1: torch.Tensor, | |
data2: torch.Tensor, | |
dim: int, | |
method: Literal[ | |
"first_in_first_out", "first_in_last_out", "intertwine", "index" | |
] = "first_in_first_out", | |
data1_index: torch.long = None, | |
data2_index: torch.long = None, | |
return_index: bool = False, | |
): | |
"""concat two tensor along dim with given method | |
Args: | |
data1 (torch.Tensor): first in data | |
data2 (torch.Tensor): last in data | |
dim (int): _description_ | |
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine" ], optional): _description_. Defaults to "first_in_first_out". | |
Raises: | |
NotImplementedError: unsupported method | |
ValueError: unsupported method | |
Returns: | |
_type_: _description_ | |
""" | |
len_data1 = data1.shape[dim] | |
len_data2 = data2.shape[dim] | |
if method == "first_in_first_out": | |
res = torch.concat([data1, data2], dim=dim) | |
data1_index = range(len_data1) | |
data2_index = [len_data1 + x for x in range(len_data2)] | |
elif method == "first_in_last_out": | |
res = torch.concat([data2, data1], dim=dim) | |
data2_index = range(len_data2) | |
data1_index = [len_data2 + x for x in range(len_data1)] | |
elif method == "intertwine": | |
raise NotImplementedError("intertwine") | |
elif method == "index": | |
res = concat_two_tensor_with_index( | |
data1=data1, | |
data1_index=data1_index, | |
data2=data2, | |
data2_index=data2_index, | |
dim=dim, | |
) | |
else: | |
raise ValueError( | |
"only support first_in_first_out, first_in_last_out, intertwine, index" | |
) | |
if return_index: | |
return res, data1_index, data2_index | |
else: | |
return res | |
def concat_two_tensor_with_index( | |
data1: torch.Tensor, | |
data1_index: torch.LongTensor, | |
data2: torch.Tensor, | |
data2_index: torch.LongTensor, | |
dim: int, | |
) -> torch.Tensor: | |
"""_summary_ | |
Args: | |
data1 (torch.Tensor): b1*c1*h1*w1*... | |
data1_index (torch.LongTensor): N, if dim=1, N=c1 | |
data2 (torch.Tensor): b2*c2*h2*w2*... | |
data2_index (torch.LongTensor): M, if dim=1, M=c2 | |
dim (int): int | |
Returns: | |
torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,... | |
""" | |
shape1 = list(data1.shape) | |
shape2 = list(data2.shape) | |
target_shape = list(shape1) | |
target_shape[dim] = shape1[dim] + shape2[dim] | |
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) | |
target = batch_index_copy(target, dim=dim, index=data1_index, source=data1) | |
target = batch_index_copy(target, dim=dim, index=data2_index, source=data2) | |
return target | |
def repeat_index_to_target_size( | |
index: torch.LongTensor, target_size: int | |
) -> torch.LongTensor: | |
if len(index.shape) == 1: | |
index = repeat(index, "n -> b n", b=target_size) | |
if len(index.shape) == 2: | |
remainder = target_size % index.shape[0] | |
assert ( | |
remainder == 0 | |
), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}" | |
index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0])) | |
return index | |
def batch_concat_two_tensor_with_index( | |
data1: torch.Tensor, | |
data1_index: torch.LongTensor, | |
data2: torch.Tensor, | |
data2_index: torch.LongTensor, | |
dim: int, | |
) -> torch.Tensor: | |
return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim) | |
def interwine_two_tensor( | |
data1: torch.Tensor, | |
data2: torch.Tensor, | |
dim: int, | |
return_index: bool = False, | |
) -> torch.Tensor: | |
shape1 = list(data1.shape) | |
shape2 = list(data2.shape) | |
target_shape = list(shape1) | |
target_shape[dim] = shape1[dim] + shape2[dim] | |
target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) | |
data1_reshape = torch.swapaxes(data1, 0, dim) | |
data2_reshape = torch.swapaxes(data2, 0, dim) | |
target = torch.swapaxes(target, 0, dim) | |
total_index = set(range(target_shape[dim])) | |
data1_index = range(0, 2 * shape1[dim], 2) | |
data2_index = sorted(list(set(total_index) - set(data1_index))) | |
data1_index = torch.LongTensor(data1_index) | |
data2_index = torch.LongTensor(data2_index) | |
target[data1_index, ...] = data1_reshape | |
target[data2_index, ...] = data2_reshape | |
target = torch.swapaxes(target, 0, dim) | |
if return_index: | |
return target, data1_index, data2_index | |
else: | |
return target | |
def split_index( | |
indexs: torch.Tensor, | |
n_first: int = None, | |
n_last: int = None, | |
method: Literal[ | |
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random" | |
] = "first_in_first_out", | |
): | |
"""_summary_ | |
Args: | |
indexs (List): _description_ | |
n_first (int): _description_ | |
n_last (int): _description_ | |
method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index" ], optional): _description_. Defaults to "first_in_first_out". | |
Raises: | |
NotImplementedError: _description_ | |
Returns: | |
first_index: _description_ | |
last_index: | |
""" | |
# assert ( | |
# n_first is None and n_last is None | |
# ), "must assign one value for n_first or n_last" | |
n_total = len(indexs) | |
if n_first is None: | |
n_first = n_total - n_last | |
if n_last is None: | |
n_last = n_total - n_first | |
assert len(indexs) == n_first + n_last | |
if method == "first_in_first_out": | |
first_index = indexs[:n_first] | |
last_index = indexs[n_first:] | |
elif method == "first_in_last_out": | |
first_index = indexs[n_last:] | |
last_index = indexs[:n_last] | |
elif method == "intertwine": | |
raise NotImplementedError | |
elif method == "random": | |
idx_ = torch.randperm(len(indexs)) | |
first_index = indexs[idx_[:n_first]] | |
last_index = indexs[idx_[n_first:]] | |
return first_index, last_index | |
def split_tensor( | |
tensor: torch.Tensor, | |
dim: int, | |
n_first=None, | |
n_last=None, | |
method: Literal[ | |
"first_in_first_out", "first_in_last_out", "intertwine", "index", "random" | |
] = "first_in_first_out", | |
need_return_index: bool = False, | |
): | |
device = tensor.device | |
total = tensor.shape[dim] | |
if n_first is None: | |
n_first = total - n_last | |
if n_last is None: | |
n_last = total - n_first | |
indexs = torch.arange( | |
total, | |
dtype=torch.long, | |
device=device, | |
) | |
( | |
first_index, | |
last_index, | |
) = split_index( | |
indexs=indexs, | |
n_first=n_first, | |
method=method, | |
) | |
first_tensor = torch.index_select(tensor, dim=dim, index=first_index) | |
last_tensor = torch.index_select(tensor, dim=dim, index=last_index) | |
if need_return_index: | |
return ( | |
first_tensor, | |
last_tensor, | |
first_index, | |
last_index, | |
) | |
else: | |
return (first_tensor, last_tensor) | |
# TODO: 待确定batch_index_select的优化 | |
def batch_index_select( | |
tensor: torch.Tensor, index: torch.LongTensor, dim: int | |
) -> torch.Tensor: | |
"""_summary_ | |
Args: | |
tensor (torch.Tensor): D1*D2*D3*D4... | |
index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim] | |
dim (int): dim to select | |
Returns: | |
torch.Tensor: D1*...*N*... | |
""" | |
# TODO: now only support N same for every d1 | |
if len(index.shape) == 1: | |
return torch.index_select(tensor, dim=dim, index=index) | |
else: | |
index = repeat_index_to_target_size(index, tensor.shape[0]) | |
out = [] | |
for i in torch.arange(tensor.shape[0]): | |
sub_tensor = tensor[i] | |
sub_index = index[i] | |
d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index) | |
out.append(d) | |
return torch.stack(out).to(dtype=tensor.dtype) | |
def batch_index_copy( | |
tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor | |
) -> torch.Tensor: | |
"""_summary_ | |
Args: | |
tensor (torch.Tensor): b*c*h | |
dim (int): | |
index (torch.LongTensor): b*d, | |
source (torch.Tensor): | |
b*d*h*..., if dim=1 | |
b*c*d*..., if dim=2 | |
Returns: | |
torch.Tensor: b*c*d*... | |
""" | |
if len(index.shape) == 1: | |
tensor.index_copy_(dim=dim, index=index, source=source) | |
else: | |
index = repeat_index_to_target_size(index, tensor.shape[0]) | |
batch_size = tensor.shape[0] | |
for b in torch.arange(batch_size): | |
sub_index = index[b] | |
sub_source = source[b] | |
sub_tensor = tensor[b] | |
sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source) | |
tensor[b] = sub_tensor | |
return tensor | |
def batch_index_fill( | |
tensor: torch.Tensor, | |
dim: int, | |
index: torch.LongTensor, | |
value: Literal[torch.Tensor, torch.float], | |
) -> torch.Tensor: | |
"""_summary_ | |
Args: | |
tensor (torch.Tensor): b*c*h | |
dim (int): | |
index (torch.LongTensor): b*d, | |
value (torch.Tensor): b | |
Returns: | |
torch.Tensor: b*c*d*... | |
""" | |
index = repeat_index_to_target_size(index, tensor.shape[0]) | |
batch_size = tensor.shape[0] | |
for b in torch.arange(batch_size): | |
sub_index = index[b] | |
sub_value = value[b] if isinstance(value, torch.Tensor) else value | |
sub_tensor = tensor[b] | |
sub_tensor.index_fill_(dim - 1, sub_index, sub_value) | |
tensor[b] = sub_tensor | |
return tensor | |
def adaptive_instance_normalization( | |
src: torch.Tensor, | |
dst: torch.Tensor, | |
eps: float = 1e-6, | |
): | |
""" | |
Args: | |
src (torch.Tensor): b c t h w | |
dst (torch.Tensor): b c t h w | |
""" | |
ndim = src.ndim | |
if ndim == 5: | |
dim = (2, 3, 4) | |
elif ndim == 4: | |
dim = (2, 3) | |
elif ndim == 3: | |
dim = 2 | |
else: | |
raise ValueError("only support ndim in [3,4,5], but given {ndim}") | |
var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0) | |
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 | |
dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0) | |
mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0) | |
# mean_acc = sum(mean_acc) / float(len(mean_acc)) | |
# var_acc = sum(var_acc) / float(len(var_acc)) | |
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 | |
src = (((src - mean) / std) * std_acc) + mean_acc | |
return src | |
def adaptive_instance_normalization_with_ref( | |
src: torch.LongTensor, | |
dst: torch.LongTensor, | |
style_fidelity: float = 0.5, | |
do_classifier_free_guidance: bool = True, | |
): | |
# logger.debug( | |
# f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n" | |
# f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}" | |
# ) | |
batch_size = src.shape[0] // 2 | |
uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool() | |
src_uc = adaptive_instance_normalization(src, dst) | |
src_c = src_uc.clone() | |
# TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True | |
if do_classifier_free_guidance and style_fidelity > 0: | |
src_c[uc_mask] = src[uc_mask] | |
src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc | |
return src | |
def batch_adain_conditioned_tensor( | |
tensor: torch.Tensor, | |
src_index: torch.LongTensor, | |
dst_index: torch.LongTensor, | |
keep_dim: bool = True, | |
num_frames: int = None, | |
dim: int = 2, | |
style_fidelity: float = 0.5, | |
do_classifier_free_guidance: bool = True, | |
need_style_fidelity: bool = False, | |
): | |
"""_summary_ | |
Args: | |
tensor (torch.Tensor): b c t h w | |
src_index (torch.LongTensor): _description_ | |
dst_index (torch.LongTensor): _description_ | |
keep_dim (bool, optional): _description_. Defaults to True. | |
Returns: | |
_type_: _description_ | |
""" | |
ndim = tensor.ndim | |
dtype = tensor.dtype | |
if ndim == 4 and num_frames is not None: | |
tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames) | |
src = batch_index_select(tensor, dim=dim, index=src_index).contiguous() | |
dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous() | |
if need_style_fidelity: | |
src = adaptive_instance_normalization_with_ref( | |
src=src, | |
dst=dst, | |
style_fidelity=style_fidelity, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
need_style_fidelity=need_style_fidelity, | |
) | |
else: | |
src = adaptive_instance_normalization( | |
src=src, | |
dst=dst, | |
) | |
if keep_dim: | |
src = batch_concat_two_tensor_with_index( | |
src.to(dtype=dtype), | |
src_index, | |
dst.to(dtype=dtype), | |
dst_index, | |
dim=dim, | |
) | |
if ndim == 4 and num_frames is not None: | |
src = rearrange(tensor, "b c t h w ->(b t) c h w") | |
return src | |
def align_repeat_tensor_single_dim( | |
src: torch.Tensor, | |
target_length: int, | |
dim: int = 0, | |
n_src_base_length: int = 1, | |
src_base_index: List[int] = None, | |
) -> torch.Tensor: | |
"""沿着 dim 纬度, 补齐 src 的长度到目标 target_length。 | |
当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length | |
align length of src to target_length along dim | |
when src length is less than target_length, take the first n_src_base_length and repeat to target_length | |
Args: | |
src (torch.Tensor): 输入 tensor, input tensor | |
target_length (int): 目标长度, target_length | |
dim (int, optional): 处理纬度, target dim . Defaults to 0. | |
n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1. | |
Returns: | |
torch.Tensor: _description_ | |
""" | |
src_dim_length = src.shape[dim] | |
if target_length > src_dim_length: | |
if target_length % src_dim_length == 0: | |
new = src.repeat_interleave( | |
repeats=target_length // src_dim_length, dim=dim | |
) | |
else: | |
if src_base_index is None and n_src_base_length is not None: | |
src_base_index = torch.arange(n_src_base_length) | |
new = src.index_select( | |
dim=dim, | |
index=torch.LongTensor(src_base_index).to(device=src.device), | |
) | |
new = new.repeat_interleave( | |
repeats=target_length // len(src_base_index), | |
dim=dim, | |
) | |
elif target_length < src_dim_length: | |
new = src.index_select( | |
dim=dim, | |
index=torch.LongTensor(torch.arange(target_length)).to(device=src.device), | |
) | |
else: | |
new = src | |
return new | |
def fuse_part_tensor( | |
src: torch.Tensor, | |
dst: torch.Tensor, | |
overlap: int, | |
weight: float = 0.5, | |
skip_step: int = 0, | |
) -> torch.Tensor: | |
"""fuse overstep tensor with weight of src into dst | |
out = src_fused_part * weight + dst * (1-weight) for overlap | |
Args: | |
src (torch.Tensor): b c t h w | |
dst (torch.Tensor): b c t h w | |
overlap (int): 1 | |
weight (float, optional): weight of src tensor part. Defaults to 0.5. | |
Returns: | |
torch.Tensor: fused tensor | |
""" | |
if overlap == 0: | |
return dst | |
else: | |
dst[:, :, skip_step : skip_step + overlap] = ( | |
weight * src[:, :, -overlap:] | |
+ (1 - weight) * dst[:, :, skip_step : skip_step + overlap] | |
) | |
return dst | |