mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
多端连接
This commit is contained in:
parent
a58e077c35
commit
45698b0b21
@ -26,30 +26,13 @@ parser.add_argument("--chunk_size",
|
||||
args = parser.parse_args()
|
||||
|
||||
voices = Queue()
|
||||
async def ws_client():
|
||||
global ws # 定义一个全局变量ws,用于保存websocket连接对象
|
||||
# uri = "ws://11.167.134.197:8899"
|
||||
uri = "ws://{}:{}".format(args.host, args.port)
|
||||
ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接
|
||||
ws.max_size = 1024 * 1024 * 20
|
||||
print("connected ws server")
|
||||
|
||||
|
||||
|
||||
async def send(data):
|
||||
global ws # 引用全局变量ws
|
||||
try:
|
||||
await ws.send(data) # 通过ws对象发送数据
|
||||
except Exception as e:
|
||||
print('Exception occurred:', e)
|
||||
|
||||
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程
|
||||
|
||||
|
||||
# 其他函数可以通过调用send(data)来发送数据,例如:
|
||||
async def test():
|
||||
async def record():
|
||||
#print("2")
|
||||
global voices
|
||||
global voices
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
RATE = 16000
|
||||
@ -69,27 +52,49 @@ async def test():
|
||||
|
||||
voices.put(data)
|
||||
#print(voices.qsize())
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
async def ws_send():
|
||||
global voices
|
||||
global websocket
|
||||
print("started to sending data!")
|
||||
while True:
|
||||
while not voices.empty():
|
||||
data = voices.get()
|
||||
voices.task_done()
|
||||
await send(data)
|
||||
try:
|
||||
await websocket.send(data) # 通过ws对象发送数据
|
||||
except Exception as e:
|
||||
print('Exception occurred:', e)
|
||||
await asyncio.sleep(0.01)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def main():
|
||||
task = asyncio.create_task(test()) # 创建一个后台任务
|
||||
task2 = asyncio.create_task(ws_send()) # 创建一个后台任务
|
||||
|
||||
await asyncio.gather(task, task2)
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
async def message():
|
||||
global websocket
|
||||
while True:
|
||||
try:
|
||||
print(await websocket.recv())
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
|
||||
|
||||
|
||||
async def ws_client():
|
||||
global websocket # 定义一个全局变量ws,用于保存websocket连接对象
|
||||
# uri = "ws://11.167.134.197:8899"
|
||||
uri = "ws://{}:{}".format(args.host, args.port)
|
||||
#ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接
|
||||
async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
|
||||
task = asyncio.create_task(record()) # 创建一个后台任务录音
|
||||
task2 = asyncio.create_task(ws_send()) # 创建一个后台任务发送
|
||||
task3 = asyncio.create_task(message()) # 创建一个后台接收消息的任务
|
||||
await asyncio.gather(task, task2, task3)
|
||||
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
@ -9,11 +9,15 @@ from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
import logging
|
||||
import tracemalloc
|
||||
tracemalloc.start()
|
||||
|
||||
logger = get_logger(log_level=logging.CRITICAL)
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
websocket_users = set() #维护客户端列表
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
@ -46,8 +50,7 @@ parser.add_argument("--ngpu",
|
||||
args = parser.parse_args()
|
||||
|
||||
print("model loading")
|
||||
voices = Queue()
|
||||
speek = Queue()
|
||||
|
||||
|
||||
# vad
|
||||
inference_pipeline_vad = pipeline(
|
||||
@ -86,20 +89,76 @@ print("model loaded")
|
||||
|
||||
|
||||
async def ws_serve(websocket, path):
|
||||
global voices
|
||||
#speek = Queue()
|
||||
frames = [] # 存储所有的帧数据
|
||||
buffer = [] # 存储缓存中的帧数据(最多两个片段)
|
||||
RECORD_NUM = 0
|
||||
global websocket_users
|
||||
speech_start, speech_end = False, False
|
||||
# 调用asr函数
|
||||
websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包
|
||||
websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端
|
||||
websocket_users.add(websocket)
|
||||
ss = threading.Thread(target=asr, args=(websocket,))
|
||||
ss.start()
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
voices.put(message)
|
||||
#voices.put(message)
|
||||
#print("put")
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
print('Connection closed with exception:', e)
|
||||
#await websocket.send("123")
|
||||
buffer.append(message)
|
||||
if len(buffer) > 2:
|
||||
buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个
|
||||
|
||||
if speech_start:
|
||||
frames.append(message)
|
||||
RECORD_NUM += 1
|
||||
speech_start_i, speech_end_i = vad(message)
|
||||
#print(speech_start_i, speech_end_i)
|
||||
if speech_start_i:
|
||||
speech_start = speech_start_i
|
||||
frames = []
|
||||
frames.extend(buffer) # 把之前2个语音数据快加入
|
||||
if speech_end_i or RECORD_NUM > 300:
|
||||
speech_start = False
|
||||
audio_in = b"".join(frames)
|
||||
websocket.speek.put(audio_in)
|
||||
frames = [] # 清空所有的帧数据
|
||||
buffer = [] # 清空缓存中的帧数据(最多两个片段)
|
||||
RECORD_NUM = 0
|
||||
if not websocket.send_msg.empty():
|
||||
await websocket.send(websocket.send_msg.get())
|
||||
websocket.send_msg.task_done()
|
||||
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
print("ConnectionClosed...", websocket_users) # 链接断开
|
||||
websocket_users.remove(websocket)
|
||||
except websockets.InvalidState:
|
||||
print("InvalidState...") # 无效状态
|
||||
except Exception as e:
|
||||
print('Exception occurred:', e)
|
||||
print("Exception:", e)
|
||||
|
||||
|
||||
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
|
||||
def asr(websocket): # ASR推理
|
||||
global inference_pipeline2
|
||||
global param_dict_punc
|
||||
global websocket_users
|
||||
while websocket in websocket_users:
|
||||
if not websocket.speek.empty():
|
||||
audio_in = websocket.speek.get()
|
||||
websocket.speek.task_done()
|
||||
if len(audio_in) > 0:
|
||||
rec_result = inference_pipeline_asr(audio_in=audio_in)
|
||||
if inference_pipeline_punc is not None and 'text' in rec_result:
|
||||
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc)
|
||||
results = (rec_result["text"] if "text" in rec_result else rec_result)
|
||||
websocket.send_msg.put(results) # 存入发送队列 直接调用send发送不了
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def vad(data): # 推理
|
||||
def vad(data): # VAD推理
|
||||
global vad_pipline, param_dict_vad
|
||||
#print(type(data))
|
||||
# print(param_dict_vad)
|
||||
@ -117,79 +176,7 @@ def vad(data): # 推理
|
||||
speech_end = True
|
||||
return speech_start, speech_end
|
||||
|
||||
def asr(): # 推理
|
||||
global inference_pipeline2
|
||||
global speek, param_dict_punc
|
||||
while True:
|
||||
while not speek.empty():
|
||||
audio_in = speek.get()
|
||||
speek.task_done()
|
||||
if len(audio_in) > 0:
|
||||
rec_result = inference_pipeline_asr(audio_in=audio_in)
|
||||
if inference_pipeline_punc is not None and 'text' in rec_result:
|
||||
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc)
|
||||
print(rec_result["text"] if "text" in rec_result else rec_result)
|
||||
time.sleep(0.1)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def main(): # 推理
|
||||
frames = [] # 存储所有的帧数据
|
||||
buffer = [] # 存储缓存中的帧数据(最多两个片段)
|
||||
# silence_count = 0 # 统计连续静音的次数
|
||||
# speech_detected = False # 标记是否检测到语音
|
||||
RECORD_NUM = 0
|
||||
global voices
|
||||
global speek
|
||||
speech_start, speech_end = False, False
|
||||
while True:
|
||||
while not voices.empty():
|
||||
|
||||
data = voices.get()
|
||||
#print("队列排队数",voices.qsize())
|
||||
voices.task_done()
|
||||
buffer.append(data)
|
||||
if len(buffer) > 2:
|
||||
buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个
|
||||
|
||||
if speech_start:
|
||||
frames.append(data)
|
||||
RECORD_NUM += 1
|
||||
speech_start_i, speech_end_i = vad(data)
|
||||
# print(speech_start_i, speech_end_i)
|
||||
if speech_start_i:
|
||||
speech_start = speech_start_i
|
||||
# if not speech_detected:
|
||||
# print("检测到人声...")
|
||||
# speech_detected = True # 标记为检测到语音
|
||||
frames = []
|
||||
frames.extend(buffer) # 把之前2个语音数据快加入
|
||||
# silence_count = 0 # 重置静音次数
|
||||
if speech_end_i or RECORD_NUM > 300:
|
||||
# silence_count += 1 # 增加静音次数
|
||||
# speech_end = speech_end_i
|
||||
speech_start = False
|
||||
# if RECORD_NUM > 300: #这里 50 可根据需求改为合适的数据快数量
|
||||
# print("说话结束或者超过设置最长时间...")
|
||||
audio_in = b"".join(frames)
|
||||
#asrt = threading.Thread(target=asr,args=(audio_in,))
|
||||
#asrt.start()
|
||||
speek.put(audio_in)
|
||||
#rec_result = inference_pipeline2(audio_in=audio_in) # ASR 模型里跑一跑
|
||||
frames = [] # 清空所有的帧数据
|
||||
buffer = [] # 清空缓存中的帧数据(最多两个片段)
|
||||
# silence_count = 0 # 统计连续静音的次数清零
|
||||
# speech_detected = False # 标记是否检测到语音
|
||||
RECORD_NUM = 0
|
||||
time.sleep(0.01)
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
|
||||
s = threading.Thread(target=main)
|
||||
s.start()
|
||||
s = threading.Thread(target=asr)
|
||||
s.start()
|
||||
|
||||
|
||||
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
Loading…
Reference in New Issue
Block a user