From 7a9e0545a91e5b5160b7612b4bbb09ed3fe7e018 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Thu, 12 Sep 2024 19:19:38 +0800 Subject: [PATCH] cut paragraph for streaming s2s --- funasr/models/llm_asr/model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 0f0dec914..f90aa6675 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -3024,25 +3024,24 @@ class LLMASRXvecSlotTTS(nn.Module): # remove duplicated pounctuations normed_preds = [] for i, c in enumerate(preds): - if i > 0 and text[i - 1] in pounc and text[i] in pounc: + if i > 0 and preds[i - 1] in pounc and preds[i] in pounc: continue normed_preds.append(c) - preds = self.split_characters_and_words("".join(normed_preds)) + normed_preds = "".join(normed_preds) + preds = self.split_characters_and_words(normed_preds) idx = -1 for p in pounc: idx = preds.index(p) if idx > -1: break - _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") + _text = f"<|endofprompt|><|sil|>{text+normed_preds}" + ("<|sil|>" if is_last else "") para_end = False if idx > -1 and not is_last: pre_part = "".join(preds[:idx+1]) if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len: _text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>" para_end = True - text = "".join(preds[idx+1:]) - last_t_size = 0 cur_token, feat, wav = None, None, None text_token = self.tts_tokenizer_warpper(_text) @@ -3089,6 +3088,12 @@ class LLMASRXvecSlotTTS(nn.Module): if not para_end: last_t_size = t_size + if para_end: + text = "".join(preds[idx + 1:]) + last_t_size = 0 + prompt_token, prompt_audio = [None, None], [None, None] + wav = torch.cat([wav, torch.zeros([1, 4410]).to(wav)], dim=1) + return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) def convert_wav_to_mp3(self, wav: torch.Tensor):