From ba736edb14b3b5c978f14282285b2a5bbd91f9b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Thu, 12 Sep 2024 19:08:47 +0800 Subject: [PATCH] cut paragraph for streaming s2s --- funasr/models/llm_asr/model.py | 49 ++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 3522ce8ad..0f0dec914 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -3007,7 +3007,7 @@ class LLMASRXvecSlotTTS(nn.Module): def generate_speech_one_step( self, - text: str, + text: str, preds: str, last_t_size, llm_cur_kv_cache, llm_cur_kv_cache_len, @@ -3016,24 +3016,38 @@ class LLMASRXvecSlotTTS(nn.Module): tts_text_chunk_size, chunk_idx, is_last, - para_len=30, + para_phone_len=200, ): device = llm_cur_kv_cache.device pounc = ["。", "?", "!", ";", ":", ".", "?", "!", ";", "\n"] # remove duplicated pounctuations - normed_text = [] - for i, c in enumerate(text): + normed_preds = [] + for i, c in enumerate(preds): if i > 0 and text[i - 1] in pounc and text[i] in pounc: continue - normed_text.append(c) - text = "".join(normed_text) + normed_preds.append(c) + preds = self.split_characters_and_words("".join(normed_preds)) + idx = -1 + for p in pounc: + idx = preds.index(p) + if idx > -1: + break + + _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") + para_end = False + if idx > -1 and not is_last: + pre_part = "".join(preds[:idx+1]) + if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len: + _text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>" + para_end = True + text = "".join(preds[idx+1:]) + last_t_size = 0 cur_token, feat, wav = None, None, None - _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") text_token = self.tts_tokenizer_warpper(_text) t_size = len(text_token) - if (t_size - last_t_size) >= tts_text_chunk_size or is_last: + if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end: text_token = torch.tensor([text_token], dtype=torch.long, device=device) text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) cur_token, feat = self.tts_model.streaming_one_step( @@ -3072,19 +3086,8 @@ class LLMASRXvecSlotTTS(nn.Module): cur_token, feat, wav = None, None, None # post process - last_t_size = t_size - # restart a new paragraph - # char_words = self.split_characters_and_words(text) - # if len(char_words) > para_len: - # # find the last pounc to split paragraph - # idx = -1 - # for i in range(len(char_words)-1, -1, -1): - # if char_words[i] in pounc: - # idx = i - # break - # if idx > 0: - # text = text[idx+1:] - # last_t_size = len(self.tts_tokenizer_warpper(text)) + if not para_end: + last_t_size = t_size return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) @@ -3187,9 +3190,9 @@ class LLMASRXvecSlotTTS(nn.Module): states["prompt_audio"], states["chunk_idx"], ) - new_text = new_text + preds + # new_text = new_text + preds rt_value, states_ret = self.generate_speech_one_step( - new_text, + new_text, preds, last_t_size, llm_cur_kv_cache, llm_cur_kv_cache_len,