mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
cut paragraph for streaming s2s
This commit is contained in:
parent
4b65ccee2a
commit
8a06c8d44e
@ -3005,6 +3005,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
text_token = text_token[:-1]
|
||||
return text_token
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_speech_one_step(
|
||||
self,
|
||||
text: str, preds: str,
|
||||
@ -3095,7 +3096,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
text = "".join(preds[idx + 1:])
|
||||
last_t_size = 0
|
||||
prompt_token, prompt_audio = [None, None], [None, None]
|
||||
wav = torch.cat([wav, torch.zeros([1, 4410]).to(wav)], dim=1)
|
||||
wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1)
|
||||
chunk_idx = 0
|
||||
else:
|
||||
text = text + normed_preds
|
||||
|
||||
@ -3111,7 +3113,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
channels=1,
|
||||
)
|
||||
mp3_buffer = BytesIO()
|
||||
mp3.export(mp3_buffer, format="mp3", bitrate="48k")
|
||||
mp3.export(mp3_buffer, format="mp3", bitrate="192k")
|
||||
# we should return this to web page.
|
||||
mp3_bytes_data = mp3_buffer.getvalue()
|
||||
|
||||
@ -3201,17 +3203,18 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
states["chunk_idx"],
|
||||
)
|
||||
# new_text = new_text + preds
|
||||
rt_value, states_ret = self.generate_speech_one_step(
|
||||
new_text, preds,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
prompt_token,
|
||||
prompt_audio,
|
||||
text_chunk_size,
|
||||
chunk_idx,
|
||||
is_last,
|
||||
)
|
||||
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
||||
rt_value, states_ret = self.generate_speech_one_step(
|
||||
new_text, preds,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
prompt_token,
|
||||
prompt_audio,
|
||||
text_chunk_size,
|
||||
chunk_idx,
|
||||
is_last,
|
||||
)
|
||||
cur_token, feat, wav = rt_value
|
||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states_ret
|
||||
states["new_text"] = new_text
|
||||
|
||||
@ -10,6 +10,7 @@ from funasr.utils.hinter import hint_once
|
||||
from funasr.models.transformer.utils.nets_utils import pad_list
|
||||
import numpy as np
|
||||
import random
|
||||
from funasr.train_utils.set_all_random_seed import set_all_random_seed
|
||||
|
||||
|
||||
def norm_and_sample_xvec(xvec, xvec_lengths):
|
||||
@ -335,6 +336,7 @@ class UpsampleCtcTokenDiffModel(nn.Module):
|
||||
)
|
||||
|
||||
# forward FM model
|
||||
set_all_random_seed(0)
|
||||
feat = self.fm_model.inference(
|
||||
aligned_token_emb, aligned_token_lens,
|
||||
prompt=dict(
|
||||
@ -1011,7 +1013,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
**kwargs,
|
||||
)
|
||||
feat = self.rms_rescale_feat(feat)
|
||||
print_token = tokens.cpu().squeeze().tolist()
|
||||
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
|
||||
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
|
||||
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user