diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ASR_client.py index 8010b18bf..fe6798127 100644 --- a/funasr/runtime/python/websocket/ASR_client.py +++ b/funasr/runtime/python/websocket/ASR_client.py @@ -26,30 +26,13 @@ parser.add_argument("--chunk_size", args = parser.parse_args() voices = Queue() -async def ws_client(): - global ws # 定义一个全局变量ws,用于保存websocket连接对象 - # uri = "ws://11.167.134.197:8899" - uri = "ws://{}:{}".format(args.host, args.port) - ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接 - ws.max_size = 1024 * 1024 * 20 - print("connected ws server") + + -async def send(data): - global ws # 引用全局变量ws - try: - await ws.send(data) # 通过ws对象发送数据 - except Exception as e: - print('Exception occurred:', e) - - - -asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程 - - # 其他函数可以通过调用send(data)来发送数据,例如: -async def test(): +async def record(): #print("2") - global voices + global voices FORMAT = pyaudio.paInt16 CHANNELS = 1 RATE = 16000 @@ -69,27 +52,49 @@ async def test(): voices.put(data) #print(voices.qsize()) + await asyncio.sleep(0.01) - - async def ws_send(): global voices + global websocket print("started to sending data!") while True: while not voices.empty(): data = voices.get() voices.task_done() - await send(data) + try: + await websocket.send(data) # 通过ws对象发送数据 + except Exception as e: + print('Exception occurred:', e) await asyncio.sleep(0.01) await asyncio.sleep(0.01) -async def main(): - task = asyncio.create_task(test()) # 创建一个后台任务 - task2 = asyncio.create_task(ws_send()) # 创建一个后台任务 - - await asyncio.gather(task, task2) -asyncio.run(main()) + +async def message(): + global websocket + while True: + try: + print(await websocket.recv()) + except Exception as e: + print("Exception:", e) + + + +async def ws_client(): + global websocket # 定义一个全局变量ws,用于保存websocket连接对象 + # uri = "ws://11.167.134.197:8899" + uri = "ws://{}:{}".format(args.host, args.port) + #ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接 + async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None): + task = asyncio.create_task(record()) # 创建一个后台任务录音 + task2 = asyncio.create_task(ws_send()) # 创建一个后台任务发送 + task3 = asyncio.create_task(message()) # 创建一个后台接收消息的任务 + await asyncio.gather(task, task2, task3) + + +asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程 +asyncio.get_event_loop().run_forever() diff --git a/funasr/runtime/python/websocket/ASR_server.py b/funasr/runtime/python/websocket/ASR_server.py index 1fd02b5aa..79c3a7a96 100644 --- a/funasr/runtime/python/websocket/ASR_server.py +++ b/funasr/runtime/python/websocket/ASR_server.py @@ -9,11 +9,15 @@ from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger import logging +import tracemalloc +tracemalloc.start() logger = get_logger(log_level=logging.CRITICAL) logger.setLevel(logging.CRITICAL) +websocket_users = set() #维护客户端列表 + parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, @@ -46,8 +50,7 @@ parser.add_argument("--ngpu", args = parser.parse_args() print("model loading") -voices = Queue() -speek = Queue() + # vad inference_pipeline_vad = pipeline( @@ -86,20 +89,76 @@ print("model loaded") async def ws_serve(websocket, path): - global voices + #speek = Queue() + frames = [] # 存储所有的帧数据 + buffer = [] # 存储缓存中的帧数据(最多两个片段) + RECORD_NUM = 0 + global websocket_users + speech_start, speech_end = False, False + # 调用asr函数 + websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 + websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 + websocket_users.add(websocket) + ss = threading.Thread(target=asr, args=(websocket,)) + ss.start() + try: async for message in websocket: - voices.put(message) + #voices.put(message) #print("put") - except websockets.exceptions.ConnectionClosedError as e: - print('Connection closed with exception:', e) + #await websocket.send("123") + buffer.append(message) + if len(buffer) > 2: + buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 + + if speech_start: + frames.append(message) + RECORD_NUM += 1 + speech_start_i, speech_end_i = vad(message) + #print(speech_start_i, speech_end_i) + if speech_start_i: + speech_start = speech_start_i + frames = [] + frames.extend(buffer) # 把之前2个语音数据快加入 + if speech_end_i or RECORD_NUM > 300: + speech_start = False + audio_in = b"".join(frames) + websocket.speek.put(audio_in) + frames = [] # 清空所有的帧数据 + buffer = [] # 清空缓存中的帧数据(最多两个片段) + RECORD_NUM = 0 + if not websocket.send_msg.empty(): + await websocket.send(websocket.send_msg.get()) + websocket.send_msg.task_done() + + + except websockets.ConnectionClosed: + print("ConnectionClosed...", websocket_users) # 链接断开 + websocket_users.remove(websocket) + except websockets.InvalidState: + print("InvalidState...") # 无效状态 except Exception as e: - print('Exception occurred:', e) + print("Exception:", e) + -start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) +def asr(websocket): # ASR推理 + global inference_pipeline2 + global param_dict_punc + global websocket_users + while websocket in websocket_users: + if not websocket.speek.empty(): + audio_in = websocket.speek.get() + websocket.speek.task_done() + if len(audio_in) > 0: + rec_result = inference_pipeline_asr(audio_in=audio_in) + if inference_pipeline_punc is not None and 'text' in rec_result: + rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc) + results = (rec_result["text"] if "text" in rec_result else rec_result) + websocket.send_msg.put(results) # 存入发送队列 直接调用send发送不了 + + time.sleep(0.1) - -def vad(data): # 推理 +def vad(data): # VAD推理 global vad_pipline, param_dict_vad #print(type(data)) # print(param_dict_vad) @@ -117,79 +176,7 @@ def vad(data): # 推理 speech_end = True return speech_start, speech_end -def asr(): # 推理 - global inference_pipeline2 - global speek, param_dict_punc - while True: - while not speek.empty(): - audio_in = speek.get() - speek.task_done() - if len(audio_in) > 0: - rec_result = inference_pipeline_asr(audio_in=audio_in) - if inference_pipeline_punc is not None and 'text' in rec_result: - rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc) - print(rec_result["text"] if "text" in rec_result else rec_result) - time.sleep(0.1) - time.sleep(0.1) - - -def main(): # 推理 - frames = [] # 存储所有的帧数据 - buffer = [] # 存储缓存中的帧数据(最多两个片段) - # silence_count = 0 # 统计连续静音的次数 - # speech_detected = False # 标记是否检测到语音 - RECORD_NUM = 0 - global voices - global speek - speech_start, speech_end = False, False - while True: - while not voices.empty(): - - data = voices.get() - #print("队列排队数",voices.qsize()) - voices.task_done() - buffer.append(data) - if len(buffer) > 2: - buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 - - if speech_start: - frames.append(data) - RECORD_NUM += 1 - speech_start_i, speech_end_i = vad(data) - # print(speech_start_i, speech_end_i) - if speech_start_i: - speech_start = speech_start_i - # if not speech_detected: - # print("检测到人声...") - # speech_detected = True # 标记为检测到语音 - frames = [] - frames.extend(buffer) # 把之前2个语音数据快加入 - # silence_count = 0 # 重置静音次数 - if speech_end_i or RECORD_NUM > 300: - # silence_count += 1 # 增加静音次数 - # speech_end = speech_end_i - speech_start = False - # if RECORD_NUM > 300: #这里 50 可根据需求改为合适的数据快数量 - # print("说话结束或者超过设置最长时间...") - audio_in = b"".join(frames) - #asrt = threading.Thread(target=asr,args=(audio_in,)) - #asrt.start() - speek.put(audio_in) - #rec_result = inference_pipeline2(audio_in=audio_in) # ASR 模型里跑一跑 - frames = [] # 清空所有的帧数据 - buffer = [] # 清空缓存中的帧数据(最多两个片段) - # silence_count = 0 # 统计连续静音的次数清零 - # speech_detected = False # 标记是否检测到语音 - RECORD_NUM = 0 - time.sleep(0.01) - time.sleep(0.01) - - - -s = threading.Thread(target=main) -s.start() -s = threading.Thread(target=asr) -s.start() - + +start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() \ No newline at end of file