MuseV-test / mmcm /utils /itertools_util.py
kevinwang676's picture
Upload folder using huggingface_hub
6755a2d verified
from typing import Any, List, Union, Sequence, Tuple
import numpy as np
def generate_sample_idxs(
total: int,
window_size: int,
step: int,
sample_rate: int = 1,
drop_last: bool = False,
max_num_per_window: int = None,
) -> List[List[int]]:
"""generate sample idxs list by given relate parameters
Args:
total (int): total num of sampling source
window_size (int):
step (int): _description_
sample_rate (int, optional): _description_. Defaults to 1.
drop_last (bool, optional): wthether drop the last, if not enough for window_size. Defaults to False.
Returns:
List[List[int]]: sample idx list
"""
idxs = range(total)
idxs = [idx for i, idx in enumerate(idxs) if i % sample_rate == 0]
sample_idxs = []
new_total = len(idxs)
last_idx = new_total - 1
window_start = 0
while window_start < new_total:
window_end = window_start + window_size
window = idxs[window_start:window_end]
if max_num_per_window is not None and len(window) > max_num_per_window:
window = uniform_sample_subseq(
window, max_num=max_num_per_window, need_index=False
)
if window_end > new_total and drop_last:
break
else:
sample_idxs.append(window)
window_start += step
return sample_idxs
def overlap2step(overlap: Union[int, float], window_size: int) -> int:
if isinstance(overlap, int):
step = window_size - overlap
elif isinstance(overlap, float):
if overlap <= 0:
raise ValueError(f"relative overlap should be > 0, but given{overlap}")
overlap = int(overlap * window_size)
else:
raise ValueError(
f"overlap only support int(>0) or float(>0), but given {overlap} type({type(overlap)})"
)
return step
def step2overlap(step: int, window_size: int) -> int:
overlap = window_size - step
return overlap
def uniform_sample_subseq(
seq: Sequence, max_num: int, need_index: bool = False
) -> Union[Sequence, Tuple[Sequence, Sequence]]:
n_seq = len(seq)
sample_num = min(n_seq, max_num)
if n_seq <= max_num:
if need_index:
return seq, list(range(n_seq))
else:
return seq
idx = sorted(list(set(np.linspace(0, n_seq - 1, dtype=int))))
subseq = [seq[i] for i in idx]
if need_index:
return subseq, idx
else:
return subseq
def convert_list_flat2nest(
seq: Sequence,
window: int,
) -> List[List[Any]]:
n_seq = len(seq)
n_lst = n_seq // window + int(n_seq % window > 0)
res = [seq[i * window : (i + 1) * window] for i in range(n_lst)]
return res