From ca5def5d243ac3872877a079c486047302bb44ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=A8=E5=AE=88?= Date: Sun, 15 Sep 2024 10:14:40 +0800 Subject: [PATCH] update --- .../funasr_wss_server_streaming_llm_2.py | 751 ++++++++++++++++++ 1 file changed, 751 insertions(+) create mode 100644 runtime/python/websocket/funasr_wss_server_streaming_llm_2.py diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm_2.py b/runtime/python/websocket/funasr_wss_server_streaming_llm_2.py new file mode 100644 index 000000000..597386289 --- /dev/null +++ b/runtime/python/websocket/funasr_wss_server_streaming_llm_2.py @@ -0,0 +1,751 @@ +import os +import asyncio +import json +import websockets +import time +from datetime import datetime +import argparse +import ssl +import numpy as np +from threading import Thread +from transformers import TextIteratorStreamer +from funasr import AutoModel +from modelscope.hub.api import HubApi +from modelscope.hub.snapshot_download import snapshot_download +import torch +import traceback +import re +import soundfile as sf + +parser = argparse.ArgumentParser() +parser.add_argument( + "--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0" +) +parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port") +parser.add_argument( + "--vad_model", + type=str, + default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", + help="model from modelscope", +) +parser.add_argument("--vad_model_revision", type=str, default="master", help="") +parser.add_argument("--model_path", type=str, default=None, help="model path (vad/sensevoice/qwen/gummy)") +parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu") +parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu") +parser.add_argument("--ncpu", type=int, default=4, help="cpu cores") +parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res") +parser.add_argument("--no_vad", action="store_true", help="infer without vad") +parser.add_argument( + "--certfile", + type=str, + default="", + required=False, + help="certfile for ssl", +) +parser.add_argument( + "--keyfile", + type=str, + default="ssl_key/server.key", + required=False, + help="keyfile for ssl", +) +args = parser.parse_args() + +websocket_users = set() + +if args.model_path is None: + vad_model_path = args.vad_model +else: + vad_model_path = os.path.join(args.model_path, "vad_model") + +print("model loading") +# vad +model_vad = AutoModel( + model=vad_model_path, + model_revision=args.vad_model_revision, + ngpu=args.ngpu, + ncpu=args.ncpu, + device=args.device, + disable_pbar=True, + disable_log=True, + speech_noise_thres=0.4, + max_single_segment_time=35000, + max_end_silence_time=800, + # chunk_size=60, +) + +if args.model_path is None: + api = HubApi() + key = "ed70b703-9ec7-44b8-b5ce-5f4527719810" + api.login(key) + if "key" in os.environ: + key = os.environ["key"] + api.login(key) + +# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope" + +if args.model_path is None: + llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master") + audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master") +else: + llm_dir = os.path.join(args.model_path, "llm_model") + audio_encoder_dir = os.path.join(args.model_path, "audio_model") + +device = "cuda:0" +all_file_paths = [ + "FunAudioLLM/qwen2_7b_mmt_v15_20240912_streaming", + "FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming", + "FunAudioLLM/qwen2_7b_mmt_v15_20240902", + "FunAudioLLM/qwen2_7b_mmt_v14_20240830", + "FunAudioLLM/audiolm_v11_20240807", + "FunAudioLLM/Speech2Text_Align_V0712", + "FunAudioLLM/Speech2Text_Align_V0718", + "FunAudioLLM/Speech2Text_Align_V0628", +] + +llm_kwargs = {"num_beams": 1, "do_sample": False, "repetition_penalty": 1.3} +UNFIX_LEN = 5 +MIN_LEN_PER_PARAGRAPH = 25 +MIN_LEN_SEC_AUDIO_FIX = 1.1 +DO_ASR_FRAME_INTERVAL = 12 + +ckpt_dir = all_file_paths[0] + +if args.model_path is None: + ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master") +else: + ckpt_dir = os.path.join(args.model_path, "gummy_model") + +model_llm = AutoModel( + model=ckpt_dir, + device=device, + fp16=False, + bf16=False, + llm_dtype="bf16", + max_length=1024, + llm_kwargs=llm_kwargs, + llm_conf={"init_param_path": llm_dir, "load_kwargs": {"attn_implementation": "eager"}}, + tokenizer_conf={"init_param_path": llm_dir}, + audio_encoder=audio_encoder_dir, +) + +model = model_llm.model +frontend = model_llm.kwargs["frontend"] +tokenizer = model_llm.kwargs["tokenizer"] + +model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer} + +print("model loaded! only support one client at the same time now!!!!") + +def load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in "iu": + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype("float32") + if dtype.kind != "f": + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array + +def is_chinese_ending(s): + return re.search(r'[\u4e00-\u9fff]$', s) is not None + +def is_alpha_ending(s): + return re.search(r'[a-zA-Z]$', s) is not None + +async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=None, asr_prompt=None, s2tt_prompt=None): + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " call streaming_transcribe function:") + if his_state is None: + his_state = model_dict + model = his_state["model"] + tokenizer = his_state["tokenizer"] + + if websocket.streaming_state is None: + previous_asr_text = "" + previous_s2tt_text = "" + previous_vad_onscreen_asr_text = "" + previous_vad_onscreen_s2tt_text = "" + # concat_asr_text = [] + # concat_s2tt_text = [] + # concat_audio = [] + # concat_audio_embedding = [] + # concat_audio_embedding_lens = [] + else: + previous_asr_text = websocket.streaming_state.get("previous_asr_text", "") + previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "") + previous_vad_onscreen_asr_text = websocket.streaming_state.get( + "previous_vad_onscreen_asr_text", "" + ) + previous_vad_onscreen_s2tt_text = websocket.streaming_state.get( + "previous_vad_onscreen_s2tt_text", "" + ) + # concat_asr_text = websocket.streaming_state.get("concat_asr_text", []) + # concat_s2tt_text = websocket.streaming_state.get("concat_s2tt_text", []) + # concat_audio = websocket.streaming_state.get("concat_audio", []) + # concat_audio_embedding = websocket.streaming_state.get("concat_audio_embedding", []) + # concat_audio_embedding_lens = websocket.streaming_state.get("concat_audio_embedding_lens", []) + + if asr_prompt is None or asr_prompt == "": + asr_prompt = "Speech transcription:" + if s2tt_prompt is None or s2tt_prompt == "": + s2tt_prompt = "Translate into English:" + + audio_seconds = len(audio_in) // 32 / 1000 + cur_audio = audio_in + print(f"Streaming audio length: {audio_seconds} seconds") + + asr_content = [] + system_prompt = "You are a helpful assistant." + asr_content.append({"role": "system", "content": system_prompt}) + s2tt_content = [] + system_prompt = "You are a helpful assistant." + s2tt_content.append({"role": "system", "content": system_prompt}) + + user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}" + user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_s2tt_text}" + + asr_content.append({"role": "user", "content": user_asr_prompt, "audio": audio_in}) + asr_content.append({"role": "assistant", "content": "target_out"}) + s2tt_content.append({"role": "user", "content": user_s2tt_prompt, "audio": audio_in}) + s2tt_content.append({"role": "assistant", "content": "target_out"}) + + streaming_time_beg = time.time() + inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( + [asr_content], + None, + "test_demo", + tokenizer, + frontend, + device=device, + infer_with_assistant_input=True, + ) + cur_audio_embedding, cur_audio_embedding_lens = meta_data["audio_adaptor_out"], meta_data["audio_adaptor_out_lens"] + # if not args.return_sentence and len(concat_audio_embedding) != 0: + if False: + audio_embedding = torch.cat([concat_audio_embedding[-1], cur_audio_embedding], dim=1) + audio_embedding_lens = concat_audio_embedding_lens[-1] + cur_audio_embedding_lens + + actual_prev_asr_text = concat_asr_text[-1] + previous_asr_text + actual_prev_s2tt_text = concat_s2tt_text[-1] + previous_s2tt_text + actual_audio = concat_audio[-1] + cur_audio + + user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{actual_prev_asr_text}" + user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{actual_prev_s2tt_text}" + + asr_content[1] = {"role": "user", "content": user_asr_prompt, "audio": actual_audio} + s2tt_content[1] = {"role": "user", "content": user_s2tt_prompt, "audio": actual_audio} + + inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( + [asr_content], + None, + "test_demo", + tokenizer, + frontend, + device=device, + infer_with_assistant_input=True, + audio_embedding=audio_embedding, + audio_embedding_lens=audio_embedding_lens, + ) + else: + audio_embedding = cur_audio_embedding + audio_embedding_lens = cur_audio_embedding_lens + + model_asr_inputs = {} + model_asr_inputs["inputs_embeds"] = inputs_asr_embeds + inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( + [s2tt_content], + None, + "test_demo", + tokenizer, + frontend, + device=device, + infer_with_assistant_input=True, + audio_embedding=audio_embedding, + audio_embedding_lens=audio_embedding_lens, + ) + model_s2tt_inputs = {} + model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds + + print("previous_asr_text:", previous_asr_text) + print("previous_s2tt_text:", previous_s2tt_text) + + asr_streamer = TextIteratorStreamer(tokenizer) + asr_generation_kwargs = dict(model_asr_inputs, streamer=asr_streamer, max_new_tokens=1024) + asr_generation_kwargs.update(llm_kwargs) + asr_thread = Thread(target=model.llm.generate, kwargs=asr_generation_kwargs) + asr_thread.start() + s2tt_streamer = TextIteratorStreamer(tokenizer) + s2tt_generation_kwargs = dict(model_s2tt_inputs, streamer=s2tt_streamer, max_new_tokens=1024) + s2tt_generation_kwargs.update(llm_kwargs) + s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs) + s2tt_thread.start() + + onscreen_asr_res = previous_asr_text + onscreen_s2tt_res = previous_s2tt_text + + remain_s2tt_text = True + + for new_asr_text in asr_streamer: + current_time = datetime.now() + print("DEBUG: " + str(current_time) + " " + f"generated new asr text: {new_asr_text}") + if len(new_asr_text) > 0: + onscreen_asr_res += new_asr_text.replace("<|im_end|>", "") + + if remain_s2tt_text: + try: + new_s2tt_text = next(s2tt_streamer) + current_time = datetime.now() + print( + "DEBUG: " + + str(current_time) + + " " + + f"generated new s2tt text: {new_s2tt_text}" + ) + + if len(new_s2tt_text) > 0: + onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") + + except StopIteration: + new_s2tt_text = "" + remain_s2tt_text = False + pass + + if len(new_asr_text) > 0 or len(new_s2tt_text) > 0: + all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res + fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text + unfix_asr_part = all_asr_res[len(fix_asr_part):] + if not is_vad_end: + return_asr_res = fix_asr_part + "" + unfix_asr_part + "" + else: + return_asr_res = fix_asr_part + unfix_asr_part + "" + all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text + unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):] + if not is_vad_end: + return_s2tt_res = fix_s2tt_part + "" + unfix_s2tt_part + "" + else: + return_s2tt_res = fix_s2tt_part + unfix_s2tt_part + "" + message = json.dumps( + { + "mode": "online", + "asr_text": return_asr_res, + "s2tt_text": return_s2tt_res, + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + "is_sentence_end": False, + } + ) + await websocket.send(message) + websocket.streaming_state["onscreen_asr_res"] = ( + previous_vad_onscreen_asr_text + onscreen_asr_res + ) + websocket.streaming_state["onscreen_s2tt_res"] = ( + previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + ) + + if remain_s2tt_text: + for new_s2tt_text in s2tt_streamer: + current_time = datetime.now() + print( + "DEBUG: " + str(current_time) + " " + f"generated new s2tt text: {new_s2tt_text}" + ) + if len(new_s2tt_text) > 0: + onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") + + if len(new_s2tt_text) > 0: + all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res + fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text + unfix_asr_part = all_asr_res[len(fix_asr_part):] + if not is_vad_end: + return_asr_res = fix_asr_part + "" + unfix_asr_part + "" + else: + return_asr_res = fix_asr_part + unfix_asr_part + "" + all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text + unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):] + if not is_vad_end: + return_s2tt_res = fix_s2tt_part + "" + unfix_s2tt_part + "" + else: + return_s2tt_res = fix_s2tt_part + unfix_s2tt_part + "" + message = json.dumps( + { + "mode": "online", + "asr_text": return_asr_res, + "s2tt_text": return_s2tt_res, + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + "is_sentence_end": False, + } + ) + await websocket.send(message) + websocket.streaming_state["onscreen_asr_res"] = ( + previous_vad_onscreen_asr_text + onscreen_asr_res + ) + websocket.streaming_state["onscreen_s2tt_res"] = ( + previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + ) + + if is_vad_end: + # concat_asr_text.append(onscreen_asr_res) + # concat_s2tt_text.append(onscreen_s2tt_res) + # concat_audio.append(cur_audio) + # concat_audio_embedding.append(cur_audio_embedding) + # concat_audio_embedding_lens.append(cur_audio_embedding_lens) + + # websocket.streaming_state["concat_asr_text"] = concat_asr_text + # websocket.streaming_state["concat_s2tt_text"] = concat_s2tt_text + # websocket.streaming_state["concat_audio"] = concat_audio + # websocket.streaming_state["concat_audio_embedding"] = concat_audio_embedding + # websocket.streaming_state["concat_audio_embedding_lens"] = concat_audio_embedding_lens + clean_return_asr_res = return_asr_res.replace("", "").replace("", "") + clean_return_s2tt_res = return_s2tt_res.replace("", "").replace("", "") + if is_alpha_ending(clean_return_asr_res): + return_asr_res = clean_return_asr_res + "." + onscreen_asr_res += "." + elif is_chinese_ending(clean_return_asr_res): + return_asr_res = clean_return_asr_res + "。" + onscreen_asr_res += "。" + + if is_alpha_ending(clean_return_s2tt_res): + return_s2tt_res = clean_return_s2tt_res + "." + onscreen_s2tt_res += "." + elif is_chinese_ending(clean_return_s2tt_res): + return_s2tt_res = clean_return_s2tt_res + "。" + onscreen_s2tt_res += "。" + + message = json.dumps( + { + "mode": "online", + "asr_text": return_asr_res, + "s2tt_text": return_s2tt_res, + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + "is_sentence_end": True, + } + ) + await websocket.send(message) + + streaming_time_end = time.time() + print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}") + + asr_text_len = len(tokenizer.encode(onscreen_asr_res)) + s2tt_text_len = len(tokenizer.encode(onscreen_s2tt_res)) + + if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX: + pre_previous_asr_text = previous_asr_text + previous_asr_text = tokenizer.decode( + tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN] + ).replace("�", "") + if len(previous_asr_text) <= len(pre_previous_asr_text): + previous_asr_text = pre_previous_asr_text + else: + previous_asr_text = "" + if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX: + pre_previous_s2tt_text = previous_s2tt_text + previous_s2tt_text = tokenizer.decode( + tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN] + ).replace("�", "") + if len(previous_s2tt_text) <= len(pre_previous_s2tt_text): + previous_s2tt_text = pre_previous_s2tt_text + else: + previous_s2tt_text = "" + + websocket.streaming_state["previous_asr_text"] = previous_asr_text + websocket.streaming_state["onscreen_asr_res"] = ( + previous_vad_onscreen_asr_text + onscreen_asr_res + ) + websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text + websocket.streaming_state["onscreen_s2tt_res"] = ( + previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + ) + + print("fix asr part:", previous_asr_text) + print("fix s2tt part:", previous_s2tt_text) + + +async def ws_reset(websocket): + print("ws reset now, total num is ", len(websocket_users)) + + websocket.streaming_state = {} + websocket.streaming_state["is_final"] = True + websocket.streaming_state["previous_asr_text"] = "" + websocket.streaming_state["previous_s2tt_text"] = "" + websocket.streaming_state["onscreen_asr_res"] = "" + websocket.streaming_state["onscreen_s2tt_res"] = "" + websocket.streaming_state["previous_vad_onscreen_asr_text"] = "" + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = "" + # websocket.streaming_state["concat_asr_text"] = [] + # websocket.streaming_state["concat_s2tt_text"] = [] + # websocket.streaming_state["concat_audio"] = [] + # websocket.streaming_state["concat_audio_embedding"] = [] + # websocket.streaming_state["concat_audio_embedding_lens"] = [] + websocket.status_dict_vad["cache"] = {} + websocket.status_dict_vad["is_final"] = True + + await websocket.close() + + +async def clear_websocket(): + for websocket in websocket_users: + await ws_reset(websocket) + websocket_users.clear() + + +async def ws_serve(websocket, path): + frames = [] + frames_asr = [] + global websocket_users + # await clear_websocket() + websocket_users.add(websocket) + websocket.streaming_state = { + "previous_asr_text": "", + "previous_s2tt_text": "", + "onscreen_asr_res": "", + "onscreen_s2tt_res": "", + "previous_vad_onscreen_asr_text": "", + "previous_vad_onscreen_s2tt_text": "", + # "concat_asr_text": [], + # "concat_s2tt_text": [], + # "concat_audio": [], + # "concat_audio_embedding": [], + # "concat_audio_embedding_lens": [], + "is_final": False, + } + websocket.status_dict_vad = {"cache": {}, "is_final": False} + + websocket.chunk_interval = 10 + websocket.vad_pre_idx = 0 + speech_start = False + speech_end_i = -1 + websocket.wav_name = "microphone" + print("new user connected", flush=True) + + try: + async for message in websocket: + if isinstance(message, str): + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " received message:", message) + else: + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " received audio bytes:") + + if isinstance(message, str): + messagejson = json.loads(message) + + if "is_speaking" in messagejson: + websocket.is_speaking = messagejson["is_speaking"] + websocket.streaming_state["is_final"] = not websocket.is_speaking + # if not messagejson["is_speaking"]: + # await clear_websocket() + if "chunk_interval" in messagejson: + websocket.chunk_interval = messagejson["chunk_interval"] + if "wav_name" in messagejson: + websocket.wav_name = messagejson.get("wav_name") + if "chunk_size" in messagejson: + chunk_size = messagejson["chunk_size"] + if isinstance(chunk_size, str): + chunk_size = chunk_size.split(",") + chunk_size = [int(x) for x in chunk_size] + if "asr_prompt" in messagejson: + asr_prompt = messagejson["asr_prompt"] + else: + asr_prompt = "Speech transcription:" + if "s2tt_prompt" in messagejson: + s2tt_prompt = messagejson["s2tt_prompt"] + else: + s2tt_prompt = "Translate into English:" + + websocket.status_dict_vad["chunk_size"] = int( + chunk_size[1] * 60 / websocket.chunk_interval + ) + if len(frames_asr) > 0 or not isinstance(message, str): + if not isinstance(message, str): + frames.append(message) + duration_ms = len(message) // 32 + websocket.vad_pre_idx += duration_ms + + # asr online + websocket.streaming_state["is_final"] = speech_end_i != -1 + if ( + len(frames_asr) % DO_ASR_FRAME_INTERVAL == 0 + or websocket.streaming_state["is_final"] + ) and len(frames_asr) != 0: + audio_in = b"".join(frames_asr) + try: + await streaming_transcribe( + websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt + ) + except Exception as e: + print(f"error in streaming, {e}") + print(f"error in streaming, {websocket.streaming_state}") + traceback.print_exc() + if speech_start: + frames_asr.append(message) + + # vad online + if not args.no_vad: + try: + speech_start_i, speech_end_i = await async_vad(websocket, message) + except Exception as e: + print(f"error in vad, {e}") + traceback.print_exc() + if speech_start_i != -1: + speech_start = True + speech_end_i = -1 + beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms + frames_pre = frames[-beg_bias:] + frames_asr = [] + frames_asr.extend(frames_pre) + else: + speech_start = True + speech_end_i = -1 + frames_asr = [] + frames_asr.extend(frames) + + # vad end + if speech_end_i != -1 or not websocket.is_speaking: + if speech_end_i != -1: + audio_in = b"".join(frames_asr) + # use sf to save audio_in + now = datetime.datetime.now() + formatted_now = now.strftime("%Y-%m-%d_%H-%M-%S") + sf.write(formatted_now + ".wav", load_bytes(audio_in), 16000) + try: + await streaming_transcribe( + websocket, audio_in, is_vad_end=True, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt + ) + except Exception as e: + print(f"error in streaming, {e}") + print(f"error in streaming, {websocket.streaming_state}") + traceback.print_exc() + frames_asr = [] + speech_start = False + websocket.streaming_state["previous_asr_text"] = "" + websocket.streaming_state["previous_s2tt_text"] = "" + if not websocket.is_speaking: + if is_alpha_ending(websocket.streaming_state["onscreen_asr_res"]): + websocket.streaming_state["onscreen_asr_res"] += "." + elif is_chinese_ending(websocket.streaming_state["onscreen_asr_res"]): + websocket.streaming_state["onscreen_asr_res"] += "。" + + if is_alpha_ending(websocket.streaming_state["onscreen_s2tt_res"]): + websocket.streaming_state["onscreen_s2tt_res"] += "." + elif is_chinese_ending(websocket.streaming_state["onscreen_s2tt_res"]): + websocket.streaming_state["onscreen_s2tt_res"] += "。" + message = json.dumps( + { + "mode": "online", + "asr_text": websocket.streaming_state["onscreen_asr_res"] + "", + "s2tt_text": websocket.streaming_state["onscreen_s2tt_res"] + "", + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + "is_sentence_end": True, + } + ) + await websocket.send(message) + await clear_websocket() + if args.return_sentence: + websocket.streaming_state["previous_vad_onscreen_asr_text"] = "" + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = "" + else: + now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "") + now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "") + if ( + len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) + < MIN_LEN_PER_PARAGRAPH + or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) + < MIN_LEN_PER_PARAGRAPH + ): + if ( + now_onscreen_asr_res.endswith(".") + or now_onscreen_asr_res.endswith("?") + or now_onscreen_asr_res.endswith("!") + ): + now_onscreen_asr_res += " " + if ( + now_onscreen_s2tt_res.endswith(".") + or now_onscreen_s2tt_res.endswith("?") + or now_onscreen_s2tt_res.endswith("!") + ): + now_onscreen_s2tt_res += " " + websocket.streaming_state["previous_vad_onscreen_asr_text"] = ( + now_onscreen_asr_res + ) + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ( + now_onscreen_s2tt_res + ) + else: + websocket.streaming_state["previous_vad_onscreen_asr_text"] = ( + now_onscreen_asr_res + "\n\n" + ) + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ( + now_onscreen_s2tt_res + "\n\n" + ) + if not websocket.is_speaking: + websocket.vad_pre_idx = 0 + frames = [] + websocket.status_dict_vad["cache"] = {} + websocket.streaming_state["previous_asr_text"] = "" + websocket.streaming_state["previous_s2tt_text"] = "" + websocket.streaming_state["onscreen_asr_res"] = "" + websocket.streaming_state["onscreen_s2tt_res"] = "" + # websocket.streaming_state["concat_asr_text"] = [] + # websocket.streaming_state["concat_s2tt_text"] = [] + # websocket.streaming_state["concat_audio"] = [] + # websocket.streaming_state["concat_audio_embedding"] = [] + # websocket.streaming_state["concat_audio_embedding_lens"] = [] + else: + frames = frames[-20:] + else: + print(f"message: {message}") + except websockets.ConnectionClosed: + print("ConnectionClosed...", websocket_users, flush=True) + await ws_reset(websocket) + websocket_users.remove(websocket) + except websockets.InvalidState: + print("InvalidState...") + except Exception as e: + print("Exception:", e) + + +async def async_vad(websocket, audio_in): + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " call vad function:") + segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"] + # print(segments_result) + + speech_start = -1 + speech_end = -1 + + if len(segments_result) == 0 or len(segments_result) > 1: + return speech_start, speech_end + if segments_result[0][0] != -1: + speech_start = segments_result[0][0] + if segments_result[0][1] != -1: + speech_end = segments_result[0][1] + return speech_start, speech_end + + +if False: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + + # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions + ssl_cert = args.certfile + ssl_key = args.keyfile + + ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key) + start_server = websockets.serve( + ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None, ssl=ssl_context + ) +else: + start_server = websockets.serve( + ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None + ) +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever()