diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index f50fb6658..319e1a88d 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2871,9 +2871,6 @@ class LLMASRXvecSlotTTS(nn.Module): # tts related inference, require the kv cache of llm last layer for only the current inputs # TODO: select kv cache of the current turn inputs - import pdb - - pdb.set_trace() attention_mask = batch.get("attention_mask", None) model_outputs = self.llm( inputs_embeds=inputs_embeds, @@ -2918,8 +2915,11 @@ class LLMASRXvecSlotTTS(nn.Module): ): assert llm_cur_kv_cache is not None set_all_random_seed(rand_seed) - speech_tokens, mel, wav = self.generate_speech( - response, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype] + # speech_tokens, mel, wav = self.generate_speech( + # response, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype] + # ) + speech_tokens, mel, wav = self.simulate_streaming_generate_speech( + target_ids, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype], tokenizer ) self.write_mel_wav(kwargs.get("output_dir"), mel, wav, key[0]) @@ -2942,12 +2942,142 @@ class LLMASRXvecSlotTTS(nn.Module): None, outside_prompt=llm_cur_kv_cache, outside_prompt_lengths=llm_cur_kv_cache_len, + sampling="threshold_1e-6", ) # vocoder forward wav = self.vocoder.inference(mel_feats.transpose(1, 2)) return speech_tokens, mel_feats, wav + def split_characters_and_words(self, input_string): + # 定义正则表达式模式 + pattern = r'[\u4e00-\u9fff]|[\w]+|[^\w\s]' + # 使用 re.findall 找到所有匹配的字符和单词 + results = re.findall(pattern, input_string) + return results + + def tts_tokenizer_warpper(self, text): + text_token = self.tts_text_tokenizer.text2tokens(text) + # remove the added pouc by ttsfrd. + if text[-1] != "。" and text_token[-1] == 1542: + text_token = text_token[:-1] + return text_token + + def generate_speech_one_step( + self, + text: str, last_t_size, + llm_cur_kv_cache, llm_cur_kv_cache_len, + prompt_token, prompt_audio, tts_text_chunk_size, + chunk_idx, is_last, para_len=30, + ): + device = llm_cur_kv_cache.device + pounc = ['。', '?', '!', ';', ':', '.', '?', '!', ';', '\n'] + + # remove duplicated pounctuations + normed_text = [] + for i, c in enumerate(text): + if i > 0 and text[i-1] in pounc and text[i] in pounc: + continue + normed_text.append(c) + text = "".join(normed_text) + + cur_token, feat, wav = None, None, None + _text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "") + text_token = self.tts_tokenizer_warpper(_text) + t_size = len(text_token) + if (t_size - last_t_size) >= tts_text_chunk_size or is_last: + text_token = torch.tensor([text_token], dtype=torch.long, device=device) + text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device) + cur_token, feat = self.tts_model.streaming_one_step( + text_token, text_token_len, + xvec=None, xvec_lengths=None, + prompt_dict={ + "prompt_token": prompt_token, + "prompt_audio": prompt_audio, + }, + outside_prompt=llm_cur_kv_cache, + outside_prompt_lengths=llm_cur_kv_cache_len, + sampling="threshold_1e-6", + chunk_idx=chunk_idx, + ) + if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0: + # process first package, token in B,T,D, feat in B,F,T + if prompt_token[0] is None: + prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)] + prompt_audio = [feat.transpose(1, 2), torch.tensor([feat.shape[2]], dtype=torch.long, device=device)] + else: + prompt_token[1] = prompt_token[1] + cur_token.shape[1] + prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1) + prompt_audio[1] = prompt_audio[1] + feat.shape[2] + prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1) + wav = self.vocoder.inference(feat.transpose(1, 2)) + chunk_idx += 1 + else: + cur_token, feat, wav = None, None, None + + # post process + last_t_size = t_size + # restart a new paragraph + # char_words = self.split_characters_and_words(text) + # if len(char_words) > para_len: + # # find the last pounc to split paragraph + # idx = -1 + # for i in range(len(char_words)-1, -1, -1): + # if char_words[i] in pounc: + # idx = i + # break + # if idx > 0: + # text = text[idx+1:] + # last_t_size = len(self.tts_tokenizer_warpper(text)) + + return ((cur_token, feat, wav), + (text, last_t_size, prompt_token, prompt_audio, chunk_idx)) + + def simulate_streaming_generate_speech(self, preds, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype, llm_tokenizer): + # self.tts_text_tokenizer = self.tts_text_tokenizer + self.vocoder.to(llm_dtype) + self.tts_model.to(llm_dtype) + llm_token_num_per_call = 3 + text_chunk_size = 8 + given_rtf = 0.5 + + token_list, feat_list, wav_list = [], [], [] + prompt_token, prompt_audio = [None, None], [None, None] + new_text, last_t_size, chunk_idx = "", 0, 0 + st, count = 0, 0 + while st < preds.shape[1]: + chunk_size = int(llm_token_num_per_call / (given_rtf ** min(count, 2))) + _resp = llm_tokenizer.batch_decode( + preds[:, st:st + chunk_size], + add_special_tokens=False, + skip_special_tokens=True, + )[0] + is_last = (st + chunk_size >= preds.shape[1]) + + new_text = new_text + _resp + rt_value, states = self.generate_speech_one_step( + new_text, last_t_size, + llm_cur_kv_cache, llm_cur_kv_cache_len, + prompt_token, prompt_audio, + text_chunk_size, + chunk_idx, is_last, + ) + cur_token, feat, wav = rt_value + new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states + # save results + if cur_token is not None and feat is not None and wav is not None: + token_list.append(cur_token) + feat_list.append(feat) + wav_list.append(wav) + + st += chunk_size + count += 1 + + speech_tokens = torch.cat(token_list, dim=1) + mel_feats = torch.cat(feat_list, dim=2) + wav = torch.cat(wav_list, dim=1) + return speech_tokens, mel_feats, wav + def write_mel_wav(self, output_dir, feat, wav, key): out_dir = os.path.join(output_dir, "1best_recog", "mels") os.makedirs(out_dir, exist_ok=True) diff --git a/funasr/models/llm_asr/tts_models/e2e_model.py b/funasr/models/llm_asr/tts_models/e2e_model.py index 9670405ab..d8f9a3204 100644 --- a/funasr/models/llm_asr/tts_models/e2e_model.py +++ b/funasr/models/llm_asr/tts_models/e2e_model.py @@ -741,7 +741,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): device = text.device use_causal_prob = kwargs.get("use_causal_prob", 1.0) # streaming related config - chunk_size = kwargs.get("streaming_chunk_size", 1) + chunk_size = kwargs.get("streaming_chunk_size", 4) chunk_size_maxium = kwargs.get("chunk_size_maxium", 16) try: lookahead_size = self.am_model.encoder.pre_lookahead_len @@ -899,3 +899,129 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel): break return tokens, prompt_audio[0].transpose(1, 2) + + def streaming_one_step( + self, text: torch.Tensor, text_lengths: torch.Tensor, + xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None, + chunk_idx=0, + **kwargs + ): + device = text.device + use_causal_prob = kwargs.get("use_causal_prob", 1.0) + # streaming related config + chunk_size = kwargs.get("streaming_chunk_size", 4) + chunk_size_maxium = kwargs.get("chunk_size_maxium", 16) + lookahead_size = self.am_model.encoder.pre_lookahead_len + hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, " + f"pre lookahead size={lookahead_size}.", + "pre_lookahead_len") + + blank_penalty = kwargs.get("blank_penalty", 0.0) + sampling = kwargs.get("sampling", "greedy") + prompt_dict = kwargs.get("prompt_dict", {}) + prompt_token = list(prompt_dict.get("prompt_token", [None, None])) + prompt_audio = list(prompt_dict.get("prompt_audio", [None, None])) + + ftype = self.text_embedding.weight.dtype + if prompt_token[0] is None: + prompt_token[0] = torch.zeros([1, 0, self.output_size], device=device, dtype=ftype) + prompt_token[1] = torch.tensor([0], device=device, dtype=torch.long) + if prompt_audio[0] is None: + prompt_audio[0] = torch.zeros( + [1, 0, self.fm_model.mel_extractor.num_mels], + device=device, dtype=ftype + ) + prompt_audio[1] = torch.tensor([0], device=device, dtype=torch.long) + + # embed text inputs + mask = (text != -1).float().unsqueeze(-1) + text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask + text_emb_lengths = text_lengths + + batch_size = text.shape[0] + + prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt( + text_emb, text_emb_lengths, text, text_lengths + ) + if "outside_prompt" in kwargs: + prompt = kwargs["outside_prompt"].to(device) + if "outside_prompt_lengths" in kwargs: + prompt_lens = kwargs["outside_prompt_lengths"] + else: + prompt_lens = torch.tensor([prompt.shape[1]]).to(text_lengths) + prompt = self.outside_prompt_poj(prompt) + hint_once("use outside_prompt", "outside_prompt") + + if xvec is not None: + # using speaker embedding + hint_once("using speaker embedding for slot.", "use_spk_emb") + xvec = xvec[:, :xvec_lengths.max()] + else: + # textual prompt xvec + hint_once("using textual prompt for slot.", "use_spk_emb") + prompt_xvec = self.spk_aggregator( + prompt, prompt_lens, + self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1] * batch_size).to(prompt_lens) + )[0] + xvec = self.prompt_xvec_proj(prompt_xvec) + xvec_lengths = torch.tensor([1] * batch_size).to(text_lengths) + + chunk_text_emb = text_emb + chunk_text_emb_lengths = torch.tensor([chunk_text_emb.shape[1]], dtype=torch.long, device=device) + + outs_tuple = self.text_encoder(chunk_text_emb, ilens=chunk_text_emb_lengths) + text_enc = outs_tuple[0] + text_enc_lens = chunk_text_emb_lengths + + # forward AM model + tokens, aligned_token_emb, aligned_token_lens = self.am_model.inference( + text_enc, text_enc_lens, + xvec, xvec_lengths, + sampling=sampling, + blank_penalty=blank_penalty, + text_is_embedding=True, + return_hidden=True, + use_causal_prob=use_causal_prob, + ) + token_hop_len, mel_hop_len = 0, 0 + if isinstance(tokens, tuple): + tokens, fa_tokens = tokens + token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size) + mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio)) + + cur_token, feat = None, None + # exclude empty tokens. + if aligned_token_emb.shape[1] > prompt_token[0].shape[1]: + cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:] + cur_token_len = aligned_token_lens - prompt_token[1] + + # v2: excluding lookahead tokens for not-last packages + if text[0, -1] != self.endofprompt_token_id+1: + cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :] + cur_token_len = cur_token_len - token_hop_len + + # forward FM model + feat = self.fm_model.inference( + cur_token, cur_token_len, + xvec, xvec_lengths, + prompt=dict( + prompt_text=prompt_token, + prompt_audio=prompt_audio, + ), + **kwargs, + ) + feat = self.rms_rescale_feat(feat) + print_token = tokens.cpu().squeeze().tolist() + logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, " + f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.") + + # v2: reback token and mel feat + if text[0, -1] != self.endofprompt_token_id+1: + text_reback = 2 if chunk_idx == 0 else 4 + token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback) + token_reback = token_hop_len_2 - token_hop_len + cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :] + feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio)) + feat = feat[:, :, :feat.shape[2] - feat_reback] + + return cur_token, feat diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm.py b/runtime/python/websocket/funasr_wss_server_streaming_llm.py index fb1bada7a..74d9cace6 100644 --- a/runtime/python/websocket/funasr_wss_server_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_server_streaming_llm.py @@ -29,6 +29,7 @@ parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu") parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu") parser.add_argument("--ncpu", type=int, default=4, help="cpu cores") parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res") +parser.add_argument("--no_vad", action="store_true", help="infer without vad") parser.add_argument( "--certfile", type=str, @@ -78,6 +79,7 @@ audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision # audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712" device = "cuda:0" all_file_paths = [ + # "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240910_streaming", "FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming", "FunAudioLLM/qwen2_7b_mmt_v15_20240902", "FunAudioLLM/qwen2_7b_mmt_v14_20240830", @@ -91,7 +93,6 @@ llm_kwargs = {"num_beams": 1, "do_sample": False, "repetition_penalty": 1.3} UNFIX_LEN = 5 MIN_LEN_PER_PARAGRAPH = 25 MIN_LEN_SEC_AUDIO_FIX = 1.1 -MAX_ITER_PER_CHUNK = 20 ckpt_dir = all_file_paths[0] @@ -483,32 +484,51 @@ async def ws_serve(websocket, path): frames_asr.append(message) # vad online - try: - speech_start_i, speech_end_i = await async_vad(websocket, message) - except: - print("error in vad") - if speech_start_i != -1: + if not args.no_vad: + try: + speech_start_i, speech_end_i = await async_vad(websocket, message) + except: + print("error in vad") + if speech_start_i != -1: + speech_start = True + speech_end_i = -1 + beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms + frames_pre = frames[-beg_bias:] + frames_asr = [] + frames_asr.extend(frames_pre) + else: speech_start = True - beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms - frames_pre = frames[-beg_bias:] + speech_end_i = -1 frames_asr = [] - frames_asr.extend(frames_pre) + frames_asr.extend(frames) # vad end if speech_end_i != -1 or not websocket.is_speaking: - audio_in = b"".join(frames_asr) - try: - await streaming_transcribe( - websocket, audio_in, is_vad_end=True, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt - ) - except Exception as e: - print(f"error in streaming, {e}") - print(f"error in streaming, {websocket.streaming_state}") + if speech_end_i != -1: + audio_in = b"".join(frames_asr) + try: + await streaming_transcribe( + websocket, audio_in, is_vad_end=True, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt + ) + except Exception as e: + print(f"error in streaming, {e}") + print(f"error in streaming, {websocket.streaming_state}") frames_asr = [] speech_start = False websocket.streaming_state["previous_asr_text"] = "" websocket.streaming_state["previous_s2tt_text"] = "" if not websocket.is_speaking: + message = json.dumps( + { + "mode": "online", + "asr_text": websocket.streaming_state["onscreen_asr_res"] + "", + "s2tt_text": websocket.streaming_state["onscreen_s2tt_res"] + "", + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + "is_sentence_end": True, + } + ) + await websocket.send(message) await clear_websocket() if args.return_sentence: websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""