simple streaming

This commit is contained in:
志浩 2024-09-13 16:24:06 +08:00
parent 441b997f19
commit 4a8cb6f0c4

View File

@ -3130,42 +3130,43 @@ class LLMASRXvecSlotTTS(nn.Module):
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
text_token = self.tts_tokenizer_warpper(_text) text_token = self.tts_tokenizer_warpper(_text)
text_token = torch.tensor([text_token], dtype=torch.long, device=device) cur_token, feat, wav = None, None, None
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) if len(text_token) > tts_text_chunk_size:
cur_token, feat = self.tts_model.streaming_one_step( text_token = torch.tensor([text_token], dtype=torch.long, device=device)
text_token, text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
text_token_len, cur_token, feat = self.tts_model.streaming_one_step(
xvec=None, text_token,
xvec_lengths=None, text_token_len,
prompt_dict={ xvec=None,
"prompt_token": prompt_token, xvec_lengths=None,
"prompt_audio": prompt_audio, prompt_dict={
}, "prompt_token": prompt_token,
outside_prompt=llm_cur_kv_cache, "prompt_audio": prompt_audio,
outside_prompt_lengths=llm_cur_kv_cache_len, },
sampling="threshold_1e-6", outside_prompt=llm_cur_kv_cache,
chunk_idx=chunk_idx, outside_prompt_lengths=llm_cur_kv_cache_len,
) sampling="threshold_1e-6",
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0: chunk_idx=chunk_idx,
# process first package, token in B,T,D, feat in B,F,T diff_steps=5,
if prompt_token[0] is None: )
prompt_token = [ if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
cur_token, # process first package, token in B,T,D, feat in B,F,T
torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device), if prompt_token[0] is None:
] prompt_token = [
prompt_audio = [ cur_token,
feat.transpose(1, 2), torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device),
torch.tensor([feat.shape[2]], dtype=torch.long, device=device), ]
] prompt_audio = [
else: feat.transpose(1, 2),
prompt_token[1] = prompt_token[1] + cur_token.shape[1] torch.tensor([feat.shape[2]], dtype=torch.long, device=device),
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1) ]
prompt_audio[1] = prompt_audio[1] + feat.shape[2] else:
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1) prompt_token[1] = prompt_token[1] + cur_token.shape[1]
wav = self.vocoder.inference(feat.transpose(1, 2)) prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
chunk_idx += 1 prompt_audio[1] = prompt_audio[1] + feat.shape[2]
else: prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
cur_token, feat, wav = None, None, None wav = self.vocoder.inference(feat.transpose(1, 2))
chunk_idx += 1
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))