mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
cut paragraph for streaming s2s
This commit is contained in:
parent
2617c07387
commit
7a9e0545a9
@ -3024,25 +3024,24 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
# remove duplicated pounctuations
|
||||
normed_preds = []
|
||||
for i, c in enumerate(preds):
|
||||
if i > 0 and text[i - 1] in pounc and text[i] in pounc:
|
||||
if i > 0 and preds[i - 1] in pounc and preds[i] in pounc:
|
||||
continue
|
||||
normed_preds.append(c)
|
||||
preds = self.split_characters_and_words("".join(normed_preds))
|
||||
normed_preds = "".join(normed_preds)
|
||||
preds = self.split_characters_and_words(normed_preds)
|
||||
idx = -1
|
||||
for p in pounc:
|
||||
idx = preds.index(p)
|
||||
if idx > -1:
|
||||
break
|
||||
|
||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||
_text = f"<|endofprompt|><|sil|>{text+normed_preds}" + ("<|sil|>" if is_last else "")
|
||||
para_end = False
|
||||
if idx > -1 and not is_last:
|
||||
pre_part = "".join(preds[:idx+1])
|
||||
if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len:
|
||||
_text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>"
|
||||
para_end = True
|
||||
text = "".join(preds[idx+1:])
|
||||
last_t_size = 0
|
||||
|
||||
cur_token, feat, wav = None, None, None
|
||||
text_token = self.tts_tokenizer_warpper(_text)
|
||||
@ -3089,6 +3088,12 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
if not para_end:
|
||||
last_t_size = t_size
|
||||
|
||||
if para_end:
|
||||
text = "".join(preds[idx + 1:])
|
||||
last_t_size = 0
|
||||
prompt_token, prompt_audio = [None, None], [None, None]
|
||||
wav = torch.cat([wav, torch.zeros([1, 4410]).to(wav)], dim=1)
|
||||
|
||||
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))
|
||||
|
||||
def convert_wav_to_mp3(self, wav: torch.Tensor):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user