from typing import Callable import torch from loguru import logger from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture class VQManager: def __init__(self): # Make Pylance happy (attribut/method not defined...) self.decoder_model: FireflyArchitecture self.load_audio: Callable def decode_vq_tokens(self, codes): feature_lengths = torch.tensor( [codes.shape[1]], device=self.decoder_model.device ) logger.info(f"VQ features: {codes.shape}") if isinstance(self.decoder_model, FireflyArchitecture): return self.decoder_model.decode( indices=codes[None], feature_lengths=feature_lengths, )[0].squeeze() raise ValueError(f"Unknown model type: {type(self.decoder_model)}") def encode_reference(self, reference_audio, enable_reference_audio): if enable_reference_audio and reference_audio is not None: # Load audios, and prepare basic info here reference_audio_content = self.load_audio( reference_audio, self.decoder_model.spec_transform.sample_rate ) audios = torch.from_numpy(reference_audio_content).to( self.decoder_model.device )[None, None, :] audio_lengths = torch.tensor( [audios.shape[2]], device=self.decoder_model.device, dtype=torch.long ) logger.info( f"Loaded audio with {audios.shape[2] / self.decoder_model.spec_transform.sample_rate:.2f} seconds" ) # VQ Encoder if isinstance(self.decoder_model, FireflyArchitecture): prompt_tokens = self.decoder_model.encode(audios, audio_lengths)[0][0] logger.info(f"Encoded prompt: {prompt_tokens.shape}") else: raise ValueError(f"Unknown model type: {type(self.decoder_model)}") else: prompt_tokens = None logger.info("No reference audio provided") return prompt_tokens