diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 0b8ac336a..4713b52b6 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -3062,7 +3062,7 @@ class LLMASRXvecSlotTTS(nn.Module): if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end: text_token = torch.tensor([text_token], dtype=torch.long, device=device) text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) - cur_token, feat = self.tts_model.streaming_one_step( + cur_token, feat, wav, prompt_token, prompt_audio = self.tts_model.streaming_one_step( text_token, text_token_len, xvec=None, @@ -3075,24 +3075,9 @@ class LLMASRXvecSlotTTS(nn.Module): outside_prompt_lengths=llm_cur_kv_cache_len, sampling="threshold_6e-1", chunk_idx=chunk_idx, + vocoder=self.vocoder, ) if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0: - # process first package, token in B,T,D, feat in B,F,T - if prompt_token[0] is None: - prompt_token = [ - cur_token, - torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device), - ] - prompt_audio = [ - feat.transpose(1, 2), - torch.tensor([feat.shape[2]], dtype=torch.long, device=device), - ] - else: - prompt_token[1] = prompt_token[1] + cur_token.shape[1] - prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1) - prompt_audio[1] = prompt_audio[1] + feat.shape[2] - prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1) - wav = self.vocoder.inference(feat.transpose(1, 2)) chunk_idx += 1 else: cur_token, feat, wav = None, None, None @@ -3247,8 +3232,8 @@ class LLMASRXvecSlotTTS(nn.Module): states = {} states["new_text"] = "" states["last_t_size"] = 0 - states["prompt_token"] = [None, None] - states["prompt_audio"] = [None, None] + states["prompt_token"] = ([None, None], 0) + states["prompt_audio"] = ([None, None], 0) states["chunk_idx"] = 0 def streaming_generate_speech( diff --git a/funasr/models/llm_asr/tts_models/e2e_model.py b/funasr/models/llm_asr/tts_models/e2e_model.py index 482328407..973d7a550 100644 --- a/funasr/models/llm_asr/tts_models/e2e_model.py +++ b/funasr/models/llm_asr/tts_models/e2e_model.py @@ -349,7 +349,7 @@ class UpsampleCtcTokenDiffModel(nn.Module): return tokens, feat - def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.1): + def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.5): feat_power = feat.exp().sum(1) # not silence if feat_power.max() > feat_sil_th: @@ -902,6 +902,17 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): return tokens, prompt_audio[0].transpose(1, 2) + def cross_fade(self, pre: torch.Tensor, feat: torch.Tensor, hop_size: int): + if pre is not None: + hop_len = min(hop_size, feat.shape[1], pre.shape[1]) + sin_wind = torch.tensor(np.sin((np.arange(hop_len * 2) + 1) / (hop_len * 2 + 1) * np.pi)[None, :, None]).to(feat) + cf_overlap = ((pre * sin_wind[:, -hop_len:] + + feat[:, :hop_len] * sin_wind[:, :hop_len]) / + (sin_wind[:, :hop_len] + sin_wind[:, -hop_len:])) + feat[:, :hop_len] = cf_overlap + + return feat + def streaming_one_step( self, text: torch.Tensor, text_lengths: torch.Tensor, xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None, @@ -917,12 +928,12 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, " f"pre lookahead size={lookahead_size}.", "pre_lookahead_len") - + wav_vocoder = kwargs.get("vocoder", None) blank_penalty = kwargs.get("blank_penalty", 0.0) sampling = kwargs.get("sampling", "greedy") prompt_dict = kwargs.get("prompt_dict", {}) - prompt_token = list(prompt_dict.get("prompt_token", [None, None])) - prompt_audio = list(prompt_dict.get("prompt_audio", [None, None])) + prompt_token, pre_token_lb = list(prompt_dict.get("prompt_token", ([None, None], 0))) + prompt_audio, pre_feat_lb = list(prompt_dict.get("prompt_audio", ([None, None], 0))) ftype = self.text_embedding.weight.dtype if prompt_token[0] is None: @@ -985,33 +996,32 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): return_hidden=True, use_causal_prob=use_causal_prob, ) - token_hop_len, mel_hop_len = 0, 0 - if isinstance(tokens, tuple): - tokens, fa_tokens = tokens - token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size) - mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio)) + tokens, fa_tokens = tokens + token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size) - cur_token, feat = None, None - # exclude empty tokens. + cur_token, feat, wav = None, None, None + token_reback, feat_reback = pre_token_lb, pre_feat_lb + # generate feat, exclude empty tokens. if aligned_token_emb.shape[1] > prompt_token[0].shape[1]: - cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:] - cur_token_len = aligned_token_lens - prompt_token[1] + # need synthesize extra overlap parts + cur_token = aligned_token_emb[:, prompt_token[0].shape[1] - pre_token_lb:] + cur_token_len = aligned_token_lens - prompt_token[1] + pre_token_lb # v2: excluding lookahead tokens for not-last packages if text[0, -1] != self.endofprompt_token_id+1: cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :] cur_token_len = cur_token_len - token_hop_len + if cur_token_len[0] < 1: + return None, None, None, (prompt_token, pre_token_lb), (prompt_audio, pre_feat_lb) # forward FM model # set_all_random_seed(0) - if cur_token_len[0] < 1: - return None, None feat = self.fm_model.inference( cur_token, cur_token_len, xvec, xvec_lengths, prompt=dict( - prompt_text=prompt_token, - prompt_audio=prompt_audio, + prompt_text=[prompt_token[0][:, :-pre_token_lb], prompt_token[1] - pre_token_lb], + prompt_audio=[prompt_audio[0][:, :-pre_feat_lb], prompt_audio[1] - pre_feat_lb], ), **kwargs, ) @@ -1019,6 +1029,20 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist() logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, " f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.") + feat = self.cross_fade(prompt_audio[0], feat.transpose(1, 2), pre_feat_lb) + pre_wav = wav_vocoder.inference(prompt_audio[0]) + cur_wav = wav_vocoder.inference(feat.transpose(1, 2)) + pre_wav_lb = int(1.0 / pre_token_lb * wav_vocoder.sample_rate) + wav = self.cross_fade(pre_wav, cur_wav, pre_wav_lb) + + prompt_token = [ + torch.cat([prompt_token[0][:, :-pre_token_lb], cur_token], dim=1), + prompt_token[1] + cur_token_len - pre_token_lb, + ] + prompt_audio = [ + torch.cat([prompt_audio[0][:, :-pre_feat_lb], feat.transpose(1, 2)], dim=1), + prompt_audio[1] + feat.shape[2] - pre_feat_lb + ] # v2: reback token and mel feat if text[0, -1] != self.endofprompt_token_id+1: @@ -1026,7 +1050,9 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback) token_reback = token_hop_len_2 - token_hop_len cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :] - feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio)) + feat_reback = int(token_reback * self.fm_model.length_normalizer_ratio) feat = feat[:, :, :feat.shape[2] - feat_reback] + wav_reback = int(1.0 / token_reback * wav_vocoder.sample_rate) + wav = wav[:, :wav.shape[1] - wav_reback] - return cur_token, feat + return cur_token, feat, wav, (prompt_token, token_reback), (prompt_audio, feat_reback)