mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
cut paragraph for streaming s2s
This commit is contained in:
parent
8a06c8d44e
commit
d82cfa21a5
@ -5,6 +5,7 @@ from funasr.models.llm_asr.diffusion_models.matcha_decoder import (Decoder, Cond
|
|||||||
import logging
|
import logging
|
||||||
from funasr.utils.hinter import hint_once
|
from funasr.utils.hinter import hint_once
|
||||||
import time
|
import time
|
||||||
|
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||||
|
|
||||||
|
|
||||||
class BASECFM(torch.nn.Module, ABC):
|
class BASECFM(torch.nn.Module, ABC):
|
||||||
@ -47,6 +48,7 @@ class BASECFM(torch.nn.Module, ABC):
|
|||||||
sample: generated mel-spectrogram
|
sample: generated mel-spectrogram
|
||||||
shape: (batch_size, n_feats, mel_timesteps)
|
shape: (batch_size, n_feats, mel_timesteps)
|
||||||
"""
|
"""
|
||||||
|
set_all_random_seed(0)
|
||||||
z = torch.randn_like(mu) * temperature
|
z = torch.randn_like(mu) * temperature
|
||||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
||||||
if self.t_scheduler == 'cosine':
|
if self.t_scheduler == 'cosine':
|
||||||
|
|||||||
@ -3142,9 +3142,9 @@ class LLMASRXvecSlotTTS(nn.Module):
|
|||||||
)[0]
|
)[0]
|
||||||
is_last = st + chunk_size >= preds.shape[1]
|
is_last = st + chunk_size >= preds.shape[1]
|
||||||
|
|
||||||
new_text = new_text + _resp
|
# new_text = new_text + _resp
|
||||||
rt_value, states = self.generate_speech_one_step(
|
rt_value, states = self.generate_speech_one_step(
|
||||||
new_text,
|
new_text, _resp,
|
||||||
last_t_size,
|
last_t_size,
|
||||||
llm_cur_kv_cache,
|
llm_cur_kv_cache,
|
||||||
llm_cur_kv_cache_len,
|
llm_cur_kv_cache_len,
|
||||||
|
|||||||
@ -1003,6 +1003,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
|||||||
cur_token_len = cur_token_len - token_hop_len
|
cur_token_len = cur_token_len - token_hop_len
|
||||||
|
|
||||||
# forward FM model
|
# forward FM model
|
||||||
|
set_all_random_seed(0)
|
||||||
feat = self.fm_model.inference(
|
feat = self.fm_model.inference(
|
||||||
cur_token, cur_token_len,
|
cur_token, cur_token_len,
|
||||||
xvec, xvec_lengths,
|
xvec, xvec_lengths,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user