diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index f6d230b9f..319e1a88d 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2982,10 +2982,10 @@ class LLMASRXvecSlotTTS(nn.Module): text = "".join(normed_text) cur_token, feat, wav = None, None, None - t_size = len(self.tts_text_tokenizer.text2tokens(text)) - if (t_size - last_t_size) >= tts_text_chunk_size: - _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") - text_token = self.tts_tokenizer_warpper(_text) + _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") + text_token = self.tts_tokenizer_warpper(_text) + t_size = len(text_token) + if (t_size - last_t_size) >= tts_text_chunk_size or is_last: text_token = torch.tensor([text_token], dtype=torch.long, device=device) text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) cur_token, feat = self.tts_model.streaming_one_step( @@ -3000,7 +3000,7 @@ class LLMASRXvecSlotTTS(nn.Module): sampling="threshold_1e-6", chunk_idx=chunk_idx, ) - if cur_token is not None: + if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0: # process first package, token in B,T,D, feat in B,F,T if prompt_token[0] is None: prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)] @@ -3012,6 +3012,8 @@ class LLMASRXvecSlotTTS(nn.Module): prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1) wav = self.vocoder.inference(feat.transpose(1, 2)) chunk_idx += 1 + else: + cur_token, feat, wav = None, None, None # post process last_t_size = t_size @@ -3042,15 +3044,15 @@ class LLMASRXvecSlotTTS(nn.Module): token_list, feat_list, wav_list = [], [], [] prompt_token, prompt_audio = [None, None], [None, None] new_text, last_t_size, chunk_idx = "", 0, 0 - i = 0 - while i < preds.shape[1]: - chunk_size = int(llm_token_num_per_call / (given_rtf ** min(i, 2))) + st, count = 0, 0 + while st < preds.shape[1]: + chunk_size = int(llm_token_num_per_call / (given_rtf ** min(count, 2))) _resp = llm_tokenizer.batch_decode( - preds[:, i:i + chunk_size], + preds[:, st:st + chunk_size], add_special_tokens=False, skip_special_tokens=True, )[0] - is_last = (i + chunk_size >= preds.shape[1]) + is_last = (st + chunk_size >= preds.shape[1]) new_text = new_text + _resp rt_value, states = self.generate_speech_one_step( @@ -3063,12 +3065,13 @@ class LLMASRXvecSlotTTS(nn.Module): cur_token, feat, wav = rt_value new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states # save results - if cur_token is not None: + if cur_token is not None and feat is not None and wav is not None: token_list.append(cur_token) feat_list.append(feat) wav_list.append(wav) - i += chunk_size + st += chunk_size + count += 1 speech_tokens = torch.cat(token_list, dim=1) mel_feats = torch.cat(feat_list, dim=2) diff --git a/funasr/models/llm_asr/tts_models/e2e_model.py b/funasr/models/llm_asr/tts_models/e2e_model.py index e682445c1..d8f9a3204 100644 --- a/funasr/models/llm_asr/tts_models/e2e_model.py +++ b/funasr/models/llm_asr/tts_models/e2e_model.py @@ -911,10 +911,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): # streaming related config chunk_size = kwargs.get("streaming_chunk_size", 4) chunk_size_maxium = kwargs.get("chunk_size_maxium", 16) - try: - lookahead_size = self.am_model.encoder.pre_lookahead_len - except AttributeError: - lookahead_size = 0 + lookahead_size = self.am_model.encoder.pre_lookahead_len hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, " f"pre lookahead size={lookahead_size}.", "pre_lookahead_len")