diff --git a/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py b/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py
index f28de1faf..b8d768e8c 100644
--- a/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py
+++ b/examples/industrial_data_pretraining/llm_asr/app_chatbot_audio_audio.py
@@ -32,29 +32,34 @@ if len(sys.argv) > 1:
if len(sys.argv) > 6:
new_sys = True
else:
- ckpt_dir = "/data/zhifu.gzf/init_model/gpt4o-exp7-4"
- ckpt_id = "model.pt.ep1.140000"
+ ckpt_dir = "/data/zhifu.gzf/init_model/exp8-2"
+ ckpt_id = "model.pt.ep2.90000"
jsonl = (
"/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/s2tchat.v20240619.test.jsonl"
)
dataset = jsonl.split("/")[-1]
output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset)
- device = "cuda:1"
+ device = "cuda:5"
new_sys = True
+init_param = "/data/zhifu.gzf/init_model/gpt4o-exp7-4/model.pt.ep1.390000"
+init_param_ckpt = f"{os.path.join(ckpt_dir, ckpt_id)}"
+flow_init = "/data/zhifu.gzf/init_model/cosyvoice_flow_matching_for_streaming_with_prompt_random_cut_sft_zh_0630_25hz_1/60epoch.pth.prefix"
+vocoder_init = "/data/zhifu.gzf/init_model/hiftnet_1400k_cvt/model.pth.prefix"
+init_param = f"{init_param},{init_param_ckpt},{flow_init},{vocoder_init}"
model_llm = AutoModel(
model=ckpt_dir,
- init_param=f"{os.path.join(ckpt_dir, ckpt_id)}",
+ init_param=init_param,
output_dir=output_dir,
device=device,
fp16=False,
bf16=False,
llm_dtype="bf16",
)
-model = model_llm.model
-frontend = model_llm.kwargs["frontend"]
-tokenizer = model_llm.kwargs["tokenizer"]
+# model = model_llm.model
+# frontend = model_llm.kwargs["frontend"]
+# tokenizer = model_llm.kwargs["tokenizer"]
model_asr = AutoModel(
model="/data/zhifu.gzf/init_model/SenseVoice",
@@ -112,31 +117,22 @@ def model_inference(input_wav, text_inputs, state, turn_num, history):
print(f"contents_i: {contents_i}")
- inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
- [contents_i], None, "test_demo", tokenizer, frontend, device=device
+ res = model_llm.generate(
+ input=[contents_i],
+ tearchforing=False,
+ cache={},
+ key="test_demo",
)
- model_inputs = {}
- model_inputs["inputs_embeds"] = inputs_embeds
-
- streamer = TextIteratorStreamer(tokenizer)
-
- generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=200)
- thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
- thread.start()
- res = ""
- for new_text in streamer:
- print(f"generated new text: {new_text}")
- res += new_text.replace("<|im_end|>", "")
- print(f"total generated: {res}")
- history[-1][1] = res
+ res_text = res[0]["text"]
+ history[-1][1] = gr.Audio((16000, res[0]["wav"].flatten()), autoplay=True)
out_his = state.get("out", "")
- out = f"{out_his}" f"
" f"Q: {asr_out}" f"
" f"A: {res}"
+ out = f"{out_his}" f"
" f"Q: {asr_out}" f"
" f"A: {res_text}"
# out = f"{res}"
- contents_i[-1]["content"] = res
+ contents_i[-1]["content"] = res_text
state["contents_i"] = contents_i
state["out"] = out
# print(f'state_1: {state["contents_i"]}')
- return state, history
+ return state, history, out
def clear_state(audio_inputs, text_inputs, state, turn_num, chatbot):
@@ -221,20 +217,19 @@ def launch():
# fn_button = gr.Button("Start")
clear_button = gr.Button("Clear")
- # text_outputs = gr.HTML(label="Results")
+ text_outputs = gr.HTML(label="Results")
# fn_button.click(model_inference, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot])
- # with gr.Accordion("More examples"):
- # gr.HTML(centered_table_html)
+
audio_inputs.stop_recording(
model_inference,
inputs=[audio_inputs, text_inputs, state, turn_num, chatbot],
- outputs=[state, chatbot],
+ outputs=[state, chatbot, text_outputs],
)
audio_inputs.upload(
model_inference,
inputs=[audio_inputs, text_inputs, state, turn_num, chatbot],
- outputs=[state, chatbot],
+ outputs=[state, chatbot, text_outputs],
)
# clear.click(clear_state, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot], queue=False)
@@ -248,7 +243,7 @@ def launch():
demo.launch(
share=False,
server_name="0.0.0.0",
- server_port=12340,
+ server_port=12343,
ssl_certfile="./cert.pem",
ssl_keyfile="./key.pem",
inbrowser=True,
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 3382302e2..22323e320 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -1594,6 +1594,7 @@ class LLMASR5(nn.Module):
return None
if name == "MaskedDiffWithXvec":
from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec
+
return MaskedDiffWithXvec(**conf)
return None
@@ -1602,6 +1603,7 @@ class LLMASR5(nn.Module):
return None
if name == "HifiGAN":
from funasr.models.llm_asr.hifigan import HifiGan
+
return HifiGan(**conf)
return None
@@ -1974,6 +1976,8 @@ class LLMASR5(nn.Module):
if len(input_ids) > kwargs.get("max_token_length", 1500):
break
+ if isinstance(user_prompt, (list, tuple)):
+ user_prompt, audio = user_prompt
if i == 0:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
else:
@@ -1998,8 +2002,9 @@ class LLMASR5(nn.Module):
)
if sub_str.startswith("!"):
sub_str = sub_str[1:]
- if sub_str.startswith("!"): # !!bytes
- sub_str = eval(sub_str[1:])
+ sub_str = sub_str[1:]
+ if sub_str.startswith("!"): # !!: audio sample point
+ sub_str = audio
try:
time1 = time.perf_counter()
data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
@@ -2275,6 +2280,7 @@ class LLMASR5(nn.Module):
"text_tn": response_clean,
"label": label,
"speech_tokens": speech_tokens,
+ "wav": wav,
}
if loss is not None:
result_i["loss"] = loss
@@ -2305,8 +2311,11 @@ class LLMASR5(nn.Module):
if wav is not None:
path = os.path.join(out_dir, f"{key}.wav")
torchaudio.save(
- path, wav.cpu(), sample_rate=self.vocoder.sample_rate,
- encoding='PCM_S', bits_per_sample=16
+ path,
+ wav.cpu(),
+ sample_rate=self.vocoder.sample_rate,
+ encoding="PCM_S",
+ bits_per_sample=16,
)
def synthesize_waveform(self, speech_tokens, spk_emb, device):
@@ -2324,14 +2333,13 @@ class LLMASR5(nn.Module):
xvec_lens = torch.tensor([xvec.shape[1]], device=device, dtype=torch.int64)
token_lens = torch.tensor([tokens.shape[1]], device=device, dtype=torch.int64)
feat = self.mel_decoder.inference(
- tokens, token_lens,
- xvec, xvec_lens,
+ tokens,
+ token_lens,
+ xvec,
+ xvec_lens,
diff_steps=10,
temperature=1.0,
- prompt=dict(
- prompt_text=(None, None),
- prompt_audio=(None, None)
- )
+ prompt=dict(prompt_text=(None, None), prompt_audio=(None, None)),
)
return feat
@@ -2355,17 +2363,18 @@ class LLMASR5(nn.Module):
)
prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1)
seq_input = torch.zeros(
- [1, prompt.shape[1] + max_length, prompt.shape[2]],
- dtype=torch.float32, device=device
+ [1, prompt.shape[1] + max_length, prompt.shape[2]], dtype=torch.float32, device=device
)
- seq_input[:, :prompt.shape[1], :] = prompt
+ seq_input[:, : prompt.shape[1], :] = prompt
out_tokens = torch.zeros([1, max_length, 1], dtype=torch.int64, device=device)
out_token_len = 0
prompt_len = prompt.shape[1]
state, hit_eos = None, False
for i in range(max_length):
# use state for speedup
- pred, (state, _) = self.audio_decoder.score(seq_input[0, :prompt_len+out_token_len], state, prompt[0])
+ pred, (state, _) = self.audio_decoder.score(
+ seq_input[0, : prompt_len + out_token_len], state, prompt[0]
+ )
# sampling all `nq` token ids
pred = pred.reshape(self.predict_nq, -1)