This commit is contained in:
游雁 2024-07-11 10:37:21 +08:00
parent a1af287928
commit 1e9e3f864e
2 changed files with 50 additions and 46 deletions

View File

@ -32,29 +32,34 @@ if len(sys.argv) > 1:
if len(sys.argv) > 6:
new_sys = True
else:
ckpt_dir = "/data/zhifu.gzf/init_model/gpt4o-exp7-4"
ckpt_id = "model.pt.ep1.140000"
ckpt_dir = "/data/zhifu.gzf/init_model/exp8-2"
ckpt_id = "model.pt.ep2.90000"
jsonl = (
"/nfs/beinian.lzr/workspace/GPT-4o/Data/Speech2Text/TestData/s2tchat.v20240619.test.jsonl"
)
dataset = jsonl.split("/")[-1]
output_dir = os.path.join(ckpt_dir, f"inference-{ckpt_id}", dataset)
device = "cuda:1"
device = "cuda:5"
new_sys = True
init_param = "/data/zhifu.gzf/init_model/gpt4o-exp7-4/model.pt.ep1.390000"
init_param_ckpt = f"{os.path.join(ckpt_dir, ckpt_id)}"
flow_init = "/data/zhifu.gzf/init_model/cosyvoice_flow_matching_for_streaming_with_prompt_random_cut_sft_zh_0630_25hz_1/60epoch.pth.prefix"
vocoder_init = "/data/zhifu.gzf/init_model/hiftnet_1400k_cvt/model.pth.prefix"
init_param = f"{init_param},{init_param_ckpt},{flow_init},{vocoder_init}"
model_llm = AutoModel(
model=ckpt_dir,
init_param=f"{os.path.join(ckpt_dir, ckpt_id)}",
init_param=init_param,
output_dir=output_dir,
device=device,
fp16=False,
bf16=False,
llm_dtype="bf16",
)
model = model_llm.model
frontend = model_llm.kwargs["frontend"]
tokenizer = model_llm.kwargs["tokenizer"]
# model = model_llm.model
# frontend = model_llm.kwargs["frontend"]
# tokenizer = model_llm.kwargs["tokenizer"]
model_asr = AutoModel(
model="/data/zhifu.gzf/init_model/SenseVoice",
@ -112,31 +117,22 @@ def model_inference(input_wav, text_inputs, state, turn_num, history):
print(f"contents_i: {contents_i}")
inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[contents_i], None, "test_demo", tokenizer, frontend, device=device
res = model_llm.generate(
input=[contents_i],
tearchforing=False,
cache={},
key="test_demo",
)
model_inputs = {}
model_inputs["inputs_embeds"] = inputs_embeds
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=200)
thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
thread.start()
res = ""
for new_text in streamer:
print(f"generated new text {new_text}")
res += new_text.replace("<|im_end|>", "")
print(f"total generated: {res}")
history[-1][1] = res
res_text = res[0]["text"]
history[-1][1] = gr.Audio((16000, res[0]["wav"].flatten()), autoplay=True)
out_his = state.get("out", "")
out = f"{out_his}" f"<br><br>" f"Q: {asr_out}" f"<br>" f"A: {res}"
out = f"{out_his}" f"<br><br>" f"Q: {asr_out}" f"<br>" f"A: {res_text}"
# out = f"{res}"
contents_i[-1]["content"] = res
contents_i[-1]["content"] = res_text
state["contents_i"] = contents_i
state["out"] = out
# print(f'state_1: {state["contents_i"]}')
return state, history
return state, history, out
def clear_state(audio_inputs, text_inputs, state, turn_num, chatbot):
@ -221,20 +217,19 @@ def launch():
# fn_button = gr.Button("Start")
clear_button = gr.Button("Clear")
# text_outputs = gr.HTML(label="Results")
text_outputs = gr.HTML(label="Results")
# fn_button.click(model_inference, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot])
# with gr.Accordion("More examples"):
# gr.HTML(centered_table_html)
audio_inputs.stop_recording(
model_inference,
inputs=[audio_inputs, text_inputs, state, turn_num, chatbot],
outputs=[state, chatbot],
outputs=[state, chatbot, text_outputs],
)
audio_inputs.upload(
model_inference,
inputs=[audio_inputs, text_inputs, state, turn_num, chatbot],
outputs=[state, chatbot],
outputs=[state, chatbot, text_outputs],
)
# clear.click(clear_state, inputs=[audio_inputs, text_inputs, state, turn_num, chatbot], outputs=[state, chatbot], queue=False)
@ -248,7 +243,7 @@ def launch():
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=12340,
server_port=12343,
ssl_certfile="./cert.pem",
ssl_keyfile="./key.pem",
inbrowser=True,

