FIX:When a prompt is present, concatenate two patches as the context for VAE decoding

This commit is contained in:
刘鑫
2025-12-15 20:35:46 +08:00
parent aabda60833
commit b3a2d95fec

View File

@ -452,6 +452,7 @@ class VoxCPMModel(nn.Module):
patch_len = self.patch_size * self.chunk_size patch_len = self.patch_size * self.chunk_size
for latent_pred, _ in inference_result: for latent_pred, _ in inference_result:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)) decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
print(decode_audio.shape)
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu() decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
yield decode_audio yield decode_audio
break break
@ -583,6 +584,7 @@ class VoxCPMModel(nn.Module):
retry_badcase_max_times: int = 3, retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0, retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False, streaming: bool = False,
streaming_prefix_len: int = 3,
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]: ) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
""" """
Generate audio using pre-built prompt cache. Generate audio using pre-built prompt cache.
@ -598,6 +600,7 @@ class VoxCPMModel(nn.Module):
retry_badcase_max_times: Maximum retry attempts retry_badcase_max_times: Maximum retry attempts
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
streaming: Whether to return a generator of audio chunks streaming: Whether to return a generator of audio chunks
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
Returns: Returns:
Generator of Tuple containing: Generator of Tuple containing:
@ -664,6 +667,7 @@ class VoxCPMModel(nn.Module):
inference_timesteps=inference_timesteps, inference_timesteps=inference_timesteps,
cfg_value=cfg_value, cfg_value=cfg_value,
streaming=streaming, streaming=streaming,
streaming_prefix_len=streaming_prefix_len,
) )
if streaming: if streaming:
patch_len = self.patch_size * self.chunk_size patch_len = self.patch_size * self.chunk_size
@ -688,8 +692,12 @@ class VoxCPMModel(nn.Module):
else: else:
break break
if not streaming: if not streaming:
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu() decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
patch_len = self.patch_size * self.chunk_size
if audio_mask.sum().item() > 0:
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
else:
decode_audio = decode_audio[..., :].squeeze(1).cpu()
yield ( yield (
decode_audio, decode_audio,
target_text_token, target_text_token,
@ -754,6 +762,17 @@ class VoxCPMModel(nn.Module):
pred_feat_seq = [] # b, t, p, d pred_feat_seq = [] # b, t, p, d
curr_embed = None curr_embed = None
# Prepare prompt context patches for streaming mode
# When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context
prompt_context_patches = []
audio_patch_count = int(feat_mask.sum().item())
if audio_patch_count > 0:
context_len = min(streaming_prefix_len - 1, audio_patch_count)
# Take the last context_len patches from prompt audio as initial context
# Split into list of [b, 1, p, d] tensors to match pred_feat_seq format
prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1))
pred_feat_seq = prompt_context_patches + pred_feat_seq
enc_outputs, kv_cache_tuple = self.base_lm( enc_outputs, kv_cache_tuple = self.base_lm(
inputs_embeds=combined_embed, inputs_embeds=combined_embed,
is_causal=True, is_causal=True,