mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
a1af287928
commit
1e9e3f864e
@ -32,29 +32,34 @@ if len(sys.argv) > 1:
|
||||
if len(sys.argv) > 6:
|
||||
new_sys = True
|
||||
else:
|
||||
ckpt_dir = "/data/zhifu.gzf/init_model/gpt4o-exp7-4"
|
||||
ckpt_id = "model.pt.ep1.140000"
|
||||
ckpt_dir = "/data/zhifu.gzf/init_model/exp8-2"
|
||||
ckpt_id = "model.pt.ep2.90000"
|
||||
jsonl = (
|
||||
"/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/s2tchat.v20240619.test.jsonl"
|
||||
)
|
||||
dataset = jsonl.split("/")[-1]
|
||||
output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset)
|
||||
device = "cuda:1"
|
||||
device = "cuda:5"
|
||||
new_sys = True
|
||||
|
||||
init_param = "/data/zhifu.gzf/init_model/gpt4o-exp7-4/model.pt.ep1.390000"
|
||||
init_param_ckpt = f"{os.path.join(ckpt_dir, ckpt_id)}"
|
||||
flow_init = "/data/zhifu.gzf/init_model/cosyvoice_flow_matching_for_streaming_with_prompt_random_cut_sft_zh_0630_25hz_1/60epoch.pth.prefix"
|
||||
vocoder_init = "/data/zhifu.gzf/init_model/hiftnet_1400k_cvt/model.pth.prefix"
|
||||
init_param = f"{init_param},{init_param_ckpt},{flow_init},{vocoder_init}"
|
||||
|
||||
model_llm = AutoModel(
|
||||
model=ckpt_dir,
|
||||
init_param=f"{os.path.join(ckpt_dir, ckpt_id)}",
|
||||
init_param=init_param,
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
fp16=False,
|
||||
bf16=False,
|
||||
llm_dtype="bf16",
|
||||
)
|
||||
model = model_llm.model
|
||||
frontend = model_llm.kwargs["frontend"]
|
||||
tokenizer = model_llm.kwargs["tokenizer"]
|
||||
# model = model_llm.model
|
||||
# frontend = model_llm.kwargs["frontend"]
|
||||
# tokenizer = model_llm.kwargs["tokenizer"]
|
||||
|
||||
model_asr = AutoModel(
|
||||
model="/data/zhifu.gzf/init_model/SenseVoice",
|
||||
@ -112,31 +117,22 @@ def model_inference(input_wav, text_inputs, state, turn_num, history):
|
||||
|
||||
print(f"contents_i: {contents_i}")
|
||||
|
||||
inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
|
||||
[contents_i], None, "test_demo", tokenizer, frontend, device=device
|
||||
res = model_llm.generate(
|
||||
input=[contents_i],
|
||||
tearchforing=False,
|
||||
cache={},
|
||||
key="test_demo",
|
||||
)
|
||||
model_inputs = {}
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
|
||||
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=200)
|
||||
thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
res = ""
|
||||
for new_text in streamer:
|
||||
print(f"generated new text: {new_text}")
|
||||
res += new_text.replace("<|im_end|>", "")
|
||||
print(f"total generated: {res}")
|
||||
history[-1][1] = res
|
||||
res_text = res[0]["text"]
|
||||
history[-1][1] = gr.Audio((16000, res[0]["wav"].flatten()), autoplay=True)
|
||||
out_his = state.get("out", "")
|
||||
out = f"{out_his}" f"<br><br>" f"Q: {asr_out}" f"<br>" f"A: {res}"
|
||||
out = f"{out_his}" f"<br><br>" f"Q: {asr_out}" f"<br>" f"A: {res_text}"
|
||||
# out = f"{res}"
|
||||
contents_i[-1]["content"] = res
|
||||
contents_i[-1]["content"] = res_text
|
||||
state["contents_i"] = contents_i
|
||||
state["out"] = out
|
||||
# print(f'state_1: {state["contents_i"]}')
|
||||
return state, history
|
||||
return state, history, out
|
||||
|
||||
|
||||
def clear_state(audio_inputs, text_inputs, state, turn_num, chatbot):
|
||||
@ -221,20 +217,19 @@ def launch():
|
||||
# fn_button = gr.Button("Start")
|
||||
clear_button = gr.Button("Clear")
|
||||
|
||||
# text_outputs = gr.HTML(label="Results")
|
||||
text_outputs = gr.HTML(label="Results")
|
||||
|
||||
# fn_button.click(model_inference, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot])
|
||||
# with gr.Accordion("More examples"):
|
||||
# gr.HTML(centered_table_html)
|
||||
|
||||
audio_inputs.stop_recording(
|
||||
model_inference,
|
||||
inputs=[audio_inputs, text_inputs, state, turn_num, chatbot],
|
||||
outputs=[state, chatbot],
|
||||
outputs=[state, chatbot, text_outputs],
|
||||
)
|
||||
audio_inputs.upload(
|
||||
model_inference,
|
||||
inputs=[audio_inputs, text_inputs, state, turn_num, chatbot],
|
||||
outputs=[state, chatbot],
|
||||
outputs=[state, chatbot, text_outputs],
|
||||
)
|
||||
|
||||
# clear.click(clear_state, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot], queue=False)
|
||||
@ -248,7 +243,7 @@ def launch():
|
||||
demo.launch(
|
||||
share=False,
|
||||
server_name="0.0.0.0",
|
||||
server_port=12340,
|
||||
server_port=12343,
|
||||
ssl_certfile="./cert.pem",
|
||||
ssl_keyfile="./key.pem",
|
||||
inbrowser=True,
|
||||
|
||||
@ -1594,6 +1594,7 @@ class LLMASR5(nn.Module):
|
||||
return None
|
||||
if name == "MaskedDiffWithXvec":
|
||||
from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec
|
||||
|
||||
return MaskedDiffWithXvec(**conf)
|
||||
return None
|
||||
|
||||
@ -1602,6 +1603,7 @@ class LLMASR5(nn.Module):
|
||||
return None
|
||||
if name == "HifiGAN":
|
||||
from funasr.models.llm_asr.hifigan import HifiGan
|
||||
|
||||
return HifiGan(**conf)
|
||||
return None
|
||||
|
||||
@ -1974,6 +1976,8 @@ class LLMASR5(nn.Module):
|
||||
if len(input_ids) > kwargs.get("max_token_length", 1500):
|
||||
break
|
||||
|
||||
if isinstance(user_prompt, (list, tuple)):
|
||||
user_prompt, audio = user_prompt
|
||||
if i == 0:
|
||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
else:
|
||||
@ -1998,8 +2002,9 @@ class LLMASR5(nn.Module):
|
||||
)
|
||||
if sub_str.startswith("!"):
|
||||
sub_str = sub_str[1:]
|
||||
if sub_str.startswith("!"): # !!bytes
|
||||
sub_str = eval(sub_str[1:])
|
||||
sub_str = sub_str[1:]
|
||||
if sub_str.startswith("!"): # !!: audio sample point
|
||||
sub_str = audio
|
||||
try:
|
||||
time1 = time.perf_counter()
|
||||
data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
|
||||
@ -2275,6 +2280,7 @@ class LLMASR5(nn.Module):
|
||||
"text_tn": response_clean,
|
||||
"label": label,
|
||||
"speech_tokens": speech_tokens,
|
||||
"wav": wav,
|
||||
}
|
||||
if loss is not None:
|
||||
result_i["loss"] = loss
|
||||
@ -2305,8 +2311,11 @@ class LLMASR5(nn.Module):
|
||||
if wav is not None:
|
||||
path = os.path.join(out_dir, f"{key}.wav")
|
||||
torchaudio.save(
|
||||
path, wav.cpu(), sample_rate=self.vocoder.sample_rate,
|
||||
encoding='PCM_S', bits_per_sample=16
|
||||
path,
|
||||
wav.cpu(),
|
||||
sample_rate=self.vocoder.sample_rate,
|
||||
encoding="PCM_S",
|
||||
bits_per_sample=16,
|
||||
)
|
||||
|
||||
def synthesize_waveform(self, speech_tokens, spk_emb, device):
|
||||
@ -2324,14 +2333,13 @@ class LLMASR5(nn.Module):
|
||||
xvec_lens = torch.tensor([xvec.shape[1]], device=device, dtype=torch.int64)
|
||||
token_lens = torch.tensor([tokens.shape[1]], device=device, dtype=torch.int64)
|
||||
feat = self.mel_decoder.inference(
|
||||
tokens, token_lens,
|
||||
xvec, xvec_lens,
|
||||
tokens,
|
||||
token_lens,
|
||||
xvec,
|
||||
xvec_lens,
|
||||
diff_steps=10,
|
||||
temperature=1.0,
|
||||
prompt=dict(
|
||||
prompt_text=(None, None),
|
||||
prompt_audio=(None, None)
|
||||
)
|
||||
prompt=dict(prompt_text=(None, None), prompt_audio=(None, None)),
|
||||
)
|
||||
return feat
|
||||
|
||||
@ -2355,17 +2363,18 @@ class LLMASR5(nn.Module):
|
||||
)
|
||||
prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1)
|
||||
seq_input = torch.zeros(
|
||||
[1, prompt.shape[1] + max_length, prompt.shape[2]],
|
||||
dtype=torch.float32, device=device
|
||||
[1, prompt.shape[1] + max_length, prompt.shape[2]], dtype=torch.float32, device=device
|
||||
)
|
||||
seq_input[:, :prompt.shape[1], :] = prompt
|
||||
seq_input[:, : prompt.shape[1], :] = prompt
|
||||
out_tokens = torch.zeros([1, max_length, 1], dtype=torch.int64, device=device)
|
||||
out_token_len = 0
|
||||
prompt_len = prompt.shape[1]
|
||||
state, hit_eos = None, False
|
||||
for i in range(max_length):
|
||||
# use state for speedup
|
||||
pred, (state, _) = self.audio_decoder.score(seq_input[0, :prompt_len+out_token_len], state, prompt[0])
|
||||
pred, (state, _) = self.audio_decoder.score(
|
||||
seq_input[0, : prompt_len + out_token_len], state, prompt[0]
|
||||
)
|
||||
|
||||
# sampling all `nq` token ids
|
||||
pred = pred.reshape(self.predict_nq, -1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user