mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
This commit is contained in:
commit
72d561531f
@ -30,14 +30,7 @@ from funasr.models.frontend.wav_frontend import WavFrontendOnline
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
from funasr.bin.vad_inference import Speech2VadSegment
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
|
||||
global_asr_language: str = 'zh-cn'
|
||||
global_sample_rate: Union[int, Dict[Any, int]] = {
|
||||
'audio_fs': 16000,
|
||||
'model_fs': 16000
|
||||
}
|
||||
|
||||
|
||||
class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||
|
||||
@ -5,13 +5,35 @@ import websockets
|
||||
import asyncio
|
||||
from queue import Queue
|
||||
# import threading
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
required=False,
|
||||
help="host ip, localhost, 0.0.0.0")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
required=False,
|
||||
help="grpc server port")
|
||||
parser.add_argument("--chunk_size",
|
||||
type=int,
|
||||
default=300,
|
||||
help="ms")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
voices = Queue()
|
||||
async def hello():
|
||||
async def ws_client():
|
||||
global ws # 定义一个全局变量ws,用于保存websocket连接对象
|
||||
uri = "ws://localhost:8899"
|
||||
# uri = "ws://11.167.134.197:8899"
|
||||
uri = "ws://{}:{}".format(args.host, args.port)
|
||||
ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接
|
||||
ws.max_size = 1024 * 1024 * 20
|
||||
print("connected ws server")
|
||||
|
||||
async def send(data):
|
||||
global ws # 引用全局变量ws
|
||||
try:
|
||||
@ -21,7 +43,7 @@ async def send(data):
|
||||
|
||||
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(hello()) # 启动协程
|
||||
asyncio.get_event_loop().run_until_complete(ws_client()) # 启动协程
|
||||
|
||||
|
||||
# 其他函数可以通过调用send(data)来发送数据,例如:
|
||||
@ -31,7 +53,7 @@ async def test():
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
RATE = 16000
|
||||
CHUNK = int(RATE / 1000 * 300)
|
||||
CHUNK = int(RATE / 1000 * args.chunk_size)
|
||||
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
@ -70,4 +92,4 @@ async def main():
|
||||
|
||||
await asyncio.gather(task, task2)
|
||||
|
||||
asyncio.run(main())
|
||||
asyncio.run(main())
|
||||
|
||||
@ -6,37 +6,73 @@ import logging
|
||||
|
||||
logger = get_logger(log_level=logging.CRITICAL)
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
|
||||
import asyncio
|
||||
import websockets #区别客户端这里是 websockets库
|
||||
import websockets
|
||||
import time
|
||||
from queue import Queue
|
||||
import threading
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
required=False,
|
||||
help="host ip, localhost, 0.0.0.0")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
required=False,
|
||||
help="grpc server port")
|
||||
parser.add_argument("--asr_model",
|
||||
type=str,
|
||||
default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="model from modelscope")
|
||||
parser.add_argument("--vad_model",
|
||||
type=str,
|
||||
default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
help="model from modelscope")
|
||||
|
||||
parser.add_argument("--punc_model",
|
||||
type=str,
|
||||
default="",
|
||||
help="model from modelscope")
|
||||
parser.add_argument("--ngpu",
|
||||
type=int,
|
||||
default=1,
|
||||
help="0 for cpu, 1 for gpu")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("model loading")
|
||||
voices = Queue()
|
||||
speek = Queue()
|
||||
|
||||
# 创建一个VAD对象
|
||||
vad_pipline = pipeline(
|
||||
task=Tasks.voice_activity_detection,
|
||||
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
model=args.vad_model,
|
||||
model_revision="v1.2.0",
|
||||
output_dir=None,
|
||||
batch_size=1,
|
||||
mode='online'
|
||||
)
|
||||
param_dict_vad = {'in_cache': dict(), "is_final": False}
|
||||
|
||||
# 创建一个ASR对象
|
||||
param_dict = dict()
|
||||
param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开
|
||||
# param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开
|
||||
inference_pipeline2 = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
|
||||
model=args.asr_model,
|
||||
param_dict=param_dict,
|
||||
)
|
||||
print("model loaded")
|
||||
|
||||
|
||||
|
||||
async def echo(websocket, path):
|
||||
async def ws_serve(websocket, path):
|
||||
global voices
|
||||
try:
|
||||
async for message in websocket:
|
||||
@ -47,18 +83,26 @@ async def echo(websocket, path):
|
||||
except Exception as e:
|
||||
print('Exception occurred:', e)
|
||||
|
||||
start_server = websockets.serve(echo, "localhost", 8899, subprotocols=["binary"],ping_interval=None)
|
||||
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
|
||||
|
||||
|
||||
def vad(data): # 推理
|
||||
global vad_pipline
|
||||
global vad_pipline, param_dict_vad
|
||||
#print(type(data))
|
||||
segments_result = vad_pipline(audio_in=data)
|
||||
#print(segments_result)
|
||||
if len(segments_result) == 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
# print(param_dict_vad)
|
||||
segments_result = vad_pipline(audio_in=data, param_dict=param_dict_vad)
|
||||
# print(segments_result)
|
||||
# print(param_dict_vad)
|
||||
speech_start = False
|
||||
speech_end = False
|
||||
|
||||
if len(segments_result) == 0 or len(segments_result["text"]) > 1:
|
||||
return speech_start, speech_end
|
||||
if segments_result["text"][0][0] != -1:
|
||||
speech_start = True
|
||||
if segments_result["text"][0][1] != -1:
|
||||
speech_end = True
|
||||
return speech_start, speech_end
|
||||
|
||||
def asr(): # 推理
|
||||
global inference_pipeline2
|
||||
@ -76,11 +120,12 @@ def asr(): # 推理
|
||||
def main(): # 推理
|
||||
frames = [] # 存储所有的帧数据
|
||||
buffer = [] # 存储缓存中的帧数据(最多两个片段)
|
||||
silence_count = 0 # 统计连续静音的次数
|
||||
speech_detected = False # 标记是否检测到语音
|
||||
# silence_count = 0 # 统计连续静音的次数
|
||||
# speech_detected = False # 标记是否检测到语音
|
||||
RECORD_NUM = 0
|
||||
global voices
|
||||
global speek
|
||||
speech_start, speech_end = False, False
|
||||
while True:
|
||||
while not voices.empty():
|
||||
|
||||
@ -91,32 +136,35 @@ def main(): # 推理
|
||||
if len(buffer) > 2:
|
||||
buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个
|
||||
|
||||
if speech_detected:
|
||||
if speech_start:
|
||||
frames.append(data)
|
||||
RECORD_NUM += 1
|
||||
|
||||
if vad(data):
|
||||
if not speech_detected:
|
||||
print("检测到人声...")
|
||||
speech_detected = True # 标记为检测到语音
|
||||
frames = []
|
||||
frames.extend(buffer) # 把之前2个语音数据快加入
|
||||
silence_count = 0 # 重置静音次数
|
||||
else:
|
||||
silence_count += 1 # 增加静音次数
|
||||
|
||||
if speech_detected and (silence_count > 4 or RECORD_NUM > 50): #这里 50 可根据需求改为合适的数据快数量
|
||||
print("说话结束或者超过设置最长时间...")
|
||||
audio_in = b"".join(frames)
|
||||
#asrt = threading.Thread(target=asr,args=(audio_in,))
|
||||
#asrt.start()
|
||||
speek.put(audio_in)
|
||||
#rec_result = inference_pipeline2(audio_in=audio_in) # ASR 模型里跑一跑
|
||||
frames = [] # 清空所有的帧数据
|
||||
buffer = [] # 清空缓存中的帧数据(最多两个片段)
|
||||
silence_count = 0 # 统计连续静音的次数清零
|
||||
speech_detected = False # 标记是否检测到语音
|
||||
RECORD_NUM = 0
|
||||
RECORD_NUM += 1
|
||||
speech_start_i, speech_end_i = vad(data)
|
||||
# print(speech_start_i, speech_end_i)
|
||||
if speech_start_i:
|
||||
speech_start = speech_start_i
|
||||
# if not speech_detected:
|
||||
# print("检测到人声...")
|
||||
# speech_detected = True # 标记为检测到语音
|
||||
frames = []
|
||||
frames.extend(buffer) # 把之前2个语音数据快加入
|
||||
# silence_count = 0 # 重置静音次数
|
||||
if speech_end_i or RECORD_NUM > 300:
|
||||
# silence_count += 1 # 增加静音次数
|
||||
# speech_end = speech_end_i
|
||||
speech_start = False
|
||||
# if RECORD_NUM > 300: #这里 50 可根据需求改为合适的数据快数量
|
||||
# print("说话结束或者超过设置最长时间...")
|
||||
audio_in = b"".join(frames)
|
||||
#asrt = threading.Thread(target=asr,args=(audio_in,))
|
||||
#asrt.start()
|
||||
speek.put(audio_in)
|
||||
#rec_result = inference_pipeline2(audio_in=audio_in) # ASR 模型里跑一跑
|
||||
frames = [] # 清空所有的帧数据
|
||||
buffer = [] # 清空缓存中的帧数据(最多两个片段)
|
||||
# silence_count = 0 # 统计连续静音的次数清零
|
||||
# speech_detected = False # 标记是否检测到语音
|
||||
RECORD_NUM = 0
|
||||
time.sleep(0.01)
|
||||
time.sleep(0.01)
|
||||
|
||||
@ -128,16 +176,4 @@ s = threading.Thread(target=asr)
|
||||
s.start()
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
asyncio.get_event_loop().run_forever()
|
||||
46
funasr/runtime/python/websocket/README.md
Normal file
46
funasr/runtime/python/websocket/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
# Using funasr with websocket
|
||||
We can send streaming audio data to server in real-time with grpc client every 300 ms e.g., and get transcribed text when stop speaking.
|
||||
The audio data is in streaming, the asr inference process is in offline.
|
||||
|
||||
# Steps
|
||||
|
||||
## For the Server
|
||||
|
||||
Install the modelscope and funasr
|
||||
|
||||
```shell
|
||||
pip install "modelscope[audio_asr]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
pip install --editable ./
|
||||
```
|
||||
|
||||
Install the requirements for server
|
||||
|
||||
```shell
|
||||
cd funasr/runtime/python/websocket
|
||||
pip install -r requirements_server.txt
|
||||
```
|
||||
|
||||
Start server
|
||||
|
||||
```shell
|
||||
python ASR_server.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
```
|
||||
|
||||
## For the client
|
||||
|
||||
Install the requirements for client
|
||||
```shell
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
cd funasr/runtime/python/websocket
|
||||
pip install -r requirements_client.txt
|
||||
```
|
||||
|
||||
Start client
|
||||
|
||||
```shell
|
||||
python ASR_client.py --host "127.0.0.1" --port 10095 --chunk_size 300
|
||||
```
|
||||
|
||||
## Acknowledge
|
||||
1. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service.
|
||||
2
funasr/runtime/python/websocket/requirements_client.txt
Normal file
2
funasr/runtime/python/websocket/requirements_client.txt
Normal file
@ -0,0 +1,2 @@
|
||||
websockets
|
||||
pyaudio
|
||||
1
funasr/runtime/python/websocket/requirements_server.txt
Normal file
1
funasr/runtime/python/websocket/requirements_server.txt
Normal file
@ -0,0 +1 @@
|
||||
websockets
|
||||
Loading…
Reference in New Issue
Block a user