mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
This commit is contained in:
commit
2bd8241948
@ -78,10 +78,11 @@ Here we provided several pretrained models on different datasets. The details of
|
|||||||
|
|
||||||
### Punctuation Restoration
|
### Punctuation Restoration
|
||||||
|
|
||||||
| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
|
| Model Name | Language | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
|
||||||
|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
|
|:--------------------------------------------------------------------------------------------------------------------------:|:---------|:----------------------------:|:----------:|:----------:|:--------------:|:------|
|
||||||
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
|
| [CT-Transformer-Large](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) | CN & EN | Alibaba Text Data(100M) | 1.1G | 471067 | Offline | large offline punctuation model |
|
||||||
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
|
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | CN & EN | Alibaba Text Data(70M) | 291M | 272727 | Offline | offline punctuation model |
|
||||||
|
| [CT-Transformer-Realtime](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | CN & EN | Alibaba Text Data(70M) | 288M | 272727 | Online | online punctuation model |
|
||||||
|
|
||||||
### Language Models
|
### Language Models
|
||||||
|
|
||||||
|
|||||||
@ -83,10 +83,11 @@
|
|||||||
|
|
||||||
### 标点恢复模型
|
### 标点恢复模型
|
||||||
|
|
||||||
| 模型名字 | 训练数据 | 模型参数 | Vocab Size| 非实时/实时 | 备注 |
|
| 模型名字 | 语言 | 训练数据 | 模型参数 | Vocab Size| 非实时/实时 | 备注 |
|
||||||
|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:--------|
|
|:--------------------------------------------------------------------------------------------------------------------------:|:----------:|:----------------------------:|:----------:|:----------:|:--------------:|:--------|
|
||||||
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | 非实时 | 支持中英文标点 |
|
| [CT-Transformer-Large](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) | 中文和英文 | Alibaba Text Data(100M) | 1.1G | 471067 | 非实时 | 支持中英文标点大模型 |
|
||||||
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | 实时 | VAD点实时 |
|
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | 中文和英文 | Alibaba Text Data(70M) | 291M | 272727 | 非实时 | 支持中英文标点 |
|
||||||
|
| [CT-Transformer-Realtime](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | 中文和英文 | Alibaba Text Data(70M) | 288M | 272727 | 实时 | VAD点实时标点 |
|
||||||
|
|
||||||
### 语音模型
|
### 语音模型
|
||||||
|
|
||||||
|
|||||||
@ -1337,7 +1337,7 @@ class Speech2TextTransducer:
|
|||||||
quantize_dtype: str = "qint8",
|
quantize_dtype: str = "qint8",
|
||||||
nbest: int = 1,
|
nbest: int = 1,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
simu_streaming: bool = False,
|
fake_streaming: bool = False,
|
||||||
full_utt: bool = False,
|
full_utt: bool = False,
|
||||||
chunk_size: int = 16,
|
chunk_size: int = 16,
|
||||||
left_context: int = 32,
|
left_context: int = 32,
|
||||||
@ -1432,7 +1432,7 @@ class Speech2TextTransducer:
|
|||||||
|
|
||||||
self.beam_search = beam_search
|
self.beam_search = beam_search
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.simu_streaming = simu_streaming
|
self.fake_streaming = fake_streaming
|
||||||
self.full_utt = full_utt
|
self.full_utt = full_utt
|
||||||
self.chunk_size = max(chunk_size, 0)
|
self.chunk_size = max(chunk_size, 0)
|
||||||
self.left_context = left_context
|
self.left_context = left_context
|
||||||
@ -1442,8 +1442,8 @@ class Speech2TextTransducer:
|
|||||||
self.streaming = False
|
self.streaming = False
|
||||||
self.asr_model.encoder.dynamic_chunk_training = False
|
self.asr_model.encoder.dynamic_chunk_training = False
|
||||||
|
|
||||||
if not simu_streaming or chunk_size == 0:
|
if not fake_streaming or chunk_size == 0:
|
||||||
self.simu_streaming = False
|
self.fake_streaming = False
|
||||||
self.asr_model.encoder.dynamic_chunk_training = False
|
self.asr_model.encoder.dynamic_chunk_training = False
|
||||||
|
|
||||||
self.frontend = frontend
|
self.frontend = frontend
|
||||||
@ -1520,7 +1520,7 @@ class Speech2TextTransducer:
|
|||||||
return nbest_hyps
|
return nbest_hyps
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
|
def fake_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]:
|
||||||
"""Speech2Text call.
|
"""Speech2Text call.
|
||||||
Args:
|
Args:
|
||||||
speech: Speech data. (S)
|
speech: Speech data. (S)
|
||||||
|
|||||||
@ -427,7 +427,7 @@ def inference_paraformer(
|
|||||||
else:
|
else:
|
||||||
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
|
text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
|
||||||
item = {'key': key, 'value': text_postprocessed}
|
item = {'key': key, 'value': text_postprocessed}
|
||||||
if timestamp_postprocessed != "" or len(timestamp) == 0:
|
if timestamp_postprocessed != "":
|
||||||
item['timestamp'] = timestamp_postprocessed
|
item['timestamp'] = timestamp_postprocessed
|
||||||
asr_result_list.append(item)
|
asr_result_list.append(item)
|
||||||
finish_count += 1
|
finish_count += 1
|
||||||
@ -719,7 +719,7 @@ def inference_paraformer_vad_punc(
|
|||||||
item = {'key': key, 'value': text_postprocessed_punc}
|
item = {'key': key, 'value': text_postprocessed_punc}
|
||||||
if text_postprocessed != "":
|
if text_postprocessed != "":
|
||||||
item['text_postprocessed'] = text_postprocessed
|
item['text_postprocessed'] = text_postprocessed
|
||||||
if time_stamp_postprocessed != "" or len(time_stamp) == 0:
|
if time_stamp_postprocessed != "":
|
||||||
item['time_stamp'] = time_stamp_postprocessed
|
item['time_stamp'] = time_stamp_postprocessed
|
||||||
|
|
||||||
item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
|
item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
|
||||||
@ -1297,7 +1297,7 @@ def inference_transducer(
|
|||||||
quantize_modules: Optional[List[str]] = None,
|
quantize_modules: Optional[List[str]] = None,
|
||||||
quantize_dtype: Optional[str] = "float16",
|
quantize_dtype: Optional[str] = "float16",
|
||||||
streaming: Optional[bool] = False,
|
streaming: Optional[bool] = False,
|
||||||
simu_streaming: Optional[bool] = False,
|
fake_streaming: Optional[bool] = False,
|
||||||
full_utt: Optional[bool] = False,
|
full_utt: Optional[bool] = False,
|
||||||
chunk_size: Optional[int] = 16,
|
chunk_size: Optional[int] = 16,
|
||||||
left_context: Optional[int] = 16,
|
left_context: Optional[int] = 16,
|
||||||
@ -1374,7 +1374,7 @@ def inference_transducer(
|
|||||||
quantize_modules=quantize_modules,
|
quantize_modules=quantize_modules,
|
||||||
quantize_dtype=quantize_dtype,
|
quantize_dtype=quantize_dtype,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
simu_streaming=simu_streaming,
|
fake_streaming=fake_streaming,
|
||||||
full_utt=full_utt,
|
full_utt=full_utt,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
left_context=left_context,
|
left_context=left_context,
|
||||||
@ -1432,8 +1432,8 @@ def inference_transducer(
|
|||||||
final_hyps = speech2text.streaming_decode(
|
final_hyps = speech2text.streaming_decode(
|
||||||
speech[_end: len(speech)], is_final=True
|
speech[_end: len(speech)], is_final=True
|
||||||
)
|
)
|
||||||
elif speech2text.simu_streaming:
|
elif speech2text.fake_streaming:
|
||||||
final_hyps = speech2text.simu_streaming_decode(**batch)
|
final_hyps = speech2text.fake_streaming_decode(**batch)
|
||||||
elif speech2text.full_utt:
|
elif speech2text.full_utt:
|
||||||
final_hyps = speech2text.full_utt_decode(**batch)
|
final_hyps = speech2text.full_utt_decode(**batch)
|
||||||
else:
|
else:
|
||||||
@ -1823,7 +1823,7 @@ def get_parser():
|
|||||||
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
||||||
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
||||||
group.add_argument("--streaming", type=str2bool, default=False)
|
group.add_argument("--streaming", type=str2bool, default=False)
|
||||||
group.add_argument("--simu_streaming", type=str2bool, default=False)
|
group.add_argument("--fake_streaming", type=str2bool, default=False)
|
||||||
group.add_argument("--full_utt", type=str2bool, default=False)
|
group.add_argument("--full_utt", type=str2bool, default=False)
|
||||||
group.add_argument("--chunk_size", type=int, default=16)
|
group.add_argument("--chunk_size", type=int, default=16)
|
||||||
group.add_argument("--left_context", type=int, default=16)
|
group.add_argument("--left_context", type=int, default=16)
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class CommonPreprocessor(AbsPreprocessor):
|
|||||||
self.seg_dict = None
|
self.seg_dict = None
|
||||||
if seg_dict_file is not None:
|
if seg_dict_file is not None:
|
||||||
self.seg_dict = {}
|
self.seg_dict = {}
|
||||||
with open(seg_dict_file) as f:
|
with open(seg_dict_file, "r", encoding="utf8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
s = line.strip().split()
|
s = line.strip().split()
|
||||||
|
|||||||
@ -21,9 +21,11 @@ python server.py \
|
|||||||
--host [host ip] \
|
--host [host ip] \
|
||||||
--port [server port] \
|
--port [server port] \
|
||||||
--asr_model [asr model_name] \
|
--asr_model [asr model_name] \
|
||||||
|
--vad_model [vad model_name] \
|
||||||
--punc_model [punc model_name] \
|
--punc_model [punc model_name] \
|
||||||
--ngpu [0 or 1] \
|
--ngpu [0 or 1] \
|
||||||
--ncpu [1 or 4] \
|
--ncpu [1 or 4] \
|
||||||
|
--hotword_path [path of hot word txt] \
|
||||||
--certfile [path of certfile for ssl] \
|
--certfile [path of certfile for ssl] \
|
||||||
--keyfile [path of keyfile for ssl] \
|
--keyfile [path of keyfile for ssl] \
|
||||||
--temp_dir [upload file temp dir]
|
--temp_dir [upload file temp dir]
|
||||||
@ -45,3 +47,22 @@ python server.py \
|
|||||||
--add_pun [add pun to result] \
|
--add_pun [add pun to result] \
|
||||||
--audio_path [use audio path]
|
--audio_path [use audio path]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 支持多进程
|
||||||
|
|
||||||
|
方法是启动多个`server.py`,然后通过Nginx的负载均衡分发请求,达到支持多用户同时连效果,处理方式如下,默认您已经安装了Nginx,没安装的请参考[官方安装教程](https://nginx.org/en/linux_packages.html#Ubuntu)。
|
||||||
|
|
||||||
|
配置Nginx。
|
||||||
|
```shell
|
||||||
|
sudo cp -f asr_nginx.conf /etc/nginx/nginx.conf
|
||||||
|
sudo service nginx reload
|
||||||
|
```
|
||||||
|
|
||||||
|
然后使用脚本启动多个服务,每个服务的端口号不一样。
|
||||||
|
```shell
|
||||||
|
sudo chmod +x start_server.sh
|
||||||
|
./start_server.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
**说明:** 默认是3个进程,如果需要修改,首先修改`start_server.sh`的最后那部分,可以添加启动数量。然后修改`asr_nginx.conf`配置文件的`upstream backend`部分,增加新启动的服务,可以使其他服务器的服务。
|
||||||
|
|||||||
44
funasr/runtime/python/http/asr_nginx.conf
Normal file
44
funasr/runtime/python/http/asr_nginx.conf
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
user nginx;
|
||||||
|
worker_processes auto;
|
||||||
|
|
||||||
|
error_log /var/log/nginx/error.log notice;
|
||||||
|
pid /var/run/nginx.pid;
|
||||||
|
|
||||||
|
events {
|
||||||
|
worker_connections 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
http {
|
||||||
|
include /etc/nginx/mime.types;
|
||||||
|
default_type application/octet-stream;
|
||||||
|
|
||||||
|
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
|
||||||
|
'$status $body_bytes_sent "$http_referer" '
|
||||||
|
'"$http_user_agent" "$http_x_forwarded_for"';
|
||||||
|
|
||||||
|
access_log /var/log/nginx/access.log main;
|
||||||
|
|
||||||
|
sendfile on;
|
||||||
|
keepalive_timeout 65;
|
||||||
|
|
||||||
|
upstream backend {
|
||||||
|
# 最少连接算法
|
||||||
|
least_conn;
|
||||||
|
# 启动的服务地址
|
||||||
|
server localhost:8001;
|
||||||
|
server localhost:8002;
|
||||||
|
server localhost:8003;
|
||||||
|
}
|
||||||
|
|
||||||
|
server {
|
||||||
|
# 实际访问的端口
|
||||||
|
listen 8000;
|
||||||
|
|
||||||
|
location / {
|
||||||
|
proxy_pass http://backend;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
include /etc/nginx/conf.d/*.conf;
|
||||||
|
}
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
@ -23,12 +25,16 @@ parser.add_argument("--audio_path",
|
|||||||
required=False,
|
required=False,
|
||||||
help="use audio path")
|
help="use audio path")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
print("----------- Configuration Arguments -----------")
|
||||||
|
for arg, value in vars(args).items():
|
||||||
|
print("%s: %s" % (arg, value))
|
||||||
|
print("------------------------------------------------")
|
||||||
|
|
||||||
|
|
||||||
url = f'http://{args.host}:{args.port}/recognition'
|
url = f'http://{args.host}:{args.port}/recognition'
|
||||||
data = {'add_pun': args.add_pun}
|
data = {'add_pun': args.add_pun}
|
||||||
headers = {}
|
headers = {}
|
||||||
files = [('audio', ('file', open(args.audio_path, 'rb'), 'application/octet-stream'))]
|
files = [('audio', (os.path.basename(args.audio_path), open(args.audio_path, 'rb'), 'application/octet-stream'))]
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, data=data, files=files)
|
response = requests.post(url, headers=headers, data=data, files=files)
|
||||||
print(response.text)
|
print(response.text)
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import uuid
|
||||||
import time
|
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import ffmpeg
|
import ffmpeg
|
||||||
@ -29,11 +28,15 @@ parser.add_argument("--port",
|
|||||||
parser.add_argument("--asr_model",
|
parser.add_argument("--asr_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||||
help="model from modelscope")
|
help="offline asr model from modelscope")
|
||||||
|
parser.add_argument("--vad_model",
|
||||||
|
type=str,
|
||||||
|
default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
|
help="vad model from modelscope")
|
||||||
parser.add_argument("--punc_model",
|
parser.add_argument("--punc_model",
|
||||||
type=str,
|
type=str,
|
||||||
default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
|
default="damo/punc_ct-transformer_cn-en-common-vocab471067-large",
|
||||||
help="model from modelscope")
|
help="punc model from modelscope")
|
||||||
parser.add_argument("--ngpu",
|
parser.add_argument("--ngpu",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=1,
|
||||||
@ -42,6 +45,10 @@ parser.add_argument("--ncpu",
|
|||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="cpu cores")
|
help="cpu cores")
|
||||||
|
parser.add_argument("--hotword_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="hot word txt path, only the hot word model works")
|
||||||
parser.add_argument("--certfile",
|
parser.add_argument("--certfile",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
@ -58,22 +65,30 @@ parser.add_argument("--temp_dir",
|
|||||||
required=False,
|
required=False,
|
||||||
help="temp dir")
|
help="temp dir")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
print("----------- Configuration Arguments -----------")
|
||||||
|
for arg, value in vars(args).items():
|
||||||
|
print("%s: %s" % (arg, value))
|
||||||
|
print("------------------------------------------------")
|
||||||
|
|
||||||
|
|
||||||
os.makedirs(args.temp_dir, exist_ok=True)
|
os.makedirs(args.temp_dir, exist_ok=True)
|
||||||
|
|
||||||
print("model loading")
|
print("model loading")
|
||||||
|
param_dict = {}
|
||||||
|
if args.hotword_path is not None and os.path.exists(args.hotword_path):
|
||||||
|
param_dict['hotword'] = args.hotword_path
|
||||||
# asr
|
# asr
|
||||||
inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
|
inference_pipeline_asr = pipeline(task=Tasks.auto_speech_recognition,
|
||||||
model=args.asr_model,
|
model=args.asr_model,
|
||||||
|
vad_model=args.vad_model,
|
||||||
ngpu=args.ngpu,
|
ngpu=args.ngpu,
|
||||||
ncpu=args.ncpu,
|
ncpu=args.ncpu,
|
||||||
model_revision=None)
|
param_dict=param_dict)
|
||||||
print(f'loaded asr models.')
|
print(f'loaded asr models.')
|
||||||
|
|
||||||
if args.punc_model != "":
|
if args.punc_model != "":
|
||||||
inference_pipeline_punc = pipeline(task=Tasks.punctuation,
|
inference_pipeline_punc = pipeline(task=Tasks.punctuation,
|
||||||
model=args.punc_model,
|
model=args.punc_model,
|
||||||
model_revision="v1.0.2",
|
|
||||||
ngpu=args.ngpu,
|
ngpu=args.ngpu,
|
||||||
ncpu=args.ncpu)
|
ncpu=args.ncpu)
|
||||||
print(f'loaded pun models.')
|
print(f'loaded pun models.')
|
||||||
@ -87,7 +102,7 @@ app = FastAPI(title="FunASR")
|
|||||||
async def api_recognition(audio: UploadFile = File(..., description="audio file"),
|
async def api_recognition(audio: UploadFile = File(..., description="audio file"),
|
||||||
add_pun: int = Body(1, description="add punctuation", embed=True)):
|
add_pun: int = Body(1, description="add punctuation", embed=True)):
|
||||||
suffix = audio.filename.split('.')[-1]
|
suffix = audio.filename.split('.')[-1]
|
||||||
audio_path = f'{args.temp_dir}/{int(time.time() * 1000)}_{random.randint(100, 999)}.{suffix}'
|
audio_path = f'{args.temp_dir}/{str(uuid.uuid1())}.{suffix}'
|
||||||
async with aiofiles.open(audio_path, 'wb') as out_file:
|
async with aiofiles.open(audio_path, 'wb') as out_file:
|
||||||
content = await audio.read()
|
content = await audio.read()
|
||||||
await out_file.write(content)
|
await out_file.write(content)
|
||||||
@ -100,6 +115,7 @@ async def api_recognition(audio: UploadFile = File(..., description="audio file"
|
|||||||
if add_pun:
|
if add_pun:
|
||||||
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
|
rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict={'cache': list()})
|
||||||
ret = {"results": rec_result['text'], "code": 0}
|
ret = {"results": rec_result['text'], "code": 0}
|
||||||
|
print(ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
21
funasr/runtime/python/http/start_server.sh
Normal file
21
funasr/runtime/python/http/start_server.sh
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 创建日志文件夹
|
||||||
|
if [ ! -d "log/" ];then
|
||||||
|
mkdir log
|
||||||
|
fi
|
||||||
|
|
||||||
|
# kill掉之前的进程
|
||||||
|
server_id=`ps -ef | grep server.py | grep -v "grep" | awk '{print $2}'`
|
||||||
|
echo $server_id
|
||||||
|
|
||||||
|
for id in $server_id
|
||||||
|
do
|
||||||
|
kill -9 $id
|
||||||
|
echo "killed $id"
|
||||||
|
done
|
||||||
|
|
||||||
|
# 启动多个服务,可以设置使用不同的显卡
|
||||||
|
CUDA_VISIBLE_DEVICES=0 nohup python -u server.py --host=localhost --port=8001 >> log/output1.log 2>&1 &
|
||||||
|
CUDA_VISIBLE_DEVICES=0 nohup python -u server.py --host=localhost --port=8002 >> log/output2.log 2>&1 &
|
||||||
|
CUDA_VISIBLE_DEVICES=0 nohup python -u server.py --host=localhost --port=8003 >> log/output3.log 2>&1 &
|
||||||
@ -107,6 +107,20 @@ Loadding from wav.scp(kaldi style)
|
|||||||
# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
|
# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms
|
||||||
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
|
python funasr_wss_client.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Websocket api
|
||||||
|
```shell
|
||||||
|
# class Funasr_websocket_recognizer example with 3 step
|
||||||
|
# 1.create an recognizer
|
||||||
|
rcg=Funasr_websocket_recognizer(host="127.0.0.1",port="30035",is_ssl=True,mode="2pass")
|
||||||
|
# 2.send pcm data to asr engine and get asr result
|
||||||
|
text=rcg.feed_chunk(data)
|
||||||
|
print("text",text)
|
||||||
|
# 3.get last result, set timeout=3
|
||||||
|
text=rcg.close(timeout=3)
|
||||||
|
print("text",text)
|
||||||
|
```
|
||||||
|
|
||||||
## Acknowledge
|
## Acknowledge
|
||||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||||
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service.
|
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service.
|
||||||
|
|||||||
134
funasr/runtime/python/websocket/funasr_client_api.py
Normal file
134
funasr/runtime/python/websocket/funasr_client_api.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
'''
|
||||||
|
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||||
|
Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
2022-2023 by zhaomingwork@qq.com
|
||||||
|
'''
|
||||||
|
# pip install websocket-client
|
||||||
|
import ssl
|
||||||
|
from websocket import ABNF
|
||||||
|
from websocket import create_connection
|
||||||
|
from queue import Queue
|
||||||
|
import threading
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
# class for recognizer in websocket
|
||||||
|
class Funasr_websocket_recognizer():
|
||||||
|
'''
|
||||||
|
python asr recognizer lib
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self, host="127.0.0.1", port="30035", is_ssl=True,chunk_size="5, 10, 5",chunk_interval=10,mode="offline",wav_name="default"):
|
||||||
|
'''
|
||||||
|
host: server host ip
|
||||||
|
port: server port
|
||||||
|
is_ssl: True for wss protocal, False for ws
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if is_ssl == True:
|
||||||
|
ssl_context = ssl.SSLContext()
|
||||||
|
ssl_context.check_hostname = False
|
||||||
|
ssl_context.verify_mode = ssl.CERT_NONE
|
||||||
|
uri = "wss://{}:{}".format(host, port)
|
||||||
|
ssl_opt={"cert_reqs": ssl.CERT_NONE}
|
||||||
|
else:
|
||||||
|
uri = "ws://{}:{}".format(host, port)
|
||||||
|
ssl_context = None
|
||||||
|
ssl_opt=None
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
|
||||||
|
self.msg_queue = Queue() # used for recognized result text
|
||||||
|
|
||||||
|
print("connect to url",uri)
|
||||||
|
self.websocket=create_connection(uri,ssl=ssl_context,sslopt=ssl_opt)
|
||||||
|
|
||||||
|
self.thread_msg = threading.Thread(target=Funasr_websocket_recognizer.thread_rec_msg,args=(self,))
|
||||||
|
self.thread_msg.start()
|
||||||
|
chunk_size = [int(x) for x in chunk_size.split(",")]
|
||||||
|
stride = int(60 * chunk_size[1]/ chunk_interval / 1000 * 16000 * 2)
|
||||||
|
chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||||
|
|
||||||
|
message = json.dumps({"mode": mode, "chunk_size": chunk_size, "chunk_interval": chunk_interval,
|
||||||
|
"wav_name": wav_name, "is_speaking": True})
|
||||||
|
|
||||||
|
self.websocket.send(message)
|
||||||
|
|
||||||
|
print("send json",message)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception:", e)
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# threads for rev msg
|
||||||
|
def thread_rec_msg(self):
|
||||||
|
try:
|
||||||
|
while(True):
|
||||||
|
msg=self.websocket.recv()
|
||||||
|
if msg is None or len(msg)==0:
|
||||||
|
continue
|
||||||
|
msg = json.loads(msg)
|
||||||
|
|
||||||
|
self.msg_queue.put(msg)
|
||||||
|
except Exception as e:
|
||||||
|
print("client closed")
|
||||||
|
|
||||||
|
# feed data to asr engine, wait_time means waiting for result until time out
|
||||||
|
def feed_chunk(self, chunk,wait_time=0.01):
|
||||||
|
try:
|
||||||
|
self.websocket.send(chunk, ABNF.OPCODE_BINARY)
|
||||||
|
# loop to check if there is a message, timeout in 0.01s
|
||||||
|
while(True):
|
||||||
|
msg = self.msg_queue.get(timeout=wait_time)
|
||||||
|
if self.msg_queue.empty():
|
||||||
|
break
|
||||||
|
|
||||||
|
return msg
|
||||||
|
except:
|
||||||
|
return ""
|
||||||
|
def close(self,timeout=1):
|
||||||
|
message = json.dumps({"is_speaking": False})
|
||||||
|
self.websocket.send(message)
|
||||||
|
# sleep for timeout seconds to wait for result
|
||||||
|
time.sleep(timeout)
|
||||||
|
msg=""
|
||||||
|
while(not self.msg_queue.empty()):
|
||||||
|
msg = self.msg_queue.get()
|
||||||
|
|
||||||
|
self.websocket.close()
|
||||||
|
# only resturn the last msg
|
||||||
|
return msg
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print('example for Funasr_websocket_recognizer')
|
||||||
|
import wave
|
||||||
|
wav_path="asr_example.wav"
|
||||||
|
with wave.open(wav_path, "rb") as wav_file:
|
||||||
|
params = wav_file.getparams()
|
||||||
|
frames = wav_file.readframes(wav_file.getnframes())
|
||||||
|
audio_bytes = bytes(frames)
|
||||||
|
|
||||||
|
|
||||||
|
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
|
||||||
|
chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||||
|
# create an recognizer
|
||||||
|
rcg=Funasr_websocket_recognizer(host="127.0.0.1",port="30035",is_ssl=True,mode="2pass")
|
||||||
|
# loop to send chunk
|
||||||
|
for i in range(chunk_num):
|
||||||
|
|
||||||
|
beg = i * stride
|
||||||
|
data = audio_bytes[beg:beg + stride]
|
||||||
|
|
||||||
|
text=rcg.feed_chunk(data,wait_time=0.02)
|
||||||
|
if len(text)>0:
|
||||||
|
print("text",text)
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
# get last message
|
||||||
|
text=rcg.close(timeout=3)
|
||||||
|
print("text",text)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user