This commit is contained in:
游雁 2024-07-17 10:13:42 +08:00
parent bf42d9e8ac
commit a4bb21b888

View File

@ -50,7 +50,7 @@ model_asr = AutoModel(
output_dir=output_dir,
device=device,
fp16=False,
bf16=False,
bf16=True,
llm_dtype="bf16",
)
@ -137,7 +137,7 @@ def load_model(init_param, his_state):
return his_state, f"Model has been loaded! time: {end-beg:.2f}"
def model_inference(his_state, input_wav, text_inputs, state, turn_num, history, text_usr):
def model_inference(his_state, input_wav, text_inputs, state, turn_num, history, text_usr, do_asr):
if his_state is None:
his_state = model_dict
model = his_state["model"]
@ -167,7 +167,9 @@ def model_inference(his_state, input_wav, text_inputs, state, turn_num, history,
input_wav_t = torch.from_numpy(input_wav).to(torch.float32)
input_wav = resampler(input_wav_t[None, :])[0, :].numpy().astype("float32")
beg_asr = time.time()
asr_out = model_asr.generate(input_wav)[0]["text"]
asr_out = "User audio input"
if do_asr:
asr_out = model_asr.generate(input_wav)[0]["text"]
end_asr = time.time()
print(f"asr_out: {asr_out}, time: {end_asr-beg_asr:.2f}")
@ -228,19 +230,20 @@ def model_inference(his_state, input_wav, text_inputs, state, turn_num, history,
def clear_state(his_state):
model = his_state["model"]
frontend = his_state["frontend"]
tokenizer = his_state["tokenizer"]
del model
del frontend
del tokenizer
del his_state["model"]
del his_state["frontend"]
del his_state["tokenizer"]
import gc
if his_state is not None:
model = his_state["model"]
frontend = his_state["frontend"]
tokenizer = his_state["tokenizer"]
del model
del frontend
del tokenizer
del his_state["model"]
del his_state["frontend"]
del his_state["tokenizer"]
import gc
gc.collect()
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
return None, None, None, None, None
@ -305,12 +308,6 @@ def launch():
chatbot = gr.Chatbot()
with gr.Column():
with gr.Row():
# text_ckpt = gr.Text(
# label="Set the model path",
# value="",
# )
# fn_load_model = gr.Button("Load model")
# text_outputs = gr.HTML(label="Results")
audio_inputs = gr.Audio(label="Upload audio or use the microphone")
with gr.Column():
@ -318,12 +315,14 @@ def launch():
label="System Prompt",
value="你是小夏,一位典型的温婉江南姑娘。你出生于杭州,声音清甜并有亲近感,会用简洁语言表达你的想法。你是用户的好朋友。你的回答将通过逼真的文字转语音技术读出。",
)
text_inputs_usr = gr.Text(
label="User Prompt",
)
with gr.Row():
turn_num = gr.Number(label="Max dialog turns", value=5, maximum=5)
text_inputs_usr = gr.Text(
label="User Prompt",
do_asr = gr.Dropdown(
choices=[False, True], value=False, label="Wether do asr"
)
do_asr = gr.Dropdown(choices=[True, False], value=False, label="Wether do asr")
with gr.Row():
model_ckpt_list = gr.Dropdown(
choices=all_file_paths, value=all_file_paths[0], label="Model ckpt path"
@ -342,12 +341,30 @@ def launch():
# gr.HTML(centered_table_html)
audio_inputs.stop_recording(
model_inference,
inputs=[state_m, audio_inputs, text_inputs, state, turn_num, chatbot, text_inputs_usr],
inputs=[
state_m,
audio_inputs,
text_inputs,
state,
turn_num,
chatbot,
text_inputs_usr,
do_asr,
],
outputs=[state, chatbot],
)
audio_inputs.upload(
model_inference,
inputs=[state_m, audio_inputs, text_inputs, state, turn_num, chatbot, text_inputs_usr],
inputs=[
state_m,
audio_inputs,
text_inputs,
state,
turn_num,
chatbot,
text_inputs_usr,
do_asr,
],
outputs=[state, chatbot],
)
@ -364,7 +381,7 @@ def launch():
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=12345,
server_port=12346,
ssl_certfile="./cert.pem",
ssl_keyfile="./key.pem",
inbrowser=True,