From 829b1f8c7247e71f2936318f4200d28c9ed4f5a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E6=B5=A9?= Date: Wed, 18 Sep 2024 23:04:25 +0800 Subject: [PATCH] redo:fix mp3 bug --- funasr/models/llm_asr/model.py | 31 +++++++-- funasr/models/llm_asr/tts_models/e2e_model.py | 69 +++++-------------- 2 files changed, 45 insertions(+), 55 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 36ef5b3e9..c452f286b 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -3068,7 +3068,7 @@ class LLMASRXvecSlotTTS(nn.Module): if (t_size - last_t_size) >= tts_text_chunk_size or is_last or para_end: text_token = torch.tensor([text_token], dtype=torch.long, device=device) text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) - cur_token, feat, wav, prompt_token, prompt_audio = self.tts_model.streaming_one_step( + cur_token, feat = self.tts_model.streaming_one_step( text_token, text_token_len, xvec=None, @@ -3081,10 +3081,25 @@ class LLMASRXvecSlotTTS(nn.Module): outside_prompt_lengths=llm_cur_kv_cache_len, sampling="threshold_6e-1", chunk_idx=chunk_idx, - vocoder=self.vocoder, diff_steps=5, ) if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0: + # process first package, token in B,T,D, feat in B,F,T + if prompt_token[0] is None: + prompt_token = [ + cur_token, + torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device), + ] + prompt_audio = [ + feat.transpose(1, 2), + torch.tensor([feat.shape[2]], dtype=torch.long, device=device), + ] + else: + prompt_token[1] = prompt_token[1] + cur_token.shape[1] + prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1) + prompt_audio[1] = prompt_audio[1] + feat.shape[2] + prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1) + wav = self.vocoder.inference(feat.transpose(1, 2)) chunk_idx += 1 else: cur_token, feat, wav = None, None, None @@ -3096,7 +3111,7 @@ class LLMASRXvecSlotTTS(nn.Module): if para_end: text = "".join(preds[idx + 1:]) last_t_size = 0 - prompt_token, prompt_audio = ([None, None], 0), ([None, None], 0) + prompt_token, prompt_audio = [None, None], [None, None] wav = torch.cat([wav, torch.zeros([1, 2205]).to(wav)], dim=1) chunk_idx = 0 else: @@ -3178,6 +3193,12 @@ class LLMASRXvecSlotTTS(nn.Module): mp3_data = self.mp3_encoder.encode(wav.tobytes()) if is_last: mp3_data += self.mp3_encoder.flush() + import lameenc + self.mp3_encoder = lameenc.Encoder() + self.mp3_encoder.set_bit_rate(128) + self.mp3_encoder.set_in_sample_rate(22050) + self.mp3_encoder.set_channels(1) + self.mp3_encoder.set_quality(2) return mp3_data @@ -3242,8 +3263,8 @@ class LLMASRXvecSlotTTS(nn.Module): states = {} states["new_text"] = "" states["last_t_size"] = 0 - states["prompt_token"] = ([None, None], 0) - states["prompt_audio"] = ([None, None], 0) + states["prompt_token"] = [None, None] + states["prompt_audio"] = [None, None] states["chunk_idx"] = 0 def streaming_generate_speech( diff --git a/funasr/models/llm_asr/tts_models/e2e_model.py b/funasr/models/llm_asr/tts_models/e2e_model.py index baa34bf19..482328407 100644 --- a/funasr/models/llm_asr/tts_models/e2e_model.py +++ b/funasr/models/llm_asr/tts_models/e2e_model.py @@ -349,7 +349,7 @@ class UpsampleCtcTokenDiffModel(nn.Module): return tokens, feat - def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.5): + def rms_rescale_feat(self, feat, target_feat_rms=3.5, feat_sil_th=0.1): feat_power = feat.exp().sum(1) # not silence if feat_power.max() > feat_sil_th: @@ -902,19 +902,6 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): return tokens, prompt_audio[0].transpose(1, 2) - def cross_fade(self, pre: torch.Tensor, feat: torch.Tensor, hop_size: int): - if pre is not None: - hop_len = min(hop_size, feat.shape[1]) - sin_wind = torch.tensor(np.sin((np.arange(hop_len * 2) + 1) / (hop_len * 2 + 1) * np.pi)[None, :]).to(feat) - if feat.dim() == 3: - sin_wind = sin_wind.unsqueeze(-1) - cf_overlap = ((pre * sin_wind[:, -hop_len:] + - feat[:, :hop_len] * sin_wind[:, :hop_len]) / - (sin_wind[:, :hop_len] + sin_wind[:, -hop_len:])) - feat[:, :hop_len] = cf_overlap - - return feat - def streaming_one_step( self, text: torch.Tensor, text_lengths: torch.Tensor, xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None, @@ -930,12 +917,12 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, " f"pre lookahead size={lookahead_size}.", "pre_lookahead_len") - wav_vocoder = kwargs.get("vocoder", None) + blank_penalty = kwargs.get("blank_penalty", 0.0) sampling = kwargs.get("sampling", "greedy") prompt_dict = kwargs.get("prompt_dict", {}) - prompt_token, pre_token_lb = list(prompt_dict.get("prompt_token", ([None, None], 0))) - prompt_audio, pre_feat_lb = list(prompt_dict.get("prompt_audio", ([None, None], 0))) + prompt_token = list(prompt_dict.get("prompt_token", [None, None])) + prompt_audio = list(prompt_dict.get("prompt_audio", [None, None])) ftype = self.text_embedding.weight.dtype if prompt_token[0] is None: @@ -998,32 +985,33 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): return_hidden=True, use_causal_prob=use_causal_prob, ) - tokens, fa_tokens = tokens - token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size) + token_hop_len, mel_hop_len = 0, 0 + if isinstance(tokens, tuple): + tokens, fa_tokens = tokens + token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size) + mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio)) - cur_token, feat, wav = None, None, None - token_reback, feat_reback = pre_token_lb, pre_feat_lb - # generate feat, exclude empty tokens. + cur_token, feat = None, None + # exclude empty tokens. if aligned_token_emb.shape[1] > prompt_token[0].shape[1]: - # need synthesize extra overlap parts - cur_token = aligned_token_emb[:, prompt_token[0].shape[1] - pre_token_lb:] - cur_token_len = aligned_token_lens - prompt_token[1] + pre_token_lb + cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:] + cur_token_len = aligned_token_lens - prompt_token[1] # v2: excluding lookahead tokens for not-last packages if text[0, -1] != self.endofprompt_token_id+1: cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :] cur_token_len = cur_token_len - token_hop_len - if cur_token_len[0] < 1: - return None, None, None, (prompt_token, pre_token_lb), (prompt_audio, pre_feat_lb) # forward FM model # set_all_random_seed(0) + if cur_token_len[0] < 1: + return None, None feat = self.fm_model.inference( cur_token, cur_token_len, xvec, xvec_lengths, prompt=dict( - prompt_text=[prompt_token[0][:, :-pre_token_lb], prompt_token[1] - pre_token_lb], - prompt_audio=[prompt_audio[0][:, :-pre_feat_lb], prompt_audio[1] - pre_feat_lb], + prompt_text=prompt_token, + prompt_audio=prompt_audio, ), **kwargs, ) @@ -1031,23 +1019,6 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): print_token = tokens.cpu().squeeze(0).squeeze(-1).tolist() logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, " f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.") - if prompt_audio[0].shape[1] > 0: - feat = self.cross_fade(prompt_audio[0][:, -pre_feat_lb:], feat.transpose(1, 2), pre_feat_lb).transpose(1, 2) - - wav = wav_vocoder.inference(feat.transpose(1, 2)) - if prompt_audio[0].shape[1] > 0: - pre_wav = wav_vocoder.inference(prompt_audio[0]) - pre_wav_lb = int(1.0 / 25 * wav_vocoder.sample_rate * pre_token_lb) - wav = self.cross_fade(pre_wav[:, -pre_wav_lb:], wav, pre_wav_lb) - - prompt_token = [ - torch.cat([prompt_token[0][:, :-pre_token_lb], cur_token], dim=1), - prompt_token[1] + cur_token_len - pre_token_lb, - ] - prompt_audio = [ - torch.cat([prompt_audio[0][:, :-pre_feat_lb], feat.transpose(1, 2)], dim=1), - prompt_audio[1] + feat.shape[2] - pre_feat_lb - ] # v2: reback token and mel feat if text[0, -1] != self.endofprompt_token_id+1: @@ -1055,9 +1026,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback) token_reback = token_hop_len_2 - token_hop_len cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :] - feat_reback = int(token_reback * self.fm_model.length_normalizer_ratio) + feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio)) feat = feat[:, :, :feat.shape[2] - feat_reback] - wav_reback = int(1.0 / 25 * wav_vocoder.sample_rate * token_reback) - wav = wav[:, :wav.shape[1] - wav_reback] - return cur_token, feat, wav, (prompt_token, token_reback), (prompt_audio, feat_reback) + return cur_token, feat