diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 30eef803c..d66afb691 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -930,6 +930,7 @@ class LLMASR4(nn.Module): from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5") + llm_load_kwargs = llm_conf.get("load_kwargs", {}) if not llm_conf.get("low_cpu", False): model = AutoModelForCausalLM.from_pretrained( @@ -937,6 +938,7 @@ class LLMASR4(nn.Module): load_in_8bit=None, device_map=None, use_cache=None, + **llm_load_kwargs, ) else: import os @@ -947,6 +949,7 @@ class LLMASR4(nn.Module): load_in_8bit=None, device_map="cpu", use_cache=None, + **llm_load_kwargs, ) else: llm_config = AutoConfig.from_pretrained(init_param_path) diff --git a/runtime/python/websocket/funasr_wss_client_streaming_llm.py b/runtime/python/websocket/funasr_wss_client_streaming_llm.py index 690ad1894..71429acc5 100644 --- a/runtime/python/websocket/funasr_wss_client_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_client_streaming_llm.py @@ -22,17 +22,12 @@ parser.add_argument( ) parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port") parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk") -parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk") -parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk") -parser.add_argument("--chunk_interval", type=int, default=10, help="chunk") -parser.add_argument( - "--hotword", - type=str, - default="", - help="hotword file path, one hotword perline (e.g.:阿里巴巴 20)", -) +parser.add_argument("--chunk_interval", type=int, default=60, help="chunk") parser.add_argument("--audio_in", type=str, default=None, help="audio_in") parser.add_argument("--audio_fs", type=int, default=16000, help="audio_fs") +parser.add_argument("--asr_prompt", type=str, default="Copy:", help="asr prompt") +parser.add_argument("--s2tt_prompt", type=str, default="Translate the following sentence into English:", help="s2tt prompt") + parser.add_argument( "--send_without_sleep", action="store_true", @@ -41,10 +36,9 @@ parser.add_argument( ) parser.add_argument("--thread_num", type=int, default=1, help="thread_num") parser.add_argument("--words_max_print", type=int, default=10000, help="chunk") -parser.add_argument("--output_dir", type=str, default=None, help="output_dir") parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl") -parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn") -parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass") +parser.add_argument("--mode", type=str, default="online", help="offline, online, 2pass") + args = parser.parse_args() args.chunk_size = [int(x) for x in args.chunk_size.split(",")] @@ -53,14 +47,6 @@ print(args) from queue import Queue voices = Queue() -offline_msg_done = False - -if args.output_dir is not None: - # if os.path.exists(args.output_dir): - # os.remove(args.output_dir) - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) async def record_microphone(): @@ -80,41 +66,16 @@ async def record_microphone(): stream = p.open( format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK ) - # hotwords - fst_dict = {} - hotword_msg = "" - if args.hotword.strip() != "": - if os.path.exists(args.hotword): - f_scp = open(args.hotword) - hot_lines = f_scp.readlines() - for line in hot_lines: - words = line.strip().split(" ") - if len(words) < 2: - print("Please checkout format of hotwords") - continue - try: - fst_dict[" ".join(words[:-1])] = int(words[-1]) - except ValueError: - print("Please checkout format of hotwords") - hotword_msg = json.dumps(fst_dict) - else: - hotword_msg = args.hotword - - use_itn = True - if args.use_itn == 0: - use_itn = False message = json.dumps( { "mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, - "encoder_chunk_look_back": args.encoder_chunk_look_back, - "decoder_chunk_look_back": args.decoder_chunk_look_back, "wav_name": "microphone", "is_speaking": True, - "hotwords": hotword_msg, - "itn": use_itn, + "asr_prompt": args.asr_prompt, + "s2tt_prompt": args.s2tt_prompt, } ) # voices.put(message) @@ -136,32 +97,8 @@ async def record_from_scp(chunk_begin, chunk_size): else: wavs = [args.audio_in] - # hotwords - fst_dict = {} - hotword_msg = "" - if args.hotword.strip() != "": - if os.path.exists(args.hotword): - f_scp = open(args.hotword) - hot_lines = f_scp.readlines() - for line in hot_lines: - words = line.strip().split(" ") - if len(words) < 2: - print("Please checkout format of hotwords") - continue - try: - fst_dict[" ".join(words[:-1])] = int(words[-1]) - except ValueError: - print("Please checkout format of hotwords") - hotword_msg = json.dumps(fst_dict) - else: - hotword_msg = args.hotword - print(hotword_msg) - sample_rate = args.audio_fs wav_format = "pcm" - use_itn = True - if args.use_itn == 0: - use_itn = False if chunk_size > 0: wavs = wavs[chunk_begin : chunk_begin + chunk_size] @@ -198,14 +135,10 @@ async def record_from_scp(chunk_begin, chunk_size): "mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, - "encoder_chunk_look_back": args.encoder_chunk_look_back, - "decoder_chunk_look_back": args.decoder_chunk_look_back, - "audio_fs": sample_rate, - "wav_name": wav_name, - "wav_format": wav_format, + "wav_name": "microphone", "is_speaking": True, - "hotwords": hotword_msg, - "itn": use_itn, + "asr_prompt": args.asr_prompt, + "s2tt_prompt": args.s2tt_prompt, } ) @@ -230,76 +163,24 @@ async def record_from_scp(chunk_begin, chunk_size): await asyncio.sleep(sleep_duration) - if not args.mode == "offline": - await asyncio.sleep(2) - # offline model need to wait for message recved - - if args.mode == "offline": - global offline_msg_done - while not offline_msg_done: - await asyncio.sleep(1) + await asyncio.sleep(2) await websocket.close() async def message(id): - global websocket, voices, offline_msg_done + global websocket, voices text_print = "" - text_print_2pass_online = "" - text_print_2pass_offline = "" - if args.output_dir is not None: - ibest_writer = open( - os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8" - ) - else: - ibest_writer = None try: while True: - meg = await websocket.recv() meg = json.loads(meg) - wav_name = meg.get("wav_name", "demo") - text = meg["text"] - timestamp = "" - offline_msg_done = meg.get("is_final", False) - if "timestamp" in meg: - timestamp = meg["timestamp"] + asr_text = meg["asr_text"] + s2tt_text = meg["s2tt_text"] - if ibest_writer is not None: - if timestamp != "": - text_write_line = "{}\t{}\t{}\n".format(wav_name, text, timestamp) - else: - text_write_line = "{}\t{}\n".format(wav_name, text) - ibest_writer.write(text_write_line) - - if "mode" not in meg: - continue - if meg["mode"] == "online": - text_print = text - os.system("clear") - print("\rpid" + str(id) + ": " + text_print) - elif meg["mode"] == "offline": - if timestamp != "": - text_print += "{} timestamp: {}".format(text, timestamp) - else: - text_print += "{}".format(text) - - # text_print = text_print[-args.words_max_print:] - # os.system('clear') - print("\rpid" + str(id) + ": " + wav_name + ": " + text_print) - offline_msg_done = True - else: - if meg["mode"] == "2pass-online": - text_print_2pass_online += "{}".format(text) - text_print = text_print_2pass_offline + text_print_2pass_online - else: - text_print_2pass_online = "" - text_print = text_print_2pass_offline + "{}".format(text) - text_print_2pass_offline += "{}".format(text) - text_print = text_print[-args.words_max_print :] - os.system("clear") - print("\rpid" + str(id) + ": " + text_print) - # offline_msg_done=True + text_print = "\n\n" + "ASR: " + asr_text + "\n\n" + "S2TT: " + s2tt_text + os.system("clear") + print("\rpid" + str(id) + ": " + text_print) except Exception as e: print("Exception:", e) @@ -311,10 +192,9 @@ async def ws_client(id, chunk_begin, chunk_size): if args.audio_in is None: chunk_begin = 0 chunk_size = 1 - global websocket, voices, offline_msg_done + global websocket, voices for i in range(chunk_begin, chunk_begin + chunk_size): - offline_msg_done = False voices = Queue() if args.ssl == 1: ssl_context = ssl.SSLContext() diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm.py b/runtime/python/websocket/funasr_wss_server_streaming_llm.py index 645262513..6e01a2232 100644 --- a/runtime/python/websocket/funasr_wss_server_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_server_streaming_llm.py @@ -1,38 +1,22 @@ +import os import asyncio import json import websockets import time -import logging -import tracemalloc -import numpy as np import argparse import ssl -import os -import torch -import torchaudio -from transformers import TextIteratorStreamer +import numpy as np from threading import Thread -import traceback +from transformers import TextIteratorStreamer +from funasr import AutoModel +from modelscope.hub.api import HubApi +from modelscope.hub.snapshot_download import snapshot_download parser = argparse.ArgumentParser() parser.add_argument( "--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0" ) parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port") -parser.add_argument( - "--asr_model", - type=str, - default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - help="model from modelscope", -) -parser.add_argument("--asr_model_revision", type=str, default="master", help="") -parser.add_argument( - "--asr_model_online", - type=str, - default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", - help="model from modelscope", -) -parser.add_argument("--asr_model_online_revision", type=str, default="master", help="") parser.add_argument( "--vad_model", type=str, @@ -50,7 +34,6 @@ parser.add_argument( required=False, help="certfile for ssl", ) - parser.add_argument( "--keyfile", type=str, @@ -63,9 +46,6 @@ args = parser.parse_args() websocket_users = set() print("model loading") -from funasr import AutoModel - - # vad model_vad = AutoModel( model=args.vad_model, @@ -78,17 +58,11 @@ model_vad = AutoModel( # chunk_size=60, ) - -from funasr import AutoModel -from modelscope.hub.api import HubApi - api = HubApi() if "key" in os.environ: key = os.environ["key"] api.login(key) -from modelscope.hub.snapshot_download import snapshot_download - # os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope" # llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master') # audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master') @@ -102,7 +76,6 @@ all_file_paths = [ llm_kwargs = {"num_beams": 1, "do_sample": False} unfix_len = 5 -max_streaming_res_onetime = 100 ckpt_dir = all_file_paths[0] @@ -114,7 +87,7 @@ model_llm = AutoModel( llm_dtype="bf16", max_length=1024, llm_kwargs=llm_kwargs, - llm_conf={"init_param_path": llm_dir}, + llm_conf={"init_param_path": llm_dir, "load_kwargs": {"attn_implementation": "eager"}}, tokenizer_conf={"init_param_path": llm_dir}, audio_encoder=audio_encoder_dir, ) @@ -125,6 +98,8 @@ tokenizer = model_llm.kwargs["tokenizer"] model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer} +print("model loaded! only support one client at the same time now!!!!") + def load_bytes(input): middle_data = np.frombuffer(input, dtype=np.int16) middle_data = np.asarray(middle_data) @@ -140,7 +115,7 @@ def load_bytes(input): array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array -async def streaming_transcribe(websocket, audio_in, his_state=None, prompt=None): +async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None): if his_state is None: his_state = model_dict model = his_state["model"] @@ -148,11 +123,19 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, prompt=None) if websocket.streaming_state is None: previous_asr_text = "" + previous_s2tt_text = "" + previous_vad_onscreen_asr_text = "" + previous_vad_onscreen_s2tt_text = "" else: - previous_asr_text = websocket.streaming_state["previous_asr_text"] + previous_asr_text = websocket.streaming_state.get("previous_asr_text", "") + previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "") + previous_vad_onscreen_asr_text = websocket.streaming_state.get("previous_vad_onscreen_asr_text", "") + previous_vad_onscreen_s2tt_text = websocket.streaming_state.get("previous_vad_onscreen_s2tt_text", "") - if prompt is None: - prompt = "Copy:" + if asr_prompt is None or asr_prompt == "": + asr_prompt = "Copy:" + if s2tt_prompt is None or s2tt_prompt == "": + s2tt_prompt = "Translate the following sentence into English:" audio_seconds = load_bytes(audio_in).shape[0] / 16000 print(f"Streaming audio length: {audio_seconds} seconds") @@ -160,74 +143,128 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, prompt=None) asr_content = [] system_prompt = "You are a helpful assistant." asr_content.append({"role": "system", "content": system_prompt}) + s2tt_content = [] + system_prompt = "You are a helpful assistant." + s2tt_content.append({"role": "system", "content": system_prompt}) - asr_user_prompt = f"{prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}" + user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}" + user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_s2tt_text}" - asr_content.append({"role": "user", "content": asr_user_prompt, "audio": audio_in}) + asr_content.append({"role": "user", "content": user_asr_prompt, "audio": audio_in}) asr_content.append({"role": "assistant", "content": "target_out"}) + s2tt_content.append({"role": "user", "content": user_s2tt_prompt, "audio": audio_in}) + s2tt_content.append({"role": "assistant", "content": "target_out"}) - streaming_asr_time_beg = time.time() - asr_inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( + streaming_time_beg = time.time() + inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( [asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True ) - asr_model_inputs = {} - asr_model_inputs["inputs_embeds"] = asr_inputs_embeds + model_asr_inputs = {} + model_asr_inputs["inputs_embeds"] = inputs_asr_embeds + inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( + [s2tt_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True + ) + model_s2tt_inputs = {} + model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds print("previous_asr_text:", previous_asr_text) + print("previous_s2tt_text:", previous_s2tt_text) - streamer = TextIteratorStreamer(tokenizer) - generation_kwargs = dict(asr_model_inputs, streamer=streamer, max_new_tokens=1024) - thread = Thread(target=model.llm.generate, kwargs=generation_kwargs) - thread.start() + asr_streamer = TextIteratorStreamer(tokenizer) + asr_generation_kwargs = dict(model_asr_inputs, streamer=asr_streamer, max_new_tokens=1024) + asr_generation_kwargs.update(llm_kwargs) + asr_thread = Thread(target=model.llm.generate, kwargs=asr_generation_kwargs) + asr_thread.start() + s2tt_streamer = TextIteratorStreamer(tokenizer) + s2tt_generation_kwargs = dict(model_s2tt_inputs, streamer=s2tt_streamer, max_new_tokens=1024) + s2tt_generation_kwargs.update(llm_kwargs) + s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs) + s2tt_thread.start() onscreen_asr_res = previous_asr_text - beg_llm = time.time() - for new_text in streamer: - end_llm = time.time() - print( - f"generated new text: {new_text}, time_llm_decode: {end_llm - beg_llm:.2f}" - ) - if len(new_text) > 0: - onscreen_asr_res += new_text.replace("<|im_end|>", "") + onscreen_s2tt_res = previous_s2tt_text - mode = "online" + for new_asr_text in asr_streamer: + print(f"generated new asr text: {new_asr_text}") + if len(new_asr_text) > 0: + onscreen_asr_res += new_asr_text.replace("<|im_end|>", "") + + new_s2tt_text = next(s2tt_streamer) + print(f"generated new s2tt text: {new_s2tt_text}") + if len(new_s2tt_text) > 0: + onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") + + if len(new_asr_text) > 0 or len(new_s2tt_text) > 0: message = json.dumps( { - "mode": mode, - "text": onscreen_asr_res, + "mode": "online", + "asr_text": previous_vad_onscreen_asr_text + onscreen_asr_res, + "s2tt_text": previous_vad_onscreen_s2tt_text + onscreen_s2tt_res, "wav_name": websocket.wav_name, "is_final": websocket.is_speaking, } ) await websocket.send(message) + websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res + websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res - streaming_asr_time_end = time.time() - print(f"Streaming ASR inference time: {streaming_asr_time_end - streaming_asr_time_beg}") + + + for new_s2tt_text in s2tt_streamer: + print(f"generated new s2tt text: {new_s2tt_text}") + if len(new_s2tt_text) > 0: + onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") + + if len(new_s2tt_text) > 0: + message = json.dumps( + { + "mode": "online", + "asr_text": previous_vad_onscreen_asr_text + onscreen_asr_res, + "s2tt_text": previous_vad_onscreen_s2tt_text + onscreen_s2tt_res, + "wav_name": websocket.wav_name, + "is_final": websocket.is_speaking, + } + ) + await websocket.send(message) + websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res + websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + + streaming_time_end = time.time() + print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}") asr_text_len = len(tokenizer.encode(onscreen_asr_res)) + s2tt_text_len = len(tokenizer.encode(onscreen_s2tt_res)) if asr_text_len > unfix_len and audio_seconds > 1.1: - if asr_text_len <= max_streaming_res_onetime: - previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-unfix_len]) - else: - onscreen_asr_res = previous_asr_text + previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-unfix_len]) else: previous_asr_text = "" + if s2tt_text_len > unfix_len and audio_seconds > 1.1: + previous_s2tt_text = tokenizer.decode(tokenizer.encode(onscreen_s2tt_res)[:-unfix_len]) + else: + previous_s2tt_text = "" - websocket.streaming_state = {} websocket.streaming_state["previous_asr_text"] = previous_asr_text + websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res + websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text + websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + print("fix asr part:", previous_asr_text) - - -print("model loaded! only support one client at the same time now!!!!") + print("fix s2tt part:", previous_s2tt_text) async def ws_reset(websocket): print("ws reset now, total num is ", len(websocket_users)) - websocket.status_dict_asr_online["cache"] = {} - websocket.status_dict_asr_online["is_final"] = True - websocket.streaming_state = None + websocket.streaming_state = {} + websocket.streaming_state["is_final"] = True + websocket.streaming_state["previous_asr_text"] = "" + websocket.streaming_state["previous_s2tt_text"] = "" + websocket.streaming_state["onscreen_asr_res"] = "" + websocket.streaming_state["onscreen_s2tt_res"] = "" + websocket.streaming_state["previous_vad_onscreen_asr_text"] = "" + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = "" + websocket.status_dict_vad["cache"] = {} websocket.status_dict_vad["is_final"] = True @@ -246,8 +283,15 @@ async def ws_serve(websocket, path): global websocket_users # await clear_websocket() websocket_users.add(websocket) - websocket.status_dict_asr = {} - websocket.status_dict_asr_online = {"cache": {}, "is_final": False} + websocket.streaming_state = { + "previous_asr_text": "", + "previous_s2tt_text": "", + "onscreen_asr_res": "", + "onscreen_s2tt_res": "", + "previous_vad_onscreen_asr_text": "", + "previous_vad_onscreen_s2tt_text": "", + "is_final": False, + } websocket.status_dict_vad = {"cache": {}, "is_final": False} websocket.chunk_interval = 10 @@ -255,8 +299,6 @@ async def ws_serve(websocket, path): speech_start = False speech_end_i = -1 websocket.wav_name = "microphone" - websocket.mode = "online" - websocket.streaming_state = None print("new user connected", flush=True) try: @@ -266,7 +308,9 @@ async def ws_serve(websocket, path): if "is_speaking" in messagejson: websocket.is_speaking = messagejson["is_speaking"] - websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking + websocket.streaming_state["is_final"] = not websocket.is_speaking + if not messagejson["is_speaking"]: + await clear_websocket() if "chunk_interval" in messagejson: websocket.chunk_interval = messagejson["chunk_interval"] if "wav_name" in messagejson: @@ -275,22 +319,18 @@ async def ws_serve(websocket, path): chunk_size = messagejson["chunk_size"] if isinstance(chunk_size, str): chunk_size = chunk_size.split(",") - websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size] - if "encoder_chunk_look_back" in messagejson: - websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson[ - "encoder_chunk_look_back" - ] - if "decoder_chunk_look_back" in messagejson: - websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson[ - "decoder_chunk_look_back" - ] - if "hotword" in messagejson: - websocket.status_dict_asr["hotword"] = messagejson["hotwords"] - if "mode" in messagejson: - websocket.mode = messagejson["mode"] + chunk_size = [int(x) for x in chunk_size] + if "asr_prompt" in messagejson: + asr_prompt = messagejson["asr_prompt"] + else: + asr_prompt = "Copy:" + if "s2tt_prompt" in messagejson: + s2tt_prompt = messagejson["s2tt_prompt"] + else: + s2tt_prompt = "Translate the following sentence into English:" websocket.status_dict_vad["chunk_size"] = int( - websocket.status_dict_asr_online["chunk_size"][1] * 60 / websocket.chunk_interval + chunk_size[1] * 60 / websocket.chunk_interval ) if len(frames_asr) > 0 or not isinstance(message, str): if not isinstance(message, str): @@ -299,20 +339,21 @@ async def ws_serve(websocket, path): websocket.vad_pre_idx += duration_ms # asr online - websocket.status_dict_asr_online["is_final"] = speech_end_i != -1 + websocket.streaming_state["is_final"] = speech_end_i != -1 if ( (len(frames_asr) % websocket.chunk_interval == 0 - or websocket.status_dict_asr_online["is_final"]) + or websocket.streaming_state["is_final"]) and len(frames_asr) != 0 ): - if websocket.mode == "2pass" or websocket.mode == "online": - audio_in = b"".join(frames_asr) - try: - await streaming_transcribe(websocket, audio_in) - except: - print(f"error in asr streaming, {websocket.status_dict_asr_online}") + audio_in = b"".join(frames_asr) + try: + await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt) + except Exception as e: + print(f"error in streaming, {e}") + print(f"error in streaming, {websocket.streaming_state}") if speech_start: frames_asr.append(message) + # vad online try: speech_start_i, speech_end_i = await async_vad(websocket, message) @@ -324,17 +365,22 @@ async def ws_serve(websocket, path): frames_pre = frames[-beg_bias:] frames_asr = [] frames_asr.extend(frames_pre) - # asr punc offline + + # vad end if speech_end_i != -1 or not websocket.is_speaking: + audio_in = b"".join(frames_asr) frames_asr = [] speech_start = False - websocket.status_dict_asr_online["cache"] = {} - websocket.streaming_state = None + websocket.streaming_state["previous_asr_text"] = "" + websocket.streaming_state["previous_s2tt_text"] = "" + websocket.streaming_state["previous_vad_onscreen_asr_text"] = websocket.streaming_state.get("onscreen_asr_res", "") + "\n\n" + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = websocket.streaming_state.get("onscreen_s2tt_res", "") + "\n\n" if not websocket.is_speaking: websocket.vad_pre_idx = 0 frames = [] websocket.status_dict_vad["cache"] = {} - websocket.streaming_state = None + websocket.streaming_state["previous_asr_text"] = "" + websocket.streaming_state["previous_s2tt_text"] = "" else: frames = frames[-20:] else: