mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix streaming speech gen bug
This commit is contained in:
parent
d7bc9c54b6
commit
1587194e75
@ -2982,10 +2982,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
text = "".join(normed_text)
|
||||
|
||||
cur_token, feat, wav = None, None, None
|
||||
t_size = len(self.tts_text_tokenizer.text2tokens(text))
|
||||
if (t_size - last_t_size) >= tts_text_chunk_size:
|
||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||
text_token = self.tts_tokenizer_warpper(_text)
|
||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||
text_token = self.tts_tokenizer_warpper(_text)
|
||||
t_size = len(text_token)
|
||||
if (t_size - last_t_size) >= tts_text_chunk_size or is_last:
|
||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||
cur_token, feat = self.tts_model.streaming_one_step(
|
||||
@ -3000,7 +3000,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
sampling="threshold_1e-6",
|
||||
chunk_idx=chunk_idx,
|
||||
)
|
||||
if cur_token is not None:
|
||||
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
||||
# process first package, token in B,T,D, feat in B,F,T
|
||||
if prompt_token[0] is None:
|
||||
prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)]
|
||||
@ -3012,6 +3012,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
|
||||
wav = self.vocoder.inference(feat.transpose(1, 2))
|
||||
chunk_idx += 1
|
||||
else:
|
||||
cur_token, feat, wav = None, None, None
|
||||
|
||||
# post process
|
||||
last_t_size = t_size
|
||||
@ -3042,15 +3044,15 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
token_list, feat_list, wav_list = [], [], []
|
||||
prompt_token, prompt_audio = [None, None], [None, None]
|
||||
new_text, last_t_size, chunk_idx = "", 0, 0
|
||||
i = 0
|
||||
while i < preds.shape[1]:
|
||||
chunk_size = int(llm_token_num_per_call / (given_rtf ** min(i, 2)))
|
||||
st, count = 0, 0
|
||||
while st < preds.shape[1]:
|
||||
chunk_size = int(llm_token_num_per_call / (given_rtf ** min(count, 2)))
|
||||
_resp = llm_tokenizer.batch_decode(
|
||||
preds[:, i:i + chunk_size],
|
||||
preds[:, st:st + chunk_size],
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
)[0]
|
||||
is_last = (i + chunk_size >= preds.shape[1])
|
||||
is_last = (st + chunk_size >= preds.shape[1])
|
||||
|
||||
new_text = new_text + _resp
|
||||
rt_value, states = self.generate_speech_one_step(
|
||||
@ -3063,12 +3065,13 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
cur_token, feat, wav = rt_value
|
||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states
|
||||
# save results
|
||||
if cur_token is not None:
|
||||
if cur_token is not None and feat is not None and wav is not None:
|
||||
token_list.append(cur_token)
|
||||
feat_list.append(feat)
|
||||
wav_list.append(wav)
|
||||
|
||||
i += chunk_size
|
||||
st += chunk_size
|
||||
count += 1
|
||||
|
||||
speech_tokens = torch.cat(token_list, dim=1)
|
||||
mel_feats = torch.cat(feat_list, dim=2)
|
||||
|
||||
@ -911,10 +911,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
# streaming related config
|
||||
chunk_size = kwargs.get("streaming_chunk_size", 4)
|
||||
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
|
||||
try:
|
||||
lookahead_size = self.am_model.encoder.pre_lookahead_len
|
||||
except AttributeError:
|
||||
lookahead_size = 0
|
||||
lookahead_size = self.am_model.encoder.pre_lookahead_len
|
||||
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
|
||||
f"pre lookahead size={lookahead_size}.",
|
||||
"pre_lookahead_len")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user