From 4d001cc1858ed2230a9c829be874897ca98b583d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Wed, 18 Sep 2024 15:14:52 +0800 Subject: [PATCH] add cross fade for 4o --- funasr/models/llm_asr/tts_models/e2e_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/funasr/models/llm_asr/tts_models/e2e_model.py b/funasr/models/llm_asr/tts_models/e2e_model.py index 29fa347bc..5f78870ea 100644 --- a/funasr/models/llm_asr/tts_models/e2e_model.py +++ b/funasr/models/llm_asr/tts_models/e2e_model.py @@ -904,11 +904,11 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): def cross_fade(self, pre: torch.Tensor, feat: torch.Tensor, hop_size: int): if pre is not None: - hop_len = min(hop_size, feat.shape[1], pre.shape[1]) + hop_len = min(hop_size, feat.shape[1]) sin_wind = torch.tensor(np.sin((np.arange(hop_len * 2) + 1) / (hop_len * 2 + 1) * np.pi)[None, :]).to(feat) if feat.dim() == 3: sin_wind = sin_wind.unsqueeze(-1) - cf_overlap = ((pre[:, -hop_len:] * sin_wind[:, -hop_len:] + + cf_overlap = ((pre * sin_wind[:, -hop_len:] + feat[:, :hop_len] * sin_wind[:, :hop_len]) / (sin_wind[:, :hop_len] + sin_wind[:, -hop_len:])) feat[:, :hop_len] = cf_overlap @@ -1032,13 +1032,13 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, " f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.") if prompt_audio[0].shape[1] > 0: - feat = self.cross_fade(prompt_audio[0], feat.transpose(1, 2), pre_feat_lb).transpose(1, 2) + feat = self.cross_fade(prompt_audio[0][:, -pre_feat_lb:], feat.transpose(1, 2), pre_feat_lb).transpose(1, 2) wav = wav_vocoder.inference(feat.transpose(1, 2)) if prompt_audio[0].shape[1] > 0: pre_wav = wav_vocoder.inference(prompt_audio[0]) pre_wav_lb = int(1.0 / pre_token_lb * wav_vocoder.sample_rate) - wav = self.cross_fade(pre_wav, wav, pre_wav_lb) + wav = self.cross_fade(pre_wav[:, -pre_wav_lb:], wav, pre_wav_lb) prompt_token = [ torch.cat([prompt_token[0][:, :-pre_token_lb], cur_token], dim=1),