diff --git a/runtime/python/websocket/funasr_wss_client_streaming_llm.py b/runtime/python/websocket/funasr_wss_client_streaming_llm.py index eb4560fce..dbd3ed741 100644 --- a/runtime/python/websocket/funasr_wss_client_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_client_streaming_llm.py @@ -25,8 +25,9 @@ parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk") parser.add_argument("--chunk_interval", type=int, default=10, help="chunk") parser.add_argument("--audio_in", type=str, default=None, help="audio_in") parser.add_argument("--audio_fs", type=int, default=16000, help="audio_fs") -parser.add_argument("--asr_prompt", type=str, default="Copy:", help="asr prompt") -parser.add_argument("--s2tt_prompt", type=str, default="Translate the following sentence into English:", help="s2tt prompt") +parser.add_argument("--asr_prompt", type=str, default="", help="asr prompt") +parser.add_argument("--s2tt_prompt", type=str, default="", help="s2tt prompt") +parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res") parser.add_argument( "--send_without_sleep", @@ -181,7 +182,20 @@ async def record_from_scp(chunk_begin, chunk_size): await asyncio.sleep(2) - await websocket.close() + message = json.dumps( + { + "mode": args.mode, + "chunk_size": args.chunk_size, + "chunk_interval": args.chunk_interval, + "wav_name": "microphone", + "is_speaking": False, + "asr_prompt": args.asr_prompt, + "s2tt_prompt": args.s2tt_prompt, + } + ) + + # voices.put(message) + await websocket.send(message) async def message(id): @@ -189,31 +203,62 @@ async def message(id): text_print = "" prev_asr_text = "" prev_s2tt_text = "" + prev_sentence_asr_text = "" + prev_sentence_s2tt_text = "" try: while True: meg = await websocket.recv() meg = json.loads(meg) asr_text = meg["asr_text"] s2tt_text = meg["s2tt_text"] + if args.return_sentence: + is_sentence_end = meg["is_sentence_end"] - clean_prev_asr_text = prev_asr_text.replace("", "").replace("", "") - clean_prev_s2tt_text = prev_s2tt_text.replace("", "").replace("", "") - clean_asr_text = asr_text.replace("", "").replace("", "") - clean_s2tt_text = s2tt_text.replace("", "").replace("", "") + clean_prev_asr_text = prev_asr_text.replace("", "").replace("", "") + clean_prev_s2tt_text = prev_s2tt_text.replace("", "").replace("", "") + clean_asr_text = asr_text.replace("", "").replace("", "") + clean_s2tt_text = s2tt_text.replace("", "").replace("", "") - if clean_prev_asr_text.startswith(clean_asr_text): - new_asr_unfix_pos = asr_text.find("") - asr_text = clean_prev_asr_text[:new_asr_unfix_pos] + "" + clean_prev_asr_text[new_asr_unfix_pos:] + "" + if clean_prev_asr_text.startswith(clean_asr_text): + new_asr_unfix_pos = asr_text.find("") + asr_text = clean_prev_asr_text[:new_asr_unfix_pos] + "" + clean_prev_asr_text[new_asr_unfix_pos:] + "" - if clean_prev_s2tt_text.startswith(clean_s2tt_text): - new_s2tt_unfix_pos = s2tt_text.find("") - s2tt_text = clean_prev_s2tt_text[:new_s2tt_unfix_pos] + "" + clean_prev_s2tt_text[new_s2tt_unfix_pos:] + "" + if clean_prev_s2tt_text.startswith(clean_s2tt_text): + new_s2tt_unfix_pos = s2tt_text.find("") + s2tt_text = clean_prev_s2tt_text[:new_s2tt_unfix_pos] + "" + clean_prev_s2tt_text[new_s2tt_unfix_pos:] + "" - prev_asr_text = asr_text - prev_s2tt_text = s2tt_text - print_asr_text = Colors.OKGREEN + asr_text[:asr_text.find("")] + Colors.ENDC + Colors.OKCYAN + asr_text[asr_text.find("") + len(""): -len("")] + Colors.ENDC - print_s2tt_text = Colors.OKGREEN + s2tt_text[:s2tt_text.find("")] + Colors.ENDC + Colors.OKCYAN + s2tt_text[s2tt_text.find("") + len(""): -len("")] + Colors.ENDC - text_print = "\n\n" + "ASR: " + print_asr_text + "\n\n" + "S2TT: " + print_s2tt_text + prev_asr_text = asr_text + prev_s2tt_text = s2tt_text + print_asr_text = Colors.OKGREEN + prev_sentence_asr_text + asr_text[:asr_text.find("")] + Colors.ENDC + Colors.OKCYAN + asr_text[asr_text.find("") + len(""): -len("")] + Colors.ENDC + print_s2tt_text = Colors.OKGREEN + prev_sentence_s2tt_text + s2tt_text[:s2tt_text.find("")] + Colors.ENDC + Colors.OKCYAN + s2tt_text[s2tt_text.find("") + len(""): -len("")] + Colors.ENDC + + if is_sentence_end: + prev_asr_text = "" + prev_s2tt_text = "" + clean_asr_text = asr_text.replace("", "").replace("", "") + clean_s2tt_text = s2tt_text.replace("", "").replace("", "") + prev_sentence_asr_text = clean_asr_text + "\n\n" + prev_sentence_s2tt_text = clean_s2tt_text + "\n\n" + + else: + clean_prev_asr_text = prev_asr_text.replace("", "").replace("", "") + clean_prev_s2tt_text = prev_s2tt_text.replace("", "").replace("", "") + clean_asr_text = asr_text.replace("", "").replace("", "") + clean_s2tt_text = s2tt_text.replace("", "").replace("", "") + + if clean_prev_asr_text.startswith(clean_asr_text): + new_asr_unfix_pos = asr_text.find("") + asr_text = clean_prev_asr_text[:new_asr_unfix_pos] + "" + clean_prev_asr_text[new_asr_unfix_pos:] + "" + + if clean_prev_s2tt_text.startswith(clean_s2tt_text): + new_s2tt_unfix_pos = s2tt_text.find("") + s2tt_text = clean_prev_s2tt_text[:new_s2tt_unfix_pos] + "" + clean_prev_s2tt_text[new_s2tt_unfix_pos:] + "" + + prev_asr_text = asr_text + prev_s2tt_text = s2tt_text + print_asr_text = Colors.OKGREEN + asr_text[:asr_text.find("")] + Colors.ENDC + Colors.OKCYAN + asr_text[asr_text.find("") + len(""): -len("")] + Colors.ENDC + print_s2tt_text = Colors.OKGREEN + s2tt_text[:s2tt_text.find("")] + Colors.ENDC + Colors.OKCYAN + s2tt_text[s2tt_text.find("") + len(""): -len("")] + Colors.ENDC + text_print = "\n\n" + "ASR: " + print_asr_text + "\n\n" + "S2TT: " + print_s2tt_text os.system("clear") print("\rpid" + str(id) + ": " + text_print) diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm.py b/runtime/python/websocket/funasr_wss_server_streaming_llm.py index c3260e007..fb1bada7a 100644 --- a/runtime/python/websocket/funasr_wss_server_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_server_streaming_llm.py @@ -28,6 +28,7 @@ parser.add_argument("--vad_model_revision", type=str, default="master", help="") parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu") parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu") parser.add_argument("--ncpu", type=int, default=4, help="cpu cores") +parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res") parser.add_argument( "--certfile", type=str, @@ -56,8 +57,8 @@ model_vad = AutoModel( device=args.device, disable_pbar=True, disable_log=True, - speech_noise_thres=0.2, - max_single_segment_time=60000, + speech_noise_thres=0.4, + max_single_segment_time=30000, max_end_silence_time=800, # chunk_size=60, ) @@ -77,7 +78,7 @@ audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision # audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712" device = "cuda:0" all_file_paths = [ - # "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240902", + "FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming", "FunAudioLLM/qwen2_7b_mmt_v15_20240902", "FunAudioLLM/qwen2_7b_mmt_v14_20240830", "FunAudioLLM/audiolm_v11_20240807", @@ -86,7 +87,7 @@ all_file_paths = [ "FunAudioLLM/Speech2Text_Align_V0628", ] -llm_kwargs = {"num_beams": 1, "do_sample": False} +llm_kwargs = {"num_beams": 1, "do_sample": False, "repetition_penalty": 1.3} UNFIX_LEN = 5 MIN_LEN_PER_PARAGRAPH = 25 MIN_LEN_SEC_AUDIO_FIX = 1.1 @@ -221,21 +222,11 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state= remain_s2tt_text = True - asr_iter_cnt = 0 - s2tt_iter_cnt = 0 - is_asr_repetition = False - is_s2tt_repetition = False - for new_asr_text in asr_streamer: current_time = datetime.now() print("DEBUG: " + str(current_time) + " " + f"generated new asr text: {new_asr_text}") if len(new_asr_text) > 0: onscreen_asr_res += new_asr_text.replace("<|im_end|>", "") - if len(new_asr_text.replace("<|im_end|>", "")) > 0: - asr_iter_cnt += 1 - if asr_iter_cnt > MAX_ITER_PER_CHUNK: - is_asr_repetition = True - break if remain_s2tt_text: try: @@ -247,9 +238,10 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state= + " " + f"generated new s2tt text: {new_s2tt_text}" ) - s2tt_iter_cnt += 1 + if len(new_s2tt_text) > 0: onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") + except StopIteration: new_s2tt_text = "" remain_s2tt_text = False @@ -296,11 +288,6 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state= ) if len(new_s2tt_text) > 0: onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") - if len(new_s2tt_text.replace("<|im_end|>", "")) > 0: - s2tt_iter_cnt += 1 - if s2tt_iter_cnt > MAX_ITER_PER_CHUNK: - is_s2tt_repetition = True - break if len(new_s2tt_text) > 0: all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res @@ -354,30 +341,22 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state= asr_text_len = len(tokenizer.encode(onscreen_asr_res)) s2tt_text_len = len(tokenizer.encode(onscreen_s2tt_res)) - if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_asr_repetition: + if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX: pre_previous_asr_text = previous_asr_text previous_asr_text = tokenizer.decode( tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN] ).replace("�", "") if len(previous_asr_text) <= len(pre_previous_asr_text): previous_asr_text = pre_previous_asr_text - elif is_asr_repetition: - pass else: previous_asr_text = "" - if ( - s2tt_text_len > UNFIX_LEN - and audio_seconds > MIN_LEN_SEC_AUDIO_FIX - and not is_s2tt_repetition - ): + if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX: pre_previous_s2tt_text = previous_s2tt_text previous_s2tt_text = tokenizer.decode( tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN] ).replace("�", "") if len(previous_s2tt_text) <= len(pre_previous_s2tt_text): previous_s2tt_text = pre_previous_s2tt_text - elif is_s2tt_repetition: - pass else: previous_s2tt_text = "" @@ -457,8 +436,8 @@ async def ws_serve(websocket, path): if "is_speaking" in messagejson: websocket.is_speaking = messagejson["is_speaking"] websocket.streaming_state["is_final"] = not websocket.is_speaking - if not messagejson["is_speaking"]: - await clear_websocket() + # if not messagejson["is_speaking"]: + # await clear_websocket() if "chunk_interval" in messagejson: websocket.chunk_interval = messagejson["chunk_interval"] if "wav_name" in messagejson: @@ -529,9 +508,45 @@ async def ws_serve(websocket, path): speech_start = False websocket.streaming_state["previous_asr_text"] = "" websocket.streaming_state["previous_s2tt_text"] = "" - websocket.streaming_state["previous_vad_onscreen_asr_text"] = "" - websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = "" - + if not websocket.is_speaking: + await clear_websocket() + if args.return_sentence: + websocket.streaming_state["previous_vad_onscreen_asr_text"] = "" + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = "" + else: + now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "") + now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "") + if ( + len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) + < MIN_LEN_PER_PARAGRAPH + or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) + < MIN_LEN_PER_PARAGRAPH + ): + if ( + now_onscreen_asr_res.endswith(".") + or now_onscreen_asr_res.endswith("?") + or now_onscreen_asr_res.endswith("!") + ): + now_onscreen_asr_res += " " + if ( + now_onscreen_s2tt_res.endswith(".") + or now_onscreen_s2tt_res.endswith("?") + or now_onscreen_s2tt_res.endswith("!") + ): + now_onscreen_s2tt_res += " " + websocket.streaming_state["previous_vad_onscreen_asr_text"] = ( + now_onscreen_asr_res + ) + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ( + now_onscreen_s2tt_res + ) + else: + websocket.streaming_state["previous_vad_onscreen_asr_text"] = ( + now_onscreen_asr_res + "\n\n" + ) + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ( + now_onscreen_s2tt_res + "\n\n" + ) if not websocket.is_speaking: websocket.vad_pre_idx = 0 frames = []