cut paragraph for streaming s2s

This commit is contained in:
志浩 2024-09-13 10:36:34 +08:00
parent 4b65ccee2a
commit 8a06c8d44e
2 changed files with 19 additions and 14 deletions

View File

@ -3005,6 +3005,7 @@ class LLMASRXvecSlotTTS(nn.Module):
text_token = text_token[:-1]
return text_token
@torch.no_grad()
def generate_speech_one_step(
self,
text: str, preds: str,
@ -3095,7 +3096,8 @@ class LLMASRXvecSlotTTS(nn.Module):
text = "".join(preds[idx + 1:])
last_t_size = 0
prompt_token, prompt_audio = [None, None], [None, None]
wav = torch.cat([wav, torch.zeros([1, 4410]).to(wav)], dim=1)
wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1)
chunk_idx = 0
else:
text = text + normed_preds
@ -3111,7 +3113,7 @@ class LLMASRXvecSlotTTS(nn.Module):
channels=1,
)
mp3_buffer = BytesIO()
mp3.export(mp3_buffer, format="mp3", bitrate="48k")
mp3.export(mp3_buffer, format="mp3", bitrate="192k")
# we should return this to web page.
mp3_bytes_data = mp3_buffer.getvalue()
@ -3201,17 +3203,18 @@ class LLMASRXvecSlotTTS(nn.Module):
states["chunk_idx"],
)
# new_text = new_text + preds
rt_value, states_ret = self.generate_speech_one_step(
new_text, preds,
last_t_size,
llm_cur_kv_cache,
llm_cur_kv_cache_len,
prompt_token,
prompt_audio,
text_chunk_size,
chunk_idx,
is_last,
)
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
rt_value, states_ret = self.generate_speech_one_step(
new_text, preds,
last_t_size,
llm_cur_kv_cache,
llm_cur_kv_cache_len,
prompt_token,
prompt_audio,
text_chunk_size,
chunk_idx,
is_last,
)
cur_token, feat, wav = rt_value
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states_ret
states["new_text"] = new_text

View File

@ -10,6 +10,7 @@ from funasr.utils.hinter import hint_once
from funasr.models.transformer.utils.nets_utils import pad_list
import numpy as np
import random
from funasr.train_utils.set_all_random_seed import set_all_random_seed
def norm_and_sample_xvec(xvec, xvec_lengths):
@ -335,6 +336,7 @@ class UpsampleCtcTokenDiffModel(nn.Module):
)
# forward FM model
set_all_random_seed(0)
feat = self.fm_model.inference(
aligned_token_emb, aligned_token_lens,
prompt=dict(
@ -1011,7 +1013,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
**kwargs,
)
feat = self.rms_rescale_feat(feat)
print_token = tokens.cpu().squeeze().tolist()
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")