mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
bf42d9e8ac
commit
a4bb21b888
@ -50,7 +50,7 @@ model_asr = AutoModel(
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
fp16=False,
|
||||
bf16=False,
|
||||
bf16=True,
|
||||
llm_dtype="bf16",
|
||||
)
|
||||
|
||||
@ -137,7 +137,7 @@ def load_model(init_param, his_state):
|
||||
return his_state, f"Model has been loaded! time: {end-beg:.2f}"
|
||||
|
||||
|
||||
def model_inference(his_state, input_wav, text_inputs, state, turn_num, history, text_usr):
|
||||
def model_inference(his_state, input_wav, text_inputs, state, turn_num, history, text_usr, do_asr):
|
||||
if his_state is None:
|
||||
his_state = model_dict
|
||||
model = his_state["model"]
|
||||
@ -167,7 +167,9 @@ def model_inference(his_state, input_wav, text_inputs, state, turn_num, history,
|
||||
input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
|
||||
input_wav = resampler(input_wav_t[None, :])[0, :].numpy().astype("float32")
|
||||
beg_asr = time.time()
|
||||
asr_out = model_asr.generate(input_wav)[0]["text"]
|
||||
asr_out = "User audio input"
|
||||
if do_asr:
|
||||
asr_out = model_asr.generate(input_wav)[0]["text"]
|
||||
end_asr = time.time()
|
||||
|
||||
print(f"asr_out: {asr_out}, time: {end_asr-beg_asr:.2f}")
|
||||
@ -228,19 +230,20 @@ def model_inference(his_state, input_wav, text_inputs, state, turn_num, history,
|
||||
|
||||
|
||||
def clear_state(his_state):
|
||||
model = his_state["model"]
|
||||
frontend = his_state["frontend"]
|
||||
tokenizer = his_state["tokenizer"]
|
||||
del model
|
||||
del frontend
|
||||
del tokenizer
|
||||
del his_state["model"]
|
||||
del his_state["frontend"]
|
||||
del his_state["tokenizer"]
|
||||
import gc
|
||||
if his_state is not None:
|
||||
model = his_state["model"]
|
||||
frontend = his_state["frontend"]
|
||||
tokenizer = his_state["tokenizer"]
|
||||
del model
|
||||
del frontend
|
||||
del tokenizer
|
||||
del his_state["model"]
|
||||
del his_state["frontend"]
|
||||
del his_state["tokenizer"]
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return None, None, None, None, None
|
||||
|
||||
@ -305,12 +308,6 @@ def launch():
|
||||
chatbot = gr.Chatbot()
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
# text_ckpt = gr.Text(
|
||||
# label="Set the model path",
|
||||
# value="",
|
||||
# )
|
||||
# fn_load_model = gr.Button("Load model")
|
||||
# text_outputs = gr.HTML(label="Results")
|
||||
|
||||
audio_inputs = gr.Audio(label="Upload audio or use the microphone")
|
||||
with gr.Column():
|
||||
@ -318,12 +315,14 @@ def launch():
|
||||
label="System Prompt",
|
||||
value="你是小夏,一位典型的温婉江南姑娘。你出生于杭州,声音清甜并有亲近感,会用简洁语言表达你的想法。你是用户的好朋友。你的回答将通过逼真的文字转语音技术读出。",
|
||||
)
|
||||
text_inputs_usr = gr.Text(
|
||||
label="User Prompt",
|
||||
)
|
||||
with gr.Row():
|
||||
turn_num = gr.Number(label="Max dialog turns", value=5, maximum=5)
|
||||
text_inputs_usr = gr.Text(
|
||||
label="User Prompt",
|
||||
do_asr = gr.Dropdown(
|
||||
choices=[False, True], value=False, label="Wether do asr"
|
||||
)
|
||||
do_asr = gr.Dropdown(choices=[True, False], value=False, label="Wether do asr")
|
||||
with gr.Row():
|
||||
model_ckpt_list = gr.Dropdown(
|
||||
choices=all_file_paths, value=all_file_paths[0], label="Model ckpt path"
|
||||
@ -342,12 +341,30 @@ def launch():
|
||||
# gr.HTML(centered_table_html)
|
||||
audio_inputs.stop_recording(
|
||||
model_inference,
|
||||
inputs=[state_m, audio_inputs, text_inputs, state, turn_num, chatbot, text_inputs_usr],
|
||||
inputs=[
|
||||
state_m,
|
||||
audio_inputs,
|
||||
text_inputs,
|
||||
state,
|
||||
turn_num,
|
||||
chatbot,
|
||||
text_inputs_usr,
|
||||
do_asr,
|
||||
],
|
||||
outputs=[state, chatbot],
|
||||
)
|
||||
audio_inputs.upload(
|
||||
model_inference,
|
||||
inputs=[state_m, audio_inputs, text_inputs, state, turn_num, chatbot, text_inputs_usr],
|
||||
inputs=[
|
||||
state_m,
|
||||
audio_inputs,
|
||||
text_inputs,
|
||||
state,
|
||||
turn_num,
|
||||
chatbot,
|
||||
text_inputs_usr,
|
||||
do_asr,
|
||||
],
|
||||
outputs=[state, chatbot],
|
||||
)
|
||||
|
||||
@ -364,7 +381,7 @@ def launch():
|
||||
demo.launch(
|
||||
share=False,
|
||||
server_name="0.0.0.0",
|
||||
server_port=12345,
|
||||
server_port=12346,
|
||||
ssl_certfile="./cert.pem",
|
||||
ssl_keyfile="./key.pem",
|
||||
inbrowser=True,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user