mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix streaming speech gen bug
This commit is contained in:
parent
d7bc9c54b6
commit
1587194e75
@ -2982,10 +2982,10 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
text = "".join(normed_text)
|
text = "".join(normed_text)
|
||||||
|
|
||||||
cur_token, feat, wav = None, None, None
|
cur_token, feat, wav = None, None, None
|
||||||
t_size = len(self.tts_text_tokenizer.text2tokens(text))
|
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||||
if (t_size - last_t_size) >= tts_text_chunk_size:
|
text_token = self.tts_tokenizer_warpper(_text)
|
||||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
t_size = len(text_token)
|
||||||
text_token = self.tts_tokenizer_warpper(_text)
|
if (t_size - last_t_size) >= tts_text_chunk_size or is_last:
|
||||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||||
cur_token, feat = self.tts_model.streaming_one_step(
|
cur_token, feat = self.tts_model.streaming_one_step(
|
||||||
@ -3000,7 +3000,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
sampling="threshold_1e-6",
|
sampling="threshold_1e-6",
|
||||||
chunk_idx=chunk_idx,
|
chunk_idx=chunk_idx,
|
||||||
)
|
)
|
||||||
if cur_token is not None:
|
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
||||||
# process first package, token in B,T,D, feat in B,F,T
|
# process first package, token in B,T,D, feat in B,F,T
|
||||||
if prompt_token[0] is None:
|
if prompt_token[0] is None:
|
||||||
prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)]
|
prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)]
|
||||||
@ -3012,6 +3012,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
|
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
|
||||||
wav = self.vocoder.inference(feat.transpose(1, 2))
|
wav = self.vocoder.inference(feat.transpose(1, 2))
|
||||||
chunk_idx += 1
|
chunk_idx += 1
|
||||||
|
else:
|
||||||
|
cur_token, feat, wav = None, None, None
|
||||||
|
|
||||||
# post process
|
# post process
|
||||||
last_t_size = t_size
|
last_t_size = t_size
|
||||||
@ -3042,15 +3044,15 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
token_list, feat_list, wav_list = [], [], []
|
token_list, feat_list, wav_list = [], [], []
|
||||||
prompt_token, prompt_audio = [None, None], [None, None]
|
prompt_token, prompt_audio = [None, None], [None, None]
|
||||||
new_text, last_t_size, chunk_idx = "", 0, 0
|
new_text, last_t_size, chunk_idx = "", 0, 0
|
||||||
i = 0
|
st, count = 0, 0
|
||||||
while i < preds.shape[1]:
|
while st < preds.shape[1]:
|
||||||
chunk_size = int(llm_token_num_per_call / (given_rtf ** min(i, 2)))
|
chunk_size = int(llm_token_num_per_call / (given_rtf ** min(count, 2)))
|
||||||
_resp = llm_tokenizer.batch_decode(
|
_resp = llm_tokenizer.batch_decode(
|
||||||
preds[:, i:i + chunk_size],
|
preds[:, st:st + chunk_size],
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
)[0]
|
)[0]
|
||||||
is_last = (i + chunk_size >= preds.shape[1])
|
is_last = (st + chunk_size >= preds.shape[1])
|
||||||
|
|
||||||
new_text = new_text + _resp
|
new_text = new_text + _resp
|
||||||
rt_value, states = self.generate_speech_one_step(
|
rt_value, states = self.generate_speech_one_step(
|
||||||
@ -3063,12 +3065,13 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
cur_token, feat, wav = rt_value
|
cur_token, feat, wav = rt_value
|
||||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states
|
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states
|
||||||
# save results
|
# save results
|
||||||
if cur_token is not None:
|
if cur_token is not None and feat is not None and wav is not None:
|
||||||
token_list.append(cur_token)
|
token_list.append(cur_token)
|
||||||
feat_list.append(feat)
|
feat_list.append(feat)
|
||||||
wav_list.append(wav)
|
wav_list.append(wav)
|
||||||
|
|
||||||
i += chunk_size
|
st += chunk_size
|
||||||
|
count += 1
|
||||||
|
|
||||||
speech_tokens = torch.cat(token_list, dim=1)
|
speech_tokens = torch.cat(token_list, dim=1)
|
||||||
mel_feats = torch.cat(feat_list, dim=2)
|
mel_feats = torch.cat(feat_list, dim=2)
|
||||||
|
|||||||
@ -911,10 +911,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
|||||||
# streaming related config
|
# streaming related config
|
||||||
chunk_size = kwargs.get("streaming_chunk_size", 4)
|
chunk_size = kwargs.get("streaming_chunk_size", 4)
|
||||||
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
|
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
|
||||||
try:
|
lookahead_size = self.am_model.encoder.pre_lookahead_len
|
||||||
lookahead_size = self.am_model.encoder.pre_lookahead_len
|
|
||||||
except AttributeError:
|
|
||||||
lookahead_size = 0
|
|
||||||
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
|
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
|
||||||
f"pre lookahead size={lookahead_size}.",
|
f"pre lookahead size={lookahead_size}.",
|
||||||
"pre_lookahead_len")
|
"pre_lookahead_len")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user