diff --git a/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py b/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py index f28de1faf..b8d768e8c 100644 --- a/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py +++ b/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py @@ -32,29 +32,34 @@ if len(sys.argv) > 1: if len(sys.argv) > 6: new_sys = True else: - ckpt_dir = "/data/zhifu.gzf/init_model/gpt4o-exp7-4" - ckpt_id = "model.pt.ep1.140000" + ckpt_dir = "/data/zhifu.gzf/init_model/exp8-2" + ckpt_id = "model.pt.ep2.90000" jsonl = ( "/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/s2tchat.v20240619.test.jsonl" ) dataset = jsonl.split("/")[-1] output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset) - device = "cuda:1" + device = "cuda:5" new_sys = True +init_param = "/data/zhifu.gzf/init_model/gpt4o-exp7-4/model.pt.ep1.390000" +init_param_ckpt = f"{os.path.join(ckpt_dir, ckpt_id)}" +flow_init = "/data/zhifu.gzf/init_model/cosyvoice_flow_matching_for_streaming_with_prompt_random_cut_sft_zh_0630_25hz_1/60epoch.pth.prefix" +vocoder_init = "/data/zhifu.gzf/init_model/hiftnet_1400k_cvt/model.pth.prefix" +init_param = f"{init_param},{init_param_ckpt},{flow_init},{vocoder_init}" model_llm = AutoModel( model=ckpt_dir, - init_param=f"{os.path.join(ckpt_dir, ckpt_id)}", + init_param=init_param, output_dir=output_dir, device=device, fp16=False, bf16=False, llm_dtype="bf16", ) -model = model_llm.model -frontend = model_llm.kwargs["frontend"] -tokenizer = model_llm.kwargs["tokenizer"] +# model = model_llm.model +# frontend = model_llm.kwargs["frontend"] +# tokenizer = model_llm.kwargs["tokenizer"] model_asr = AutoModel( model="/data/zhifu.gzf/init_model/SenseVoice", @@ -112,31 +117,22 @@ def model_inference(input_wav, text_inputs, state, turn_num, history): print(f"contents_i: {contents_i}") - inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( - [contents_i], None, "test_demo", tokenizer, frontend, device=device + res = model_llm.generate( + input=[contents_i], + tearchforing=False, + cache={}, + key="test_demo", ) - model_inputs = {} - model_inputs["inputs_embeds"] = inputs_embeds - - streamer = TextIteratorStreamer(tokenizer) - - generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=200) - thread = Thread(target=model.llm.generate, kwargs=generation_kwargs) - thread.start() - res = "" - for new_text in streamer: - print(f"generated new text: {new_text}") - res += new_text.replace("<|im_end|>", "") - print(f"total generated: {res}") - history[-1][1] = res + res_text = res[0]["text"] + history[-1][1] = gr.Audio((16000, res[0]["wav"].flatten()), autoplay=True) out_his = state.get("out", "") - out = f"{out_his}" f"

" f"Q: {asr_out}" f"
" f"A: {res}" + out = f"{out_his}" f"

