diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index 8e3ed67..5664353 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -452,6 +452,7 @@ class VoxCPMModel(nn.Module): patch_len = self.patch_size * self.chunk_size for latent_pred, _ in inference_result: decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) + print(decode_audio.shape) decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() yield decode_audio break @@ -583,6 +584,7 @@ class VoxCPMModel(nn.Module): retry_badcase_max_times: int = 3, retry_badcase_ratio_threshold: float = 6.0, streaming: bool = False, + streaming_prefix_len: int = 3, ) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]: """ Generate audio using pre-built prompt cache. @@ -598,6 +600,7 @@ class VoxCPMModel(nn.Module): retry_badcase_max_times: Maximum retry attempts retry_badcase_ratio_threshold: Threshold for audio-to-text ratio streaming: Whether to return a generator of audio chunks + streaming_prefix_len: Number of prefix audio patches to use for streaming mode Returns: Generator of Tuple containing: @@ -664,6 +667,7 @@ class VoxCPMModel(nn.Module): inference_timesteps=inference_timesteps, cfg_value=cfg_value, streaming=streaming, + streaming_prefix_len=streaming_prefix_len, ) if streaming: patch_len = self.patch_size * self.chunk_size @@ -688,8 +692,12 @@ class VoxCPMModel(nn.Module): else: break if not streaming: - decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() - + decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) + patch_len = self.patch_size * self.chunk_size + if audio_mask.sum().item() > 0: + decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu() + else: + decode_audio = decode_audio[..., :].squeeze(1).cpu() yield ( decode_audio, target_text_token, @@ -754,6 +762,17 @@ class VoxCPMModel(nn.Module): pred_feat_seq = [] # b, t, p, d curr_embed = None + # Prepare prompt context patches for streaming mode + # When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context + prompt_context_patches = [] + audio_patch_count = int(feat_mask.sum().item()) + if audio_patch_count > 0: + context_len = min(streaming_prefix_len - 1, audio_patch_count) + # Take the last context_len patches from prompt audio as initial context + # Split into list of [b, 1, p, d] tensors to match pred_feat_seq format + prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1)) + pred_feat_seq = prompt_context_patches + pred_feat_seq + enc_outputs, kv_cache_tuple = self.base_lm( inputs_embeds=combined_embed, is_causal=True,