cut paragraph for streaming s2s

This commit is contained in:
志浩 2024-09-13 11:03:38 +08:00
parent 8a06c8d44e
commit d82cfa21a5
3 changed files with 5 additions and 2 deletions

View File

@ -5,6 +5,7 @@ from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, Cond
import logging import logging
from funasr.utils.hinter import hint_once from funasr.utils.hinter import hint_once
import time import time
from funasr.train_utils.set_all_random_seed import set_all_random_seed
class BASECFM(torch.nn.Module, ABC): class BASECFM(torch.nn.Module, ABC):
@ -47,6 +48,7 @@ class BASECFM(torch.nn.Module, ABC):
sample: generated mel-spectrogram sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps) shape: (batch_size, n_feats, mel_timesteps)
""" """
set_all_random_seed(0)
z = torch.randn_like(mu) * temperature z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine': if self.t_scheduler == 'cosine':

View File

@ -3142,9 +3142,9 @@ class LLMASRXvecSlotTTS(nn.Module):
)[0] )[0]
is_last = st + chunk_size >= preds.shape[1] is_last = st + chunk_size >= preds.shape[1]
new_text = new_text + _resp # new_text = new_text + _resp
rt_value, states = self.generate_speech_one_step( rt_value, states = self.generate_speech_one_step(
new_text, new_text, _resp,
last_t_size, last_t_size,
llm_cur_kv_cache, llm_cur_kv_cache,
llm_cur_kv_cache_len, llm_cur_kv_cache_len,

View File

@ -1003,6 +1003,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
cur_token_len = cur_token_len - token_hop_len cur_token_len = cur_token_len - token_hop_len
# forward FM model # forward FM model
set_all_random_seed(0)
feat = self.fm_model.inference( feat = self.fm_model.inference(
cur_token, cur_token_len, cur_token, cur_token_len,
xvec, xvec_lengths, xvec, xvec_lengths,