" f"Q: {asr_out}" f"
" f"A: {res_text}" # out = f"{res}" - contents_i[-1]["content"] = res + contents_i[-1]["content"] = res_text state["contents_i"] = contents_i state["out"] = out # print(f'state_1: {state["contents_i"]}') - return state, history + return state, history, out def clear_state(audio_inputs, text_inputs, state, turn_num, chatbot): @@ -221,20 +217,19 @@ def launch(): # fn_button = gr.Button("Start") clear_button = gr.Button("Clear") - # text_outputs = gr.HTML(label="Results") + text_outputs = gr.HTML(label="Results") # fn_button.click(model_inference, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot]) - # with gr.Accordion("More examples"): - # gr.HTML(centered_table_html) + audio_inputs.stop_recording( model_inference, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], - outputs=[state, chatbot], + outputs=[state, chatbot, text_outputs], ) audio_inputs.upload( model_inference, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], - outputs=[state, chatbot], + outputs=[state, chatbot, text_outputs], ) # clear.click(clear_state, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot], queue=False) @@ -248,7 +243,7 @@ def launch(): demo.launch( share=False, server_name="0.0.0.0", - server_port=12340, + server_port=12343, ssl_certfile="./cert.pem", ssl_keyfile="./key.pem", inbrowser=True, diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 3382302e2..22323e320 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -1594,6 +1594,7 @@ class LLMASR5(nn.Module): return None if name == "MaskedDiffWithXvec": from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec + return MaskedDiffWithXvec(**conf) return None @@ -1602,6 +1603,7 @@ class LLMASR5(nn.Module): return None if name == "HifiGAN": from funasr.models.llm_asr.hifigan import HifiGan + return HifiGan(**conf) return None @@ -1974,6 +1976,8 @@ class LLMASR5(nn.Module): if len(input_ids) > kwargs.get("max_token_length", 1500): break + if isinstance(user_prompt, (list, tuple)): + user_prompt, audio = user_prompt if i == 0: source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n" else: @@ -1998,8 +2002,9 @@ class LLMASR5(nn.Module): ) if sub_str.startswith("!"): sub_str = sub_str[1:] - if sub_str.startswith("!"): # !!bytes - sub_str = eval(sub_str[1:]) + sub_str = sub_str[1:] + if sub_str.startswith("!"): # !!: audio sample point + sub_str = audio try: time1 = time.perf_counter() data_src = load_audio_text_image_video(sub_str, fs=frontend.fs) @@ -2275,6 +2280,7 @@ class LLMASR5(nn.Module): "text_tn": response_clean, "label": label, "speech_tokens": speech_tokens, + "wav": wav, } if loss is not None: result_i["loss"] = loss @@ -2305,8 +2311,11 @@ class LLMASR5(nn.Module): if wav is not None: path = os.path.join(out_dir, f"{key}.wav") torchaudio.save( - path, wav.cpu(), sample_rate=self.vocoder.sample_rate, - encoding='PCM_S', bits_per_sample=16 + path, + wav.cpu(), + sample_rate=self.vocoder.sample_rate, + encoding="PCM_S", + bits_per_sample=16, ) def synthesize_waveform(self, speech_tokens, spk_emb, device): @@ -2324,14 +2333,13 @@ class LLMASR5(nn.Module): xvec_lens = torch.tensor([xvec.shape[1]], device=device, dtype=torch.int64) token_lens = torch.tensor([tokens.shape[1]], device=device, dtype=torch.int64) feat = self.mel_decoder.inference( - tokens, token_lens, - xvec, xvec_lens, + tokens, + token_lens, + xvec, + xvec_lens, diff_steps=10, temperature=1.0, - prompt=dict( - prompt_text=(None, None), - prompt_audio=(None, None) - ) + prompt=dict(prompt_text=(None, None), prompt_audio=(None, None)), ) return feat @@ -2355,17 +2363,18 @@ class LLMASR5(nn.Module): ) prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1) seq_input = torch.zeros( - [1, prompt.shape[1] + max_length, prompt.shape[2]], - dtype=torch.float32, device=device + [1, prompt.shape[1] + max_length, prompt.shape[2]], dtype=torch.float32, device=device ) - seq_input[:, :prompt.shape[1], :] = prompt + seq_input[:, : prompt.shape[1], :] = prompt out_tokens = torch.zeros([1, max_length, 1], dtype=torch.int64, device=device) out_token_len = 0 prompt_len = prompt.shape[1] state, hit_eos = None, False for i in range(max_length): # use state for speedup - pred, (state, _) = self.audio_decoder.score(seq_input[0, :prompt_len+out_token_len], state, prompt[0]) + pred, (state, _) = self.audio_decoder.score( + seq_input[0, : prompt_len + out_token_len], state, prompt[0] + ) # sampling all `nq` token ids pred = pred.reshape(self.predict_nq, -1)