add cross fade for 4o

This commit is contained in:
志浩 2024-09-18 13:59:14 +08:00
parent 2c3ae95bbf
commit 814dc25492
2 changed files with 49 additions and 38 deletions

View File

@ -3062,7 +3062,7 @@ class LLMASRXvecSlotTTS(nn.Module):
if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end:
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
cur_token, feat = self.tts_model.streaming_one_step(
cur_token, feat, wav, prompt_token, prompt_audio = self.tts_model.streaming_one_step(
text_token,
text_token_len,
xvec=None,
@ -3075,24 +3075,9 @@ class LLMASRXvecSlotTTS(nn.Module):
outside_prompt_lengths=llm_cur_kv_cache_len,
sampling="threshold_6e-1",
chunk_idx=chunk_idx,
vocoder=self.vocoder,
)
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
# process first package, token in B,T,D, feat in B,F,T
if prompt_token[0] is None:
prompt_token = [
cur_token,
torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device),
]
prompt_audio = [
feat.transpose(1, 2),
torch.tensor([feat.shape[2]], dtype=torch.long, device=device),
]
else:
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
prompt_audio[1] = prompt_audio[1] + feat.shape[2]
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
wav = self.vocoder.inference(feat.transpose(1, 2))
chunk_idx += 1
else:
cur_token, feat, wav = None, None, None
@ -3247,8 +3232,8 @@ class LLMASRXvecSlotTTS(nn.Module):
states = {}
states["new_text"] = ""
states["last_t_size"] = 0
states["prompt_token"] = [None, None]
states["prompt_audio"] = [None, None]
states["prompt_token"] = ([None, None], 0)
states["prompt_audio"] = ([None, None], 0)
states["chunk_idx"] = 0
def streaming_generate_speech(

View File

@ -349,7 +349,7 @@ class UpsampleCtcTokenDiffModel(nn.Module):
return tokens, feat
def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.1):
def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.5):
feat_power = feat.exp().sum(1)
# not silence
if feat_power.max() > feat_sil_th:
@ -902,6 +902,17 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
return tokens, prompt_audio[0].transpose(1, 2)
def cross_fade(self, pre: torch.Tensor, feat: torch.Tensor, hop_size: int):
if pre is not None:
hop_len = min(hop_size, feat.shape[1], pre.shape[1])
sin_wind = torch.tensor(np.sin((np.arange(hop_len * 2) + 1) / (hop_len * 2 + 1) * np.pi)[None, :, None]).to(feat)
cf_overlap = ((pre * sin_wind[:, -hop_len:] +
feat[:, :hop_len] * sin_wind[:, :hop_len]) /
(sin_wind[:, :hop_len] + sin_wind[:, -hop_len:]))
feat[:, :hop_len] = cf_overlap
return feat
def streaming_one_step(
self, text: torch.Tensor, text_lengths: torch.Tensor,
xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None,
@ -917,12 +928,12 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
f"pre lookahead size={lookahead_size}.",
"pre_lookahead_len")
wav_vocoder = kwargs.get("vocoder", None)
blank_penalty = kwargs.get("blank_penalty", 0.0)
sampling = kwargs.get("sampling", "greedy")
prompt_dict = kwargs.get("prompt_dict", {})
prompt_token = list(prompt_dict.get("prompt_token", [None, None]))
prompt_audio = list(prompt_dict.get("prompt_audio", [None, None]))
prompt_token, pre_token_lb = list(prompt_dict.get("prompt_token", ([None, None], 0)))
prompt_audio, pre_feat_lb = list(prompt_dict.get("prompt_audio", ([None, None], 0)))
ftype = self.text_embedding.weight.dtype
if prompt_token[0] is None:
@ -985,33 +996,32 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
return_hidden=True,
use_causal_prob=use_causal_prob,
)
token_hop_len, mel_hop_len = 0, 0
if isinstance(tokens, tuple):
tokens, fa_tokens = tokens
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio))
tokens, fa_tokens = tokens
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
cur_token, feat = None, None
# exclude empty tokens.
cur_token, feat, wav = None, None, None
token_reback, feat_reback = pre_token_lb, pre_feat_lb
# generate feat, exclude empty tokens.
if aligned_token_emb.shape[1] > prompt_token[0].shape[1]:
cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:]
cur_token_len = aligned_token_lens - prompt_token[1]
# need synthesize extra overlap parts
cur_token = aligned_token_emb[:, prompt_token[0].shape[1] - pre_token_lb:]
cur_token_len = aligned_token_lens - prompt_token[1] + pre_token_lb
# v2: excluding lookahead tokens for not-last packages
if text[0, -1] != self.endofprompt_token_id+1:
cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :]
cur_token_len = cur_token_len - token_hop_len
if cur_token_len[0] < 1:
return None, None, None, (prompt_token, pre_token_lb), (prompt_audio, pre_feat_lb)
# forward FM model
# set_all_random_seed(0)
if cur_token_len[0] < 1:
return None, None
feat = self.fm_model.inference(
cur_token, cur_token_len,
xvec, xvec_lengths,
prompt=dict(
prompt_text=prompt_token,
prompt_audio=prompt_audio,
prompt_text=[prompt_token[0][:, :-pre_token_lb], prompt_token[1] - pre_token_lb],
prompt_audio=[prompt_audio[0][:, :-pre_feat_lb], prompt_audio[1] - pre_feat_lb],
),
**kwargs,
)
@ -1019,6 +1029,20 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
feat = self.cross_fade(prompt_audio[0], feat.transpose(1, 2), pre_feat_lb)
pre_wav = wav_vocoder.inference(prompt_audio[0])
cur_wav = wav_vocoder.inference(feat.transpose(1, 2))
pre_wav_lb = int(1.0 / pre_token_lb * wav_vocoder.sample_rate)
wav = self.cross_fade(pre_wav, cur_wav, pre_wav_lb)
prompt_token = [
torch.cat([prompt_token[0][:, :-pre_token_lb], cur_token], dim=1),
prompt_token[1] + cur_token_len - pre_token_lb,
]
prompt_audio = [
torch.cat([prompt_audio[0][:, :-pre_feat_lb], feat.transpose(1, 2)], dim=1),
prompt_audio[1] + feat.shape[2] - pre_feat_lb
]
# v2: reback token and mel feat
if text[0, -1] != self.endofprompt_token_id+1:
@ -1026,7 +1050,9 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback)
token_reback = token_hop_len_2 - token_hop_len
cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :]
feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio))
feat_reback = int(token_reback * self.fm_model.length_normalizer_ratio)
feat = feat[:, :, :feat.shape[2] - feat_reback]
wav_reback = int(1.0 / token_reback * wav_vocoder.sample_rate)
wav = wav[:, :wav.shape[1] - wav_reback]
return cur_token, feat
return cur_token, feat, wav, (prompt_token, token_reback), (prompt_audio, feat_reback)