mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
bf42d9e8ac
commit
a4bb21b888
@ -50,7 +50,7 @@ model_asr = AutoModel(
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
device=device,
|
device=device,
|
||||||
fp16=False,
|
fp16=False,
|
||||||
bf16=False,
|
bf16=True,
|
||||||
llm_dtype="bf16",
|
llm_dtype="bf16",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,7 +137,7 @@ def load_model(init_param, his_state):
|
|||||||
return his_state, f"Model has been loaded! time: {end-beg:.2f}"
|
return his_state, f"Model has been loaded! time: {end-beg:.2f}"
|
||||||
|
|
||||||
|
|
||||||
def model_inference(his_state, input_wav, text_inputs, state, turn_num, history, text_usr):
|
def model_inference(his_state, input_wav, text_inputs, state, turn_num, history, text_usr, do_asr):
|
||||||
if his_state is None:
|
if his_state is None:
|
||||||
his_state = model_dict
|
his_state = model_dict
|
||||||
model = his_state["model"]
|
model = his_state["model"]
|
||||||
@ -167,7 +167,9 @@ def model_inference(his_state, input_wav, text_inputs, state, turn_num, history,
|
|||||||
input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
|
input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
|
||||||
input_wav = resampler(input_wav_t[None, :])[0, :].numpy().astype("float32")
|
input_wav = resampler(input_wav_t[None, :])[0, :].numpy().astype("float32")
|
||||||
beg_asr = time.time()
|
beg_asr = time.time()
|
||||||
asr_out = model_asr.generate(input_wav)[0]["text"]
|
asr_out = "User audio input"
|
||||||
|
if do_asr:
|
||||||
|
asr_out = model_asr.generate(input_wav)[0]["text"]
|
||||||
end_asr = time.time()
|
end_asr = time.time()
|
||||||
|
|
||||||
print(f"asr_out: {asr_out}, time: {end_asr-beg_asr:.2f}")
|
print(f"asr_out: {asr_out}, time: {end_asr-beg_asr:.2f}")
|
||||||
@ -228,19 +230,20 @@ def model_inference(his_state, input_wav, text_inputs, state, turn_num, history,
|
|||||||
|
|
||||||
|
|
||||||
def clear_state(his_state):
|
def clear_state(his_state):
|
||||||
model = his_state["model"]
|
if his_state is not None:
|
||||||
frontend = his_state["frontend"]
|
model = his_state["model"]
|
||||||
tokenizer = his_state["tokenizer"]
|
frontend = his_state["frontend"]
|
||||||
del model
|
tokenizer = his_state["tokenizer"]
|
||||||
del frontend
|
del model
|
||||||
del tokenizer
|
del frontend
|
||||||
del his_state["model"]
|
del tokenizer
|
||||||
del his_state["frontend"]
|
del his_state["model"]
|
||||||
del his_state["tokenizer"]
|
del his_state["frontend"]
|
||||||
import gc
|
del his_state["tokenizer"]
|
||||||
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return None, None, None, None, None
|
return None, None, None, None, None
|
||||||
|
|
||||||
@ -305,12 +308,6 @@ def launch():
|
|||||||
chatbot = gr.Chatbot()
|
chatbot = gr.Chatbot()
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# text_ckpt = gr.Text(
|
|
||||||
# label="Set the model path",
|
|
||||||
# value="",
|
|
||||||
# )
|
|
||||||
# fn_load_model = gr.Button("Load model")
|
|
||||||
# text_outputs = gr.HTML(label="Results")
|
|
||||||
|
|
||||||
audio_inputs = gr.Audio(label="Upload audio or use the microphone")
|
audio_inputs = gr.Audio(label="Upload audio or use the microphone")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -318,12 +315,14 @@ def launch():
|
|||||||
label="System Prompt",
|
label="System Prompt",
|
||||||
value="你是小夏,一位典型的温婉江南姑娘。你出生于杭州,声音清甜并有亲近感,会用简洁语言表达你的想法。你是用户的好朋友。你的回答将通过逼真的文字转语音技术读出。",
|
value="你是小夏,一位典型的温婉江南姑娘。你出生于杭州,声音清甜并有亲近感,会用简洁语言表达你的想法。你是用户的好朋友。你的回答将通过逼真的文字转语音技术读出。",
|
||||||
)
|
)
|
||||||
|
text_inputs_usr = gr.Text(
|
||||||
|
label="User Prompt",
|
||||||
|
)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
turn_num = gr.Number(label="Max dialog turns", value=5, maximum=5)
|
turn_num = gr.Number(label="Max dialog turns", value=5, maximum=5)
|
||||||
text_inputs_usr = gr.Text(
|
do_asr = gr.Dropdown(
|
||||||
label="User Prompt",
|
choices=[False, True], value=False, label="Wether do asr"
|
||||||
)
|
)
|
||||||
do_asr = gr.Dropdown(choices=[True, False], value=False, label="Wether do asr")
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model_ckpt_list = gr.Dropdown(
|
model_ckpt_list = gr.Dropdown(
|
||||||
choices=all_file_paths, value=all_file_paths[0], label="Model ckpt path"
|
choices=all_file_paths, value=all_file_paths[0], label="Model ckpt path"
|
||||||
@ -342,12 +341,30 @@ def launch():
|
|||||||
# gr.HTML(centered_table_html)
|
# gr.HTML(centered_table_html)
|
||||||
audio_inputs.stop_recording(
|
audio_inputs.stop_recording(
|
||||||
model_inference,
|
model_inference,
|
||||||
inputs=[state_m, audio_inputs, text_inputs, state, turn_num, chatbot, text_inputs_usr],
|
inputs=[
|
||||||
|
state_m,
|
||||||
|
audio_inputs,
|
||||||
|
text_inputs,
|
||||||
|
state,
|
||||||
|
turn_num,
|
||||||
|
chatbot,
|
||||||
|
text_inputs_usr,
|
||||||
|
do_asr,
|
||||||
|
],
|
||||||
outputs=[state, chatbot],
|
outputs=[state, chatbot],
|
||||||
)
|
)
|
||||||
audio_inputs.upload(
|
audio_inputs.upload(
|
||||||
model_inference,
|
model_inference,
|
||||||
inputs=[state_m, audio_inputs, text_inputs, state, turn_num, chatbot, text_inputs_usr],
|
inputs=[
|
||||||
|
state_m,
|
||||||
|
audio_inputs,
|
||||||
|
text_inputs,
|
||||||
|
state,
|
||||||
|
turn_num,
|
||||||
|
chatbot,
|
||||||
|
text_inputs_usr,
|
||||||
|
do_asr,
|
||||||
|
],
|
||||||
outputs=[state, chatbot],
|
outputs=[state, chatbot],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -364,7 +381,7 @@ def launch():
|
|||||||
demo.launch(
|
demo.launch(
|
||||||
share=False,
|
share=False,
|
||||||
server_name="0.0.0.0",
|
server_name="0.0.0.0",
|
||||||
server_port=12345,
|
server_port=12346,
|
||||||
ssl_certfile="./cert.pem",
|
ssl_certfile="./cert.pem",
|
||||||
ssl_keyfile="./key.pem",
|
ssl_keyfile="./key.pem",
|
||||||
inbrowser=True,
|
inbrowser=True,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user