add cross fade for 4o

This commit is contained in:
志浩 2024-09-18 14:14:24 +08:00
parent 814dc25492
commit f7435706a9

View File

@ -1029,11 +1029,14 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
feat = self.cross_fade(prompt_audio[0], feat.transpose(1, 2), pre_feat_lb)
pre_wav = wav_vocoder.inference(prompt_audio[0])
cur_wav = wav_vocoder.inference(feat.transpose(1, 2))
pre_wav_lb = int(1.0 / pre_token_lb * wav_vocoder.sample_rate)
wav = self.cross_fade(pre_wav, cur_wav, pre_wav_lb)
if prompt_audio[0].shape[1] > 0:
feat = self.cross_fade(prompt_audio[0], feat.transpose(1, 2), pre_feat_lb).transpose(1, 2)
wav = wav_vocoder.inference(feat)
if prompt_audio[0].shape[1] > 0:
pre_wav = wav_vocoder.inference(prompt_audio[0])
pre_wav_lb = int(1.0 / pre_token_lb * wav_vocoder.sample_rate)
wav = self.cross_fade(pre_wav, wav, pre_wav_lb)
prompt_token = [
torch.cat([prompt_token[0][:, :-pre_token_lb], cur_token], dim=1),