From d82cfa21a556ac13e3145d0685be4a43ad6e21be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Fri, 13 Sep 2024 11:03:38 +0800 Subject: [PATCH] cut paragraph for streaming s2s --- funasr/models/llm_asr/diffusion_models/flow_matching.py | 2 ++ funasr/models/llm_asr/model.py | 4 ++-- funasr/models/llm_asr/tts_models/e2e_model.py | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/funasr/models/llm_asr/diffusion_models/flow_matching.py b/funasr/models/llm_asr/diffusion_models/flow_matching.py index cbf67e17c..d128d6ede 100644 --- a/funasr/models/llm_asr/diffusion_models/flow_matching.py +++ b/funasr/models/llm_asr/diffusion_models/flow_matching.py @@ -5,6 +5,7 @@ from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, Cond import logging from funasr.utils.hinter import hint_once import time +from funasr.train_utils.set_all_random_seed import set_all_random_seed class BASECFM(torch.nn.Module, ABC): @@ -47,6 +48,7 @@ class BASECFM(torch.nn.Module, ABC): sample: generated mel-spectrogram shape: (batch_size, n_feats, mel_timesteps) """ + set_all_random_seed(0) z = torch.randn_like(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) if self.t_scheduler == 'cosine': diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index ca573e33b..01862b76a 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -3142,9 +3142,9 @@ class LLMASRXvecSlotTTS(nn.Module): )[0] is_last = st + chunk_size >= preds.shape[1] - new_text = new_text + _resp + # new_text = new_text + _resp rt_value, states = self.generate_speech_one_step( - new_text, + new_text, _resp, last_t_size, llm_cur_kv_cache, llm_cur_kv_cache_len, diff --git a/funasr/models/llm_asr/tts_models/e2e_model.py b/funasr/models/llm_asr/tts_models/e2e_model.py index 23633651e..6ee922b96 100644 --- a/funasr/models/llm_asr/tts_models/e2e_model.py +++ b/funasr/models/llm_asr/tts_models/e2e_model.py @@ -1003,6 +1003,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): cur_token_len = cur_token_len - token_hop_len # forward FM model + set_all_random_seed(0) feat = self.fm_model.inference( cur_token, cur_token_len, xvec, xvec_lengths,