View File

@ -1594,6 +1594,7 @@ class LLMASR5(nn.Module):
return None
if name == "MaskedDiffWithXvec":
from funasr.models.llm_asr.flow_matching import MaskedDiffWithXvec
return MaskedDiffWithXvec(**conf)
return None
@ -1602,6 +1603,7 @@ class LLMASR5(nn.Module):
return None
if name == "HifiGAN":
from funasr.models.llm_asr.hifigan import HifiGan
return HifiGan(**conf)
return None
@ -1974,6 +1976,8 @@ class LLMASR5(nn.Module):
if len(input_ids) > kwargs.get("max_token_length", 1500):
break
if isinstance(user_prompt, (list, tuple)):
user_prompt, audio = user_prompt
if i == 0:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
else:
@ -1998,8 +2002,9 @@ class LLMASR5(nn.Module):
)
if sub_str.startswith("!"):
sub_str = sub_str[1:]
if sub_str.startswith("!"): # !!bytes
sub_str = eval(sub_str[1:])
sub_str = sub_str[1:]
if sub_str.startswith("!"): # !!: audio sample point
sub_str = audio
try:
time1 = time.perf_counter()
data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
@ -2275,6 +2280,7 @@ class LLMASR5(nn.Module):
"text_tn": response_clean,
"label": label,
"speech_tokens": speech_tokens,
"wav": wav,
}
if loss is not None:
result_i["loss"] = loss
@ -2305,8 +2311,11 @@ class LLMASR5(nn.Module):
if wav is not None:
path = os.path.join(out_dir, f"{key}.wav")
torchaudio.save(
path, wav.cpu(), sample_rate=self.vocoder.sample_rate,
encoding='PCM_S', bits_per_sample=16
path,
wav.cpu(),
sample_rate=self.vocoder.sample_rate,
encoding="PCM_S",
bits_per_sample=16,
)
def synthesize_waveform(self, speech_tokens, spk_emb, device):
@ -2324,14 +2333,13 @@ class LLMASR5(nn.Module):
xvec_lens = torch.tensor([xvec.shape[1]], device=device, dtype=torch.int64)
token_lens = torch.tensor([tokens.shape[1]], device=device, dtype=torch.int64)
feat = self.mel_decoder.inference(
tokens, token_lens,
xvec, xvec_lens,
tokens,
token_lens,
xvec,
xvec_lens,
diff_steps=10,
temperature=1.0,
prompt=dict(
prompt_text=(None, None),
prompt_audio=(None, None)
)
prompt=dict(prompt_text=(None, None), prompt_audio=(None, None)),
)
return feat
@ -2355,17 +2363,18 @@ class LLMASR5(nn.Module):
)
prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1)
seq_input = torch.zeros(
[1, prompt.shape[1] + max_length, prompt.shape[2]],
dtype=torch.float32, device=device
[1, prompt.shape[1] + max_length, prompt.shape[2]], dtype=torch.float32, device=device
)
seq_input[:, :prompt.shape[1], :] = prompt
seq_input[:, : prompt.shape[1], :] = prompt
out_tokens = torch.zeros([1, max_length, 1], dtype=torch.int64, device=device)
out_token_len = 0
prompt_len = prompt.shape[1]
state, hit_eos = None, False
for i in range(max_length):
# use state for speedup
pred, (state, _) = self.audio_decoder.score(seq_input[0, :prompt_len+out_token_len], state, prompt[0])
pred, (state, _) = self.audio_decoder.score(
seq_input[0, : prompt_len + out_token_len], state, prompt[0]
)
# sampling all `nq` token ids
pred = pred.reshape(self.predict_nq, -1)