diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 01862b76a..127e15712 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -3005,6 +3005,15 @@ class LLMASRXvecSlotTTS(nn.Module): text_token = text_token[:-1] return text_token + def find_pounc_idx(self, pouncs: list, text: str): + idx = -1 + for p in pouncs: + idx = text.find(p) + if idx >= 0: + break + + return idx + @torch.no_grad() def generate_speech_one_step( self, @@ -3103,6 +3112,63 @@ class LLMASRXvecSlotTTS(nn.Module): return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) + @torch.no_grad() + def simple_generate_speech_one_step( + self, + text: str, preds: str, + last_t_size, + llm_cur_kv_cache, + llm_cur_kv_cache_len, + prompt_token, + prompt_audio, + tts_text_chunk_size, + chunk_idx, + is_last, + para_phone_len=200, + ): + device = llm_cur_kv_cache.device + _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") + text_token = self.tts_tokenizer_warpper(_text) + + text_token = torch.tensor([text_token], dtype=torch.long, device=device) + text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) + cur_token, feat = self.tts_model.streaming_one_step( + text_token, + text_token_len, + xvec=None, + xvec_lengths=None, + prompt_dict={ + "prompt_token": prompt_token, + "prompt_audio": prompt_audio, + }, + outside_prompt=llm_cur_kv_cache, + outside_prompt_lengths=llm_cur_kv_cache_len, + sampling="threshold_1e-6", + chunk_idx=chunk_idx, + ) + if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0: + # process first package, token in B,T,D, feat in B,F,T + if prompt_token[0] is None: + prompt_token = [ + cur_token, + torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device), + ] + prompt_audio = [ + feat.transpose(1, 2), + torch.tensor([feat.shape[2]], dtype=torch.long, device=device), + ] + else: + prompt_token[1] = prompt_token[1] + cur_token.shape[1] + prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1) + prompt_audio[1] = prompt_audio[1] + feat.shape[2] + prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1) + wav = self.vocoder.inference(feat.transpose(1, 2)) + chunk_idx += 1 + else: + cur_token, feat, wav = None, None, None + + return ((cur_token, feat, wav), (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) + def convert_wav_to_mp3(self, wav: torch.Tensor): wav = wav.detach().cpu().numpy() wav = (wav * (2**15 - 1) * 0.8).astype(np.int16) @@ -3204,7 +3270,7 @@ class LLMASRXvecSlotTTS(nn.Module): ) # new_text = new_text + preds with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): - rt_value, states_ret = self.generate_speech_one_step( + rt_value, states_ret = self.simple_generate_speech_one_step( new_text, preds, last_t_size, llm_cur_kv_cache,