FIX:When a prompt is present, concatenate two patches as the context for VAE decoding
This commit is contained in:
@ -452,6 +452,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
patch_len = self.patch_size * self.chunk_size
|
patch_len = self.patch_size * self.chunk_size
|
||||||
for latent_pred, _ in inference_result:
|
for latent_pred, _ in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
|
print(decode_audio.shape)
|
||||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
yield decode_audio
|
yield decode_audio
|
||||||
break
|
break
|
||||||
@ -583,6 +584,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0,
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
|
streaming_prefix_len: int = 3,
|
||||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, Union[torch.Tensor, List[torch.Tensor]]], None, None]:
|
||||||
"""
|
"""
|
||||||
Generate audio using pre-built prompt cache.
|
Generate audio using pre-built prompt cache.
|
||||||
@ -598,6 +600,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
retry_badcase_max_times: Maximum retry attempts
|
retry_badcase_max_times: Maximum retry attempts
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio
|
||||||
streaming: Whether to return a generator of audio chunks
|
streaming: Whether to return a generator of audio chunks
|
||||||
|
streaming_prefix_len: Number of prefix audio patches to use for streaming mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generator of Tuple containing:
|
Generator of Tuple containing:
|
||||||
@ -664,6 +667,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
|
streaming_prefix_len=streaming_prefix_len,
|
||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
patch_len = self.patch_size * self.chunk_size
|
patch_len = self.patch_size * self.chunk_size
|
||||||
@ -688,8 +692,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32)).squeeze(1).cpu()
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
|
patch_len = self.patch_size * self.chunk_size
|
||||||
|
if audio_mask.sum().item() > 0:
|
||||||
|
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
||||||
|
else:
|
||||||
|
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||||
yield (
|
yield (
|
||||||
decode_audio,
|
decode_audio,
|
||||||
target_text_token,
|
target_text_token,
|
||||||
@ -754,6 +762,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
pred_feat_seq = [] # b, t, p, d
|
pred_feat_seq = [] # b, t, p, d
|
||||||
curr_embed = None
|
curr_embed = None
|
||||||
|
|
||||||
|
# Prepare prompt context patches for streaming mode
|
||||||
|
# When there's a prompt audio, use its last (streaming_prefix_len - 1) patches as initial context
|
||||||
|
prompt_context_patches = []
|
||||||
|
audio_patch_count = int(feat_mask.sum().item())
|
||||||
|
if audio_patch_count > 0:
|
||||||
|
context_len = min(streaming_prefix_len - 1, audio_patch_count)
|
||||||
|
# Take the last context_len patches from prompt audio as initial context
|
||||||
|
# Split into list of [b, 1, p, d] tensors to match pred_feat_seq format
|
||||||
|
prompt_context_patches = list(feat[:, -context_len:, :, :].split(1, dim=1))
|
||||||
|
pred_feat_seq = prompt_context_patches + pred_feat_seq
|
||||||
|
|
||||||
enc_outputs, kv_cache_tuple = self.base_lm(
|
enc_outputs, kv_cache_tuple = self.base_lm(
|
||||||
inputs_embeds=combined_embed,
|
inputs_embeds=combined_embed,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user