mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
redo:fix mp3 bug
This commit is contained in:
parent
9edbcd5420
commit
829b1f8c72
@ -3068,7 +3068,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end:
|
||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||
cur_token, feat, wav, prompt_token, prompt_audio = self.tts_model.streaming_one_step(
|
||||
cur_token, feat = self.tts_model.streaming_one_step(
|
||||
text_token,
|
||||
text_token_len,
|
||||
xvec=None,
|
||||
@ -3081,10 +3081,25 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
outside_prompt_lengths=llm_cur_kv_cache_len,
|
||||
sampling="threshold_6e-1",
|
||||
chunk_idx=chunk_idx,
|
||||
vocoder=self.vocoder,
|
||||
diff_steps=5,
|
||||
)
|
||||
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
||||
# process first package, token in B,T,D, feat in B,F,T
|
||||
if prompt_token[0] is None:
|
||||
prompt_token = [
|
||||
cur_token,
|
||||
torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device),
|
||||
]
|
||||
prompt_audio = [
|
||||
feat.transpose(1, 2),
|
||||
torch.tensor([feat.shape[2]], dtype=torch.long, device=device),
|
||||
]
|
||||
else:
|
||||
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
|
||||
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
|
||||
prompt_audio[1] = prompt_audio[1] + feat.shape[2]
|
||||
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
|
||||
wav = self.vocoder.inference(feat.transpose(1, 2))
|
||||
chunk_idx += 1
|
||||
else:
|
||||
cur_token, feat, wav = None, None, None
|
||||
@ -3096,7 +3111,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
if para_end:
|
||||
text = "".join(preds[idx + 1:])
|
||||
last_t_size = 0
|
||||
prompt_token, prompt_audio = ([None, None], 0), ([None, None], 0)
|
||||
prompt_token, prompt_audio = [None, None], [None, None]
|
||||
wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1)
|
||||
chunk_idx = 0
|
||||
else:
|
||||
@ -3178,6 +3193,12 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
mp3_data = self.mp3_encoder.encode(wav.tobytes())
|
||||
if is_last:
|
||||
mp3_data += self.mp3_encoder.flush()
|
||||
import lameenc
|
||||
self.mp3_encoder = lameenc.Encoder()
|
||||
self.mp3_encoder.set_bit_rate(128)
|
||||
self.mp3_encoder.set_in_sample_rate(22050)
|
||||
self.mp3_encoder.set_channels(1)
|
||||
self.mp3_encoder.set_quality(2)
|
||||
|
||||
return mp3_data
|
||||
|
||||
@ -3242,8 +3263,8 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
states = {}
|
||||
states["new_text"] = ""
|
||||
states["last_t_size"] = 0
|
||||
states["prompt_token"] = ([None, None], 0)
|
||||
states["prompt_audio"] = ([None, None], 0)
|
||||
states["prompt_token"] = [None, None]
|
||||
states["prompt_audio"] = [None, None]
|
||||
states["chunk_idx"] = 0
|
||||
|
||||
def streaming_generate_speech(
|
||||
|
||||
@ -349,7 +349,7 @@ class UpsampleCtcTokenDiffModel(nn.Module):
|
||||
|
||||
return tokens, feat
|
||||
|
||||
def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.5):
|
||||
def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.1):
|
||||
feat_power = feat.exp().sum(1)
|
||||
# not silence
|
||||
if feat_power.max() > feat_sil_th:
|
||||
@ -902,19 +902,6 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
|
||||
return tokens, prompt_audio[0].transpose(1, 2)
|
||||
|
||||
def cross_fade(self, pre: torch.Tensor, feat: torch.Tensor, hop_size: int):
|
||||
if pre is not None:
|
||||
hop_len = min(hop_size, feat.shape[1])
|
||||
sin_wind = torch.tensor(np.sin((np.arange(hop_len * 2) + 1) / (hop_len * 2 + 1) * np.pi)[None, :]).to(feat)
|
||||
if feat.dim() == 3:
|
||||
sin_wind = sin_wind.unsqueeze(-1)
|
||||
cf_overlap = ((pre * sin_wind[:, -hop_len:] +
|
||||
feat[:, :hop_len] * sin_wind[:, :hop_len]) /
|
||||
(sin_wind[:, :hop_len] + sin_wind[:, -hop_len:]))
|
||||
feat[:, :hop_len] = cf_overlap
|
||||
|
||||
return feat
|
||||
|
||||
def streaming_one_step(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None,
|
||||
@ -930,12 +917,12 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
|
||||
f"pre lookahead size={lookahead_size}.",
|
||||
"pre_lookahead_len")
|
||||
wav_vocoder = kwargs.get("vocoder", None)
|
||||
|
||||
blank_penalty = kwargs.get("blank_penalty", 0.0)
|
||||
sampling = kwargs.get("sampling", "greedy")
|
||||
prompt_dict = kwargs.get("prompt_dict", {})
|
||||
prompt_token, pre_token_lb = list(prompt_dict.get("prompt_token", ([None, None], 0)))
|
||||
prompt_audio, pre_feat_lb = list(prompt_dict.get("prompt_audio", ([None, None], 0)))
|
||||
prompt_token = list(prompt_dict.get("prompt_token", [None, None]))
|
||||
prompt_audio = list(prompt_dict.get("prompt_audio", [None, None]))
|
||||
|
||||
ftype = self.text_embedding.weight.dtype
|
||||
if prompt_token[0] is None:
|
||||
@ -998,32 +985,33 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
return_hidden=True,
|
||||
use_causal_prob=use_causal_prob,
|
||||
)
|
||||
tokens, fa_tokens = tokens
|
||||
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
|
||||
token_hop_len, mel_hop_len = 0, 0
|
||||
if isinstance(tokens, tuple):
|
||||
tokens, fa_tokens = tokens
|
||||
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
|
||||
mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio))
|
||||
|
||||
cur_token, feat, wav = None, None, None
|
||||
token_reback, feat_reback = pre_token_lb, pre_feat_lb
|
||||
# generate feat, exclude empty tokens.
|
||||
cur_token, feat = None, None
|
||||
# exclude empty tokens.
|
||||
if aligned_token_emb.shape[1] > prompt_token[0].shape[1]:
|
||||
# need synthesize extra overlap parts
|
||||
cur_token = aligned_token_emb[:, prompt_token[0].shape[1] - pre_token_lb:]
|
||||
cur_token_len = aligned_token_lens - prompt_token[1] + pre_token_lb
|
||||
cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:]
|
||||
cur_token_len = aligned_token_lens - prompt_token[1]
|
||||
|
||||
# v2: excluding lookahead tokens for not-last packages
|
||||
if text[0, -1] != self.endofprompt_token_id+1:
|
||||
cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :]
|
||||
cur_token_len = cur_token_len - token_hop_len
|
||||
|
||||
if cur_token_len[0] < 1:
|
||||
return None, None, None, (prompt_token, pre_token_lb), (prompt_audio, pre_feat_lb)
|
||||
# forward FM model
|
||||
# set_all_random_seed(0)
|
||||
if cur_token_len[0] < 1:
|
||||
return None, None
|
||||
feat = self.fm_model.inference(
|
||||
cur_token, cur_token_len,
|
||||
xvec, xvec_lengths,
|
||||
prompt=dict(
|
||||
prompt_text=[prompt_token[0][:, :-pre_token_lb], prompt_token[1] - pre_token_lb],
|
||||
prompt_audio=[prompt_audio[0][:, :-pre_feat_lb], prompt_audio[1] - pre_feat_lb],
|
||||
prompt_text=prompt_token,
|
||||
prompt_audio=prompt_audio,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
@ -1031,23 +1019,6 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist()
|
||||
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
|
||||
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
|
||||
if prompt_audio[0].shape[1] > 0:
|
||||
feat = self.cross_fade(prompt_audio[0][:, -pre_feat_lb:], feat.transpose(1, 2), pre_feat_lb).transpose(1, 2)
|
||||
|
||||
wav = wav_vocoder.inference(feat.transpose(1, 2))
|
||||
if prompt_audio[0].shape[1] > 0:
|
||||
pre_wav = wav_vocoder.inference(prompt_audio[0])
|
||||
pre_wav_lb = int(1.0 / 25 * wav_vocoder.sample_rate * pre_token_lb)
|
||||
wav = self.cross_fade(pre_wav[:, -pre_wav_lb:], wav, pre_wav_lb)
|
||||
|
||||
prompt_token = [
|
||||
torch.cat([prompt_token[0][:, :-pre_token_lb], cur_token], dim=1),
|
||||
prompt_token[1] + cur_token_len - pre_token_lb,
|
||||
]
|
||||
prompt_audio = [
|
||||
torch.cat([prompt_audio[0][:, :-pre_feat_lb], feat.transpose(1, 2)], dim=1),
|
||||
prompt_audio[1] + feat.shape[2] - pre_feat_lb
|
||||
]
|
||||
|
||||
# v2: reback token and mel feat
|
||||
if text[0, -1] != self.endofprompt_token_id+1:
|
||||
@ -1055,9 +1026,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback)
|
||||
token_reback = token_hop_len_2 - token_hop_len
|
||||
cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :]
|
||||
feat_reback = int(token_reback * self.fm_model.length_normalizer_ratio)
|
||||
feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio))
|
||||
feat = feat[:, :, :feat.shape[2] - feat_reback]
|
||||
wav_reback = int(1.0 / 25 * wav_vocoder.sample_rate * token_reback)
|
||||
wav = wav[:, :wav.shape[1] - wav_reback]
|
||||
|
||||
return cur_token, feat, wav, (prompt_token, token_reback), (prompt_audio, feat_reback)
|
||||
return cur_token, feat
|
||||
|
||||
Loading…
Reference in New Issue
Block a user