多端连接

This commit is contained in:
cgisky1980 2023-03-24 15:11:56 +08:00
parent a58e077c35
commit 45698b0b21
2 changed files with 106 additions and 114 deletions

View File

@ -26,30 +26,13 @@ parser.add_argument("--chunk_size",
args = parser.parse_args()
voices = Queue()
async def ws_client():
global ws # 定义一个全局变量ws用于保存websocket连接对象
# uri = "ws://11.167.134.197:8899"
uri = "ws://{}:{}".format(args.host, args.port)
ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接
ws.max_size = 1024 * 1024 * 20
print("connected ws server")
async def send(data):
global ws # 引用全局变量ws
try:
await ws.send(data) # 通过ws对象发送数据
except Exception as e:
print('Exception occurred:', e)
asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程
# 其他函数可以通过调用send(data)来发送数据,例如:
async def test():
async def record():
#print("2")
global voices
global voices
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
@ -69,27 +52,49 @@ async def test():
voices.put(data)
#print(voices.qsize())
await asyncio.sleep(0.01)
async def ws_send():
global voices
global websocket
print("started to sending data!")
while True:
while not voices.empty():
data = voices.get()
voices.task_done()
await send(data)
try:
await websocket.send(data) # 通过ws对象发送数据
except Exception as e:
print('Exception occurred:', e)
await asyncio.sleep(0.01)
await asyncio.sleep(0.01)
async def main():
task = asyncio.create_task(test()) # 创建一个后台任务
task2 = asyncio.create_task(ws_send()) # 创建一个后台任务
await asyncio.gather(task, task2)
asyncio.run(main())
async def message():
global websocket
while True:
try:
print(await websocket.recv())
except Exception as e:
print("Exception:", e)
async def ws_client():
global websocket # 定义一个全局变量ws用于保存websocket连接对象
# uri = "ws://11.167.134.197:8899"
uri = "ws://{}:{}".format(args.host, args.port)
#ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接
async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
task = asyncio.create_task(record()) # 创建一个后台任务录音
task2 = asyncio.create_task(ws_send()) # 创建一个后台任务发送
task3 = asyncio.create_task(message()) # 创建一个后台接收消息的任务
await asyncio.gather(task, task2, task3)
asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程
asyncio.get_event_loop().run_forever()

View File

@ -9,11 +9,15 @@ from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
import logging
import tracemalloc
tracemalloc.start()
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
websocket_users = set() #维护客户端列表
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
@ -46,8 +50,7 @@ parser.add_argument("--ngpu",
args = parser.parse_args()
print("model loading")
voices = Queue()
speek = Queue()
# vad
inference_pipeline_vad = pipeline(
@ -86,20 +89,76 @@ print("model loaded")
async def ws_serve(websocket, path):
global voices
#speek = Queue()
frames = [] # 存储所有的帧数据
buffer = [] # 存储缓存中的帧数据(最多两个片段)
RECORD_NUM = 0
global websocket_users
speech_start, speech_end = False, False
# 调用asr函数
websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包
websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端
websocket_users.add(websocket)
ss = threading.Thread(target=asr, args=(websocket,))
ss.start()
try:
async for message in websocket:
voices.put(message)
#voices.put(message)
#print("put")
except websockets.exceptions.ConnectionClosedError as e:
print('Connection closed with exception:', e)
#await websocket.send("123")
buffer.append(message)
if len(buffer) > 2:
buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个
if speech_start:
frames.append(message)
RECORD_NUM += 1
speech_start_i, speech_end_i = vad(message)
#print(speech_start_i, speech_end_i)
if speech_start_i:
speech_start = speech_start_i
frames = []
frames.extend(buffer) # 把之前2个语音数据快加入
if speech_end_i or RECORD_NUM > 300:
speech_start = False
audio_in = b"".join(frames)
websocket.speek.put(audio_in)
frames = [] # 清空所有的帧数据
buffer = [] # 清空缓存中的帧数据(最多两个片段)
RECORD_NUM = 0
if not websocket.send_msg.empty():
await websocket.send(websocket.send_msg.get())
websocket.send_msg.task_done()
except websockets.ConnectionClosed:
print("ConnectionClosed...", websocket_users) # 链接断开
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...") # 无效状态
except Exception as e:
print('Exception occurred:', e)
print("Exception:", e)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
def asr(websocket): # ASR推理
global inference_pipeline2
global param_dict_punc
global websocket_users
while websocket in websocket_users:
if not websocket.speek.empty():
audio_in = websocket.speek.get()
websocket.speek.task_done()
if len(audio_in) > 0:
rec_result = inference_pipeline_asr(audio_in=audio_in)
if inference_pipeline_punc is not None and 'text' in rec_result:
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc)
results = (rec_result["text"] if "text" in rec_result else rec_result)
websocket.send_msg.put(results) # 存入发送队列 直接调用send发送不了
time.sleep(0.1)
def vad(data): # 推理
def vad(data): # VAD推理
global vad_pipline, param_dict_vad
#print(type(data))
# print(param_dict_vad)
@ -117,79 +176,7 @@ def vad(data): # 推理
speech_end = True
return speech_start, speech_end
def asr(): # 推理
global inference_pipeline2
global speek, param_dict_punc
while True:
while not speek.empty():
audio_in = speek.get()
speek.task_done()
if len(audio_in) > 0:
rec_result = inference_pipeline_asr(audio_in=audio_in)
if inference_pipeline_punc is not None and 'text' in rec_result:
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc)
print(rec_result["text"] if "text" in rec_result else rec_result)
time.sleep(0.1)
time.sleep(0.1)
def main(): # 推理
frames = [] # 存储所有的帧数据
buffer = [] # 存储缓存中的帧数据(最多两个片段)
# silence_count = 0 # 统计连续静音的次数
# speech_detected = False # 标记是否检测到语音
RECORD_NUM = 0
global voices
global speek
speech_start, speech_end = False, False
while True:
while not voices.empty():
data = voices.get()
#print("队列排队数",voices.qsize())
voices.task_done()
buffer.append(data)
if len(buffer) > 2:
buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个
if speech_start:
frames.append(data)
RECORD_NUM += 1
speech_start_i, speech_end_i = vad(data)
# print(speech_start_i, speech_end_i)
if speech_start_i:
speech_start = speech_start_i
# if not speech_detected:
# print("检测到人声...")
# speech_detected = True # 标记为检测到语音
frames = []
frames.extend(buffer) # 把之前2个语音数据快加入
# silence_count = 0 # 重置静音次数
if speech_end_i or RECORD_NUM > 300:
# silence_count += 1 # 增加静音次数
# speech_end = speech_end_i
speech_start = False
# if RECORD_NUM > 300: #这里 50 可根据需求改为合适的数据快数量
# print("说话结束或者超过设置最长时间...")
audio_in = b"".join(frames)
#asrt = threading.Thread(target=asr,args=(audio_in,))
#asrt.start()
speek.put(audio_in)
#rec_result = inference_pipeline2(audio_in=audio_in) # ASR 模型里跑一跑
frames = [] # 清空所有的帧数据
buffer = [] # 清空缓存中的帧数据(最多两个片段)
# silence_count = 0 # 统计连续静音的次数清零
# speech_detected = False # 标记是否检测到语音
RECORD_NUM = 0
time.sleep(0.01)
time.sleep(0.01)
s = threading.Thread(target=main)
s.start()
s = threading.Thread(target=asr)
s.start()
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()