diff --git a/funasr/models/llm_asr/diffusion_models/flow_matching.py b/funasr/models/llm_asr/diffusion_models/flow_matching.py index cbf67e17c..d128d6ede 100644 --- a/funasr/models/llm_asr/diffusion_models/flow_matching.py +++ b/funasr/models/llm_asr/diffusion_models/flow_matching.py @@ -5,6 +5,7 @@ from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, Cond import logging from funasr.utils.hinter import hint_once import time +from funasr.train_utils.set_all_random_seed import set_all_random_seed class BASECFM(torch.nn.Module, ABC): @@ -47,6 +48,7 @@ class BASECFM(torch.nn.Module, ABC): sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ + set_all_random_seed(0) z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) if self.t_scheduler == 'cosine': diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 3522ce8ad..01862b76a 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -3005,9 +3005,10 @@ class LLMASRXvecSlotTTS(nn.Module): text_token = text_token[:-1] return text_token + @torch.no_grad() def generate_speech_one_step( self, - text: str, + text: str, preds: str, last_t_size, llm_cur_kv_cache, llm_cur_kv_cache_len, @@ -3016,24 +3017,40 @@ class LLMASRXvecSlotTTS(nn.Module): tts_text_chunk_size, chunk_idx, is_last, - para_len=30, + para_phone_len=200, ): device = llm_cur_kv_cache.device pounc = ["。", "?", "!", ";", ":", ".", "?", "!", ";", "\n"] # remove duplicated pounctuations - normed_text = [] - for i, c in enumerate(text): - if i > 0 and text[i - 1] in pounc and text[i] in pounc: + normed_preds = [] + for i, c in enumerate(preds): + if i > 0 and preds[i - 1] in pounc and preds[i] in pounc: continue - normed_text.append(c) - text = "".join(normed_text) + normed_preds.append(c) + normed_preds = "".join(normed_preds) + idx = -1 + for p in pounc: + str_idx = normed_preds.find(p) + if str_idx > -1: + preds = self.split_characters_and_words(normed_preds[:str_idx]) + idx = len(preds) + preds.append(normed_preds[str_idx]) + preds.extend(self.split_characters_and_words(normed_preds[str_idx+1:])) + break + + _text = f"<|endofprompt|><|sil|>{text+normed_preds}" + ("<|sil|>" if is_last else "") + para_end = False + if idx > -1 and not is_last: + pre_part = "".join(preds[:idx+1]) + if len(self.tts_tokenizer_warpper(text+pre_part)) >= para_phone_len: + _text = f"<|endofprompt|><|sil|>{text+pre_part}<|sil|>" + para_end = True cur_token, feat, wav = None, None, None - _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") text_token = self.tts_tokenizer_warpper(_text) t_size = len(text_token) - if (t_size - last_t_size) >= tts_text_chunk_size or is_last: + if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end: text_token = torch.tensor([text_token], dtype=torch.long, device=device) text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) cur_token, feat = self.tts_model.streaming_one_step( @@ -3072,19 +3089,17 @@ class LLMASRXvecSlotTTS(nn.Module): cur_token, feat, wav = None, None, None # post process - last_t_size = t_size - # restart a new paragraph - # char_words = self.split_characters_and_words(text) - # if len(char_words) > para_len: - # # find the last pounc to split paragraph - # idx = -1 - # for i in range(len(char_words)-1, -1, -1): - # if char_words[i] in pounc: - # idx = i - # break - # if idx > 0: - # text = text[idx+1:] - # last_t_size = len(self.tts_tokenizer_warpper(text)) + if not para_end: + last_t_size = t_size + + if para_end: + text = "".join(preds[idx + 1:]) + last_t_size = 0 + prompt_token, prompt_audio = [None, None], [None, None] + wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1) + chunk_idx = 0 + else: + text = text + normed_preds return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) @@ -3098,7 +3113,7 @@ class LLMASRXvecSlotTTS(nn.Module): channels=1, ) mp3_buffer = BytesIO() - mp3.export(mp3_buffer, format="mp3", bitrate="48k") + mp3.export(mp3_buffer, format="mp3", bitrate="192k") # we should return this to web page. mp3_bytes_data = mp3_buffer.getvalue() @@ -3127,9 +3142,9 @@ class LLMASRXvecSlotTTS(nn.Module): )[0] is_last = st + chunk_size >= preds.shape[1] - new_text = new_text + _resp + # new_text = new_text + _resp rt_value, states = self.generate_speech_one_step( - new_text, + new_text, _resp, last_t_size, llm_cur_kv_cache, llm_cur_kv_cache_len, @@ -3187,18 +3202,19 @@ class LLMASRXvecSlotTTS(nn.Module): states["prompt_audio"], states["chunk_idx"], ) - new_text = new_text + preds - rt_value, states_ret = self.generate_speech_one_step( - new_text, - last_t_size, - llm_cur_kv_cache, - llm_cur_kv_cache_len, - prompt_token, - prompt_audio, - text_chunk_size, - chunk_idx, - is_last, - ) + # new_text = new_text + preds + with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): + rt_value, states_ret = self.generate_speech_one_step( + new_text, preds, + last_t_size, + llm_cur_kv_cache, + llm_cur_kv_cache_len, + prompt_token, + prompt_audio, + text_chunk_size, + chunk_idx, + is_last, + ) cur_token, feat, wav = rt_value new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states_ret states["new_text"] = new_text diff --git a/funasr/models/llm_asr/tts_models/e2e_model.py b/funasr/models/llm_asr/tts_models/e2e_model.py index d8f9a3204..6ee922b96 100644 --- a/funasr/models/llm_asr/tts_models/e2e_model.py +++ b/funasr/models/llm_asr/tts_models/e2e_model.py @@ -10,6 +10,7 @@ from funasr.utils.hinter import hint_once from funasr.models.transformer.utils.nets_utils import pad_list import numpy as np import random +from funasr.train_utils.set_all_random_seed import set_all_random_seed def norm_and_sample_xvec(xvec, xvec_lengths): @@ -335,6 +336,7 @@ class UpsampleCtcTokenDiffModel(nn.Module): ) # forward FM model + set_all_random_seed(0) feat = self.fm_model.inference( aligned_token_emb, aligned_token_lens, prompt=dict( @@ -1001,6 +1003,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): cur_token_len = cur_token_len - token_hop_len # forward FM model + set_all_random_seed(0) feat = self.fm_model.inference( cur_token, cur_token_len, xvec, xvec_lengths, @@ -1011,7 +1014,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): **kwargs, ) feat = self.rms_rescale_feat(feat) - print_token = tokens.cpu().squeeze().tolist() + print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist() logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, " f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.") diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm.py b/runtime/python/websocket/funasr_wss_server_streaming_llm.py index 868b8e9c4..50a6f35a1 100644 --- a/runtime/python/websocket/funasr_wss_server_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_server_streaming_llm.py @@ -612,6 +612,15 @@ async def ws_serve(websocket, path): websocket.streaming_state["previous_asr_text"] = "" websocket.streaming_state["previous_s2tt_text"] = "" if not websocket.is_speaking: + if is_alpha_ending(websocket.streaming_state["onscreen_asr_res"]): + websocket.streaming_state["onscreen_asr_res"] += "." + elif is_chinese_ending(websocket.streaming_state["onscreen_asr_res"]): + websocket.streaming_state["onscreen_asr_res"] += "。" + + if is_alpha_ending(websocket.streaming_state["onscreen_s2tt_res"]): + websocket.streaming_state["onscreen_s2tt_res"] += "." + elif is_chinese_ending(websocket.streaming_state["onscreen_s2tt_res"]): + websocket.streaming_state["onscreen_s2tt_res"] += "。" message = json.dumps( { "mode": "online",