Spaces:
No application file
No application file
from typing import Any, List, Union, Sequence, Tuple | |
import numpy as np | |
def generate_sample_idxs( | |
total: int, | |
window_size: int, | |
step: int, | |
sample_rate: int = 1, | |
drop_last: bool = False, | |
max_num_per_window: int = None, | |
) -> List[List[int]]: | |
"""generate sample idxs list by given relate parameters | |
Args: | |
total (int): total num of sampling source | |
window_size (int): | |
step (int): _description_ | |
sample_rate (int, optional): _description_. Defaults to 1. | |
drop_last (bool, optional): wthether drop the last, if not enough for window_size. Defaults to False. | |
Returns: | |
List[List[int]]: sample idx list | |
""" | |
idxs = range(total) | |
idxs = [idx for i, idx in enumerate(idxs) if i % sample_rate == 0] | |
sample_idxs = [] | |
new_total = len(idxs) | |
last_idx = new_total - 1 | |
window_start = 0 | |
while window_start < new_total: | |
window_end = window_start + window_size | |
window = idxs[window_start:window_end] | |
if max_num_per_window is not None and len(window) > max_num_per_window: | |
window = uniform_sample_subseq( | |
window, max_num=max_num_per_window, need_index=False | |
) | |
if window_end > new_total and drop_last: | |
break | |
else: | |
sample_idxs.append(window) | |
window_start += step | |
return sample_idxs | |
def overlap2step(overlap: Union[int, float], window_size: int) -> int: | |
if isinstance(overlap, int): | |
step = window_size - overlap | |
elif isinstance(overlap, float): | |
if overlap <= 0: | |
raise ValueError(f"relative overlap should be > 0, but given{overlap}") | |
overlap = int(overlap * window_size) | |
else: | |
raise ValueError( | |
f"overlap only support int(>0) or float(>0), but given {overlap} type({type(overlap)})" | |
) | |
return step | |
def step2overlap(step: int, window_size: int) -> int: | |
overlap = window_size - step | |
return overlap | |
def uniform_sample_subseq( | |
seq: Sequence, max_num: int, need_index: bool = False | |
) -> Union[Sequence, Tuple[Sequence, Sequence]]: | |
n_seq = len(seq) | |
sample_num = min(n_seq, max_num) | |
if n_seq <= max_num: | |
if need_index: | |
return seq, list(range(n_seq)) | |
else: | |
return seq | |
idx = sorted(list(set(np.linspace(0, n_seq - 1, dtype=int)))) | |
subseq = [seq[i] for i in idx] | |
if need_index: | |
return subseq, idx | |
else: | |
return subseq | |
def convert_list_flat2nest( | |
seq: Sequence, | |
window: int, | |
) -> List[List[Any]]: | |
n_seq = len(seq) | |
n_lst = n_seq // window + int(n_seq % window > 0) | |
res = [seq[i * window : (i + 1) * window] for i in range(n_lst)] | |
return res | |