Kokoro-TTS-Zero / tts_model_v1.py
Remsky's picture
Enhance processing metrics in TTSModelV1 by calculating total audio duration and processing time per chunk
a010fd1
import os
import torch
import numpy as np
import time
from typing import Tuple, List
from kokoro import KPipeline
import spaces
class TTSModelV1:
"""KPipeline-based TTS model for v1.0.0"""
def __init__(self):
self.pipeline = None
self.voices_dir = os.path.join(os.path.dirname(__file__), "voices_v1")
def initialize(self) -> bool:
"""Initialize KPipeline"""
try:
print("Initializing v1.0.0 model...")
self.pipeline = None # cannot be initialized outside of GPU decorator
print("Model initialization complete")
return True
except Exception as e:
print(f"Error initializing model: {str(e)}")
return False
def list_voices(self) -> List[str]:
"""List available voices from voices_v1 directory"""
voices = []
if os.path.exists(self.voices_dir):
for file in os.listdir(self.voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voices.append(voice_name)
return sorted(voices)
@spaces.GPU(duration=None) # Duration will be set by the UI
def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
"""Generate speech from text using KPipeline
Args:
text: Input text to convert to speech
voice_names: List of voice names to use (will be mixed if multiple)
speed: Speech speed multiplier
progress_callback: Optional callback function
progress_state: Dictionary tracking generation progress metrics
progress: Progress callback from Gradio
"""
try:
start_time = time.time()
if self.pipeline is None:
lang_code = voice_names[0][0] if voice_names else 'a'
self.pipeline = KPipeline(lang_code=lang_code)
if not text or not voice_names:
raise ValueError("Text and voice name are required")
# Handle voice selection
if isinstance(voice_names, list) and len(voice_names) > 1:
# For multiple voices, join them with underscore
voice_name = "_".join(voice_names)
else:
voice_name = voice_names[0]
# Initialize tracking
audio_chunks = []
chunk_times = []
chunk_sizes = []
total_tokens = 0
# Preprocess text - replace single newlines with spaces while preserving paragraphs
processed_text = '\n\n'.join(
paragraph.replace('\n', ' ').replace(' ', ' ').strip()
for paragraph in text.split('\n\n')
)
# Get generator from pipeline
generator = self.pipeline(
processed_text,
voice=voice_name,
speed=speed,
split_pattern=r'\n\n+' # Split on double newlines or more
)
# Process chunks
total_duration = 0 # Total audio duration in seconds
total_process_time = 0 # Total processing time in seconds
for i, (gs, ps, audio) in enumerate(generator):
chunk_process_time = time.time() - start_time - total_process_time
total_process_time += chunk_process_time
audio_chunks.append(audio)
# Calculate metrics
chunk_tokens = len(gs)
total_tokens += chunk_tokens
# Calculate audio duration
chunk_duration = len(audio) / 24000 # Convert samples to seconds
total_duration += chunk_duration
# Calculate speed metrics
tokens_per_sec = chunk_tokens / chunk_duration # Tokens per second of audio
rtf = chunk_process_time / chunk_duration # Real-time factor
chunk_times.append(chunk_process_time)
chunk_sizes.append(chunk_tokens)
print(f"Chunk {i+1}:")
print(f" Process time: {chunk_process_time:.2f}s")
print(f" Audio duration: {chunk_duration:.2f}s")
print(f" Tokens/sec: {tokens_per_sec:.1f}")
print(f" Real-time factor: {rtf:.3f}")
print(f" Speed: {(1/rtf):.1f}x real-time")
# Update progress
if progress_callback and progress_state:
# Initialize lists if needed
if "tokens_per_sec" not in progress_state:
progress_state["tokens_per_sec"] = []
if "rtf" not in progress_state:
progress_state["rtf"] = []
if "chunk_times" not in progress_state:
progress_state["chunk_times"] = []
# Update progress state
progress_state["tokens_per_sec"].append(tokens_per_sec)
progress_state["rtf"].append(rtf)
progress_state["chunk_times"].append(chunk_process_time)
progress_callback(
i + 1,
-1, # Let UI handle total chunks
tokens_per_sec,
rtf,
progress_state,
start_time,
gpu_timeout,
progress
)
# Concatenate audio chunks
audio = np.concatenate(audio_chunks)
# Return audio and metrics
return (
audio,
len(audio) / 24000,
{
"chunk_times": chunk_times,
"chunk_sizes": chunk_sizes,
"tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]] if progress_state else [],
"rtf": [float(x) for x in progress_state["rtf"]] if progress_state else [],
"total_tokens": total_tokens,
"total_time": time.time() - start_time
}
)
except Exception as e:
print(f"Error generating speech: {str(e)}")
raise