Merge branch 'dev_gzf_deepspeed' of gitlab.alibaba-inc.com:zhifu.gzf/FunASR into dev_gzf_deepspeed

merge
This commit is contained in:
游雁 2024-09-11 16:16:24 +08:00
commit bdd66d1865
2 changed files with 100 additions and 52 deletions

View File

@ -25,8 +25,9 @@ parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
parser.add_argument("--chunk_interval", type=int, default=10, help="chunk")
parser.add_argument("--audio_in", type=str, default=None, help="audio_in")
parser.add_argument("--audio_fs", type=int, default=16000, help="audio_fs")
parser.add_argument("--asr_prompt", type=str, default="Copy:", help="asr prompt")
parser.add_argument("--s2tt_prompt", type=str, default="Translate the following sentence into English:", help="s2tt prompt")
parser.add_argument("--asr_prompt", type=str, default="", help="asr prompt")
parser.add_argument("--s2tt_prompt", type=str, default="", help="s2tt prompt")
parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res")
parser.add_argument(
"--send_without_sleep",
@ -189,31 +190,63 @@ async def message(id):
text_print = ""
prev_asr_text = ""
prev_s2tt_text = ""
prev_sentence_asr_text = ""
prev_sentence_s2tt_text = ""
try:
while True:
meg = await websocket.recv()
meg = json.loads(meg)
asr_text = meg["asr_text"]
s2tt_text = meg["s2tt_text"]
if args.return_sentence:
is_sentence_end = meg["is_sentence_end"]
clean_prev_asr_text = prev_asr_text.replace("<em>", "").replace("</em>", "")
clean_prev_s2tt_text = prev_s2tt_text.replace("<em>", "").replace("</em>", "")
clean_asr_text = asr_text.replace("<em>", "").replace("</em>", "")
clean_s2tt_text = s2tt_text.replace("<em>", "").replace("</em>", "")
clean_prev_asr_text = prev_asr_text.replace("<em>", "").replace("</em>", "")
clean_prev_s2tt_text = prev_s2tt_text.replace("<em>", "").replace("</em>", "")
clean_asr_text = asr_text.replace("<em>", "").replace("</em>", "")
clean_s2tt_text = s2tt_text.replace("<em>", "").replace("</em>", "")
if clean_prev_asr_text.startswith(clean_asr_text):
new_asr_unfix_pos = asr_text.find("<em>")
asr_text = clean_prev_asr_text[:new_asr_unfix_pos] + "<em>" + clean_prev_asr_text[new_asr_unfix_pos:] + "</em>"
if clean_prev_asr_text.startswith(clean_asr_text):
new_asr_unfix_pos = asr_text.find("<em>")
asr_text = clean_prev_asr_text[:new_asr_unfix_pos] + "<em>" + clean_prev_asr_text[new_asr_unfix_pos:] + "</em>"
if clean_prev_s2tt_text.startswith(clean_s2tt_text):
new_s2tt_unfix_pos = s2tt_text.find("<em>")
s2tt_text = clean_prev_s2tt_text[:new_s2tt_unfix_pos] + "<em>" + clean_prev_s2tt_text[new_s2tt_unfix_pos:] + "</em>"
if clean_prev_s2tt_text.startswith(clean_s2tt_text):
new_s2tt_unfix_pos = s2tt_text.find("<em>")
s2tt_text = clean_prev_s2tt_text[:new_s2tt_unfix_pos] + "<em>" + clean_prev_s2tt_text[new_s2tt_unfix_pos:] + "</em>"
prev_asr_text = asr_text
prev_s2tt_text = s2tt_text
print_asr_text = Colors.OKGREEN + asr_text[:asr_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + asr_text[asr_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
print_s2tt_text = Colors.OKGREEN + s2tt_text[:s2tt_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + s2tt_text[s2tt_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
text_print = "\n\n" + "ASR: " + print_asr_text + "\n\n" + "S2TT: " + print_s2tt_text
prev_asr_text = asr_text
prev_s2tt_text = s2tt_text
print_asr_text = Colors.OKGREEN + prev_sentence_asr_text + asr_text[:asr_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + asr_text[asr_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
print_s2tt_text = Colors.OKGREEN + prev_sentence_s2tt_text + s2tt_text[:s2tt_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + s2tt_text[s2tt_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
text_print = "\n\n" + "ASR: " + print_asr_text + "\n\n" + "S2TT: " + print_s2tt_text
if is_sentence_end:
prev_asr_text = ""
prev_s2tt_text = ""
clean_asr_text = asr_text.replace("<em>", "").replace("</em>", "")
clean_s2tt_text = s2tt_text.replace("<em>", "").replace("</em>", "")
prev_sentence_asr_text += clean_asr_text + "\n\n"
prev_sentence_s2tt_text += clean_s2tt_text + "\n\n"
else:
clean_prev_asr_text = prev_asr_text.replace("<em>", "").replace("</em>", "")
clean_prev_s2tt_text = prev_s2tt_text.replace("<em>", "").replace("</em>", "")
clean_asr_text = asr_text.replace("<em>", "").replace("</em>", "")
clean_s2tt_text = s2tt_text.replace("<em>", "").replace("</em>", "")
if clean_prev_asr_text.startswith(clean_asr_text):
new_asr_unfix_pos = asr_text.find("<em>")
asr_text = clean_prev_asr_text[:new_asr_unfix_pos] + "<em>" + clean_prev_asr_text[new_asr_unfix_pos:] + "</em>"
if clean_prev_s2tt_text.startswith(clean_s2tt_text):
new_s2tt_unfix_pos = s2tt_text.find("<em>")
s2tt_text = clean_prev_s2tt_text[:new_s2tt_unfix_pos] + "<em>" + clean_prev_s2tt_text[new_s2tt_unfix_pos:] + "</em>"
prev_asr_text = asr_text
prev_s2tt_text = s2tt_text
print_asr_text = Colors.OKGREEN + asr_text[:asr_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + asr_text[asr_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
print_s2tt_text = Colors.OKGREEN + s2tt_text[:s2tt_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + s2tt_text[s2tt_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
text_print = "\n\n" + "ASR: " + print_asr_text + "\n\n" + "S2TT: " + print_s2tt_text
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)

View File

@ -28,6 +28,7 @@ parser.add_argument("--vad_model_revision", type=str, default="master", help="")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res")
parser.add_argument(
"--certfile",
type=str,
@ -56,8 +57,8 @@ model_vad = AutoModel(
device=args.device,
disable_pbar=True,
disable_log=True,
speech_noise_thres=0.2,
max_single_segment_time=60000,
speech_noise_thres=0.4,
max_single_segment_time=30000,
max_end_silence_time=800,
# chunk_size=60,
)
@ -77,7 +78,7 @@ audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
device = "cuda:0"
all_file_paths = [
# "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240902",
"FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240902",
"FunAudioLLM/qwen2_7b_mmt_v14_20240830",
"FunAudioLLM/audiolm_v11_20240807",
@ -86,7 +87,7 @@ all_file_paths = [
"FunAudioLLM/Speech2Text_Align_V0628",
]
llm_kwargs = {"num_beams": 1, "do_sample": False}
llm_kwargs = {"num_beams": 1, "do_sample": False, "repetition_penalty": 1.3}
UNFIX_LEN = 5
MIN_LEN_PER_PARAGRAPH = 25
MIN_LEN_SEC_AUDIO_FIX = 1.1
@ -221,21 +222,11 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
remain_s2tt_text = True
asr_iter_cnt = 0
s2tt_iter_cnt = 0
is_asr_repetition = False
is_s2tt_repetition = False
for new_asr_text in asr_streamer:
current_time = datetime.now()
print("DEBUG: " + str(current_time) + " " + f"generated new asr text {new_asr_text}")
if len(new_asr_text) > 0:
onscreen_asr_res += new_asr_text.replace("<|im_end|>", "")
if len(new_asr_text.replace("<|im_end|>", "")) > 0:
asr_iter_cnt += 1
if asr_iter_cnt > MAX_ITER_PER_CHUNK:
is_asr_repetition = True
break
if remain_s2tt_text:
try:
@ -247,9 +238,10 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
+ " "
+ f"generated new s2tt text {new_s2tt_text}"
)
s2tt_iter_cnt += 1
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
except StopIteration:
new_s2tt_text = ""
remain_s2tt_text = False
@ -296,11 +288,6 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
)
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
if len(new_s2tt_text.replace("<|im_end|>", "")) > 0:
s2tt_iter_cnt += 1
if s2tt_iter_cnt > MAX_ITER_PER_CHUNK:
is_s2tt_repetition = True
break
if len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
@ -354,30 +341,22 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
asr_text_len = len(tokenizer.encode(onscreen_asr_res))
s2tt_text_len = len(tokenizer.encode(onscreen_s2tt_res))
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_asr_repetition:
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX:
pre_previous_asr_text = previous_asr_text
previous_asr_text = tokenizer.decode(
tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]
).replace("<EFBFBD>", "")
if len(previous_asr_text) <= len(pre_previous_asr_text):
previous_asr_text = pre_previous_asr_text
elif is_asr_repetition:
pass
else:
previous_asr_text = ""
if (
s2tt_text_len > UNFIX_LEN
and audio_seconds > MIN_LEN_SEC_AUDIO_FIX
and not is_s2tt_repetition
):
if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX:
pre_previous_s2tt_text = previous_s2tt_text
previous_s2tt_text = tokenizer.decode(
tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]
).replace("<EFBFBD>", "")
if len(previous_s2tt_text) <= len(pre_previous_s2tt_text):
previous_s2tt_text = pre_previous_s2tt_text
elif is_s2tt_repetition:
pass
else:
previous_s2tt_text = ""
@ -457,8 +436,8 @@ async def ws_serve(websocket, path):
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.streaming_state["is_final"] = not websocket.is_speaking
if not messagejson["is_speaking"]:
await clear_websocket()
# if not messagejson["is_speaking"]:
# await clear_websocket()
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
@ -529,9 +508,45 @@ async def ws_serve(websocket, path):
speech_start = False
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ""
if not websocket.is_speaking:
await clear_websocket()
if args.return_sentence:
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ""
else:
now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "")
now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "")
if (
len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
):
if (
now_onscreen_asr_res.endswith(".")
or now_onscreen_asr_res.endswith("?")
or now_onscreen_asr_res.endswith("!")
):
now_onscreen_asr_res += " "
if (
now_onscreen_s2tt_res.endswith(".")
or now_onscreen_s2tt_res.endswith("?")
or now_onscreen_s2tt_res.endswith("!")
):
now_onscreen_s2tt_res += " "
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res
)
else:
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res + "\n\n"
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res + "\n\n"
)
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []