cut paragraph for streaming s2s

This commit is contained in:
志浩 2024-09-13 11:03:38 +08:00
parent 8a06c8d44e
commit d82cfa21a5
3 changed files with 5 additions and 2 deletions

View File

@ -5,6 +5,7 @@ from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, Cond
import logging
from funasr.utils.hinter import hint_once
import time
from funasr.train_utils.set_all_random_seed import set_all_random_seed
class BASECFM(torch.nn.Module, ABC):
@ -47,6 +48,7 @@ class BASECFM(torch.nn.Module, ABC):
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
set_all_random_seed(0)
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':

View File

@ -3142,9 +3142,9 @@ class LLMASRXvecSlotTTS(nn.Module):
)[0]
is_last = st + chunk_size >= preds.shape[1]
new_text = new_text + _resp
# new_text = new_text + _resp
rt_value, states = self.generate_speech_one_step(
new_text,
new_text, _resp,
last_t_size,
llm_cur_kv_cache,
llm_cur_kv_cache_len,

View File

@ -1003,6 +1003,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
cur_token_len = cur_token_len - token_hop_len
# forward FM model
set_all_random_seed(0)
feat = self.fm_model.inference(
cur_token, cur_token_len,
xvec, xvec_lengths,