websocket

This commit is contained in:
游雁 2023-03-24 18:35:14 +08:00
parent b831792f03
commit d14855ef20

View File

@ -62,7 +62,7 @@ inference_pipeline_vad = pipeline(
mode='online',
ngpu=args.ngpu,
)
param_dict_vad = {'in_cache': dict(), "is_final": False}
# param_dict_vad = {'in_cache': dict(), "is_final": False}
# asr
param_dict_asr = {}
@ -74,7 +74,7 @@ inference_pipeline_asr = pipeline(
ngpu=args.ngpu,
)
if args.punc_model != "":
param_dict_punc = {'cache': list()}
# param_dict_punc = {'cache': list()}
inference_pipeline_punc = pipeline(
task=Tasks.punctuation,
model=args.punc_model,
@ -96,6 +96,8 @@ async def ws_serve(websocket, path):
global websocket_users
speech_start, speech_end = False, False
# 调用asr函数
websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
websocket.param_dict_punc = {'cache': list()}
websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包
websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端
websocket_users.add(websocket)
@ -114,7 +116,7 @@ async def ws_serve(websocket, path):
if speech_start:
frames.append(message)
RECORD_NUM += 1
speech_start_i, speech_end_i = vad(message)
speech_start_i, speech_end_i = vad(message, websocket)
#print(speech_start_i, speech_end_i)
if speech_start_i:
speech_start = speech_start_i
@ -143,7 +145,7 @@ async def ws_serve(websocket, path):
def asr(websocket): # ASR推理
global inference_pipeline2
global param_dict_punc
# global param_dict_punc
global websocket_users
while websocket in websocket_users:
if not websocket.speek.empty():
@ -152,17 +154,18 @@ def asr(websocket): # ASR推理
if len(audio_in) > 0:
rec_result = inference_pipeline_asr(audio_in=audio_in)
if inference_pipeline_punc is not None and 'text' in rec_result:
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc)
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc)
# print(rec_result)
if "text" in rec_result:
websocket.send_msg.put(rec_result["text"]) # 存入发送队列 直接调用send发送不了
time.sleep(0.1)
def vad(data): # VAD推理
def vad(data, websocket): # VAD推理
global vad_pipline, param_dict_vad
#print(type(data))
# print(param_dict_vad)
segments_result = inference_pipeline_vad(audio_in=data, param_dict=param_dict_vad)
segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad)
# print(segments_result)
# print(param_dict_vad)
speech_start = False