cut paragraph for streaming s2s

This commit is contained in:
志浩 2024-09-12 19:08:47 +08:00
parent f4b5af8473
commit ba736edb14

View File

@ -3007,7 +3007,7 @@ class LLMASRXvecSlotTTS(nn.Module):
def generate_speech_one_step( def generate_speech_one_step(
self, self,
text: str, text: str, preds: str,
last_t_size, last_t_size,
llm_cur_kv_cache, llm_cur_kv_cache,
llm_cur_kv_cache_len, llm_cur_kv_cache_len,
@ -3016,24 +3016,38 @@ class LLMASRXvecSlotTTS(nn.Module):
tts_text_chunk_size, tts_text_chunk_size,
chunk_idx, chunk_idx,
is_last, is_last,
para_len=30, para_phone_len=200,
): ):
device = llm_cur_kv_cache.device device = llm_cur_kv_cache.device
pounc = ["", "", "", "", "", ".", "?", "!", ";", "\n"] pounc = ["", "", "", "", "", ".", "?", "!", ";", "\n"]
# remove duplicated pounctuations # remove duplicated pounctuations
normed_text = [] normed_preds = []
for i, c in enumerate(text): for i, c in enumerate(preds):
if i > 0 and text[i - 1] in pounc and text[i] in pounc: if i > 0 and text[i - 1] in pounc and text[i] in pounc:
continue continue
normed_text.append(c) normed_preds.append(c)
text = "".join(normed_text) preds = self.split_characters_and_words("".join(normed_preds))
idx = -1
for p in pounc:
idx = preds.index(p)
if idx > -1:
break
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
para_end = False
if idx > -1 and not is_last:
pre_part = "".join(preds[:idx+1])
if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len:
_text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>"
para_end = True
text = "".join(preds[idx+1:])
last_t_size = 0
cur_token, feat, wav = None, None, None cur_token, feat, wav = None, None, None
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
text_token = self.tts_tokenizer_warpper(_text) text_token = self.tts_tokenizer_warpper(_text)
t_size = len(text_token) t_size = len(text_token)
if (t_size - last_t_size) >= tts_text_chunk_size or is_last: if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end:
text_token = torch.tensor([text_token], dtype=torch.long, device=device) text_token = torch.tensor([text_token], dtype=torch.long, device=device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
cur_token, feat = self.tts_model.streaming_one_step( cur_token, feat = self.tts_model.streaming_one_step(
@ -3072,19 +3086,8 @@ class LLMASRXvecSlotTTS(nn.Module):
cur_token, feat, wav = None, None, None cur_token, feat, wav = None, None, None
# post process # post process
last_t_size = t_size if not para_end:
# restart a new paragraph last_t_size = t_size
# char_words = self.split_characters_and_words(text)
# if len(char_words) > para_len:
# # find the last pounc to split paragraph
# idx = -1
# for i in range(len(char_words)-1, -1, -1):
# if char_words[i] in pounc:
# idx = i
# break
# if idx > 0:
# text = text[idx+1:]
# last_t_size = len(self.tts_tokenizer_warpper(text))
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))
@ -3187,9 +3190,9 @@ class LLMASRXvecSlotTTS(nn.Module):
states["prompt_audio"], states["prompt_audio"],
states["chunk_idx"], states["chunk_idx"],
) )
new_text = new_text + preds # new_text = new_text + preds
rt_value, states_ret = self.generate_speech_one_step( rt_value, states_ret = self.generate_speech_one_step(
new_text, new_text, preds,
last_t_size, last_t_size,
llm_cur_kv_cache, llm_cur_kv_cache,
llm_cur_kv_cache_len, llm_cur_kv_cache_len,