mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add cross fade for 4o
This commit is contained in:
parent
2c3ae95bbf
commit
814dc25492
@ -3062,7 +3062,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end:
|
||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||
cur_token, feat = self.tts_model.streaming_one_step(
|
||||
cur_token, feat, wav, prompt_token, prompt_audio = self.tts_model.streaming_one_step(
|
||||
text_token,
|
||||
text_token_len,
|
||||
xvec=None,
|
||||
@ -3075,24 +3075,9 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
outside_prompt_lengths=llm_cur_kv_cache_len,
|
||||
sampling="threshold_6e-1",
|
||||
chunk_idx=chunk_idx,
|
||||
vocoder=self.vocoder,
|
||||
)
|
||||
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
||||
# process first package, token in B,T,D, feat in B,F,T
|
||||
if prompt_token[0] is None:
|
||||
prompt_token = [
|
||||
cur_token,
|
||||
torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device),
|
||||
]
|
||||
prompt_audio = [
|
||||
feat.transpose(1, 2),
|
||||
torch.tensor([feat.shape[2]], dtype=torch.long, device=device),
|
||||
]
|
||||
else:
|
||||
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
|
||||
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
|
||||
prompt_audio[1] = prompt_audio[1] + feat.shape[2]
|
||||
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
|
||||
wav = self.vocoder.inference(feat.transpose(1, 2))
|
||||
chunk_idx += 1
|
||||
else:
|
||||
cur_token, feat, wav = None, None, None
|
||||
@ -3247,8 +3232,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
states = {}
|
||||
states["new_text"] = ""
|
||||
states["last_t_size"] = 0
|
||||
states["prompt_token"] = [None, None]
|
||||
states["prompt_audio"] = [None, None]
|
||||
states["prompt_token"] = ([None, None], 0)
|
||||
states["prompt_audio"] = ([None, None], 0)
|
||||
states["chunk_idx"] = 0
|
||||
|
||||
def streaming_generate_speech(
|
||||
|
||||
@ -349,7 +349,7 @@ class UpsampleCtcTokenDiffModel(nn.Module):
|
||||
|
||||
return tokens, feat
|
||||
|
||||
def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.1):
|
||||
def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.5):
|
||||
feat_power = feat.exp().sum(1)
|
||||
# not silence
|
||||
if feat_power.max() > feat_sil_th:
|
||||
@ -902,6 +902,17 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
|
||||
return tokens, prompt_audio[0].transpose(1, 2)
|
||||
|
||||
def cross_fade(self, pre: torch.Tensor, feat: torch.Tensor, hop_size: int):
|
||||
if pre is not None:
|
||||
hop_len = min(hop_size, feat.shape[1], pre.shape[1])
|
||||
sin_wind = torch.tensor(np.sin((np.arange(hop_len * 2) + 1) / (hop_len * 2 + 1) * np.pi)[None, :, None]).to(feat)
|
||||
cf_overlap = ((pre * sin_wind[:, -hop_len:] +
|
||||
feat[:, :hop_len] * sin_wind[:, :hop_len]) /
|
||||
(sin_wind[:, :hop_len] + sin_wind[:, -hop_len:]))
|
||||
feat[:, :hop_len] = cf_overlap
|
||||
|
||||
return feat
|
||||
|
||||
def streaming_one_step(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None,
|
||||
@ -917,12 +928,12 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
|
||||
f"pre lookahead size={lookahead_size}.",
|
||||
"pre_lookahead_len")
|
||||
|
||||
wav_vocoder = kwargs.get("vocoder", None)
|
||||
blank_penalty = kwargs.get("blank_penalty", 0.0)
|
||||
sampling = kwargs.get("sampling", "greedy")
|
||||
prompt_dict = kwargs.get("prompt_dict", {})
|
||||
prompt_token = list(prompt_dict.get("prompt_token", [None, None]))
|
||||
prompt_audio = list(prompt_dict.get("prompt_audio", [None, None]))
|
||||
prompt_token, pre_token_lb = list(prompt_dict.get("prompt_token", ([None, None], 0)))
|
||||
prompt_audio, pre_feat_lb = list(prompt_dict.get("prompt_audio", ([None, None], 0)))
|
||||
|
||||
ftype = self.text_embedding.weight.dtype
|
||||
if prompt_token[0] is None:
|
||||
@ -985,33 +996,32 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
return_hidden=True,
|
||||
use_causal_prob=use_causal_prob,
|
||||
)
|
||||
token_hop_len, mel_hop_len = 0, 0
|
||||
if isinstance(tokens, tuple):
|
||||
tokens, fa_tokens = tokens
|
||||
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
|
||||
mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio))
|
||||
tokens, fa_tokens = tokens
|
||||
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
|
||||
|
||||
cur_token, feat = None, None
|
||||
# exclude empty tokens.
|
||||
cur_token, feat, wav = None, None, None
|
||||
token_reback, feat_reback = pre_token_lb, pre_feat_lb
|
||||
# generate feat, exclude empty tokens.
|
||||
if aligned_token_emb.shape[1] > prompt_token[0].shape[1]:
|
||||
cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:]
|
||||
cur_token_len = aligned_token_lens - prompt_token[1]
|
||||
# need synthesize extra overlap parts
|
||||
cur_token = aligned_token_emb[:, prompt_token[0].shape[1] - pre_token_lb:]
|
||||
cur_token_len = aligned_token_lens - prompt_token[1] + pre_token_lb
|
||||
|
||||
# v2: excluding lookahead tokens for not-last packages
|
||||
if text[0, -1] != self.endofprompt_token_id+1:
|
||||
cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :]
|
||||
cur_token_len = cur_token_len - token_hop_len
|
||||
|
||||
if cur_token_len[0] < 1:
|
||||
return None, None, None, (prompt_token, pre_token_lb), (prompt_audio, pre_feat_lb)
|
||||
# forward FM model
|
||||
# set_all_random_seed(0)
|
||||
if cur_token_len[0] < 1:
|
||||
return None, None
|
||||
feat = self.fm_model.inference(
|
||||
cur_token, cur_token_len,
|
||||
xvec, xvec_lengths,
|
||||
prompt=dict(
|
||||
prompt_text=prompt_token,
|
||||
prompt_audio=prompt_audio,
|
||||
prompt_text=[prompt_token[0][:, :-pre_token_lb], prompt_token[1] - pre_token_lb],
|
||||
prompt_audio=[prompt_audio[0][:, :-pre_feat_lb], prompt_audio[1] - pre_feat_lb],
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
@ -1019,6 +1029,20 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
|
||||
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
|
||||
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
|
||||
feat = self.cross_fade(prompt_audio[0], feat.transpose(1, 2), pre_feat_lb)
|
||||
pre_wav = wav_vocoder.inference(prompt_audio[0])
|
||||
cur_wav = wav_vocoder.inference(feat.transpose(1, 2))
|
||||
pre_wav_lb = int(1.0 / pre_token_lb * wav_vocoder.sample_rate)
|
||||
wav = self.cross_fade(pre_wav, cur_wav, pre_wav_lb)
|
||||
|
||||
prompt_token = [
|
||||
torch.cat([prompt_token[0][:, :-pre_token_lb], cur_token], dim=1),
|
||||
prompt_token[1] + cur_token_len - pre_token_lb,
|
||||
]
|
||||
prompt_audio = [
|
||||
torch.cat([prompt_audio[0][:, :-pre_feat_lb], feat.transpose(1, 2)], dim=1),
|
||||
prompt_audio[1] + feat.shape[2] - pre_feat_lb
|
||||
]
|
||||
|
||||
# v2: reback token and mel feat
|
||||
if text[0, -1] != self.endofprompt_token_id+1:
|
||||
@ -1026,7 +1050,9 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback)
|
||||
token_reback = token_hop_len_2 - token_hop_len
|
||||
cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :]
|
||||
feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio))
|
||||
feat_reback = int(token_reback * self.fm_model.length_normalizer_ratio)
|
||||
feat = feat[:, :, :feat.shape[2] - feat_reback]
|
||||
wav_reback = int(1.0 / token_reback * wav_vocoder.sample_rate)
|
||||
wav = wav[:, :wav.shape[1] - wav_reback]
|
||||
|
||||
return cur_token, feat
|
||||
return cur_token, feat, wav, (prompt_token, token_reback), (prompt_audio, feat_reback)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user