mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
simple streaming
This commit is contained in:
parent
d82cfa21a5
commit
e5696954a9
@ -3005,6 +3005,15 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
text_token = text_token[:-1]
|
||||
return text_token
|
||||
|
||||
def find_pounc_idx(self, pouncs: list, text: str):
|
||||
idx = -1
|
||||
for p in pouncs:
|
||||
idx = text.find(p)
|
||||
if idx >= 0:
|
||||
break
|
||||
|
||||
return idx
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_speech_one_step(
|
||||
self,
|
||||
@ -3103,6 +3112,63 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
|
||||
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))
|
||||
|
||||
@torch.no_grad()
|
||||
def simple_generate_speech_one_step(
|
||||
self,
|
||||
text: str, preds: str,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
llm_cur_kv_cache_len,
|
||||
prompt_token,
|
||||
prompt_audio,
|
||||
tts_text_chunk_size,
|
||||
chunk_idx,
|
||||
is_last,
|
||||
para_phone_len=200,
|
||||
):
|
||||
device = llm_cur_kv_cache.device
|
||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||
text_token = self.tts_tokenizer_warpper(_text)
|
||||
|
||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||
cur_token, feat = self.tts_model.streaming_one_step(
|
||||
text_token,
|
||||
text_token_len,
|
||||
xvec=None,
|
||||
xvec_lengths=None,
|
||||
prompt_dict={
|
||||
"prompt_token": prompt_token,
|
||||
"prompt_audio": prompt_audio,
|
||||
},
|
||||
outside_prompt=llm_cur_kv_cache,
|
||||
outside_prompt_lengths=llm_cur_kv_cache_len,
|
||||
sampling="threshold_1e-6",
|
||||
chunk_idx=chunk_idx,
|
||||
)
|
||||
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
||||
# process first package, token in B,T,D, feat in B,F,T
|
||||
if prompt_token[0] is None:
|
||||
prompt_token = [
|
||||
cur_token,
|
||||
torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device),
|
||||
]
|
||||
prompt_audio = [
|
||||
feat.transpose(1, 2),
|
||||
torch.tensor([feat.shape[2]], dtype=torch.long, device=device),
|
||||
]
|
||||
else:
|
||||
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
|
||||
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
|
||||
prompt_audio[1] = prompt_audio[1] + feat.shape[2]
|
||||
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
|
||||
wav = self.vocoder.inference(feat.transpose(1, 2))
|
||||
chunk_idx += 1
|
||||
else:
|
||||
cur_token, feat, wav = None, None, None
|
||||
|
||||
return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx))
|
||||
|
||||
def convert_wav_to_mp3(self, wav: torch.Tensor):
|
||||
wav = wav.detach().cpu().numpy()
|
||||
wav = (wav * (2**15 - 1) * 0.8).astype(np.int16)
|
||||
@ -3204,7 +3270,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
)
|
||||
# new_text = new_text + preds
|
||||
with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
|
||||
rt_value, states_ret = self.generate_speech_one_step(
|
||||
rt_value, states_ret = self.simple_generate_speech_one_step(
|
||||
new_text, preds,
|
||||
last_t_size,
|
||||
llm_cur_kv_cache,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user