Spaces:
Running
on
A10G
Running
on
A10G
File size: 2,927 Bytes
0a3525d 69e8a46 0a3525d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import os
from glob import glob
from pathlib import Path
from typing import Union
from loguru import logger
from natsort import natsorted
AUDIO_EXTENSIONS = {
".mp3",
".wav",
".flac",
".ogg",
".m4a",
".wma",
".aac",
".aiff",
".aif",
".aifc",
}
def list_files(
path: Union[Path, str],
extensions: set[str] = None,
recursive: bool = False,
sort: bool = True,
) -> list[Path]:
"""List files in a directory.
Args:
path (Path): Path to the directory.
extensions (set, optional): Extensions to filter. Defaults to None.
recursive (bool, optional): Whether to search recursively. Defaults to False.
sort (bool, optional): Whether to sort the files. Defaults to True.
Returns:
list: List of files.
"""
if isinstance(path, str):
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Directory {path} does not exist.")
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
if sort:
files = natsorted(files)
return files
def get_latest_checkpoint(path: Path | str) -> Path | None:
# Find the latest checkpoint
ckpt_dir = Path(path)
if ckpt_dir.exists() is False:
return None
ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
if len(ckpts) == 0:
return None
return ckpts[-1]
def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
"""
Load a Bert-VITS2 style filelist.
"""
files = set()
results = []
count_duplicated, count_not_found = 0, 0
LANGUAGE_TO_LANGUAGES = {
"zh": ["zh", "en"],
"jp": ["jp", "en"],
"en": ["en"],
}
with open(path, "r", encoding="utf-8") as f:
for line in f.readlines():
splits = line.strip().split("|", maxsplit=3)
if len(splits) != 4:
logger.warning(f"Invalid line: {line}")
continue
filename, speaker, language, text = splits
file = Path(filename)
language = language.strip().lower()
if language == "ja":
language = "jp"
assert language in ["zh", "jp", "en"], f"Invalid language {language}"
languages = LANGUAGE_TO_LANGUAGES[language]
if file in files:
logger.warning(f"Duplicated file: {file}")
count_duplicated += 1
continue
if not file.exists():
logger.warning(f"File not found: {file}")
count_not_found += 1
continue
results.append((file, speaker, languages, text))
if count_duplicated > 0:
logger.warning(f"Total duplicated files: {count_duplicated}")
if count_not_found > 0:
logger.warning(f"Total files not found: {count_not_found}")
return results
|