fix streaming speech gen bug

This commit is contained in:
志浩 2024-09-12 11:45:45 +08:00
parent d7bc9c54b6
commit 1587194e75
2 changed files with 16 additions and 16 deletions

View File

@ -2982,10 +2982,10 @@ class LLMASRXvecSlotTTS(nn.Module):
text = "".join(normed_text)
cur_token, feat, wav = None, None, None
t_size = len(self.tts_text_tokenizer.text2tokens(text))
if (t_size - last_t_size) >= tts_text_chunk_size:
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
text_token = self.tts_tokenizer_warpper(_text)
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
text_token = self.tts_tokenizer_warpper(_text)
t_size = len(text_token)
if (t_size - last_t_size) >= tts_text_chunk_size or is_last:
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
cur_token, feat = self.tts_model.streaming_one_step(
@ -3000,7 +3000,7 @@ class LLMASRXvecSlotTTS(nn.Module):
sampling="threshold_1e-6",
chunk_idx=chunk_idx,
)
if cur_token is not None:
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
# process first package, token in B,T,D, feat in B,F,T
if prompt_token[0] is None:
prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)]
@ -3012,6 +3012,8 @@ class LLMASRXvecSlotTTS(nn.Module):
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
wav = self.vocoder.inference(feat.transpose(1, 2))
chunk_idx += 1
else:
cur_token, feat, wav = None, None, None
# post process
last_t_size = t_size
@ -3042,15 +3044,15 @@ class LLMASRXvecSlotTTS(nn.Module):
token_list, feat_list, wav_list = [], [], []
prompt_token, prompt_audio = [None, None], [None, None]
new_text, last_t_size, chunk_idx = "", 0, 0
i = 0
while i < preds.shape[1]:
chunk_size = int(llm_token_num_per_call / (given_rtf ** min(i, 2)))
st, count = 0, 0
while st < preds.shape[1]:
chunk_size = int(llm_token_num_per_call / (given_rtf ** min(count, 2)))
_resp = llm_tokenizer.batch_decode(
preds[:, i:i + chunk_size],
preds[:, st:st + chunk_size],
add_special_tokens=False,
skip_special_tokens=True,
)[0]
is_last = (i + chunk_size >= preds.shape[1])
is_last = (st + chunk_size >= preds.shape[1])
new_text = new_text + _resp
rt_value, states = self.generate_speech_one_step(
@ -3063,12 +3065,13 @@ class LLMASRXvecSlotTTS(nn.Module):
cur_token, feat, wav = rt_value
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states
# save results
if cur_token is not None:
if cur_token is not None and feat is not None and wav is not None:
token_list.append(cur_token)
feat_list.append(feat)
wav_list.append(wav)
i += chunk_size
st += chunk_size
count += 1
speech_tokens = torch.cat(token_list, dim=1)
mel_feats = torch.cat(feat_list, dim=2)

View File

@ -911,10 +911,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
# streaming related config
chunk_size = kwargs.get("streaming_chunk_size", 4)
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
try:
lookahead_size = self.am_model.encoder.pre_lookahead_len
except AttributeError:
lookahead_size = 0
lookahead_size = self.am_model.encoder.pre_lookahead_len
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
f"pre lookahead size={lookahead_size}.",
"pre_lookahead_len")