mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed
merge
This commit is contained in:
commit
783a051f65
52
README.md
52
README.md
@ -2,8 +2,9 @@
|
||||
|
||||
([简体中文](./README_zh.md)|English)
|
||||
|
||||
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
|
||||
[//]: # (# FunASR: A Fundamental End-to-End Speech Recognition Toolkit)
|
||||
|
||||
[](https://github.com/Akshay090/svg-banners)
|
||||
|
||||
[](https://pypi.org/project/funasr/)
|
||||
|
||||
@ -34,6 +35,9 @@
|
||||
- 2024/03/05:Added support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
|
||||
- 2024/03/05: Offline File Transcription Service 4.4, Offline File Transcription Service of English 1.5,Real-time Transcription Service 1.9 released,docker image supports ARM64 platform, update modelscope;([docs](runtime/readme.md))
|
||||
- 2024/01/30:funasr-1.0 has been released ([docs](https://github.com/alibaba-damo-academy/FunASR/discussions/1319))
|
||||
|
||||
<details><summary>Full Changelog</summary>
|
||||
|
||||
- 2024/01/30:emotion recognition models are new supported. [model link](https://www.modelscope.cn/models/iic/emotion2vec_base_finetuned/summary), modified from [repo](https://github.com/ddlBoJack/emotion2vec).
|
||||
- 2024/01/25: Offline File Transcription Service 4.2, Offline File Transcription Service of English 1.3 released,optimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage, memory leak optimization; Real-time Transcription Service 1.7 released,optimizatized the client-side;([docs](runtime/readme.md))
|
||||
- 2024/01/09: The Funasr SDK for Windows version 2.0 has been released, featuring support for The offline file transcription service (CPU) of Mandarin 4.1, The offline file transcription service (CPU) of English 1.2, The real-time transcription service (CPU) of Mandarin 1.6. For more details, please refer to the official documentation or release notes([FunASR-Runtime-Windows](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
|
||||
@ -51,22 +55,31 @@
|
||||
- 2023/07/17: BAT is released, which is a low-latency and low-memory-consumption RNN-T model. For more details, please refer to ([BAT](egs/aishell/bat)).
|
||||
- 2023/06/26: ASRU2023 Multi-Channel Multi-Party Meeting Transcription Challenge 2.0 completed the competition and announced the results. For more details, please refer to ([M2MeT2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)).
|
||||
|
||||
</details>
|
||||
|
||||
<a name="Installation"></a>
|
||||
## Installation
|
||||
|
||||
- Requirements
|
||||
```text
|
||||
python>=3.8
|
||||
torch>=1.13
|
||||
torchaudio
|
||||
```
|
||||
|
||||
- Install for pypi
|
||||
```shell
|
||||
pip3 install -U funasr
|
||||
```
|
||||
Or install from source code
|
||||
- Or install from source code
|
||||
``` sh
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
pip3 install -e ./
|
||||
```
|
||||
Install modelscope for the pretrained models (Optional)
|
||||
- Install modelscope or huggingface_hub for the pretrained models (Optional)
|
||||
|
||||
```shell
|
||||
pip3 install -U modelscope
|
||||
pip3 install -U modelscope huggingface_hub
|
||||
```
|
||||
|
||||
## Model Zoo
|
||||
@ -77,19 +90,19 @@ FunASR has open-sourced a large number of pre-trained models on industrial data.
|
||||
|
||||
| Model Name | Task Details | Training Data | Parameters |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------:|:--------------------------------:|:----------:|
|
||||
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M |
|
||||
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-zh) ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M |
|
||||
| <nobr>paraformer-zh-streaming <br> ( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) )</nobr> | speech recognition, streaming | 60000 hours, Mandarin | 220M |
|
||||
| paraformer-en <br> ( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | speech recognition, without timestamps, non-streaming | 50000 hours, English | 220M |
|
||||
| conformer-en <br> ( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | speech recognition, non-streaming | 50000 hours, English | 220M |
|
||||
| ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | punctuation restoration | 100M, Mandarin and English | 1.1G |
|
||||
| ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | punctuation restoration | 100M, Mandarin and English | 290M |
|
||||
| fsmn-vad <br> ( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M |
|
||||
| fa-zh <br> ( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M |
|
||||
| cam++ <br> ( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M |
|
||||
| Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
|
||||
| Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
|
||||
| Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B |
|
||||
| Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B |
|
||||
| emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M |
|
||||
| Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
|
||||
| Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
|
||||
| Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B |
|
||||
| Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B |
|
||||
| emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M |
|
||||
|
||||
|
||||
|
||||
@ -153,6 +166,8 @@ for i in range(total_chunk_num):
|
||||
```
|
||||
Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
|
||||
|
||||
<details><summary>More Examples</summary>
|
||||
|
||||
### Voice Activity Detection (Non-Streaming)
|
||||
```python
|
||||
from funasr import AutoModel
|
||||
@ -211,9 +226,24 @@ text_file = f"{model.model_path}/example/text.txt"
|
||||
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
|
||||
print(res)
|
||||
```
|
||||
|
||||
|
||||
### Speech Emotion Recognition
|
||||
```python
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="emotion2vec_plus_large")
|
||||
|
||||
wav_file = f"{model.model_path}/example/test.wav"
|
||||
|
||||
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
|
||||
print(res)
|
||||
```
|
||||
|
||||
More usages ref to [docs](docs/tutorial/README_zh.md),
|
||||
more examples ref to [demo](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
|
||||
|
||||
</details>
|
||||
|
||||
## Export ONNX
|
||||
|
||||
|
||||
48
README_zh.md
48
README_zh.md
@ -2,7 +2,11 @@
|
||||
|
||||
(简体中文|[English](./README.md))
|
||||
|
||||
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
|
||||
|
||||
|
||||
[](https://github.com/Akshay090/svg-banners)
|
||||
|
||||
[//]: # (# FunASR: A Fundamental End-to-End Speech Recognition Toolkit)
|
||||
|
||||
[](https://pypi.org/project/funasr/)
|
||||
|
||||
@ -35,6 +39,9 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
|
||||
- 2024/03/05:新增加Whisper-large-v3模型支持,多语言语音识别/翻译/语种识别,支持从 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)仓库下载,也支持从 [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)仓库下载模型。
|
||||
- 2024/03/05: 中文离线文件转写服务 4.4、英文离线文件转写服务 1.5、中文实时语音听写服务 1.9 发布,docker镜像支持arm64平台,升级modelscope版本;详细信息参阅([部署文档](runtime/readme_cn.md))
|
||||
- 2024/01/30:funasr-1.0发布,更新说明[文档](https://github.com/alibaba-damo-academy/FunASR/discussions/1319)
|
||||
|
||||
<details><summary>展开日志</summary>
|
||||
|
||||
- 2024/01/30:新增加情感识别 [模型链接](https://www.modelscope.cn/models/iic/emotion2vec_base_finetuned/summary),原始模型 [repo](https://github.com/ddlBoJack/emotion2vec).
|
||||
- 2024/01/25: 中文离线文件转写服务 4.2、英文离线文件转写服务 1.3,优化vad数据处理方式,大幅降低峰值内存占用,内存泄漏优化;中文实时语音听写服务 1.7 发布,客户端优化;详细信息参阅([部署文档](runtime/readme_cn.md))
|
||||
- 2024/01/09: funasr社区软件包windows 2.0版本发布,支持软件包中文离线文件转写4.1、英文离线文件转写1.2、中文实时听写服务1.6的最新功能,详细信息参阅([FunASR社区软件包windows版本](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
|
||||
@ -52,21 +59,33 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
|
||||
- 2023.07.17: BAT一种低延迟低内存消耗的RNN-T模型发布,详细信息参阅([BAT](egs/aishell/bat))
|
||||
- 2023.06.26: ASRU2023 多通道多方会议转录挑战赛2.0完成竞赛结果公布,详细信息参阅([M2MeT2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2_cn/index.html))
|
||||
|
||||
</details>
|
||||
|
||||
<a name="安装教程"></a>
|
||||
## 安装教程
|
||||
|
||||
- 安装funasr之前,确保已经安装了下面依赖环境:
|
||||
```text
|
||||
python>=3.8
|
||||
torch>=1.13
|
||||
torchaudio
|
||||
```
|
||||
|
||||
- pip安装
|
||||
```shell
|
||||
pip3 install -U funasr
|
||||
```
|
||||
或者从源代码安装
|
||||
|
||||
- 或者从源代码安装
|
||||
``` sh
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
pip3 install -e ./
|
||||
```
|
||||
如果需要使用工业预训练模型,安装modelscope(可选)
|
||||
|
||||
如果需要使用工业预训练模型,安装modelscope与huggingface_hub(可选)
|
||||
|
||||
```shell
|
||||
pip3 install -U modelscope
|
||||
pip3 install -U modelscope huggingface huggingface_hub
|
||||
```
|
||||
|
||||
## 模型仓库
|
||||
@ -78,11 +97,11 @@ FunASR开源了大量在工业数据上预训练模型,您可以在[模型许
|
||||
|
||||
| 模型名字 | 任务详情 | 训练数据 | 参数量 |
|
||||
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:--------------:|:------:|
|
||||
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | 语音识别,带时间戳输出,非实时 | 60000小时,中文 | 220M |
|
||||
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-zh) ) | 语音识别,带时间戳输出,非实时 | 60000小时,中文 | 220M |
|
||||
| paraformer-zh-streaming <br> ( [⭐](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 语音识别,实时 | 60000小时,中文 | 220M |
|
||||
| paraformer-en <br> ( [⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | 语音识别,非实时 | 50000小时,英文 | 220M |
|
||||
| conformer-en <br> ( [⭐](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | 语音识别,非实时 | 50000小时,英文 | 220M |
|
||||
| ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | 标点恢复 | 100M,中文与英文 | 1.1B |
|
||||
| ct-punc <br> ( [⭐](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | 标点恢复 | 100M,中文与英文 | 290M |
|
||||
| fsmn-vad <br> ( [⭐](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | 语音端点检测,实时 | 5000小时,中文与英文 | 0.4M |
|
||||
| fa-zh <br> ( [⭐](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | 字级别时间戳预测 | 50000小时,中文 | 38M |
|
||||
| cam++ <br> ( [⭐](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | 说话人确认/分割 | 5000小时 | 7.2M |
|
||||
@ -148,6 +167,8 @@ for i in range(total_chunk_num):
|
||||
|
||||
注:`chunk_size`为流式延时配置,`[0,10,5]`表示上屏实时出字粒度为`10*60=600ms`,未来信息为`5*60=300ms`。每次推理输入为`600ms`(采样点数为`16000*0.6=960`),输出为对应文字,最后一个语音片段输入需要设置`is_final=True`来强制输出最后一个字。
|
||||
|
||||
<details><summary>更多例子</summary>
|
||||
|
||||
### 语音端点检测(非实时)
|
||||
```python
|
||||
from funasr import AutoModel
|
||||
@ -211,9 +232,24 @@ text_file = f"{model.model_path}/example/text.txt"
|
||||
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
|
||||
print(res)
|
||||
```
|
||||
|
||||
### 情感识别
|
||||
```python
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="emotion2vec_plus_large")
|
||||
|
||||
wav_file = f"{model.model_path}/example/test.wav"
|
||||
|
||||
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
|
||||
print(res)
|
||||
```
|
||||
|
||||
更详细([教程文档](docs/tutorial/README_zh.md)),
|
||||
更多([模型示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining))
|
||||
|
||||
</details>
|
||||
|
||||
## 导出ONNX
|
||||
### 从命令行导出
|
||||
```shell
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 181 KiB After Width: | Height: | Size: 183 KiB |
@ -52,7 +52,7 @@ class AutoFrontend:
|
||||
|
||||
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
|
||||
batch_size = kwargs.get("batch_size", 1)
|
||||
device = kwargs.get("device", "cpu")
|
||||
device = kwargs.get("device", "cuda")
|
||||
if device == "cpu":
|
||||
batch_size = 1
|
||||
|
||||
@ -60,7 +60,7 @@ class AutoFrontend:
|
||||
|
||||
result_list = []
|
||||
num_samples = len(data_list)
|
||||
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
|
||||
# pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
|
||||
|
||||
time0 = time.perf_counter()
|
||||
for beg_idx in range(0, num_samples, batch_size):
|
||||
@ -87,15 +87,23 @@ class AutoFrontend:
|
||||
speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
|
||||
)
|
||||
|
||||
speech.to(device=device), speech_lengths.to(device=device)
|
||||
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
|
||||
if kwargs.get("return_pt", True):
|
||||
speech, speech_lengths = speech.to(device=device), speech_lengths.to(device=device)
|
||||
else:
|
||||
speech, speech_lengths = speech.numpy(), speech_lengths.numpy()
|
||||
batch = {
|
||||
"input": speech,
|
||||
"input_len": speech_lengths,
|
||||
"key": key_batch,
|
||||
"data_type": "fbank",
|
||||
}
|
||||
result_list.append(batch)
|
||||
|
||||
pbar.update(1)
|
||||
description = f"{meta_data}, "
|
||||
pbar.set_description(description)
|
||||
# pbar.update(1)
|
||||
# description = f"{meta_data}, "
|
||||
# pbar.set_description(description)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
|
||||
# pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
|
||||
|
||||
return result_list
|
||||
|
||||
@ -42,8 +42,9 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
|
||||
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
|
||||
|
||||
chars = string.ascii_letters + string.digits
|
||||
if isinstance(data_in, str) and data_in.startswith("http"): # url
|
||||
data_in = download_from_url(data_in)
|
||||
if isinstance(data_in, str):
|
||||
if data_in.startswith("http://") or data_in.startswith("https://"): # url
|
||||
data_in = download_from_url(data_in)
|
||||
|
||||
if isinstance(data_in, str) and os.path.exists(
|
||||
data_in
|
||||
@ -284,7 +285,7 @@ class AutoModel:
|
||||
with torch.no_grad():
|
||||
res = model.inference(**batch, **kwargs)
|
||||
if isinstance(res, (list, tuple)):
|
||||
results = res[0]
|
||||
results = res[0] if len(res) > 0 else [{"text": ""}]
|
||||
meta_data = res[1] if len(res) > 1 else {}
|
||||
time2 = time.perf_counter()
|
||||
|
||||
@ -358,6 +359,7 @@ class AutoModel:
|
||||
results_sorted = []
|
||||
|
||||
if not len(sorted_data):
|
||||
results_ret_list.append({"key": key, "text": "", "timestamp": []})
|
||||
logging.info("decoding, utt: {}, empty speech".format(key))
|
||||
continue
|
||||
|
||||
|
||||
@ -147,7 +147,9 @@ class EspnetStyleBatchSampler(DistributedSampler):
|
||||
start_idx = self.rank * batches_per_rank
|
||||
end_idx = start_idx + batches_per_rank
|
||||
rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
|
||||
|
||||
self.batch_num = len(rank_batches)
|
||||
|
||||
logging.info(
|
||||
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
|
||||
)
|
||||
|
||||
@ -12,10 +12,30 @@ name_maps_ms = {
|
||||
"Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
|
||||
"Whisper-large-v3": "iic/Whisper-large-v3",
|
||||
"Qwen-Audio": "Qwen/Qwen-Audio",
|
||||
"emotion2vec_plus_large": "iic/emotion2vec_plus_large",
|
||||
"emotion2vec_plus_base": "iic/emotion2vec_plus_base",
|
||||
"emotion2vec_plus_seed": "iic/emotion2vec_plus_seed",
|
||||
}
|
||||
|
||||
name_maps_hf = {
|
||||
"": "",
|
||||
"paraformer": "funasr/paraformer-zh",
|
||||
"paraformer-zh": "funasr/paraformer-zh",
|
||||
"paraformer-en": "funasr/paraformer-zh",
|
||||
"paraformer-zh-streaming": "funasr/paraformer-zh-streaming",
|
||||
"fsmn-vad": "funasr/fsmn-vad",
|
||||
"ct-punc": "funasr/ct-punc",
|
||||
"ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
"fa-zh": "funasr/fa-zh",
|
||||
"cam++": "funasr/campplus",
|
||||
"Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
|
||||
"Whisper-large-v3": "iic/Whisper-large-v3",
|
||||
"Qwen-Audio": "Qwen/Qwen-Audio",
|
||||
"emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
|
||||
"iic/emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
|
||||
"emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
|
||||
"iic/emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
|
||||
"emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
|
||||
"iic/emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
|
||||
}
|
||||
|
||||
name_maps_openai = {
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
@ -63,3 +65,64 @@ class EncoderProjectorQFormer(nn.Module):
|
||||
query_proj = self.norm(self.linear(query_output.last_hidden_state))
|
||||
|
||||
return query_proj
|
||||
|
||||
|
||||
@tables.register("adaptor_classes", "Transformer")
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self, downsample_rate=2, encoder_dim=1280, llm_dim=4096, ffn_dim: int = 2048, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.k = downsample_rate
|
||||
self.encoder_dim = encoder_dim
|
||||
self.llm_dim = llm_dim
|
||||
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
|
||||
from funasr.models.transformer.encoder import EncoderLayer
|
||||
from funasr.models.transformer.attention import MultiHeadedAttention
|
||||
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
EncoderLayer(
|
||||
llm_dim,
|
||||
MultiHeadedAttention(
|
||||
kwargs.get("attention_heads", 8),
|
||||
llm_dim,
|
||||
kwargs.get("attention_dropout_rate", 0.0),
|
||||
),
|
||||
PositionwiseFeedForward(
|
||||
llm_dim,
|
||||
llm_dim // 4,
|
||||
kwargs.get("dropout_rate", 0.0),
|
||||
),
|
||||
kwargs.get("dropout_rate", 0.0),
|
||||
)
|
||||
for i in range(kwargs.get("n_layer", 2))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, ilens=None):
|
||||
|
||||
batch_size, seq_len, dim = x.size()
|
||||
# num_frames_to_discard = seq_len % self.k
|
||||
chunk_num = (seq_len - 1) // self.k + 1
|
||||
pad_num = chunk_num * self.k - seq_len
|
||||
x = F.pad(x, (0, 0, 0, pad_num, 0, 0), value=0.0)
|
||||
# if num_frames_to_discard > 0:
|
||||
# x = x[:, :-num_frames_to_discard, :]
|
||||
seq_len = x.size(1)
|
||||
|
||||
x = x.contiguous()
|
||||
x = x.view(batch_size, chunk_num, dim * self.k)
|
||||
x = self.linear1(x)
|
||||
x = self.relu(x)
|
||||
x = self.linear2(x)
|
||||
|
||||
olens = None
|
||||
olens = (ilens - 1) // self.k + 1
|
||||
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
|
||||
for layer, block in enumerate(self.blocks):
|
||||
x, masks = block(x, masks)
|
||||
return x, olens
|
||||
|
||||
@ -163,7 +163,11 @@ def export_backbone_forward(
|
||||
dha_ids = dha_pred.max(-1)[-1]
|
||||
dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
|
||||
decoder_out = decoder_out * dha_mask + dha_pred * (1 - dha_mask)
|
||||
return decoder_out, pre_token_length, alphas
|
||||
|
||||
# get predicted timestamps
|
||||
us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
|
||||
|
||||
return decoder_out, pre_token_length, us_alphas, us_cif_peak
|
||||
|
||||
|
||||
def export_backbone_dummy_inputs(self):
|
||||
@ -178,7 +182,7 @@ def export_backbone_input_names(self):
|
||||
|
||||
|
||||
def export_backbone_output_names(self):
|
||||
return ["logits", "token_num", "alphas"]
|
||||
return ["logits", "token_num", "us_alphas", "us_cif_peak"]
|
||||
|
||||
|
||||
def export_backbone_dynamic_axes(self):
|
||||
@ -190,6 +194,8 @@ def export_backbone_dynamic_axes(self):
|
||||
"bias_embed": {0: "batch_size", 1: "num_hotwords"},
|
||||
"logits": {0: "batch_size", 1: "logits_length"},
|
||||
"pre_acoustic_embeds": {1: "feats_length1"},
|
||||
"us_alphas": {0: "batch_size", 1: "alphas_length"},
|
||||
"us_cif_peak": {0: "batch_size", 1: "alphas_length"},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -360,6 +360,7 @@ class SenseVoiceDecoder(nn.Module):
|
||||
"""Score."""
|
||||
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
|
||||
logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
|
||||
logp = torch.log_softmax(logp, dim=-1)
|
||||
return logp.squeeze(0)[-1, :], state
|
||||
|
||||
|
||||
|
||||
@ -1264,15 +1264,29 @@ class SenseVoiceSANM(nn.Module):
|
||||
if isinstance(task, str):
|
||||
task = [task]
|
||||
task = "".join([f"<|{x}|>" for x in task])
|
||||
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
||||
|
||||
sos = kwargs.get("model_conf").get("sos")
|
||||
if isinstance(sos, str):
|
||||
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
|
||||
|
||||
language = DecodingOptions.get("language", None)
|
||||
language = None if language == "auto" else language
|
||||
language = DecodingOptions.get("language", None)
|
||||
language = None if language == "auto" else language
|
||||
|
||||
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||
sos_int = tokenizer.encode(sos, allowed_special="all")
|
||||
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||
sos_int = tokenizer.encode(sos, allowed_special="all")
|
||||
else:
|
||||
language = DecodingOptions.get("language", None)
|
||||
language = None if language == "auto" else language
|
||||
initial_prompt = kwargs.get("initial_prompt", f"{task}")
|
||||
initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
|
||||
initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
|
||||
sos_int = [sos] + initial_prompt_lid_int
|
||||
eos = kwargs.get("model_conf").get("eos")
|
||||
eos_int = tokenizer.encode(eos, allowed_special="all")
|
||||
if isinstance(eos, str):
|
||||
eos_int = tokenizer.encode(eos, allowed_special="all")
|
||||
else:
|
||||
eos_int = [eos]
|
||||
|
||||
self.beam_search.sos = sos_int
|
||||
self.beam_search.eos = eos_int[0]
|
||||
|
||||
@ -1298,7 +1312,7 @@ class SenseVoiceSANM(nn.Module):
|
||||
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
|
||||
|
||||
encoder_out, encoder_out_lens = self.encode(
|
||||
speech[None, :, :].permute(0, 2, 1), speech_lengths
|
||||
speech[None, :, :], speech_lengths
|
||||
)
|
||||
|
||||
if text_token_int is not None:
|
||||
|
||||
@ -27,9 +27,24 @@ class ModelDimensions:
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
# class LayerNorm(nn.LayerNorm):
|
||||
# def forward(self, x: Tensor) -> Tensor:
|
||||
# return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
|
||||
@ -64,7 +64,7 @@ class EncoderLayer(nn.Module):
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(size)
|
||||
|
||||
@ -621,7 +621,6 @@ class Trainer:
|
||||
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
|
||||
|
||||
def forward_step(self, model, batch, loss_dict={}):
|
||||
dtype = torch.bfloat16
|
||||
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
|
||||
retval = model(**batch)
|
||||
|
||||
|
||||
72
runtime/funasr_api/README.md
Normal file
72
runtime/funasr_api/README.md
Normal file
@ -0,0 +1,72 @@
|
||||
# python funasr_api
|
||||
|
||||
This is the api for python to use funasr engine, only support 2pass server.
|
||||
|
||||
## For install
|
||||
|
||||
### Install websocket-client and ffmpeg
|
||||
|
||||
```shell
|
||||
pip install websocket-client
|
||||
apt install ffmpeg -y
|
||||
```
|
||||
|
||||
|
||||
#### recognizer examples
|
||||
support many audio type as ffmpeg support, detail see FunASR/runtime/funasr_api/example.py
|
||||
```shell
|
||||
# create an recognizer
|
||||
rcg = FunasrApi(
|
||||
uri="wss://www.funasr.com:10096/"
|
||||
)
|
||||
# recognizer by filepath
|
||||
text=rcg.rec_file("asr_example.mp3")
|
||||
print("recognizer by filepath result=",text)
|
||||
|
||||
|
||||
# recognizer by buffer
|
||||
# rec_buf(audio_buf,ffmpeg_decode=False),set ffmpeg_decode=True if audio is not PCM or WAV type
|
||||
with open("asr_example.wav", "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
text=rcg.rec_buf(audio_bytes)
|
||||
print("recognizer by buffer result=",text)
|
||||
```
|
||||
|
||||
#### streaming recognizer examples,use FunasrApi.audio2wav to covert to WAV type if need
|
||||
|
||||
```shell
|
||||
rcg = FunasrApi(
|
||||
uri="wss://www.funasr.com:10096/"
|
||||
)
|
||||
#define call_back function for msg
|
||||
def on_msg(msg):
|
||||
print("stream msg=",msg)
|
||||
stream=rcg.create_stream(msg_callback=on_msg)
|
||||
|
||||
wav_path = "asr_example.wav"
|
||||
|
||||
with open(wav_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# use FunasrApi's audio2wav to covert other audio to PCM if needed
|
||||
#import os
|
||||
#from funasr_tools import FunasrTools
|
||||
#file_ext=os.path.splitext(wav_path)[-1].upper()
|
||||
#if not file_ext =="PCM" and not file_ext =="WAV":
|
||||
# audio_bytes=FunasrTools.audio2wav(audio_bytes)
|
||||
|
||||
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
|
||||
chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||
|
||||
for i in range(chunk_num):
|
||||
beg = i * stride
|
||||
data = audio_bytes[beg : beg + stride]
|
||||
stream.feed_chunk(data)
|
||||
final_result=stream.wait_for_end()
|
||||
print("asr_example.wav stream_result=",final_result)
|
||||
```
|
||||
|
||||
## Acknowledge
|
||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service.
|
||||
3. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service of offline model.
|
||||
BIN
runtime/funasr_api/asr_example.mp3
Normal file
BIN
runtime/funasr_api/asr_example.mp3
Normal file
Binary file not shown.
BIN
runtime/funasr_api/asr_example.wav
Normal file
BIN
runtime/funasr_api/asr_example.wav
Normal file
Binary file not shown.
70
runtime/funasr_api/example.py
Normal file
70
runtime/funasr_api/example.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""
|
||||
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
2023-2024 by zhaomingwork@qq.com
|
||||
"""
|
||||
|
||||
from funasr_api import FunasrApi
|
||||
import wave
|
||||
|
||||
def recognizer_example():
|
||||
# create an recognizer
|
||||
rcg = FunasrApi(
|
||||
uri="wss://www.funasr.com:10096/"
|
||||
)
|
||||
# recognizer by filepath
|
||||
text=rcg.rec_file("asr_example.mp3")
|
||||
print("recognizer by filepath result=",text)
|
||||
|
||||
|
||||
# recognizer by buffer
|
||||
# rec_buf(audio_buf,ffmpeg_decode=False),set ffmpeg_decode=True if audio is not PCM or WAV type
|
||||
with open("asr_example.wav", "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
text=rcg.rec_buf(audio_bytes)
|
||||
print("recognizer by buffer result=",text)
|
||||
|
||||
def recognizer_stream_example():
|
||||
|
||||
rcg = FunasrApi(
|
||||
uri="wss://www.funasr.com:10096/"
|
||||
)
|
||||
#define call_back function for msg
|
||||
def on_msg(msg):
|
||||
print("stream msg=",msg)
|
||||
stream=rcg.create_stream(msg_callback=on_msg)
|
||||
|
||||
wav_path = "asr_example.wav"
|
||||
|
||||
with open(wav_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# use FunasrApi's audio2wav to covert other audio to PCM if needed
|
||||
#import os
|
||||
#from funasr_tools import FunasrTools
|
||||
#file_ext=os.path.splitext(wav_path)[-1].upper()
|
||||
#if not file_ext =="PCM" and not file_ext =="WAV":
|
||||
# audio_bytes=FunasrTools.audio2wav(audio_bytes)
|
||||
|
||||
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
|
||||
chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||
|
||||
for i in range(chunk_num):
|
||||
beg = i * stride
|
||||
data = audio_bytes[beg : beg + stride]
|
||||
stream.feed_chunk(data)
|
||||
final_result=stream.wait_for_end()
|
||||
print("asr_example.wav stream_result=",final_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
print("example for Funasr_websocket_recognizer")
|
||||
|
||||
recognizer_stream_example()
|
||||
|
||||
recognizer_example()
|
||||
|
||||
|
||||
|
||||
96
runtime/funasr_api/funasr_api.py
Normal file
96
runtime/funasr_api/funasr_api.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
2023-2024 by zhaomingwork@qq.com
|
||||
"""
|
||||
|
||||
# pip install websocket-client
|
||||
# apt install ffmpeg
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
from funasr_stream import FunasrStream
|
||||
from funasr_tools import FunasrTools
|
||||
from funasr_core import FunasrCore
|
||||
# class for recognizer in websocket
|
||||
class FunasrApi:
|
||||
"""
|
||||
python asr recognizer lib
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri="wss://www.funasr.com:10096/",
|
||||
timeout=1000,
|
||||
msg_callback=None,
|
||||
|
||||
):
|
||||
"""
|
||||
uri: ws or wss server uri
|
||||
msg_callback: for message received
|
||||
timeout: timeout for get result
|
||||
"""
|
||||
try:
|
||||
|
||||
|
||||
self.uri=uri
|
||||
self.timeout=timeout
|
||||
self.msg_callback=msg_callback
|
||||
self.funasr_core=None
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
traceback.print_exc()
|
||||
def create_stream(self,msg_callback=None):
|
||||
if self.funasr_core is not None:
|
||||
self.funasr_core.close()
|
||||
funasr_core=self.new_core(msg_callback=msg_callback)
|
||||
return FunasrStream(funasr_core)
|
||||
|
||||
|
||||
|
||||
|
||||
def new_core(self,msg_callback=None):
|
||||
try:
|
||||
if self.funasr_core is not None:
|
||||
self.funasr_core.close()
|
||||
|
||||
if msg_callback==None:
|
||||
msg_callback=self.msg_callback
|
||||
funasr_core=FunasrCore(self.uri,msg_callback=msg_callback,timeout=self.timeout)
|
||||
funasr_core.new_connection()
|
||||
self.funasr_core=funasr_core
|
||||
return funasr_core
|
||||
|
||||
except Exception as e:
|
||||
print("init_core",e)
|
||||
exit(0)
|
||||
|
||||
# rec buffer, set ffmpeg_decode=True if audio is not PCM or WAV type
|
||||
def rec_buf(self,audio_buf,ffmpeg_decode=False):
|
||||
try:
|
||||
funasr_core=self.new_core()
|
||||
funasr_core.rec_buf(audio_buf,ffmpeg_decode=ffmpeg_decode)
|
||||
return funasr_core.get_result()
|
||||
except Exception as e:
|
||||
print("rec_file",e)
|
||||
return
|
||||
# rec file
|
||||
def rec_file(self,file_path):
|
||||
try:
|
||||
funasr_core=self.new_core()
|
||||
funasr_core.rec_file(file_path)
|
||||
return funasr_core.get_result()
|
||||
except Exception as e:
|
||||
print("rec_file",e)
|
||||
return
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
230
runtime/funasr_api/funasr_core.py
Normal file
230
runtime/funasr_api/funasr_core.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""
|
||||
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
2023-2024 by zhaomingwork@qq.com
|
||||
"""
|
||||
|
||||
# pip install websocket-client
|
||||
# apt install ffmpeg
|
||||
import ssl
|
||||
from websocket import ABNF
|
||||
from websocket import create_connection
|
||||
from queue import Queue
|
||||
import threading
|
||||
import traceback
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from funasr_tools import FunasrTools
|
||||
|
||||
# class for recognizer in websocket
|
||||
class FunasrCore:
|
||||
"""
|
||||
python asr recognizer lib
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri="wss://www.funasr.com:10096/",
|
||||
msg_callback=None,
|
||||
timeout=1000,
|
||||
|
||||
):
|
||||
"""
|
||||
uri: ws or wss server uri
|
||||
msg_callback: for message received
|
||||
timeout: timeout for get result
|
||||
"""
|
||||
try:
|
||||
if uri.find("wss://"):
|
||||
is_ssl=True
|
||||
elif uri.find("ws://"):
|
||||
is_ssl=False
|
||||
else:
|
||||
print("not support uri",uri)
|
||||
exit(0)
|
||||
|
||||
if is_ssl == True:
|
||||
ssl_context = ssl.SSLContext()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
uri = uri
|
||||
ssl_opt = {"cert_reqs": ssl.CERT_NONE}
|
||||
else:
|
||||
uri = uri
|
||||
ssl_context = None
|
||||
ssl_opt = None
|
||||
|
||||
self.ssl_opt=ssl_opt
|
||||
self.ssl_context=ssl_context
|
||||
self.uri = uri
|
||||
|
||||
|
||||
|
||||
print("connect to url", uri)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
self.msg_callback=msg_callback
|
||||
self.is_final=False
|
||||
self.rec_text=""
|
||||
self.timeout=timeout
|
||||
self.rec_file_len=0
|
||||
self.connect_state=0
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
traceback.print_exc()
|
||||
|
||||
def new_connection(self):
|
||||
try:
|
||||
self.websocket = create_connection(self.uri, ssl=self.ssl_context, sslopt=self.ssl_opt)
|
||||
|
||||
self.is_final=False
|
||||
self.rec_text=""
|
||||
self.rec_file_len=0
|
||||
self.connect_state=0
|
||||
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": "2pass",
|
||||
"chunk_size": [int(x) for x in "0,10,5".split(",")],
|
||||
"encoder_chunk_look_back": 4,
|
||||
"decoder_chunk_look_back": 1,
|
||||
"chunk_interval": 10,
|
||||
"wav_name": "funasr_api",
|
||||
"is_speaking": True,
|
||||
}
|
||||
)
|
||||
|
||||
self.websocket.send(message)
|
||||
self.connect_state=1
|
||||
# thread for receive message
|
||||
self.thread_msg = threading.Thread(
|
||||
target=FunasrCore.thread_rec_msg, args=(self,)
|
||||
)
|
||||
self.thread_msg.start()
|
||||
|
||||
print("new_connection: ",message)
|
||||
except Exception as e:
|
||||
print("new_connection",e)
|
||||
|
||||
|
||||
|
||||
# threads for rev msg
|
||||
def thread_rec_msg(self):
|
||||
try:
|
||||
while True:
|
||||
if self.connect_state==0:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
if self.connect_state==2:
|
||||
break
|
||||
msg = self.websocket.recv()
|
||||
|
||||
if msg is None or len(msg) == 0:
|
||||
continue
|
||||
msg = json.loads(msg)
|
||||
|
||||
if msg['is_final']==True:
|
||||
self.is_final=True
|
||||
|
||||
|
||||
if msg['mode']=='2pass-offline':
|
||||
self.rec_text=self.rec_text+msg['text']
|
||||
if not self.msg_callback is None:
|
||||
self.msg_callback(msg)
|
||||
|
||||
except Exception as e:
|
||||
#print("client closed")
|
||||
return
|
||||
|
||||
# feed data to asr engine in stream way
|
||||
def feed_chunk(self, chunk):
|
||||
try:
|
||||
self.websocket.send(chunk, ABNF.OPCODE_BINARY)
|
||||
return
|
||||
except:
|
||||
print("feed chunk error")
|
||||
return
|
||||
def close(self):
|
||||
self.connect_state==2
|
||||
self.websocket.close()
|
||||
|
||||
def rec_buf(self,audio_bytes,ffmpeg_decode=False):
|
||||
try:
|
||||
if ffmpeg_decode:
|
||||
audio_bytes=FunasrTools.audio2wav(audio_bytes)
|
||||
self.rec_file_len=len(audio_bytes)
|
||||
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
|
||||
chunk_num = (len(audio_bytes) - 1) // stride + 1
|
||||
|
||||
for i in range(chunk_num):
|
||||
|
||||
beg = i * stride
|
||||
data = audio_bytes[beg : beg + stride]
|
||||
self.feed_chunk(data)
|
||||
return self.get_result()
|
||||
except Exception as e:
|
||||
print("rec_file",e)
|
||||
return
|
||||
# rec file
|
||||
def rec_file(self,file_path):
|
||||
try:
|
||||
#self.new_connection()
|
||||
import os
|
||||
file_ext=os.path.splitext(file_path)[-1].upper()
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
|
||||
audio_bytes = f.read()
|
||||
if not file_ext =="PCM" and not file_ext =="WAV":
|
||||
audio_bytes=FunasrTools.audio2wav(audio_bytes)
|
||||
if audio_bytes==None:
|
||||
print("error, ffmpeg can not decode such file!")
|
||||
exit(0)
|
||||
return self.rec_buf(audio_bytes)
|
||||
except Exception as e:
|
||||
print("rec_file",e)
|
||||
return
|
||||
def wait_for_result(self):
|
||||
try:
|
||||
timeout=self.timeout
|
||||
|
||||
file_dur=self.rec_file_len/16000/2*100
|
||||
if file_dur>timeout:
|
||||
timeout=file_dur
|
||||
self.timeout=timeout
|
||||
#print("wait_for_result timeout=",timeout)
|
||||
|
||||
# if file_dur==0 means in stream way and no timeout
|
||||
while(self.is_final==False and (timeout>0 or file_dur==0 )):
|
||||
time.sleep(0.01)
|
||||
timeout=timeout-1
|
||||
|
||||
if timeout<=0 and not file_dur==0:
|
||||
print("time out!",self.timeout)
|
||||
except Exception as e:
|
||||
print("wait_for_result",e)
|
||||
return
|
||||
def get_result(self):
|
||||
try:
|
||||
message = json.dumps({"is_speaking": False})
|
||||
self.websocket.send(message)
|
||||
self.wait_for_result()
|
||||
self.close()
|
||||
|
||||
# return the msg
|
||||
return self.rec_text
|
||||
except Exception as e:
|
||||
#print("get_result ",e)
|
||||
return self.rec_text
|
||||
|
||||
|
||||
|
||||
72
runtime/funasr_api/funasr_stream.py
Normal file
72
runtime/funasr_api/funasr_stream.py
Normal file
@ -0,0 +1,72 @@
|
||||
"""
|
||||
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
2023-2024 by zhaomingwork@qq.com
|
||||
"""
|
||||
|
||||
# pip install websocket-client
|
||||
# apt install ffmpeg
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
# class for recognizer in websocket
|
||||
class FunasrStream:
|
||||
"""
|
||||
python asr recognizer lib
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
funasr_core
|
||||
|
||||
):
|
||||
"""
|
||||
uri: ws or wss server uri
|
||||
msg_callback: for message received
|
||||
timeout: timeout for get result
|
||||
"""
|
||||
try:
|
||||
self.funasr_core=funasr_core
|
||||
|
||||
except Exception as e:
|
||||
print("FunasrStream init Exception:", e)
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# feed data to asr engine in stream way
|
||||
def feed_chunk(self, chunk):
|
||||
try:
|
||||
if self.funasr_core is None:
|
||||
print("error in stream, funasr_core is None")
|
||||
exit(0)
|
||||
self.funasr_core.feed_chunk(chunk)
|
||||
return
|
||||
except:
|
||||
print("feed chunk error")
|
||||
return
|
||||
|
||||
|
||||
|
||||
# return all result for this stream
|
||||
def wait_for_end(self):
|
||||
try:
|
||||
|
||||
message = json.dumps({"is_speaking": False})
|
||||
self.funasr_core.websocket.send(message)
|
||||
self.funasr_core.wait_for_result()
|
||||
self.funasr_core.close()
|
||||
|
||||
# return the msg
|
||||
return self.funasr_core.rec_text
|
||||
except Exception as e:
|
||||
print("error get_final_result ",e)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
84
runtime/funasr_api/funasr_tools.py
Normal file
84
runtime/funasr_api/funasr_tools.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
2023-2024 by zhaomingwork@qq.com
|
||||
"""
|
||||
|
||||
# pip install websocket-client
|
||||
# apt install ffmpeg
|
||||
|
||||
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
import time
|
||||
|
||||
|
||||
|
||||
# class for recognizer in websocket
|
||||
class FunasrTools:
|
||||
"""
|
||||
python asr recognizer lib
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self
|
||||
|
||||
|
||||
):
|
||||
"""
|
||||
|
||||
"""
|
||||
try:
|
||||
|
||||
if FunasrTools.check_ffmpeg()==False:
|
||||
print("pls instal ffmpeg firest, in ubuntu, you can type apt install -y ffmpeg")
|
||||
exit(0)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# check ffmpeg installed
|
||||
@staticmethod
|
||||
def check_ffmpeg():
|
||||
import subprocess
|
||||
try:
|
||||
subprocess.run(['ffmpeg', '-version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
|
||||
return False
|
||||
# use ffmpeg to convert audio to wav
|
||||
@staticmethod
|
||||
def audio2wav(audiobuf):
|
||||
try:
|
||||
import os
|
||||
import subprocess
|
||||
if FunasrTools.check_ffmpeg()==False:
|
||||
print("pls instal ffmpeg firest, in ubuntu, you can type apt install -y ffmpeg")
|
||||
exit(0)
|
||||
return
|
||||
|
||||
ffmpeg_target_to_outwav = ["ffmpeg", "-i", '-', "-ac", "1", "-ar", "16000", "-f", "wav", "pipe:1"]
|
||||
pipe_to = subprocess.Popen(ffmpeg_target_to_outwav,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
wavbuf, err = pipe_to.communicate(audiobuf)
|
||||
if str(err).find("Error")>=0 or str(err).find("Unknown")>=0 or str(err).find("Invalid")>=0:
|
||||
print("ffmpeg err",err)
|
||||
return None
|
||||
return wavbuf
|
||||
except Exception as e:
|
||||
print("audio2wav",e)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
142
runtime/http/CMakeLists.txt
Normal file
142
runtime/http/CMakeLists.txt
Normal file
@ -0,0 +1,142 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
project(FunASRWebscoket)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
option(ENABLE_HTTP "Whether to build http server" ON)
|
||||
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
|
||||
|
||||
if(WIN32)
|
||||
file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h
|
||||
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/export.h
|
||||
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/logging.h
|
||||
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/raw_logging.h
|
||||
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/stl_logging.h
|
||||
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/vlog_is_on.h)
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
|
||||
option(ENABLE_GLOG "Whether to build glog" ON)
|
||||
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
|
||||
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
|
||||
|
||||
if(ENABLE_HTTP)
|
||||
# cmake_policy(SET CMP0135 NEW)
|
||||
include(FetchContent)
|
||||
|
||||
|
||||
|
||||
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/asio/asio )
|
||||
FetchContent_Declare(asio
|
||||
URL https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz
|
||||
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/asio
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(asio)
|
||||
endif()
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/asio/asio/include)
|
||||
|
||||
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
|
||||
FetchContent_Declare(json
|
||||
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
|
||||
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(json)
|
||||
endif()
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
|
||||
|
||||
endif()
|
||||
|
||||
if(ENABLE_PORTAUDIO)
|
||||
include(FetchContent)
|
||||
|
||||
set(portaudio_URL "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz")
|
||||
set(portaudio_URL2 "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/pa_stable_v190700_20210406.tgz")
|
||||
set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def")
|
||||
|
||||
FetchContent_Declare(portaudio
|
||||
URL
|
||||
${portaudio_URL}
|
||||
${portaudio_URL2}
|
||||
URL_HASH ${portaudio_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(portaudio)
|
||||
if(NOT portaudio_POPULATED)
|
||||
message(STATUS "Downloading portaudio from ${portaudio_URL}")
|
||||
FetchContent_Populate(portaudio)
|
||||
endif()
|
||||
message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}")
|
||||
message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}")
|
||||
|
||||
add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
if(NOT WIN32)
|
||||
target_compile_options(portaudio PRIVATE "-Wno-deprecated-declarations")
|
||||
else()
|
||||
install(TARGETS portaudio DESTINATION ..)
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
# Include generated *.pb.h files
|
||||
link_directories(${ONNXRUNTIME_DIR}/lib)
|
||||
link_directories(${FFMPEG_DIR}/lib)
|
||||
|
||||
if(ENABLE_GLOG)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src)
|
||||
set(BUILD_TESTING OFF)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
|
||||
include_directories(${glog_BINARY_DIR})
|
||||
|
||||
endif()
|
||||
|
||||
if(ENABLE_FST)
|
||||
# fst depend on glog and gflags
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags gflags)
|
||||
include_directories(${gflags_BINARY_DIR}/include)
|
||||
|
||||
# the following openfst if cloned from https://github.com/kkm000/openfst.git
|
||||
# with some patch to fix the make errors.
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/openfst openfst)
|
||||
include_directories(${openfst_SOURCE_DIR}/src/include)
|
||||
if(WIN32)
|
||||
include_directories(${openfst_SOURCE_DIR}/src/lib)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/src)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include/limonp/include)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi)
|
||||
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi kaldi)
|
||||
|
||||
# install openssl first apt-get install libssl-dev
|
||||
find_package(OpenSSL REQUIRED)
|
||||
|
||||
message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
|
||||
# 获取项目中所有包含文件夹的路径
|
||||
get_property(includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
|
||||
# 遍历并输出每个包含文件夹的路径
|
||||
foreach(include ${includes})
|
||||
message("Include directory: ${include}")
|
||||
endforeach()
|
||||
|
||||
add_subdirectory(bin)
|
||||
23
runtime/http/bin/CMakeLists.txt
Normal file
23
runtime/http/bin/CMakeLists.txt
Normal file
@ -0,0 +1,23 @@
|
||||
|
||||
|
||||
if(WIN32)
|
||||
include_directories(${ONNXRUNTIME_DIR}/include)
|
||||
include_directories(${FFMPEG_DIR}/include)
|
||||
include_directories(${OPENSSL_ROOT_DIR}//include)
|
||||
link_directories(${OPENSSL_ROOT_DIR}/lib)
|
||||
add_definitions(-D_WEBSOCKETPP_CPP11_RANDOM_DEVICE_)
|
||||
add_definitions(-D_WEBSOCKETPP_CPP11_TYPE_TRAITS_)
|
||||
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/bigobj>")
|
||||
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
|
||||
SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
|
||||
endif()
|
||||
|
||||
|
||||
find_package(ZLIB REQUIRED)
|
||||
|
||||
file(GLOB SRC_FILES "*.cpp")
|
||||
add_executable(funasr-http-server ${SRC_FILES} ${RELATION_SOURCE})
|
||||
|
||||
|
||||
|
||||
target_link_libraries(funasr-http-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
|
||||
20
runtime/http/bin/asr_sessions.h
Normal file
20
runtime/http/bin/asr_sessions.h
Normal file
@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
// FUNASR_MESSAGE define the needed message between funasr engine and http server
|
||||
#ifndef HTTP_SERVER2_SESSIONS_HPP
|
||||
#define HTTP_SERVER2_SESSIONS_HPP
|
||||
#include "funasrruntime.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include <atomic>
|
||||
typedef struct {
|
||||
nlohmann::json msg;
|
||||
std::shared_ptr<std::vector<char>> samples;
|
||||
std::shared_ptr<std::vector<std::vector<float>>> hotwords_embedding=nullptr;
|
||||
|
||||
FUNASR_DEC_HANDLE decoder_handle=nullptr;
|
||||
std::atomic<int> status;
|
||||
} FUNASR_MESSAGE;
|
||||
#endif // HTTP_SERVER2_REQUEST_PARSER_HPP
|
||||
196
runtime/http/bin/connection.cpp
Normal file
196
runtime/http/bin/connection.cpp
Normal file
@ -0,0 +1,196 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
//
|
||||
// connection.cpp
|
||||
// copy some codes from http://www.boost.org/
|
||||
#include "connection.hpp"
|
||||
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
//std::ofstream fwout("out.data", std::ios::binary);
|
||||
std::shared_ptr<FUNASR_MESSAGE> &connection::get_data_msg() { return data_msg; }
|
||||
connection::connection(asio::ip::tcp::socket socket,
|
||||
asio::io_context &io_decoder, int connection_id,
|
||||
std::shared_ptr<ModelDecoder> model_decoder)
|
||||
: socket_(std::move(socket)),
|
||||
io_decoder(io_decoder),
|
||||
connection_id(connection_id),
|
||||
model_decoder(model_decoder)
|
||||
|
||||
{
|
||||
s_timer = std::make_shared<asio::steady_timer>(io_decoder);
|
||||
}
|
||||
|
||||
void connection::setup_timer() {
|
||||
if (data_msg->status == 1) return;
|
||||
|
||||
s_timer->expires_after(std::chrono::seconds(3));
|
||||
s_timer->async_wait([=](const asio::error_code &ec) {
|
||||
if (!ec) {
|
||||
std::cout << "time is out!" << std::endl;
|
||||
if (data_msg->status == 1) return;
|
||||
data_msg->status = 1;
|
||||
s_timer->cancel();
|
||||
auto wf = std::bind(&connection::write_back, std::ref(*this), "");
|
||||
// close the connection
|
||||
strand_->post(wf);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void connection::start() {
|
||||
std::lock_guard<std::mutex> lock(m_lock); // for threads safty
|
||||
try {
|
||||
|
||||
data_msg = std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for
|
||||
// new connection
|
||||
data_msg->samples = std::make_shared<std::vector<char>>();
|
||||
//data_msg->samples->reserve(16000*20);
|
||||
data_msg->msg = nlohmann::json::parse("{}");
|
||||
data_msg->msg["wav_format"] = "pcm";
|
||||
data_msg->msg["wav_name"] = "wav-default-id";
|
||||
data_msg->msg["itn"] = true;
|
||||
data_msg->msg["audio_fs"] = 16000; // default is 16k
|
||||
data_msg->msg["access_num"] = 0; // the number of access for this object,
|
||||
// when it is 0, we can free it saftly
|
||||
data_msg->msg["is_eof"] = false;
|
||||
data_msg->status = 0;
|
||||
|
||||
strand_ = std::make_shared<asio::io_context::strand>(io_decoder);
|
||||
|
||||
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(
|
||||
model_decoder->get_asr_handle(), ASR_OFFLINE, global_beam_, lattice_beam_, am_scale_);
|
||||
|
||||
data_msg->decoder_handle = decoder_handle;
|
||||
|
||||
if (data_msg->hotwords_embedding == nullptr) {
|
||||
std::unordered_map<std::string, int> merged_hws_map;
|
||||
std::string nn_hotwords = "";
|
||||
|
||||
if (true) {
|
||||
std::string json_string = "{}";
|
||||
if (!json_string.empty()) {
|
||||
nlohmann::json json_fst_hws;
|
||||
try {
|
||||
json_fst_hws = nlohmann::json::parse(json_string);
|
||||
if (json_fst_hws.type() == nlohmann::json::value_t::object) {
|
||||
// fst
|
||||
try {
|
||||
std::unordered_map<std::string, int> client_hws_map =
|
||||
json_fst_hws;
|
||||
merged_hws_map.insert(client_hws_map.begin(),
|
||||
client_hws_map.end());
|
||||
} catch (const std::exception &e) {
|
||||
std::cout << e.what();
|
||||
}
|
||||
}
|
||||
} catch (std::exception const &e) {
|
||||
std::cout << e.what();
|
||||
// nn
|
||||
std::string client_nn_hws = "{}";
|
||||
nn_hotwords += " " + client_nn_hws;
|
||||
std::cout << "nn hotwords: " << client_nn_hws;
|
||||
}
|
||||
}
|
||||
}
|
||||
merged_hws_map.insert(hws_map_.begin(), hws_map_.end());
|
||||
|
||||
// fst
|
||||
std::cout << "hotwords: ";
|
||||
for (const auto &pair : merged_hws_map) {
|
||||
nn_hotwords += " " + pair.first;
|
||||
std::cout << pair.first << " : " << pair.second;
|
||||
}
|
||||
FunWfstDecoderLoadHwsRes(data_msg->decoder_handle, fst_inc_wts_,
|
||||
merged_hws_map);
|
||||
|
||||
// nn
|
||||
std::vector<std::vector<float>> new_hotwords_embedding =
|
||||
CompileHotwordEmbedding(model_decoder->get_asr_handle(), nn_hotwords);
|
||||
data_msg->hotwords_embedding =
|
||||
std::make_shared<std::vector<std::vector<float>>>(
|
||||
new_hotwords_embedding);
|
||||
}
|
||||
|
||||
file_parse = std::make_shared<http::server2::file_parser>(data_msg);
|
||||
do_read();
|
||||
} catch (const std::exception &e) {
|
||||
std::cout << "error:" << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void connection::write_back(std::string str) {
|
||||
|
||||
s_timer->cancel();
|
||||
std::cout << "jsonresult=" << data_msg->msg["asr_result"].dump() << std::endl;
|
||||
reply_ = reply::stock_reply(
|
||||
data_msg->msg["asr_result"].dump()); // reply::stock_reply();
|
||||
do_write();
|
||||
}
|
||||
void connection::do_read() {
|
||||
// status==1 means time out
|
||||
if (data_msg->status == 1) return;
|
||||
|
||||
|
||||
s_timer->cancel();
|
||||
setup_timer();
|
||||
auto self(shared_from_this());
|
||||
socket_.async_read_some(
|
||||
asio::buffer(buffer_),
|
||||
[this, self](asio::error_code ec, std::size_t bytes_transferred) {
|
||||
if (!ec) {
|
||||
auto is = std::begin(buffer_);
|
||||
auto ie = std::next(is, bytes_transferred);
|
||||
|
||||
http::server2::file_parser::result_type rtype =
|
||||
file_parse->parse_file(is, ie);
|
||||
if (rtype == http::server2::file_parser::result_type::ok) {
|
||||
|
||||
|
||||
//fwout.write(data_msg->samples->data(),data_msg->samples->size());
|
||||
//fwout.flush();
|
||||
auto wf = std::bind(&connection::write_back, std::ref(*this), "aa");
|
||||
auto f = std::bind(&ModelDecoder::do_decoder,
|
||||
std::ref(*model_decoder), std::ref(data_msg));
|
||||
|
||||
// for decode task
|
||||
strand_->post(f);
|
||||
// for close task
|
||||
strand_->post(wf);
|
||||
|
||||
// std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
|
||||
}
|
||||
|
||||
do_read();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void connection::do_write() {
|
||||
auto self(shared_from_this());
|
||||
asio::async_write(socket_, reply_.to_buffers(),
|
||||
[this, self](asio::error_code ec, std::size_t) {
|
||||
if (!ec) {
|
||||
// Initiate graceful connection closure.
|
||||
asio::error_code ignored_ec;
|
||||
socket_.shutdown(asio::ip::tcp::socket::shutdown_both,
|
||||
ignored_ec);
|
||||
}
|
||||
|
||||
// No new asynchronous operations are started. This means
|
||||
// that all shared_ptr references to the connection object
|
||||
// will disappear and the object will be destroyed
|
||||
// automatically after this handler returns. The
|
||||
// connection class's destructor closes the socket.
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
104
runtime/http/bin/connection.hpp
Normal file
104
runtime/http/bin/connection.hpp
Normal file
@ -0,0 +1,104 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
//
|
||||
// copy some codes from http://www.boost.org/
|
||||
//
|
||||
|
||||
#ifndef HTTP_SERVER2_CONNECTION_HPP
|
||||
#define HTTP_SERVER2_CONNECTION_HPP
|
||||
|
||||
#include <array>
|
||||
#include <asio.hpp>
|
||||
#include <atomic>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "reply.hpp"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
|
||||
#include "file_parse.hpp"
|
||||
#include "model-decoder.h"
|
||||
|
||||
|
||||
extern std::unordered_map<std::string, int> hws_map_;
|
||||
extern int fst_inc_wts_;
|
||||
extern float global_beam_, lattice_beam_, am_scale_;
|
||||
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
/// Represents a single connection from a client.
|
||||
class connection : public std::enable_shared_from_this<connection> {
|
||||
public:
|
||||
connection(const connection &) = delete;
|
||||
connection &operator=(const connection &) = delete;
|
||||
~connection() { std::cout << "one connection is close()" << std::endl; };
|
||||
|
||||
/// Construct a connection with the given socket.
|
||||
explicit connection(asio::ip::tcp::socket socket,
|
||||
asio::io_context &io_decoder, int connection_id,
|
||||
std::shared_ptr<ModelDecoder> model_decoder);
|
||||
|
||||
|
||||
/// Start the first asynchronous operation for the connection.
|
||||
void start();
|
||||
std::shared_ptr<FUNASR_MESSAGE> &get_data_msg();
|
||||
void write_back(std::string str);
|
||||
|
||||
private:
|
||||
/// Perform an asynchronous read operation.
|
||||
void do_read();
|
||||
|
||||
/// Perform an asynchronous write operation.
|
||||
void do_write();
|
||||
|
||||
void do_decoder();
|
||||
|
||||
void setup_timer();
|
||||
|
||||
/// Socket for the connection.
|
||||
asio::ip::tcp::socket socket_;
|
||||
|
||||
|
||||
|
||||
/// Buffer for incoming data.
|
||||
std::array<char, 8192> buffer_;
|
||||
/// for time out
|
||||
std::shared_ptr<asio::steady_timer> s_timer;
|
||||
|
||||
|
||||
|
||||
std::shared_ptr<ModelDecoder> model_decoder;
|
||||
|
||||
|
||||
|
||||
int connection_id = 0;
|
||||
|
||||
/// The reply to be sent back to the client.
|
||||
reply reply_;
|
||||
|
||||
asio::io_context &io_decoder;
|
||||
|
||||
|
||||
|
||||
std::shared_ptr<FUNASR_MESSAGE> data_msg;
|
||||
|
||||
std::mutex m_lock;
|
||||
|
||||
|
||||
std::shared_ptr<asio::io_context::strand> strand_;
|
||||
|
||||
std::shared_ptr<http::server2::file_parser> file_parse;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<connection> connection_ptr;
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
|
||||
#endif // HTTP_SERVER2_CONNECTION_HPP
|
||||
29
runtime/http/bin/file_parse.cpp
Normal file
29
runtime/http/bin/file_parse.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
|
||||
|
||||
#include "file_parse.hpp"
|
||||
|
||||
|
||||
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
|
||||
file_parser::file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg)
|
||||
:data_msg(data_msg)
|
||||
|
||||
{
|
||||
now_state=start;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
234
runtime/http/bin/file_parse.hpp
Normal file
234
runtime/http/bin/file_parse.hpp
Normal file
@ -0,0 +1,234 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
// ~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
#ifndef HTTP_SERVER2_REQUEST_FILEPARSER_HPP
|
||||
#define HTTP_SERVER2_REQUEST_FILEPARSER_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
#include "asr_sessions.h"
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
/// Parser for incoming requests.
|
||||
class file_parser {
|
||||
public:
|
||||
/// Construct ready to parse the request method.
|
||||
|
||||
explicit file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg);
|
||||
|
||||
/// Result of parse.
|
||||
enum result_type { start, in_boundary, data, ok };
|
||||
|
||||
template <typename InputIterator>
|
||||
void parse_one_line(InputIterator &is, InputIterator &ie, InputIterator &it) {
|
||||
if (is != it) {
|
||||
is = it;
|
||||
}
|
||||
if (*it == '\n') {
|
||||
is = std::next(is);
|
||||
}
|
||||
|
||||
it = std::find(is, ie, '\n');
|
||||
std::string str(is, it);
|
||||
|
||||
}
|
||||
std::string trim_name(std::string raw_string) {
|
||||
int pos = raw_string.find('\"');
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
raw_string = raw_string.substr(pos + 1);
|
||||
pos = raw_string.find('\"');
|
||||
raw_string = raw_string.substr(0, pos);
|
||||
}
|
||||
return raw_string;
|
||||
}
|
||||
|
||||
std::string parese_file_ext(std::string file_name) {
|
||||
int pos = file_name.rfind('.');
|
||||
std::string ext = "";
|
||||
if (pos != std::string::npos) ext = file_name.substr(pos + 1);
|
||||
|
||||
return ext;
|
||||
}
|
||||
template <typename InputIterator>
|
||||
int parse_data_content(InputIterator is, InputIterator ie, InputIterator it) {
|
||||
int len = std::distance(it + 1, ie);
|
||||
if (len <= 0) {
|
||||
return 0;
|
||||
}
|
||||
std::string str(it + 1, ie);
|
||||
|
||||
// check if at the end, "--boundary--" need +4 for "--"
|
||||
if (len == boundary.length() + 4)
|
||||
|
||||
{
|
||||
std::string str(it + 1, ie);
|
||||
// std::cout << "len good=" << str << std::endl;
|
||||
if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\n') {
|
||||
// remove '\n' in boundary
|
||||
boundary = boundary.substr(0, boundary.length() - 2);
|
||||
}
|
||||
if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\r') {
|
||||
// remove '\r' in boundary
|
||||
boundary = boundary.substr(0, boundary.length() - 2);
|
||||
}
|
||||
|
||||
auto found_boundary = str.find(boundary);
|
||||
|
||||
if (found_boundary == std::string::npos) {
|
||||
std::cout << "not found end boundary!=" << found_boundary << std::endl;
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
// remove the end of data that contains '\n' or '\r'
|
||||
int last_sub = 0;
|
||||
if (*(it) == '\n') {
|
||||
last_sub++;
|
||||
}
|
||||
|
||||
|
||||
int lasts_len = std::distance(it, ie);
|
||||
|
||||
data_msg->samples->erase(data_msg->samples->end() - last_sub - lasts_len,
|
||||
data_msg->samples->end());
|
||||
std::cout << "one file finished, file size=" << data_msg->samples->size()
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
}
|
||||
template <typename InputIterator>
|
||||
void parse_boundary_content(InputIterator is, InputIterator ie,
|
||||
InputIterator it) {
|
||||
parse_one_line(is, ie, it);
|
||||
std::string str;
|
||||
|
||||
while (it != ie) {
|
||||
|
||||
str = std::string(is, it);
|
||||
|
||||
auto found_content = str.find("Content-Disposition:");
|
||||
auto found_filename = str.find("filename=");
|
||||
if (found_content != std::string::npos &&
|
||||
found_filename != std::string::npos) {
|
||||
std::string file_name =
|
||||
str.substr(found_filename + 9, std::string::npos);
|
||||
file_name = trim_name(file_name);
|
||||
|
||||
std::string ext = parese_file_ext(file_name);
|
||||
|
||||
if (file_name.find(".wav") != std::string::npos) {
|
||||
std::cout << "set wav_format=pcm, file_name=" << file_name
|
||||
<< std::endl;
|
||||
data_msg->msg["wav_format"] = "pcm";
|
||||
} else {
|
||||
std::cout << "set wav_format=" << ext << ", file_name=" << file_name
|
||||
<< std::endl;
|
||||
data_msg->msg["wav_format"] = ext;
|
||||
}
|
||||
data_msg->msg["wav_name"] = file_name;
|
||||
now_state = data;
|
||||
} else {
|
||||
auto found_content = str.find("Content-Disposition:");
|
||||
auto found_name = str.find("name=");
|
||||
if (found_content != std::string::npos &&
|
||||
found_name != std::string::npos) {
|
||||
std::string name = str.substr(found_name + 5, std::string::npos);
|
||||
name = trim_name(name);
|
||||
parse_one_line(is, ie, it);
|
||||
if (*it == '\n') it++;
|
||||
parse_one_line(is, ie, it);
|
||||
str = std::string(is, it);
|
||||
std::cout << "para: name=" << name << ",value=" << str << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
parse_one_line(is, ie, it);
|
||||
if (now_state == data && std::distance(is, it) <= 2) {
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
if (now_state == data) {
|
||||
if (*it == '\n') it++;
|
||||
|
||||
data_msg->samples->insert(data_msg->samples->end(), it,
|
||||
it + std::distance(it, ie));
|
||||
// it=ie;
|
||||
}
|
||||
}
|
||||
template <typename InputIterator>
|
||||
result_type parse_file(InputIterator is, InputIterator ie) {
|
||||
|
||||
if (now_state == data) {
|
||||
data_msg->samples->insert(data_msg->samples->end(), is, ie);
|
||||
}
|
||||
auto it = is;
|
||||
|
||||
while (it != ie) {
|
||||
std::string str(is, it);
|
||||
|
||||
parse_one_line(is, ie, it);
|
||||
if (now_state == data) {
|
||||
// for data end search
|
||||
|
||||
int ret = parse_data_content(is, ie, it);
|
||||
if (ret == 0) continue;
|
||||
return ok;
|
||||
} else {
|
||||
std::string str(is, it + 1);
|
||||
|
||||
|
||||
if (now_state == start) {
|
||||
|
||||
|
||||
auto found_boundary = str.find("Content-Length:");
|
||||
if (found_boundary != std::string::npos) {
|
||||
std::string file_len =
|
||||
str.substr(found_boundary + 15, std::string::npos);
|
||||
|
||||
data_msg->samples->reserve(std::stoi(file_len));
|
||||
|
||||
}
|
||||
found_boundary = str.find("boundary=");
|
||||
if (found_boundary != std::string::npos) {
|
||||
boundary = str.substr(found_boundary + 9, std::string::npos);
|
||||
now_state = in_boundary;
|
||||
}
|
||||
} else if (now_state == in_boundary) {
|
||||
// for file header
|
||||
auto found_boundary = str.find(boundary);
|
||||
if (found_boundary != std::string::npos) {
|
||||
parse_boundary_content(is, ie, it);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
return now_state;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<FUNASR_MESSAGE> data_msg;
|
||||
result_type now_state;
|
||||
std::string boundary = "";
|
||||
};
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
|
||||
#endif // HTTP_SERVER2_REQUEST_FILEPARSER_HPP
|
||||
523
runtime/http/bin/funasr-http-main.cpp
Normal file
523
runtime/http/bin/funasr-http-main.cpp
Normal file
@ -0,0 +1,523 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
|
||||
#include "funasr-http-main.hpp"
|
||||
#ifdef _WIN32
|
||||
#include "win_func.h"
|
||||
#else
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <fstream>
|
||||
|
||||
#include "util.h"
|
||||
|
||||
// hotwords
|
||||
std::unordered_map<std::string, int> hws_map_;
|
||||
int fst_inc_wts_ = 20;
|
||||
float global_beam_, lattice_beam_, am_scale_;
|
||||
|
||||
using namespace std;
|
||||
void GetValue(TCLAP::ValueArg<std::string> &value_arg, string key,
|
||||
std::map<std::string, std::string> &model_path) {
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO) << key << " : " << value_arg.getValue();
|
||||
}
|
||||
|
||||
FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path,
|
||||
int thread_num) {
|
||||
try {
|
||||
// init model with api
|
||||
|
||||
FUNASR_HANDLE asr_handle = FunOfflineInit(model_path, thread_num);
|
||||
LOG(INFO) << "model successfully inited";
|
||||
|
||||
LOG(INFO) << "initAsr run check_and_clean_connection";
|
||||
// std::thread
|
||||
// clean_thread(&ModelDecoderSrv::check_and_clean_connection,this);
|
||||
// clean_thread.detach();
|
||||
LOG(INFO) << "initAsr run check_and_clean_connection finished";
|
||||
return asr_handle;
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
LOG(INFO) << e.what();
|
||||
// return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
SetConsoleOutputCP(65001);
|
||||
#endif
|
||||
try {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
std::string offline_version = "";
|
||||
#ifdef _WIN32
|
||||
offline_version = "0.1.0";
|
||||
#endif
|
||||
TCLAP::CmdLine cmd("funasr-wss-server", ' ', offline_version);
|
||||
TCLAP::ValueArg<std::string> download_model_dir(
|
||||
"", "download-model-dir",
|
||||
"Download model from Modelscope to download_model_dir", false,
|
||||
"/workspace/models", "string");
|
||||
TCLAP::ValueArg<std::string> model_dir(
|
||||
"", OFFLINE_MODEL_DIR,
|
||||
"default: "
|
||||
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx, "
|
||||
"the asr model path, which "
|
||||
"contains model_quant.onnx, config.yaml, am.mvn",
|
||||
false,
|
||||
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx",
|
||||
"string");
|
||||
TCLAP::ValueArg<std::string> model_revision("", "offline-model-revision",
|
||||
"ASR offline model revision",
|
||||
false, "v2.0.4", "string");
|
||||
TCLAP::ValueArg<std::string> quantize(
|
||||
"", QUANTIZE,
|
||||
"true (Default), load the model of model_quant.onnx in model_dir. If "
|
||||
"set "
|
||||
"false, load the model of model.onnx in model_dir",
|
||||
false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir(
|
||||
"", VAD_DIR,
|
||||
"default: damo/speech_fsmn_vad_zh-cn-16k-common-onnx, the vad model "
|
||||
"path, which contains "
|
||||
"model_quant.onnx, vad.yaml, vad.mvn",
|
||||
false, "damo/speech_fsmn_vad_zh-cn-16k-common-onnx", "string");
|
||||
TCLAP::ValueArg<std::string> vad_revision(
|
||||
"", "vad-revision", "VAD model revision", false, "v2.0.4", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant(
|
||||
"", VAD_QUANT,
|
||||
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
|
||||
"false, load the model of model.onnx in vad_dir",
|
||||
false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir(
|
||||
"", PUNC_DIR,
|
||||
"default: "
|
||||
"damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx, "
|
||||
"the punc model path, which contains "
|
||||
"model_quant.onnx, punc.yaml",
|
||||
false,
|
||||
"damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx",
|
||||
"string");
|
||||
TCLAP::ValueArg<std::string> punc_revision(
|
||||
"", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant(
|
||||
"", PUNC_QUANT,
|
||||
"true (Default), load the model of model_quant.onnx in punc_dir. If "
|
||||
"set "
|
||||
"false, load the model of model.onnx in punc_dir",
|
||||
false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> itn_dir(
|
||||
"", ITN_DIR,
|
||||
"default: thuduj12/fst_itn_zh, the itn model path, which contains "
|
||||
"zh_itn_tagger.fst, zh_itn_verbalizer.fst",
|
||||
false, "", "string");
|
||||
TCLAP::ValueArg<std::string> itn_revision(
|
||||
"", "itn-revision", "ITN model revision", false, "v1.0.1", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
|
||||
"0.0.0.0", "string");
|
||||
TCLAP::ValueArg<int> port("", "port", "port", false, 80, "int");
|
||||
TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
|
||||
false, 8, "int");
|
||||
TCLAP::ValueArg<int> decoder_thread_num(
|
||||
"", "decoder-thread-num", "decoder thread num", false, 32, "int");
|
||||
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
|
||||
"model thread num", false, 1, "int");
|
||||
|
||||
TCLAP::ValueArg<std::string> certfile(
|
||||
"", "certfile",
|
||||
"default: ../../../ssl_key/server.crt, path of certficate for WSS "
|
||||
"connection. if it is empty, it will be in WS mode.",
|
||||
false, "../../../ssl_key/server.crt", "string");
|
||||
TCLAP::ValueArg<std::string> keyfile(
|
||||
"", "keyfile",
|
||||
"default: ../../../ssl_key/server.key, path of keyfile for WSS "
|
||||
"connection",
|
||||
false, "../../../ssl_key/server.key", "string");
|
||||
|
||||
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM,
|
||||
"the decoding beam for beam searching ",
|
||||
false, 3.0, "float");
|
||||
TCLAP::ValueArg<float> lattice_beam(
|
||||
"", LAT_BEAM, "the lattice generation beam for beam searching ", false,
|
||||
3.0, "float");
|
||||
TCLAP::ValueArg<float> am_scale("", AM_SCALE,
|
||||
"the acoustic scale for beam searching ",
|
||||
false, 10.0, "float");
|
||||
|
||||
TCLAP::ValueArg<std::string> lm_dir(
|
||||
"", LM_DIR,
|
||||
"the LM model path, which contains compiled models: TLG.fst, "
|
||||
"config.yaml ",
|
||||
false, "", "string");
|
||||
TCLAP::ValueArg<std::string> lm_revision(
|
||||
"", "lm-revision", "LM model revision", false, "v1.0.2", "string");
|
||||
TCLAP::ValueArg<std::string> hotword(
|
||||
"", HOTWORD,
|
||||
"the hotword file, one hotword perline, Format: Hotword Weight (could "
|
||||
"be: 阿里巴巴 20)",
|
||||
false, "/workspace/resources/hotwords.txt", "string");
|
||||
TCLAP::ValueArg<std::int32_t> fst_inc_wts(
|
||||
"", FST_INC_WTS, "the fst hotwords incremental bias", false, 20,
|
||||
"int32_t");
|
||||
|
||||
// add file
|
||||
cmd.add(hotword);
|
||||
cmd.add(fst_inc_wts);
|
||||
cmd.add(global_beam);
|
||||
cmd.add(lattice_beam);
|
||||
cmd.add(am_scale);
|
||||
|
||||
cmd.add(certfile);
|
||||
cmd.add(keyfile);
|
||||
cmd.add(download_model_dir);
|
||||
cmd.add(model_dir);
|
||||
cmd.add(model_revision);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_revision);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_revision);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(itn_dir);
|
||||
cmd.add(itn_revision);
|
||||
cmd.add(lm_dir);
|
||||
cmd.add(lm_revision);
|
||||
|
||||
cmd.add(listen_ip);
|
||||
cmd.add(port);
|
||||
cmd.add(io_thread_num);
|
||||
cmd.add(decoder_thread_num);
|
||||
cmd.add(model_thread_num);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(itn_dir, ITN_DIR, model_path);
|
||||
GetValue(lm_dir, LM_DIR, model_path);
|
||||
GetValue(hotword, HOTWORD, model_path);
|
||||
|
||||
GetValue(model_revision, "model-revision", model_path);
|
||||
GetValue(vad_revision, "vad-revision", model_path);
|
||||
GetValue(punc_revision, "punc-revision", model_path);
|
||||
GetValue(itn_revision, "itn-revision", model_path);
|
||||
GetValue(lm_revision, "lm-revision", model_path);
|
||||
|
||||
global_beam_ = global_beam.getValue();
|
||||
lattice_beam_ = lattice_beam.getValue();
|
||||
am_scale_ = am_scale.getValue();
|
||||
|
||||
// Download model form Modelscope
|
||||
try {
|
||||
std::string s_download_model_dir = download_model_dir.getValue();
|
||||
|
||||
std::string s_vad_path = model_path[VAD_DIR];
|
||||
std::string s_vad_quant = model_path[VAD_QUANT];
|
||||
std::string s_asr_path = model_path[MODEL_DIR];
|
||||
std::string s_asr_quant = model_path[QUANTIZE];
|
||||
std::string s_punc_path = model_path[PUNC_DIR];
|
||||
std::string s_punc_quant = model_path[PUNC_QUANT];
|
||||
std::string s_itn_path = model_path[ITN_DIR];
|
||||
std::string s_lm_path = model_path[LM_DIR];
|
||||
|
||||
std::string python_cmd =
|
||||
"python -m funasr.download.runtime_sdk_download_tool --type onnx "
|
||||
"--quantize True ";
|
||||
|
||||
if (vad_dir.isSet() && !s_vad_path.empty()) {
|
||||
std::string python_cmd_vad;
|
||||
std::string down_vad_path;
|
||||
std::string down_vad_model;
|
||||
|
||||
if (access(s_vad_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["vad-revision"];
|
||||
down_vad_path = s_vad_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: ";
|
||||
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["vad-revision"];
|
||||
down_vad_path = s_download_model_dir + "/" + s_vad_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_vad.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set "
|
||||
"local vad model path, you can ignore the errors.";
|
||||
}
|
||||
down_vad_model = down_vad_path + "/model_quant.onnx";
|
||||
if (s_vad_quant == "false" || s_vad_quant == "False" ||
|
||||
s_vad_quant == "FALSE") {
|
||||
down_vad_model = down_vad_path + "/model.onnx";
|
||||
}
|
||||
|
||||
if (access(down_vad_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_vad_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[VAD_DIR] = down_vad_path;
|
||||
LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "VAD model is not set, use default.";
|
||||
}
|
||||
|
||||
if (model_dir.isSet() && !s_asr_path.empty()) {
|
||||
std::string python_cmd_asr;
|
||||
std::string down_asr_path;
|
||||
std::string down_asr_model;
|
||||
|
||||
// modify model-revision by model name
|
||||
size_t found = s_asr_path.find(
|
||||
"speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-"
|
||||
"vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"] = "v1.2.4";
|
||||
}
|
||||
|
||||
found = s_asr_path.find(
|
||||
"speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-"
|
||||
"vocab8404");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"] = "v1.0.5";
|
||||
}
|
||||
|
||||
found = s_asr_path.find(
|
||||
"speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
|
||||
if (found != std::string::npos) {
|
||||
model_path["model-revision"] = "v1.0.0";
|
||||
s_itn_path = "";
|
||||
s_lm_path = "";
|
||||
}
|
||||
|
||||
if (access(s_asr_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["model-revision"];
|
||||
down_asr_path = s_asr_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: ";
|
||||
python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["model-revision"];
|
||||
down_asr_path = s_download_model_dir + "/" + s_asr_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_asr.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set "
|
||||
"local asr model path, you can ignore the errors.";
|
||||
}
|
||||
down_asr_model = down_asr_path + "/model_quant.onnx";
|
||||
if (s_asr_quant == "false" || s_asr_quant == "False" ||
|
||||
s_asr_quant == "FALSE") {
|
||||
down_asr_model = down_asr_path + "/model.onnx";
|
||||
}
|
||||
|
||||
if (access(down_asr_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_asr_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[MODEL_DIR] = down_asr_path;
|
||||
LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "ASR model is not set, use default.";
|
||||
}
|
||||
|
||||
if (!s_itn_path.empty()) {
|
||||
std::string python_cmd_itn;
|
||||
std::string down_itn_path;
|
||||
std::string down_itn_model;
|
||||
|
||||
if (access(s_itn_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["itn-revision"] + " --export False ";
|
||||
down_itn_path = s_itn_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_itn_path
|
||||
<< " from modelscope : ";
|
||||
python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["itn-revision"] +
|
||||
" --export False ";
|
||||
down_itn_path = s_download_model_dir + "/" + s_itn_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_itn.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set "
|
||||
"local itn model path, you can ignore the errors.";
|
||||
}
|
||||
down_itn_model = down_itn_path + "/zh_itn_tagger.fst";
|
||||
|
||||
if (access(down_itn_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_itn_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[ITN_DIR] = down_itn_path;
|
||||
LOG(INFO) << "Set " << ITN_DIR << " : " << model_path[ITN_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "ITN model is not set, not executed.";
|
||||
}
|
||||
|
||||
if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") {
|
||||
std::string python_cmd_lm;
|
||||
std::string down_lm_path;
|
||||
std::string down_lm_model;
|
||||
|
||||
if (access(s_lm_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["lm-revision"] + " --export False ";
|
||||
down_lm_path = s_lm_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_lm_path << " from modelscope : ";
|
||||
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["lm-revision"] +
|
||||
" --export False ";
|
||||
down_lm_path = s_download_model_dir + "/" + s_lm_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_lm.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set "
|
||||
"local lm model path, you can ignore the errors.";
|
||||
}
|
||||
down_lm_model = down_lm_path + "/TLG.fst";
|
||||
|
||||
if (access(down_lm_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_lm_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[LM_DIR] = down_lm_path;
|
||||
LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "LM model is not set, not executed.";
|
||||
model_path[LM_DIR] = "";
|
||||
}
|
||||
|
||||
if (punc_dir.isSet() && !s_punc_path.empty()) {
|
||||
std::string python_cmd_punc;
|
||||
std::string down_punc_path;
|
||||
std::string down_punc_model;
|
||||
|
||||
if (access(s_punc_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["punc-revision"];
|
||||
down_punc_path = s_punc_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_punc_path
|
||||
<< " from modelscope: ";
|
||||
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["punc-revision"];
|
||||
down_punc_path = s_download_model_dir + "/" + s_punc_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_punc.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set "
|
||||
"local punc model path, you can ignore the errors.";
|
||||
}
|
||||
down_punc_model = down_punc_path + "/model_quant.onnx";
|
||||
if (s_punc_quant == "false" || s_punc_quant == "False" ||
|
||||
s_punc_quant == "FALSE") {
|
||||
down_punc_model = down_punc_path + "/model.onnx";
|
||||
}
|
||||
|
||||
if (access(down_punc_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_punc_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[PUNC_DIR] = down_punc_path;
|
||||
LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "PUNC model is not set, use default.";
|
||||
}
|
||||
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error: " << e.what();
|
||||
}
|
||||
|
||||
std::string s_listen_ip = listen_ip.getValue();
|
||||
int s_port = port.getValue();
|
||||
int s_io_thread_num = io_thread_num.getValue();
|
||||
int s_decoder_thread_num = decoder_thread_num.getValue();
|
||||
|
||||
int s_model_thread_num = model_thread_num.getValue();
|
||||
|
||||
asio::io_context io_decoder; // context for decoding
|
||||
|
||||
std::vector<std::thread> decoder_threads;
|
||||
|
||||
// hotword file
|
||||
std::string hotword_path;
|
||||
hotword_path = model_path.at(HOTWORD);
|
||||
fst_inc_wts_ = fst_inc_wts.getValue();
|
||||
LOG(INFO) << "hotword path: " << hotword_path;
|
||||
funasr::ExtractHws(hotword_path, hws_map_);
|
||||
|
||||
auto conn_guard = asio::make_work_guard(
|
||||
io_decoder); // make sure threads can wait in the queue
|
||||
|
||||
// create threads pool
|
||||
for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
|
||||
decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
|
||||
}
|
||||
|
||||
// ModelDecoderSrv modelSrv(
|
||||
// io_decoder); // websocket server for asr engine
|
||||
// modelSrv.initAsr(model_path, s_model_thread_num); // init asr model
|
||||
// FUNASR_HANDLE asr_handle= initAsr();
|
||||
LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
|
||||
LOG(INFO) << "io-thread-num: " << s_io_thread_num;
|
||||
LOG(INFO) << "model-thread-num: " << s_model_thread_num;
|
||||
|
||||
http::server2::server s(s_listen_ip, std::to_string(s_port), "./",
|
||||
s_io_thread_num, io_decoder, model_path,
|
||||
s_model_thread_num);
|
||||
|
||||
s.run();
|
||||
LOG(INFO) << "http model loop " << s_port;
|
||||
// wait for theads
|
||||
for (auto &t : decoder_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error: " << e.what();
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
20
runtime/http/bin/funasr-http-main.hpp
Normal file
20
runtime/http/bin/funasr-http-main.hpp
Normal file
@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
|
||||
#ifndef HTTP_SERVER2_MAIN_HPP
|
||||
#define HTTP_SERVER2_MAIN_HPP
|
||||
|
||||
#include "model-decoder.h"
|
||||
#include "server.hpp"
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
|
||||
#endif // HTTP_SERVER2_MAIN_HPP
|
||||
27
runtime/http/bin/header.hpp
Normal file
27
runtime/http/bin/header.hpp
Normal file
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
//
|
||||
// header.hpp
|
||||
// copy some codes from http://www.boost.org/
|
||||
|
||||
#ifndef HTTP_SERVER2_HEADER_HPP
|
||||
#define HTTP_SERVER2_HEADER_HPP
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
struct header
|
||||
{
|
||||
std::string name;
|
||||
std::string value;
|
||||
};
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
|
||||
#endif // HTTP_SERVER2_HEADER_HPP
|
||||
66
runtime/http/bin/io_context_pool.cpp
Normal file
66
runtime/http/bin/io_context_pool.cpp
Normal file
@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
//
|
||||
// io_context_pool.cpp
|
||||
// ~~~~~~~~~~~~~~~~~~~
|
||||
// copy some codes from http://www.boost.org/
|
||||
|
||||
#include "io_context_pool.hpp"
|
||||
#include <stdexcept>
|
||||
#include <thread>
|
||||
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
io_context_pool::io_context_pool(std::size_t pool_size)
|
||||
: next_io_context_(0)
|
||||
{
|
||||
if (pool_size == 0)
|
||||
throw std::runtime_error("io_context_pool size is 0");
|
||||
|
||||
// Give all the io_contexts work to do so that their run() functions will not
|
||||
// exit until they are explicitly stopped.
|
||||
for (std::size_t i = 0; i < pool_size; ++i)
|
||||
{
|
||||
io_context_ptr io_context(new asio::io_context);
|
||||
io_contexts_.push_back(io_context);
|
||||
work_.push_back(asio::make_work_guard(*io_context));
|
||||
}
|
||||
}
|
||||
|
||||
void io_context_pool::run()
|
||||
{
|
||||
// Create a pool of threads to run all of the io_contexts.
|
||||
std::vector<std::thread> threads;
|
||||
for (std::size_t i = 0; i < io_contexts_.size(); ++i)
|
||||
threads.emplace_back([this, i]{ io_contexts_[i]->run(); });
|
||||
|
||||
// Wait for all threads in the pool to exit.
|
||||
for (std::size_t i = 0; i < threads.size(); ++i)
|
||||
threads[i].join();
|
||||
}
|
||||
|
||||
void io_context_pool::stop()
|
||||
{
|
||||
// Explicitly stop all io_contexts.
|
||||
for (std::size_t i = 0; i < io_contexts_.size(); ++i)
|
||||
io_contexts_[i]->stop();
|
||||
}
|
||||
|
||||
asio::io_context& io_context_pool::get_io_context()
|
||||
{
|
||||
// Use a round-robin scheme to choose the next io_context to use.
|
||||
|
||||
asio::io_context& io_context = *io_contexts_[next_io_context_];
|
||||
|
||||
++next_io_context_;
|
||||
if (next_io_context_ == io_contexts_.size())
|
||||
next_io_context_ = 0;
|
||||
return io_context;
|
||||
}
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
59
runtime/http/bin/io_context_pool.hpp
Normal file
59
runtime/http/bin/io_context_pool.hpp
Normal file
@ -0,0 +1,59 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
//
|
||||
// io_context_pool.hpp
|
||||
// ~~~~~~~~~~~~~~~~~~~
|
||||
// copy some codes from http://www.boost.org/
|
||||
|
||||
#ifndef HTTP_SERVER2_IO_SERVICE_POOL_HPP
|
||||
#define HTTP_SERVER2_IO_SERVICE_POOL_HPP
|
||||
|
||||
#include <asio.hpp>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
/// A pool of io_context objects.
|
||||
class io_context_pool
|
||||
{
|
||||
public:
|
||||
/// Construct the io_context pool.
|
||||
explicit io_context_pool(std::size_t pool_size);
|
||||
|
||||
/// Run all io_context objects in the pool.
|
||||
void run();
|
||||
|
||||
/// Stop all io_context objects in the pool.
|
||||
void stop();
|
||||
|
||||
/// Get an io_context to use.
|
||||
asio::io_context& get_io_context();
|
||||
|
||||
private:
|
||||
io_context_pool(const io_context_pool&) = delete;
|
||||
io_context_pool& operator=(const io_context_pool&) = delete;
|
||||
|
||||
typedef std::shared_ptr<::asio::io_context> io_context_ptr;
|
||||
typedef asio::executor_work_guard<
|
||||
asio::io_context::executor_type> io_context_work;
|
||||
|
||||
/// The pool of io_contexts.
|
||||
std::vector<io_context_ptr> io_contexts_;
|
||||
|
||||
/// The work that keeps the io_contexts running.
|
||||
std::list<io_context_work> work_;
|
||||
|
||||
/// The next io_context to use for a connection.
|
||||
std::size_t next_io_context_;
|
||||
};
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
|
||||
#endif // HTTP_SERVER2_IO_SERVICE_POOL_HPP
|
||||
119
runtime/http/bin/model-decoder.cpp
Normal file
119
runtime/http/bin/model-decoder.cpp
Normal file
@ -0,0 +1,119 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
|
||||
// funasr asr engine
|
||||
|
||||
#include "model-decoder.h"
|
||||
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
extern std::unordered_map<std::string, int> hws_map_;
|
||||
extern int fst_inc_wts_;
|
||||
extern float global_beam_, lattice_beam_, am_scale_;
|
||||
|
||||
// feed msg to asr engine for decoder
|
||||
void ModelDecoder::do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg) {
|
||||
try {
|
||||
// std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
|
||||
if (session_msg->status == 1) return;
|
||||
//std::cout << "in do_decoder" << std::endl;
|
||||
std::shared_ptr<std::vector<char>> buffer = session_msg->samples;
|
||||
int num_samples = buffer->size(); // the size of the buf
|
||||
std::string wav_name =session_msg->msg["wav_name"];
|
||||
bool itn = session_msg->msg["itn"];
|
||||
int audio_fs = session_msg->msg["audio_fs"];;
|
||||
std::string wav_format = session_msg->msg["wav_format"];
|
||||
|
||||
|
||||
|
||||
if (num_samples > 0 && session_msg->hotwords_embedding->size() > 0) {
|
||||
std::string asr_result = "";
|
||||
std::string stamp_res = "";
|
||||
std::string stamp_sents = "";
|
||||
|
||||
try {
|
||||
std::vector<std::vector<float>> hotwords_embedding_(
|
||||
*(session_msg->hotwords_embedding));
|
||||
|
||||
|
||||
FUNASR_RESULT Result = FunOfflineInferBuffer(
|
||||
asr_handle, buffer->data(), buffer->size(), RASR_NONE, nullptr,
|
||||
std::move(hotwords_embedding_), audio_fs, wav_format, itn,
|
||||
session_msg->decoder_handle);
|
||||
|
||||
if (Result != nullptr) {
|
||||
asr_result = FunASRGetResult(Result, 0); // get decode result
|
||||
stamp_res = FunASRGetStamp(Result);
|
||||
stamp_sents = FunASRGetStampSents(Result);
|
||||
FunASRFreeResult(Result);
|
||||
|
||||
} else {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(20));
|
||||
}
|
||||
} catch (std::exception const &e) {
|
||||
std::cout << "error in decoder!!! "<<e.what() <<std::endl;
|
||||
}
|
||||
|
||||
nlohmann::json jsonresult; // result json
|
||||
jsonresult["text"] = asr_result; // put result in 'text'
|
||||
jsonresult["mode"] = "offline";
|
||||
jsonresult["is_final"] = false;
|
||||
if (stamp_res != "") {
|
||||
jsonresult["timestamp"] = stamp_res;
|
||||
}
|
||||
if (stamp_sents != "") {
|
||||
try {
|
||||
nlohmann::json json_stamp = nlohmann::json::parse(stamp_sents);
|
||||
jsonresult["stamp_sents"] = json_stamp;
|
||||
} catch (std::exception const &e) {
|
||||
std::cout << "error:" << e.what();
|
||||
jsonresult["stamp_sents"] = "";
|
||||
}
|
||||
}
|
||||
jsonresult["wav_name"] = wav_name;
|
||||
|
||||
std::cout << "buffer.size=" << buffer->size()
|
||||
<< ",result json=" << jsonresult.dump() << std::endl;
|
||||
|
||||
FunWfstDecoderUnloadHwsRes(session_msg->decoder_handle);
|
||||
FunASRWfstDecoderUninit(session_msg->decoder_handle);
|
||||
session_msg->status = 1;
|
||||
session_msg->msg["asr_result"] = jsonresult;
|
||||
return;
|
||||
} else {
|
||||
std::cout << "Sent empty msg";
|
||||
|
||||
nlohmann::json jsonresult; // result json
|
||||
jsonresult["text"] = ""; // put result in 'text'
|
||||
jsonresult["mode"] = "offline";
|
||||
jsonresult["is_final"] = false;
|
||||
jsonresult["wav_name"] = wav_name;
|
||||
}
|
||||
|
||||
} catch (std::exception const &e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// init asr model
|
||||
FUNASR_HANDLE ModelDecoder::initAsr(std::map<std::string, std::string> &model_path,
|
||||
int thread_num) {
|
||||
try {
|
||||
// init model with api
|
||||
|
||||
asr_handle = FunOfflineInit(model_path, thread_num);
|
||||
LOG(INFO) << "model successfully inited";
|
||||
|
||||
|
||||
return asr_handle;
|
||||
|
||||
} catch (const std::exception &e) {
|
||||
LOG(INFO) << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
60
runtime/http/bin/model-decoder.h
Normal file
60
runtime/http/bin/model-decoder.h
Normal file
@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
|
||||
// funasr asr engine
|
||||
|
||||
#ifndef MODEL_DECODER_SERVER_H_
|
||||
#define MODEL_DECODER_SERVER_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#define ASIO_STANDALONE 1 // not boost
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
|
||||
|
||||
#include "asio.hpp"
|
||||
#include "asr_sessions.h"
|
||||
#include "com-define.h"
|
||||
#include "funasrruntime.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "util/text-utils.h"
|
||||
|
||||
class ModelDecoder {
|
||||
public:
|
||||
ModelDecoder(asio::io_context &io_decoder,
|
||||
std::map<std::string, std::string> &model_path, int thread_num)
|
||||
: io_decoder_(io_decoder) {
|
||||
asr_handle = initAsr(model_path, thread_num);
|
||||
|
||||
}
|
||||
void do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg);
|
||||
|
||||
FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path, int thread_num);
|
||||
|
||||
|
||||
|
||||
asio::io_context &io_decoder_; // threads for asr decoder
|
||||
FUNASR_HANDLE get_asr_handle()
|
||||
{
|
||||
return asr_handle;
|
||||
}
|
||||
private:
|
||||
|
||||
FUNASR_HANDLE asr_handle; // asr engine handle
|
||||
bool isonline = false; // online or offline engine, now only support offline
|
||||
};
|
||||
|
||||
|
||||
#endif // MODEL_DECODER_SERVER_H_
|
||||
245
runtime/http/bin/reply.cpp
Normal file
245
runtime/http/bin/reply.cpp
Normal file
@ -0,0 +1,245 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
// reply.cpp
|
||||
// ~~~~~~~~~
|
||||
//
|
||||
// copy some codes from http://www.boost.org/
|
||||
|
||||
#include "reply.hpp"
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
namespace status_strings {
|
||||
|
||||
const std::string ok = "HTTP/1.0 200 OK\r\n";
|
||||
const std::string created = "HTTP/1.0 201 Created\r\n";
|
||||
const std::string accepted = "HTTP/1.0 202 Accepted\r\n";
|
||||
const std::string no_content = "HTTP/1.0 204 No Content\r\n";
|
||||
const std::string multiple_choices = "HTTP/1.0 300 Multiple Choices\r\n";
|
||||
const std::string moved_permanently = "HTTP/1.0 301 Moved Permanently\r\n";
|
||||
const std::string moved_temporarily = "HTTP/1.0 302 Moved Temporarily\r\n";
|
||||
const std::string not_modified = "HTTP/1.0 304 Not Modified\r\n";
|
||||
const std::string bad_request = "HTTP/1.0 400 Bad Request\r\n";
|
||||
const std::string unauthorized = "HTTP/1.0 401 Unauthorized\r\n";
|
||||
const std::string forbidden = "HTTP/1.0 403 Forbidden\r\n";
|
||||
const std::string not_found = "HTTP/1.0 404 Not Found\r\n";
|
||||
const std::string internal_server_error =
|
||||
"HTTP/1.0 500 Internal Server Error\r\n";
|
||||
const std::string not_implemented = "HTTP/1.0 501 Not Implemented\r\n";
|
||||
const std::string bad_gateway = "HTTP/1.0 502 Bad Gateway\r\n";
|
||||
const std::string service_unavailable = "HTTP/1.0 503 Service Unavailable\r\n";
|
||||
|
||||
asio::const_buffer to_buffer(reply::status_type status) {
|
||||
switch (status) {
|
||||
case reply::ok:
|
||||
return asio::buffer(ok);
|
||||
case reply::created:
|
||||
return asio::buffer(created);
|
||||
case reply::accepted:
|
||||
return asio::buffer(accepted);
|
||||
case reply::no_content:
|
||||
return asio::buffer(no_content);
|
||||
case reply::multiple_choices:
|
||||
return asio::buffer(multiple_choices);
|
||||
case reply::moved_permanently:
|
||||
return asio::buffer(moved_permanently);
|
||||
case reply::moved_temporarily:
|
||||
return asio::buffer(moved_temporarily);
|
||||
case reply::not_modified:
|
||||
return asio::buffer(not_modified);
|
||||
case reply::bad_request:
|
||||
return asio::buffer(bad_request);
|
||||
case reply::unauthorized:
|
||||
return asio::buffer(unauthorized);
|
||||
case reply::forbidden:
|
||||
return asio::buffer(forbidden);
|
||||
case reply::not_found:
|
||||
return asio::buffer(not_found);
|
||||
case reply::internal_server_error:
|
||||
return asio::buffer(internal_server_error);
|
||||
case reply::not_implemented:
|
||||
return asio::buffer(not_implemented);
|
||||
case reply::bad_gateway:
|
||||
return asio::buffer(bad_gateway);
|
||||
case reply::service_unavailable:
|
||||
return asio::buffer(service_unavailable);
|
||||
default:
|
||||
return asio::buffer(internal_server_error);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace status_strings
|
||||
|
||||
namespace misc_strings {
|
||||
|
||||
const char name_value_separator[] = {':', ' '};
|
||||
const char crlf[] = {'\r', '\n'};
|
||||
|
||||
} // namespace misc_strings
|
||||
|
||||
std::vector<::asio::const_buffer> reply::to_buffers() {
|
||||
std::vector<::asio::const_buffer> buffers;
|
||||
buffers.push_back(status_strings::to_buffer(status));
|
||||
for (std::size_t i = 0; i < headers.size(); ++i) {
|
||||
header &h = headers[i];
|
||||
buffers.push_back(asio::buffer(h.name));
|
||||
buffers.push_back(asio::buffer(misc_strings::name_value_separator));
|
||||
buffers.push_back(asio::buffer(h.value));
|
||||
buffers.push_back(asio::buffer(misc_strings::crlf));
|
||||
}
|
||||
buffers.push_back(asio::buffer(misc_strings::crlf));
|
||||
buffers.push_back(asio::buffer(content));
|
||||
|
||||
return buffers;
|
||||
}
|
||||
|
||||
namespace stock_replies {
|
||||
|
||||
const char ok[] = "";
|
||||
const char created[] =
|
||||
"<html>"
|
||||
"<head><title>Created</title></head>"
|
||||
"<body><h1>201 Created</h1></body>"
|
||||
"</html>";
|
||||
const char accepted[] =
|
||||
"<html>"
|
||||
"<head><title>Accepted</title></head>"
|
||||
"<body><h1>202 Accepted</h1></body>"
|
||||
"</html>";
|
||||
const char no_content[] =
|
||||
"<html>"
|
||||
"<head><title>No Content</title></head>"
|
||||
"<body><h1>204 Content</h1></body>"
|
||||
"</html>";
|
||||
const char multiple_choices[] =
|
||||
"<html>"
|
||||
"<head><title>Multiple Choices</title></head>"
|
||||
"<body><h1>300 Multiple Choices</h1></body>"
|
||||
"</html>";
|
||||
const char moved_permanently[] =
|
||||
"<html>"
|
||||
"<head><title>Moved Permanently</title></head>"
|
||||
"<body><h1>301 Moved Permanently</h1></body>"
|
||||
"</html>";
|
||||
const char moved_temporarily[] =
|
||||
"<html>"
|
||||
"<head><title>Moved Temporarily</title></head>"
|
||||
"<body><h1>302 Moved Temporarily</h1></body>"
|
||||
"</html>";
|
||||
const char not_modified[] =
|
||||
"<html>"
|
||||
"<head><title>Not Modified</title></head>"
|
||||
"<body><h1>304 Not Modified</h1></body>"
|
||||
"</html>";
|
||||
const char bad_request[] =
|
||||
"<html>"
|
||||
"<head><title>Bad Request</title></head>"
|
||||
"<body><h1>400 Bad Request</h1></body>"
|
||||
"</html>";
|
||||
const char unauthorized[] =
|
||||
"<html>"
|
||||
"<head><title>Unauthorized</title></head>"
|
||||
"<body><h1>401 Unauthorized</h1></body>"
|
||||
"</html>";
|
||||
const char forbidden[] =
|
||||
"<html>"
|
||||
"<head><title>Forbidden</title></head>"
|
||||
"<body><h1>403 Forbidden</h1></body>"
|
||||
"</html>";
|
||||
const char not_found[] =
|
||||
"<html>"
|
||||
"<head><title>Not Found</title></head>"
|
||||
"<body><h1>404 Not Found</h1></body>"
|
||||
"</html>";
|
||||
const char internal_server_error[] =
|
||||
"<html>"
|
||||
"<head><title>Internal Server Error</title></head>"
|
||||
"<body><h1>500 Internal Server Error</h1></body>"
|
||||
"</html>";
|
||||
const char not_implemented[] =
|
||||
"<html>"
|
||||
"<head><title>Not Implemented</title></head>"
|
||||
"<body><h1>501 Not Implemented</h1></body>"
|
||||
"</html>";
|
||||
const char bad_gateway[] =
|
||||
"<html>"
|
||||
"<head><title>Bad Gateway</title></head>"
|
||||
"<body><h1>502 Bad Gateway</h1></body>"
|
||||
"</html>";
|
||||
const char service_unavailable[] =
|
||||
"<html>"
|
||||
"<head><title>Service Unavailable</title></head>"
|
||||
"<body><h1>503 Service Unavailable</h1></body>"
|
||||
"</html>";
|
||||
|
||||
std::string to_string(reply::status_type status) {
|
||||
switch (status) {
|
||||
case reply::ok:
|
||||
return ok;
|
||||
case reply::created:
|
||||
return created;
|
||||
case reply::accepted:
|
||||
return accepted;
|
||||
case reply::no_content:
|
||||
return no_content;
|
||||
case reply::multiple_choices:
|
||||
return multiple_choices;
|
||||
case reply::moved_permanently:
|
||||
return moved_permanently;
|
||||
case reply::moved_temporarily:
|
||||
return moved_temporarily;
|
||||
case reply::not_modified:
|
||||
return not_modified;
|
||||
case reply::bad_request:
|
||||
return bad_request;
|
||||
case reply::unauthorized:
|
||||
return unauthorized;
|
||||
case reply::forbidden:
|
||||
return forbidden;
|
||||
case reply::not_found:
|
||||
return not_found;
|
||||
case reply::internal_server_error:
|
||||
return internal_server_error;
|
||||
case reply::not_implemented:
|
||||
return not_implemented;
|
||||
case reply::bad_gateway:
|
||||
return bad_gateway;
|
||||
case reply::service_unavailable:
|
||||
return service_unavailable;
|
||||
default:
|
||||
return internal_server_error;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace stock_replies
|
||||
reply reply::stock_reply(std::string jsonresult) {
|
||||
reply rep;
|
||||
rep.status = reply::ok;
|
||||
rep.content = jsonresult+"\n";
|
||||
rep.headers.resize(2);
|
||||
rep.headers[0].name = "Content-Length";
|
||||
rep.headers[0].value = std::to_string(rep.content.size());
|
||||
rep.headers[1].name = "Content-Type";
|
||||
rep.headers[1].value = "text/html;charset=utf-8";
|
||||
return rep;
|
||||
}
|
||||
reply reply::stock_reply(reply::status_type status) {
|
||||
reply rep;
|
||||
rep.status = status;
|
||||
rep.content = stock_replies::to_string(status);
|
||||
rep.headers.resize(2);
|
||||
rep.headers[0].name = "Content-Length";
|
||||
rep.headers[0].value = std::to_string(rep.content.size());
|
||||
rep.headers[1].name = "Content-Type";
|
||||
rep.headers[1].value = "text/html";
|
||||
return rep;
|
||||
}
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
64
runtime/http/bin/reply.hpp
Normal file
64
runtime/http/bin/reply.hpp
Normal file
@ -0,0 +1,64 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
// reply.hpp
|
||||
// ~~~~~~~~~
|
||||
//
|
||||
// copy some codes from http://www.boost.org/
|
||||
|
||||
#ifndef HTTP_SERVER2_REPLY_HPP
|
||||
#define HTTP_SERVER2_REPLY_HPP
|
||||
|
||||
#include <asio.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "header.hpp"
|
||||
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
/// A reply to be sent to a client.
|
||||
struct reply {
|
||||
/// The status of the reply.
|
||||
enum status_type {
|
||||
ok = 200,
|
||||
created = 201,
|
||||
accepted = 202,
|
||||
no_content = 204,
|
||||
multiple_choices = 300,
|
||||
moved_permanently = 301,
|
||||
moved_temporarily = 302,
|
||||
not_modified = 304,
|
||||
bad_request = 400,
|
||||
unauthorized = 401,
|
||||
forbidden = 403,
|
||||
not_found = 404,
|
||||
internal_server_error = 500,
|
||||
not_implemented = 501,
|
||||
bad_gateway = 502,
|
||||
service_unavailable = 503
|
||||
} status;
|
||||
|
||||
/// The headers to be included in the reply.
|
||||
std::vector<header> headers;
|
||||
|
||||
/// The content to be sent in the reply.
|
||||
std::string content;
|
||||
|
||||
/// Convert the reply into a vector of buffers. The buffers do not own the
|
||||
/// underlying memory blocks, therefore the reply object must remain valid and
|
||||
/// not be changed until the write operation has completed.
|
||||
std::vector<::asio::const_buffer> to_buffers();
|
||||
|
||||
/// Get a stock reply.
|
||||
static reply stock_reply(status_type status);
|
||||
static reply stock_reply(std::string jsonresult);
|
||||
};
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
|
||||
#endif // HTTP_SERVER2_REPLY_HPP
|
||||
113
runtime/http/bin/server.cpp
Normal file
113
runtime/http/bin/server.cpp
Normal file
@ -0,0 +1,113 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
//
|
||||
// server.cpp
|
||||
// copy some codes from http://www.boost.org/
|
||||
|
||||
#include "server.hpp"
|
||||
|
||||
#include <signal.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "util.h"
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
server::server(const std::string &address, const std::string &port,
|
||||
const std::string &doc_root, std::size_t io_context_pool_size,
|
||||
asio::io_context &decoder_context,
|
||||
std::map<std::string, std::string> &model_path, int thread_num)
|
||||
: io_context_pool_(io_context_pool_size),
|
||||
signals_(io_context_pool_.get_io_context()),
|
||||
acceptor_(io_context_pool_.get_io_context()),
|
||||
decoder_context(decoder_context) {
|
||||
// Register to handle the signals that indicate when the server should exit.
|
||||
// It is safe to register for the same signal multiple times in a program,
|
||||
// provided all registration for the specified signal is made through Asio.
|
||||
try {
|
||||
model_decoder =
|
||||
std::make_shared<ModelDecoder>(decoder_context, model_path, thread_num);
|
||||
|
||||
LOG(INFO) << "try to listen on port:" << port << std::endl;
|
||||
LOG(INFO) << "still not work, pls wait... " << std::endl;
|
||||
LOG(INFO) << "if always waiting here, may be port in used, pls change the "
|
||||
"port or kill pre-process!"
|
||||
<< std::endl;
|
||||
|
||||
atom_id = 0;
|
||||
|
||||
// init model with api
|
||||
|
||||
signals_.add(SIGINT);
|
||||
signals_.add(SIGTERM);
|
||||
#if defined(SIGQUIT)
|
||||
signals_.add(SIGQUIT);
|
||||
#endif // defined(SIGQUIT)
|
||||
|
||||
do_await_stop();
|
||||
|
||||
// Open the acceptor with the option to reuse the address (i.e.
|
||||
// SO_REUSEADDR).
|
||||
asio::ip::tcp::resolver resolver(acceptor_.get_executor());
|
||||
asio::ip::tcp::endpoint endpoint = *resolver.resolve(address, port).begin();
|
||||
|
||||
acceptor_.open(endpoint.protocol());
|
||||
acceptor_.set_option(asio::ip::tcp::acceptor::reuse_address(true));
|
||||
|
||||
acceptor_.bind(endpoint);
|
||||
|
||||
acceptor_.listen();
|
||||
|
||||
do_accept();
|
||||
std::cout << "use curl to test,just as " << std::endl;
|
||||
std::cout << "curl -F \"file=@example.wav\" 127.0.0.1:80" << std::endl;
|
||||
|
||||
std::cout << "http post only support offline mode, if you want online "
|
||||
"mode, pls try websocket!"
|
||||
<< std::endl;
|
||||
std::cout << "now succeed listen on port " << address << ":" << port
|
||||
<< ", can accept data now!!!" << std::endl;
|
||||
} catch (const std::exception &e) {
|
||||
std::cout << "error:" << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void server::run() { io_context_pool_.run(); }
|
||||
|
||||
void server::do_accept() {
|
||||
acceptor_.async_accept(
|
||||
io_context_pool_.get_io_context(),
|
||||
[this](asio::error_code ec, asio::ip::tcp::socket socket) {
|
||||
// Check whether the server was stopped by a signal before this
|
||||
// completion handler had a chance to run.
|
||||
if (!acceptor_.is_open()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ec) {
|
||||
std::lock_guard<std::mutex> lk(m_lock);
|
||||
atom_id = atom_id + 1;
|
||||
|
||||
std::make_shared<connection>(std::move(socket), decoder_context,
|
||||
(atom_id).load(), model_decoder)
|
||||
->start();
|
||||
}
|
||||
|
||||
do_accept();
|
||||
});
|
||||
}
|
||||
|
||||
void server::do_await_stop() {
|
||||
signals_.async_wait([this](asio::error_code /*ec*/, int /*signo*/) {
|
||||
io_context_pool_.stop();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
71
runtime/http/bin/server.hpp
Normal file
71
runtime/http/bin/server.hpp
Normal file
@ -0,0 +1,71 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023-2024 by zhaomingwork@qq.com */
|
||||
//
|
||||
// server.hpp
|
||||
// ~~~~~~~~~~
|
||||
// copy some codes from http://www.boost.org/
|
||||
|
||||
#ifndef HTTP_SERVER2_SERVER_HPP
|
||||
#define HTTP_SERVER2_SERVER_HPP
|
||||
#include <asio.hpp>
|
||||
#include <atomic>
|
||||
#include <string>
|
||||
|
||||
#include "connection.hpp"
|
||||
#include "funasrruntime.h"
|
||||
#include "io_context_pool.hpp"
|
||||
#include "model-decoder.h"
|
||||
#include "util.h"
|
||||
namespace http {
|
||||
namespace server2 {
|
||||
|
||||
/// The top-level class of the HTTP server.
|
||||
class server {
|
||||
public:
|
||||
server(const server &) = delete;
|
||||
server &operator=(const server &) = delete;
|
||||
|
||||
/// Construct the server to listen on the specified TCP address and port, and
|
||||
/// serve up files from the given directory.
|
||||
explicit server(const std::string &address, const std::string &port,
|
||||
const std::string &doc_root, std::size_t io_context_pool_size,
|
||||
asio::io_context &decoder_context,
|
||||
std::map<std::string, std::string> &model_path,
|
||||
int thread_num);
|
||||
|
||||
/// Run the server's io_context loop.
|
||||
void run();
|
||||
|
||||
private:
|
||||
/// Perform an asynchronous accept operation.
|
||||
void do_accept();
|
||||
|
||||
/// Wait for a request to stop the server.
|
||||
void do_await_stop();
|
||||
|
||||
/// The pool of io_context objects used to perform asynchronous operations.
|
||||
io_context_pool io_context_pool_;
|
||||
|
||||
asio::io_context &decoder_context;
|
||||
|
||||
/// The signal_set is used to register for process termination notifications.
|
||||
asio::signal_set signals_;
|
||||
|
||||
/// Acceptor used to listen for incoming connections.
|
||||
asio::ip::tcp::acceptor acceptor_;
|
||||
|
||||
|
||||
|
||||
std::shared_ptr<ModelDecoder> model_decoder;
|
||||
|
||||
std::atomic<int> atom_id;
|
||||
std::mutex m_lock;
|
||||
};
|
||||
|
||||
} // namespace server2
|
||||
} // namespace http
|
||||
|
||||
#endif // HTTP_SERVER2_SERVER_HPP
|
||||
58
runtime/http/readme.md
Normal file
58
runtime/http/readme.md
Normal file
@ -0,0 +1,58 @@
|
||||
# Advanced Development Guide (File transcription service) ([click](../docs/SDK_advanced_guide_offline.md))
|
||||
# Real-time Speech Transcription Service Development Guide ([click](../docs/SDK_advanced_guide_online.md))
|
||||
|
||||
|
||||
# If you want to compile the file yourself, you can follow the steps below.
|
||||
## Building for Linux/Unix
|
||||
### Download onnxruntime
|
||||
```shell
|
||||
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
|
||||
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
|
||||
```
|
||||
|
||||
### Download ffmpeg
|
||||
```shell
|
||||
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
|
||||
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
|
||||
```
|
||||
|
||||
### Install deps
|
||||
```shell
|
||||
# openblas
|
||||
sudo apt-get install libopenblas-dev #ubuntu
|
||||
# sudo yum -y install openblas-devel #centos
|
||||
|
||||
# openssl
|
||||
apt-get install libssl-dev #ubuntu
|
||||
# yum install openssl-devel #centos
|
||||
```
|
||||
|
||||
### Build runtime
|
||||
```shell
|
||||
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
|
||||
mkdir build && cd build
|
||||
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
|
||||
make -j 4
|
||||
```
|
||||
|
||||
### test
|
||||
|
||||
```shell
|
||||
curl -F \"file=@example.wav\" 127.0.0.1:80
|
||||
```
|
||||
|
||||
### run
|
||||
|
||||
```shell
|
||||
./funasr-http-server \
|
||||
--lm-dir '' \
|
||||
--itn-dir '' \
|
||||
--download-model-dir ${download_model_dir} \
|
||||
--model-dir ${model_dir} \
|
||||
--vad-dir ${vad_dir} \
|
||||
--punc-dir ${punc_dir} \
|
||||
--decoder-thread-num ${decoder_thread_num} \
|
||||
--io-thread-num ${io_thread_num} \
|
||||
--port ${port} \
|
||||
```
|
||||
|
||||
61
runtime/http/readme_zh.md
Normal file
61
runtime/http/readme_zh.md
Normal file
@ -0,0 +1,61 @@
|
||||
# FunASR离线文件转写服务开发指南([点击此处](../docs/SDK_advanced_guide_offline_zh.md))
|
||||
|
||||
# FunASR实时语音听写服务开发指南([点击此处](../docs/SDK_advanced_guide_online_zh.md))
|
||||
|
||||
# 如果您想自己编译文件,可以参考下述步骤
|
||||
## Linux/Unix 平台编译
|
||||
### 下载 onnxruntime
|
||||
```shell
|
||||
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
|
||||
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
|
||||
```
|
||||
|
||||
### 下载 ffmpeg
|
||||
```shell
|
||||
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
|
||||
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
|
||||
```
|
||||
|
||||
### 安装依赖
|
||||
```shell
|
||||
# openblas
|
||||
sudo apt-get install libopenblas-dev #ubuntu
|
||||
# sudo yum -y install openblas-devel #centos
|
||||
|
||||
# openssl
|
||||
apt-get install libssl-dev #ubuntu
|
||||
# yum install openssl-devel #centos
|
||||
```
|
||||
|
||||
### 编译 runtime
|
||||
|
||||
```shell
|
||||
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
|
||||
mkdir build && cd build
|
||||
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
|
||||
make -j 4
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```shell
|
||||
curl -F \"file=@example.wav\" 127.0.0.1:80
|
||||
```
|
||||
|
||||
### 运行
|
||||
|
||||
```shell
|
||||
./funasr-http-server \
|
||||
--lm-dir '' \
|
||||
--itn-dir '' \
|
||||
--download-model-dir ${download_model_dir} \
|
||||
--model-dir ${model_dir} \
|
||||
--vad-dir ${vad_dir} \
|
||||
--punc-dir ${punc_dir} \
|
||||
--decoder-thread-num ${decoder_thread_num} \
|
||||
--io-thread-num ${io_thread_num} \
|
||||
--port ${port} \
|
||||
```
|
||||
|
||||
|
||||
|
||||
15
runtime/http/requirements_install.md
Normal file
15
runtime/http/requirements_install.md
Normal file
@ -0,0 +1,15 @@
|
||||
#### Download onnxruntime
|
||||
```shell
|
||||
bash third_party/download_onnxruntime.sh
|
||||
```
|
||||
|
||||
#### Download ffmpeg
|
||||
```shell
|
||||
bash third_party/download_ffmpeg.sh
|
||||
```
|
||||
|
||||
#### Install openblas and openssl
|
||||
```shell
|
||||
sudo apt-get install libopenblas-dev libssl-dev #ubuntu
|
||||
# sudo yum -y install openblas-devel openssl-devel #centos
|
||||
```
|
||||
@ -4,6 +4,7 @@ project(FunASROnnx)
|
||||
|
||||
option(ENABLE_GLOG "Whether to build glog" ON)
|
||||
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
|
||||
option(GPU "Whether to build with GPU" OFF)
|
||||
|
||||
# set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
@ -49,6 +50,17 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include/limonp/inclu
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
|
||||
|
||||
if(GPU)
|
||||
add_definitions(-DUSE_GPU)
|
||||
set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
|
||||
set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
|
||||
include_directories(${TORCH_DIR}/include)
|
||||
include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
|
||||
link_directories(${TORCH_DIR}/lib)
|
||||
link_directories(${TORCH_BLADE_DIR})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
endif()
|
||||
|
||||
if(ENABLE_GLOG)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
|
||||
set(BUILD_TESTING OFF)
|
||||
|
||||
@ -10,33 +10,43 @@ SET(RELATION_SOURCE "../src/resample.cpp" "../src/util.cpp" "../src/alignedmem.c
|
||||
endif()
|
||||
|
||||
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
|
||||
target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
|
||||
|
||||
# include_directories(${FFMPEG_DIR}/include)
|
||||
|
||||
@ -52,7 +52,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
|
||||
|
||||
// warm up
|
||||
for (size_t i = 0; i < 1; i++)
|
||||
for (size_t i = 0; i < 10; i++)
|
||||
{
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
|
||||
if(result){
|
||||
@ -127,6 +127,7 @@ int main(int argc, char *argv[])
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
@ -140,11 +141,14 @@ int main(int argc, char *argv[])
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
|
||||
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
||||
TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
|
||||
TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(bladedisc);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
@ -159,11 +163,14 @@ int main(int argc, char *argv[])
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(thread_num);
|
||||
cmd.add(use_gpu);
|
||||
cmd.add(batch_size);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(bladedisc, BLADEDISC, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
@ -175,7 +182,9 @@ int main(int argc, char *argv[])
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, nullptr);
|
||||
FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1);
|
||||
bool use_gpu_ = use_gpu.getValue();
|
||||
int batch_size_ = batch_size.getValue();
|
||||
FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1, use_gpu_, batch_size_);
|
||||
|
||||
if (!asr_handle)
|
||||
{
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "com-define.h"
|
||||
#include <unordered_map>
|
||||
#include "util.h"
|
||||
#include "audio.h"
|
||||
using namespace std;
|
||||
|
||||
bool is_target_file(const std::string& filename, const std::string target) {
|
||||
@ -44,6 +45,7 @@ int main(int argc, char** argv)
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
@ -57,9 +59,12 @@ int main(int argc, char** argv)
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
|
||||
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
|
||||
TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
|
||||
TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(bladedisc);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
@ -73,11 +78,14 @@ int main(int argc, char** argv)
|
||||
cmd.add(wav_path);
|
||||
cmd.add(audio_fs);
|
||||
cmd.add(hotword);
|
||||
cmd.add(use_gpu);
|
||||
cmd.add(batch_size);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(bladedisc, BLADEDISC, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
@ -89,7 +97,9 @@ int main(int argc, char** argv)
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, nullptr);
|
||||
int thread_num = 1;
|
||||
FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
|
||||
bool use_gpu_ = use_gpu.getValue();
|
||||
int batch_size_ = batch_size.getValue();
|
||||
FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num, use_gpu_, batch_size_);
|
||||
|
||||
if (!asr_hanlde)
|
||||
{
|
||||
@ -156,7 +166,33 @@ int main(int argc, char** argv)
|
||||
for (int i = 0; i < wav_list.size(); i++) {
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
gettimeofday(&start, nullptr);
|
||||
|
||||
// For debug:begin
|
||||
// int32_t sampling_rate_ = audio_fs.getValue();
|
||||
// funasr::Audio audio(1);
|
||||
// if(is_target_file(wav_file.c_str(), "wav")){
|
||||
// if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
// LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
// exit(-1);
|
||||
// }
|
||||
// }else if(is_target_file(wav_file.c_str(), "pcm")){
|
||||
// if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
// LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
// exit(-1);
|
||||
// }
|
||||
// }else{
|
||||
// if (!audio.FfmpegLoad(wav_file.c_str(), true)){
|
||||
// LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
// exit(-1);
|
||||
// }
|
||||
// }
|
||||
// char* speech_buff = audio.GetSpeechChar();
|
||||
// int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
// gettimeofday(&start, nullptr);
|
||||
// FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle);
|
||||
// For debug:end
|
||||
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
|
||||
gettimeofday(&end, nullptr);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
|
||||
@ -83,9 +83,11 @@ class DLLAPI Audio {
|
||||
int FetchTpass(AudioFrame *&frame);
|
||||
int Fetch(float *&dout, int &len, int &flag);
|
||||
int Fetch(float *&dout, int &len, int &flag, float &start_time);
|
||||
int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
|
||||
int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
|
||||
void Padding();
|
||||
void Split(OfflineStream* offline_streamj);
|
||||
void CutSplit(OfflineStream* offline_streamj);
|
||||
void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
|
||||
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
|
||||
void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
|
||||
float GetTimeLen();
|
||||
|
||||
@ -51,6 +51,15 @@ namespace funasr {
|
||||
#define QUANT_MODEL_NAME "model_quant.onnx"
|
||||
#define VAD_CMVN_NAME "am.mvn"
|
||||
#define VAD_CONFIG_NAME "config.yaml"
|
||||
|
||||
// gpu models
|
||||
#define INFER_GPU "gpu"
|
||||
#define BATCHSIZE "batch-size"
|
||||
#define TORCH_MODEL_NAME "model.torchscripts"
|
||||
#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
|
||||
#define BLADE_MODEL_NAME "model.blade.fp16.pt"
|
||||
#define BLADEDISC "bladedisc"
|
||||
|
||||
#define AM_CMVN_NAME "am.mvn"
|
||||
#define AM_CONFIG_NAME "config.yaml"
|
||||
#define LM_CONFIG_NAME "config.yaml"
|
||||
|
||||
@ -96,7 +96,7 @@ _FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result);
|
||||
_FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle);
|
||||
|
||||
//OfflineStream
|
||||
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
|
||||
_FUNASRAPI void FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr);
|
||||
// buffer
|
||||
_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len,
|
||||
@ -106,9 +106,9 @@ _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char*
|
||||
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode,
|
||||
QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb,
|
||||
int sampling_rate=16000, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
|
||||
#if !defined(__APPLE__)
|
||||
//#if !defined(__APPLE__)
|
||||
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode=ASR_OFFLINE);
|
||||
#endif
|
||||
//#endif
|
||||
|
||||
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
|
||||
|
||||
|
||||
@ -5,6 +5,10 @@
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "funasrruntime.h"
|
||||
#include "vocab.h"
|
||||
#include "phone-set.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fst/symbol-table.h"
|
||||
namespace funasr {
|
||||
class Model {
|
||||
public:
|
||||
@ -18,13 +22,19 @@ class Model {
|
||||
virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
|
||||
virtual void InitFstDecoder(){};
|
||||
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
|
||||
virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1)
|
||||
{return std::vector<string>();};
|
||||
virtual std::string Rescoring() = 0;
|
||||
virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
|
||||
virtual void InitSegDict(const std::string &seg_dict_model){};
|
||||
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
|
||||
virtual std::string GetLang(){return "";};
|
||||
virtual int GetAsrSampleRate() = 0;
|
||||
|
||||
virtual void SetBatchSize(int batch_size) {};
|
||||
virtual int GetBatchSize() {return 0;};
|
||||
virtual Vocab* GetVocab() {return nullptr;};
|
||||
virtual Vocab* GetLmVocab() {return nullptr;};
|
||||
virtual PhoneSet* GetPhoneSet() {return nullptr;};
|
||||
};
|
||||
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
namespace funasr {
|
||||
class OfflineStream {
|
||||
public:
|
||||
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
|
||||
~OfflineStream(){};
|
||||
|
||||
std::unique_ptr<VadModel> vad_handle= nullptr;
|
||||
@ -33,6 +33,6 @@ class OfflineStream {
|
||||
bool use_itn=false;
|
||||
};
|
||||
|
||||
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
|
||||
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1);
|
||||
} // namespace funasr
|
||||
#endif
|
||||
|
||||
@ -1,7 +1,16 @@
|
||||
|
||||
file(GLOB files1 "*.cpp")
|
||||
if(APPLE)
|
||||
file(GLOB itn_files "itn-*.cpp")
|
||||
list(REMOVE_ITEM files1 ${itn_files})
|
||||
endif(APPLE)
|
||||
list(REMOVE_ITEM files1 "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
|
||||
set(files ${files1})
|
||||
|
||||
if(GPU)
|
||||
set(files ${files} "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
|
||||
endif()
|
||||
|
||||
message("files: "${files})
|
||||
|
||||
if(WIN32)
|
||||
@ -23,9 +32,17 @@ else()
|
||||
set(EXTRA_LIBS pthread yaml-cpp csrc kaldi-decoder fst glog gflags avutil avcodec avformat swresample)
|
||||
include_directories(${ONNXRUNTIME_DIR}/include)
|
||||
include_directories(${FFMPEG_DIR}/include)
|
||||
if(APPLE)
|
||||
target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
|
||||
target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
|
||||
endif(APPLE)
|
||||
endif()
|
||||
|
||||
if(GPU)
|
||||
set(TORCH_DEPS torch torch_cuda torch_cpu c10 c10_cuda torch_blade ral_base_context)
|
||||
endif()
|
||||
|
||||
#message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
|
||||
include_directories(${CMAKE_SOURCE_DIR}/include)
|
||||
include_directories(${CMAKE_SOURCE_DIR}/third_party)
|
||||
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
|
||||
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS} ${TORCH_DEPS})
|
||||
|
||||
@ -1023,6 +1023,90 @@ int Audio::Fetch(float *&dout, int &len, int &flag, float &start_time)
|
||||
}
|
||||
}
|
||||
|
||||
int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
|
||||
{
|
||||
batch_in = std::min((int)frame_queue.size(), batch_size);
|
||||
if (batch_in == 0){
|
||||
return 0;
|
||||
} else{
|
||||
// init
|
||||
dout = new float*[batch_in];
|
||||
len = new int[batch_in];
|
||||
flag = new int[batch_in];
|
||||
start_time = new float[batch_in];
|
||||
|
||||
for(int idx=0; idx < batch_in; idx++){
|
||||
AudioFrame *frame = frame_queue.front();
|
||||
frame_queue.pop();
|
||||
|
||||
start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
|
||||
dout[idx] = speech_data + frame->GetStart();
|
||||
len[idx] = frame->GetLen();
|
||||
delete frame;
|
||||
flag[idx] = S_END;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
|
||||
{
|
||||
//compute batch size
|
||||
queue<AudioFrame *> frame_batch;
|
||||
int max_acc = 300*1000*seg_sample;
|
||||
int max_sent = 60*1000*seg_sample;
|
||||
int bs_acc = 0;
|
||||
int max_len = 0;
|
||||
int max_batch = 1;
|
||||
#ifdef USE_GPU
|
||||
max_batch = batch_size;
|
||||
#endif
|
||||
max_batch = std::min(max_batch, (int)frame_queue.size());
|
||||
|
||||
for(int idx=0; idx < max_batch; idx++){
|
||||
AudioFrame *frame = frame_queue.front();
|
||||
int length = frame->GetLen();
|
||||
if(length >= max_sent){
|
||||
if(bs_acc == 0){
|
||||
bs_acc++;
|
||||
frame_batch.push(frame);
|
||||
frame_queue.pop();
|
||||
}
|
||||
break;
|
||||
}
|
||||
max_len = std::max(max_len, frame->GetLen());
|
||||
if(max_len*(bs_acc+1) > max_acc){
|
||||
break;
|
||||
}
|
||||
bs_acc++;
|
||||
frame_batch.push(frame);
|
||||
frame_queue.pop();
|
||||
}
|
||||
|
||||
batch_in = (int)frame_batch.size();
|
||||
if (batch_in == 0){
|
||||
return 0;
|
||||
} else{
|
||||
// init
|
||||
dout = new float*[batch_in];
|
||||
len = new int[batch_in];
|
||||
flag = new int[batch_in];
|
||||
start_time = new float[batch_in];
|
||||
|
||||
for(int idx=0; idx < batch_in; idx++){
|
||||
AudioFrame *frame = frame_batch.front();
|
||||
frame_batch.pop();
|
||||
|
||||
start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
|
||||
dout[idx] = speech_data + frame->GetStart();
|
||||
len[idx] = frame->GetLen();
|
||||
delete frame;
|
||||
flag[idx] = S_END;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
void Audio::Padding()
|
||||
{
|
||||
float num_samples = speech_len;
|
||||
@ -1085,7 +1169,7 @@ void Audio::Split(OfflineStream* offline_stream)
|
||||
}
|
||||
}
|
||||
|
||||
void Audio::CutSplit(OfflineStream* offline_stream)
|
||||
void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
|
||||
{
|
||||
std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
|
||||
AudioFrame *frame;
|
||||
@ -1112,6 +1196,7 @@ void Audio::CutSplit(OfflineStream* offline_stream)
|
||||
}
|
||||
|
||||
int speech_start_i = -1, speech_end_i =-1;
|
||||
std::vector<AudioFrame*> vad_frames;
|
||||
for(vector<int> vad_segment:vad_segments)
|
||||
{
|
||||
if(vad_segment.size() != 2){
|
||||
@ -1126,16 +1211,31 @@ void Audio::CutSplit(OfflineStream* offline_stream)
|
||||
}
|
||||
|
||||
if(speech_start_i!=-1 && speech_end_i!=-1){
|
||||
frame = new AudioFrame();
|
||||
int start = speech_start_i*seg_sample;
|
||||
int end = speech_end_i*seg_sample;
|
||||
frame = new AudioFrame(end-start);
|
||||
frame->SetStart(start);
|
||||
frame->SetEnd(end);
|
||||
frame_queue.push(frame);
|
||||
vad_frames.push_back(frame);
|
||||
frame = nullptr;
|
||||
speech_start_i=-1;
|
||||
speech_end_i=-1;
|
||||
}
|
||||
|
||||
}
|
||||
// sort
|
||||
{
|
||||
index_vector.clear();
|
||||
index_vector.resize(vad_frames.size());
|
||||
for (int i = 0; i < index_vector.size(); ++i) {
|
||||
index_vector[i] = i;
|
||||
}
|
||||
std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
|
||||
return vad_frames[a]->len < vad_frames[b]->len;
|
||||
});
|
||||
for (int idx : index_vector) {
|
||||
frame_queue.push(vad_frames[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -33,9 +33,9 @@
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
|
||||
{
|
||||
funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
|
||||
funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
|
||||
return mm;
|
||||
}
|
||||
|
||||
@ -74,16 +74,11 @@
|
||||
if(p_result->snippet_time == 0){
|
||||
return p_result;
|
||||
}
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
string msg = recog_obj->Forward(buff, len, input_finished);
|
||||
p_result->msg += msg;
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
|
||||
return p_result;
|
||||
}
|
||||
|
||||
@ -109,8 +104,6 @@
|
||||
float* buff;
|
||||
int len;
|
||||
int flag = 0;
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
if(p_result->snippet_time == 0){
|
||||
@ -119,11 +112,7 @@
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
string msg = recog_obj->Forward(buff, len, true);
|
||||
p_result->msg += msg;
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
|
||||
return p_result;
|
||||
}
|
||||
|
||||
@ -244,26 +233,53 @@
|
||||
if(p_result->snippet_time == 0){
|
||||
return p_result;
|
||||
}
|
||||
std::vector<int> index_vector={0};
|
||||
int msg_idx = 0;
|
||||
if(offline_stream->UseVad()){
|
||||
audio.CutSplit(offline_stream);
|
||||
audio.CutSplit(offline_stream, index_vector);
|
||||
}
|
||||
std::vector<string> msgs(index_vector.size());
|
||||
std::vector<float> msg_stimes(index_vector.size());
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
int flag = 0;
|
||||
float** buff;
|
||||
int* len;
|
||||
int* flag;
|
||||
float* start_time;
|
||||
int batch_size = offline_stream->asr_handle->GetBatchSize();
|
||||
int batch_in = 0;
|
||||
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
float start_time = 0.0;
|
||||
std::string cur_stamp = "[";
|
||||
std::string lang = (offline_stream->asr_handle)->GetLang();
|
||||
while (audio.Fetch(buff, len, flag, start_time) > 0) {
|
||||
while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
|
||||
// dec reset
|
||||
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
|
||||
if (wfst_decoder){
|
||||
wfst_decoder->StartUtterance();
|
||||
}
|
||||
string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
|
||||
vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
|
||||
for(int idx=0; idx<batch_in; idx++){
|
||||
string msg = msg_batch[idx];
|
||||
if(msg_idx < index_vector.size()){
|
||||
msgs[index_vector[msg_idx]] = msg;
|
||||
msg_stimes[index_vector[msg_idx]] = start_time[idx];
|
||||
msg_idx++;
|
||||
}else{
|
||||
LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
|
||||
}
|
||||
}
|
||||
|
||||
// release
|
||||
delete[] buff;
|
||||
buff = nullptr;
|
||||
delete[] len;
|
||||
len = nullptr;
|
||||
delete[] flag;
|
||||
flag = nullptr;
|
||||
delete[] start_time;
|
||||
start_time = nullptr;
|
||||
}
|
||||
for(int idx=0; idx<msgs.size(); idx++){
|
||||
string msg = msgs[idx];
|
||||
std::vector<std::string> msg_vec = funasr::split(msg, '|');
|
||||
if(msg_vec.size()==0){
|
||||
continue;
|
||||
@ -276,14 +292,11 @@
|
||||
if(msg_vec.size() > 1){
|
||||
std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
|
||||
for(int i=0; i<msg_stamp.size()-1; i+=2){
|
||||
float begin = std::stof(msg_stamp[i])+start_time;
|
||||
float end = std::stof(msg_stamp[i+1])+start_time;
|
||||
float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
|
||||
float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
|
||||
cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
|
||||
}
|
||||
}
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
if(cur_stamp != "["){
|
||||
cur_stamp.erase(cur_stamp.length() - 1);
|
||||
@ -342,25 +355,53 @@
|
||||
if(p_result->snippet_time == 0){
|
||||
return p_result;
|
||||
}
|
||||
std::vector<int> index_vector={0};
|
||||
int msg_idx = 0;
|
||||
if(offline_stream->UseVad()){
|
||||
audio.CutSplit(offline_stream);
|
||||
audio.CutSplit(offline_stream, index_vector);
|
||||
}
|
||||
std::vector<string> msgs(index_vector.size());
|
||||
std::vector<float> msg_stimes(index_vector.size());
|
||||
|
||||
float** buff;
|
||||
int* len;
|
||||
int* flag;
|
||||
float* start_time;
|
||||
int batch_size = offline_stream->asr_handle->GetBatchSize();
|
||||
int batch_in = 0;
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
int flag = 0;
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
float start_time = 0.0;
|
||||
std::string cur_stamp = "[";
|
||||
std::string lang = (offline_stream->asr_handle)->GetLang();
|
||||
while (audio.Fetch(buff, len, flag, start_time) > 0) {
|
||||
while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
|
||||
// dec reset
|
||||
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
|
||||
if (wfst_decoder){
|
||||
wfst_decoder->StartUtterance();
|
||||
}
|
||||
string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
|
||||
vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
|
||||
for(int idx=0; idx<batch_in; idx++){
|
||||
string msg = msg_batch[idx];
|
||||
if(msg_idx < index_vector.size()){
|
||||
msgs[index_vector[msg_idx]] = msg;
|
||||
msg_stimes[index_vector[msg_idx]] = start_time[idx];
|
||||
msg_idx++;
|
||||
}else{
|
||||
LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
|
||||
}
|
||||
}
|
||||
|
||||
// release
|
||||
delete[] buff;
|
||||
buff = nullptr;
|
||||
delete[] len;
|
||||
len = nullptr;
|
||||
delete[] flag;
|
||||
flag = nullptr;
|
||||
delete[] start_time;
|
||||
start_time = nullptr;
|
||||
}
|
||||
for(int idx=0; idx<msgs.size(); idx++){
|
||||
string msg = msgs[idx];
|
||||
std::vector<std::string> msg_vec = funasr::split(msg, '|');
|
||||
if(msg_vec.size()==0){
|
||||
continue;
|
||||
@ -373,15 +414,11 @@
|
||||
if(msg_vec.size() > 1){
|
||||
std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
|
||||
for(int i=0; i<msg_stamp.size()-1; i+=2){
|
||||
float begin = std::stof(msg_stamp[i])+start_time;
|
||||
float end = std::stof(msg_stamp[i+1])+start_time;
|
||||
float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
|
||||
float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
|
||||
cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
|
||||
}
|
||||
}
|
||||
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
fn_callback(n_step, n_total);
|
||||
}
|
||||
if(cur_stamp != "["){
|
||||
cur_stamp.erase(cur_stamp.length() - 1);
|
||||
@ -409,7 +446,7 @@
|
||||
return p_result;
|
||||
}
|
||||
|
||||
#if !defined(__APPLE__)
|
||||
//#if !defined(__APPLE__)
|
||||
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode)
|
||||
{
|
||||
if (mode == ASR_OFFLINE){
|
||||
@ -433,7 +470,7 @@
|
||||
}
|
||||
|
||||
}
|
||||
#endif
|
||||
//#endif
|
||||
|
||||
// APIs for 2pass-stream Infer
|
||||
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
|
||||
@ -518,8 +555,14 @@
|
||||
if (wfst_decoder){
|
||||
wfst_decoder->StartUtterance();
|
||||
}
|
||||
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
|
||||
|
||||
float** buff;
|
||||
int* len;
|
||||
buff = new float*[1];
|
||||
len = new int[1];
|
||||
buff[0] = frame->data;
|
||||
len[0] = frame->len;
|
||||
vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
|
||||
string msg = msgs.size()>0?msgs[0]:"";
|
||||
std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
|
||||
if(msg_vec.size()==0){
|
||||
continue;
|
||||
@ -767,16 +810,45 @@
|
||||
funasr::WfstDecoder* mm = nullptr;
|
||||
if (asr_type == ASR_OFFLINE) {
|
||||
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
|
||||
funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
|
||||
if (paraformer->lm_)
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
|
||||
if(paraformer !=nullptr){
|
||||
if (paraformer->lm_){
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
}
|
||||
return mm;
|
||||
}
|
||||
#ifdef USE_GPU
|
||||
auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
|
||||
if(paraformer_torch !=nullptr){
|
||||
if (paraformer_torch->lm_){
|
||||
mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
|
||||
paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
}
|
||||
return mm;
|
||||
}
|
||||
#endif
|
||||
|
||||
} else if (asr_type == ASR_TWO_PASS){
|
||||
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
|
||||
funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
|
||||
if (paraformer->lm_)
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
|
||||
if(paraformer !=nullptr){
|
||||
if (paraformer->lm_){
|
||||
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
|
||||
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
}
|
||||
return mm;
|
||||
}
|
||||
#ifdef USE_GPU
|
||||
auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
|
||||
if(paraformer_torch !=nullptr){
|
||||
if (paraformer_torch->lm_){
|
||||
mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
|
||||
paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
|
||||
}
|
||||
return mm;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return mm;
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
#include "precomp.h"
|
||||
|
||||
namespace funasr {
|
||||
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
|
||||
{
|
||||
// VAD model
|
||||
if(model_path.find(VAD_DIR) != model_path.end()){
|
||||
@ -36,7 +36,19 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
|
||||
string hw_compile_model_path;
|
||||
string seg_dict_path;
|
||||
|
||||
asr_handle = make_unique<Paraformer>();
|
||||
if(use_gpu){
|
||||
#ifdef USE_GPU
|
||||
asr_handle = make_unique<ParaformerTorch>();
|
||||
asr_handle->SetBatchSize(batch_size);
|
||||
#else
|
||||
LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
|
||||
asr_handle = make_unique<Paraformer>();
|
||||
use_gpu = false;
|
||||
#endif
|
||||
}else{
|
||||
asr_handle = make_unique<Paraformer>();
|
||||
}
|
||||
|
||||
bool enable_hotword = false;
|
||||
hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
|
||||
seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
|
||||
@ -55,6 +67,15 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
if(use_gpu){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
|
||||
}
|
||||
if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
|
||||
}
|
||||
}
|
||||
}
|
||||
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
|
||||
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
|
||||
@ -120,10 +141,10 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
|
||||
#endif
|
||||
}
|
||||
|
||||
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
|
||||
{
|
||||
OfflineStream *mm;
|
||||
mm = new OfflineStream(model_path, thread_num);
|
||||
mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
|
||||
return mm;
|
||||
}
|
||||
|
||||
|
||||
415
runtime/onnxruntime/src/paraformer-torch.cpp
Normal file
415
runtime/onnxruntime/src/paraformer-torch.cpp
Normal file
@ -0,0 +1,415 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#include "precomp.h"
|
||||
#include "paraformer-torch.h"
|
||||
#include "encode_converter.h"
|
||||
#include <cstddef>
|
||||
|
||||
using namespace std;
|
||||
namespace funasr {
|
||||
|
||||
ParaformerTorch::ParaformerTorch()
|
||||
:use_hotword(false){
|
||||
}
|
||||
|
||||
// offline
|
||||
void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
|
||||
LoadConfigFromYaml(am_config.c_str());
|
||||
// knf options
|
||||
fbank_opts_.frame_opts.dither = 0;
|
||||
fbank_opts_.mel_opts.num_bins = n_mels;
|
||||
fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
|
||||
fbank_opts_.frame_opts.window_type = window_type;
|
||||
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
|
||||
fbank_opts_.frame_opts.frame_length_ms = frame_length;
|
||||
fbank_opts_.energy_floor = 0;
|
||||
fbank_opts_.mel_opts.debug_mel = false;
|
||||
|
||||
vocab = new Vocab(token_file.c_str());
|
||||
phone_set_ = new PhoneSet(token_file.c_str());
|
||||
LoadCmvn(am_cmvn.c_str());
|
||||
|
||||
torch::DeviceType device = at::kCPU;
|
||||
#ifdef USE_GPU
|
||||
if (!torch::cuda::is_available()) {
|
||||
LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
|
||||
exit(-1);
|
||||
} else {
|
||||
LOG(INFO) << "CUDA is available, running on GPU";
|
||||
device = at::kCUDA;
|
||||
}
|
||||
#endif
|
||||
#ifdef USE_IPEX
|
||||
torch::jit::setTensorExprFuserEnabled(false);
|
||||
#endif
|
||||
|
||||
try {
|
||||
torch::jit::script::Module model = torch::jit::load(am_model, device);
|
||||
model_ = std::make_shared<TorchModule>(std::move(model));
|
||||
LOG(INFO) << "Successfully load model from " << am_model;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load am model: " << am_model << e.what();
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerTorch::InitLm(const std::string &lm_file,
|
||||
const std::string &lm_cfg_file,
|
||||
const std::string &lex_file) {
|
||||
try {
|
||||
lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
|
||||
fst::Fst<fst::StdArc>::Read(lm_file));
|
||||
if (lm_){
|
||||
lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
|
||||
LOG(INFO) << "Successfully load lm file " << lm_file;
|
||||
}else{
|
||||
LOG(ERROR) << "Failed to load lm file " << lm_file;
|
||||
}
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load lm file: " << e.what();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerTorch::LoadConfigFromYaml(const char* filename){
|
||||
|
||||
YAML::Node config;
|
||||
try{
|
||||
config = YAML::LoadFile(filename);
|
||||
}catch(exception const &e){
|
||||
LOG(ERROR) << "Error loading file, yaml file error or not exist.";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
try{
|
||||
YAML::Node frontend_conf = config["frontend_conf"];
|
||||
this->asr_sample_rate = frontend_conf["fs"].as<int>();
|
||||
|
||||
YAML::Node lang_conf = config["lang"];
|
||||
if (lang_conf.IsDefined()){
|
||||
language = lang_conf.as<string>();
|
||||
}
|
||||
}catch(exception const &e){
|
||||
LOG(ERROR) << "Error when load argument from vad config YAML.";
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
|
||||
// TODO
|
||||
use_hotword = true;
|
||||
}
|
||||
|
||||
void ParaformerTorch::InitSegDict(const std::string &seg_dict_model) {
|
||||
seg_dict = new SegDict(seg_dict_model.c_str());
|
||||
}
|
||||
|
||||
ParaformerTorch::~ParaformerTorch()
|
||||
{
|
||||
if(vocab){
|
||||
delete vocab;
|
||||
}
|
||||
if(lm_vocab){
|
||||
delete lm_vocab;
|
||||
}
|
||||
if(seg_dict){
|
||||
delete seg_dict;
|
||||
}
|
||||
if(phone_set_){
|
||||
delete phone_set_;
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerTorch::StartUtterance()
|
||||
{
|
||||
}
|
||||
|
||||
void ParaformerTorch::EndUtterance()
|
||||
{
|
||||
}
|
||||
|
||||
void ParaformerTorch::Reset()
|
||||
{
|
||||
}
|
||||
|
||||
void ParaformerTorch::FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats) {
|
||||
knf::OnlineFbank fbank_(fbank_opts_);
|
||||
std::vector<float> buf(len);
|
||||
for (int32_t i = 0; i != len; ++i) {
|
||||
buf[i] = waves[i] * 32768;
|
||||
}
|
||||
fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
|
||||
|
||||
int32_t frames = fbank_.NumFramesReady();
|
||||
for (int32_t i = 0; i != frames; ++i) {
|
||||
const float *frame = fbank_.GetFrame(i);
|
||||
std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
|
||||
asr_feats.emplace_back(frame_vector);
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerTorch::LoadCmvn(const char *filename)
|
||||
{
|
||||
ifstream cmvn_stream(filename);
|
||||
if (!cmvn_stream.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << filename;
|
||||
exit(-1);
|
||||
}
|
||||
string line;
|
||||
|
||||
while (getline(cmvn_stream, line)) {
|
||||
istringstream iss(line);
|
||||
vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
|
||||
if (line_item[0] == "<AddShift>") {
|
||||
getline(cmvn_stream, line);
|
||||
istringstream means_lines_stream(line);
|
||||
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
|
||||
if (means_lines[0] == "<LearnRateCoef>") {
|
||||
for (int j = 3; j < means_lines.size() - 1; j++) {
|
||||
means_list_.push_back(stof(means_lines[j]));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else if (line_item[0] == "<Rescale>") {
|
||||
getline(cmvn_stream, line);
|
||||
istringstream vars_lines_stream(line);
|
||||
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
|
||||
if (vars_lines[0] == "<LearnRateCoef>") {
|
||||
for (int j = 3; j < vars_lines.size() - 1; j++) {
|
||||
vars_list_.push_back(stof(vars_lines[j])*scale);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string ParaformerTorch::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
|
||||
{
|
||||
vector<int> hyps;
|
||||
int Tmax = n_len;
|
||||
for (int i = 0; i < Tmax; i++) {
|
||||
int max_idx;
|
||||
float max_val;
|
||||
FindMax(in + i * token_nums, token_nums, max_val, max_idx);
|
||||
hyps.push_back(max_idx);
|
||||
}
|
||||
if(!is_stamp){
|
||||
return vocab->Vector2StringV2(hyps, language);
|
||||
}else{
|
||||
std::vector<string> char_list;
|
||||
std::vector<std::vector<float>> timestamp_list;
|
||||
std::string res_str;
|
||||
vocab->Vector2String(hyps, char_list);
|
||||
std::vector<string> raw_char(char_list);
|
||||
TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
|
||||
|
||||
return PostProcess(raw_char, timestamp_list);
|
||||
}
|
||||
}
|
||||
|
||||
string ParaformerTorch::BeamSearch(WfstDecoder* &wfst_decoder, float *in, int len, int64_t token_nums)
|
||||
{
|
||||
return wfst_decoder->Search(in, len, token_nums);
|
||||
}
|
||||
|
||||
string ParaformerTorch::FinalizeDecode(WfstDecoder* &wfst_decoder,
|
||||
bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
|
||||
{
|
||||
return wfst_decoder->FinalizeDecode(is_stamp, us_alphas, us_cif_peak);
|
||||
}
|
||||
|
||||
void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
|
||||
|
||||
std::vector<std::vector<float>> out_feats;
|
||||
int T = asr_feats.size();
|
||||
int T_lrf = ceil(1.0 * T / lfr_n);
|
||||
|
||||
// Pad frames at start(copy first frame)
|
||||
for (int i = 0; i < (lfr_m - 1) / 2; i++) {
|
||||
asr_feats.insert(asr_feats.begin(), asr_feats[0]);
|
||||
}
|
||||
// Merge lfr_m frames as one,lfr_n frames per window
|
||||
T = T + (lfr_m - 1) / 2;
|
||||
std::vector<float> p;
|
||||
for (int i = 0; i < T_lrf; i++) {
|
||||
if (lfr_m <= T - i * lfr_n) {
|
||||
for (int j = 0; j < lfr_m; j++) {
|
||||
p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
p.clear();
|
||||
} else {
|
||||
// Fill to lfr_m frames at last window if less than lfr_m frames (copy last frame)
|
||||
int num_padding = lfr_m - (T - i * lfr_n);
|
||||
for (int j = 0; j < (asr_feats.size() - i * lfr_n); j++) {
|
||||
p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
|
||||
}
|
||||
for (int j = 0; j < num_padding; j++) {
|
||||
p.insert(p.end(), asr_feats[asr_feats.size() - 1].begin(), asr_feats[asr_feats.size() - 1].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
p.clear();
|
||||
}
|
||||
}
|
||||
// Apply cmvn
|
||||
for (auto &out_feat: out_feats) {
|
||||
for (int j = 0; j < means_list_.size(); j++) {
|
||||
out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
|
||||
}
|
||||
}
|
||||
asr_feats = out_feats;
|
||||
}
|
||||
|
||||
std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
|
||||
{
|
||||
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
int32_t feature_dim = lfr_m*in_feat_dim;
|
||||
|
||||
std::vector<vector<float>> feats_batch;
|
||||
std::vector<int32_t> paraformer_length;
|
||||
int max_size = 0;
|
||||
int max_frames = 0;
|
||||
for(int index=0; index<batch_in; index++){
|
||||
std::vector<std::vector<float>> asr_feats;
|
||||
FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
|
||||
if(asr_feats.size() != 0){
|
||||
LfrCmvn(asr_feats);
|
||||
}
|
||||
int32_t num_frames = asr_feats.size();
|
||||
paraformer_length.emplace_back(num_frames);
|
||||
if(max_size < asr_feats.size()*feature_dim){
|
||||
max_size = asr_feats.size()*feature_dim;
|
||||
max_frames = num_frames;
|
||||
}
|
||||
|
||||
std::vector<float> flattened;
|
||||
for (const auto& sub_vector : asr_feats) {
|
||||
flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
|
||||
}
|
||||
feats_batch.emplace_back(flattened);
|
||||
}
|
||||
|
||||
torch::NoGradGuard no_grad;
|
||||
model_->eval();
|
||||
// padding
|
||||
std::vector<float> all_feats(batch_in * max_frames * feature_dim);
|
||||
for(int index=0; index<batch_in; index++){
|
||||
feats_batch[index].resize(max_size);
|
||||
std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
|
||||
max_frames * feature_dim * sizeof(float));
|
||||
}
|
||||
torch::Tensor feats =
|
||||
torch::from_blob(all_feats.data(),
|
||||
{batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
|
||||
torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
|
||||
{batch_in}, torch::kInt32);
|
||||
|
||||
// 2. forward
|
||||
#ifdef USE_GPU
|
||||
feats = feats.to(at::kCUDA);
|
||||
feat_lens = feat_lens.to(at::kCUDA);
|
||||
#endif
|
||||
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
|
||||
|
||||
vector<std::string> results;
|
||||
try {
|
||||
auto outputs = model_->forward(inputs).toTuple()->elements();
|
||||
torch::Tensor am_scores;
|
||||
torch::Tensor valid_token_lens;
|
||||
#ifdef USE_GPU
|
||||
am_scores = outputs[0].toTensor().to(at::kCPU);
|
||||
valid_token_lens = outputs[1].toTensor().to(at::kCPU);
|
||||
#else
|
||||
am_scores = outputs[0].toTensor();
|
||||
valid_token_lens = outputs[1].toTensor();
|
||||
#endif
|
||||
// timestamp
|
||||
for(int index=0; index<batch_in; index++){
|
||||
string result="";
|
||||
if(outputs.size() == 4){
|
||||
torch::Tensor us_alphas_tensor;
|
||||
torch::Tensor us_peaks_tensor;
|
||||
#ifdef USE_GPU
|
||||
us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
|
||||
us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
|
||||
#else
|
||||
us_alphas_tensor = outputs[2].toTensor();
|
||||
us_peaks_tensor = outputs[3].toTensor();
|
||||
#endif
|
||||
|
||||
float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
|
||||
std::vector<float> us_alphas(paraformer_length[index]);
|
||||
for (int i = 0; i < us_alphas.size(); i++) {
|
||||
us_alphas[i] = us_alphas_data[i];
|
||||
}
|
||||
|
||||
float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
|
||||
std::vector<float> us_peaks(paraformer_length[index]);
|
||||
for (int i = 0; i < us_peaks.size(); i++) {
|
||||
us_peaks[i] = us_peaks_data[i];
|
||||
}
|
||||
if (lm_ == nullptr) {
|
||||
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
|
||||
} else {
|
||||
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
|
||||
if (input_finished) {
|
||||
result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
|
||||
}
|
||||
}
|
||||
}else{
|
||||
if (lm_ == nullptr) {
|
||||
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
|
||||
} else {
|
||||
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
|
||||
if (input_finished) {
|
||||
result = FinalizeDecode(wfst_decoder);
|
||||
}
|
||||
}
|
||||
}
|
||||
results.push_back(result);
|
||||
if (wfst_decoder){
|
||||
wfst_decoder->StartUtterance();
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
|
||||
// TODO
|
||||
std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
|
||||
return result;
|
||||
}
|
||||
|
||||
Vocab* ParaformerTorch::GetVocab()
|
||||
{
|
||||
return vocab;
|
||||
}
|
||||
|
||||
Vocab* ParaformerTorch::GetLmVocab()
|
||||
{
|
||||
return lm_vocab;
|
||||
}
|
||||
|
||||
PhoneSet* ParaformerTorch::GetPhoneSet()
|
||||
{
|
||||
return phone_set_;
|
||||
}
|
||||
|
||||
string ParaformerTorch::Rescoring()
|
||||
{
|
||||
LOG(ERROR)<<"Not Imp!!!!!!";
|
||||
return "";
|
||||
}
|
||||
} // namespace funasr
|
||||
96
runtime/onnxruntime/src/paraformer-torch.h
Normal file
96
runtime/onnxruntime/src/paraformer-torch.h
Normal file
@ -0,0 +1,96 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
#pragma once
|
||||
#define C10_USE_GLOG
|
||||
#include <torch/serialize.h>
|
||||
#include <torch/script.h>
|
||||
#include <torch/torch.h>
|
||||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include "precomp.h"
|
||||
#include "fst/fstlib.h"
|
||||
#include "fst/symbol-table.h"
|
||||
#include "bias-lm.h"
|
||||
#include "phone-set.h"
|
||||
|
||||
namespace funasr {
|
||||
|
||||
class ParaformerTorch : public Model {
|
||||
/**
|
||||
* Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
* Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
* https://arxiv.org/pdf/2206.08317.pdf
|
||||
*/
|
||||
private:
|
||||
Vocab* vocab = nullptr;
|
||||
Vocab* lm_vocab = nullptr;
|
||||
SegDict* seg_dict = nullptr;
|
||||
PhoneSet* phone_set_ = nullptr;
|
||||
//const float scale = 22.6274169979695;
|
||||
const float scale = 1.0;
|
||||
|
||||
void LoadConfigFromYaml(const char* filename);
|
||||
void LoadCmvn(const char *filename);
|
||||
void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
|
||||
|
||||
using TorchModule = torch::jit::script::Module;
|
||||
std::shared_ptr<TorchModule> model_ = nullptr;
|
||||
std::vector<torch::Tensor> encoder_outs_;
|
||||
bool use_hotword;
|
||||
|
||||
public:
|
||||
ParaformerTorch();
|
||||
~ParaformerTorch();
|
||||
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
|
||||
void InitHwCompiler(const std::string &hw_model, int thread_num);
|
||||
void InitSegDict(const std::string &seg_dict_model);
|
||||
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
|
||||
void Reset();
|
||||
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
|
||||
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
|
||||
string Rescoring();
|
||||
string GetLang(){return language;};
|
||||
int GetAsrSampleRate() { return asr_sample_rate; };
|
||||
void SetBatchSize(int batch_size) {batch_size_ = batch_size;};
|
||||
int GetBatchSize() {return batch_size_;};
|
||||
void StartUtterance();
|
||||
void EndUtterance();
|
||||
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
|
||||
string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
|
||||
string FinalizeDecode(WfstDecoder* &wfst_decoder,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
Vocab* GetVocab();
|
||||
Vocab* GetLmVocab();
|
||||
PhoneSet* GetPhoneSet();
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
vector<float> means_list_;
|
||||
vector<float> vars_list_;
|
||||
int lfr_m = PARA_LFR_M;
|
||||
int lfr_n = PARA_LFR_N;
|
||||
|
||||
// paraformer-offline
|
||||
std::string language="zh-cn";
|
||||
|
||||
// lm
|
||||
std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
|
||||
|
||||
string window_type = "hamming";
|
||||
int frame_length = 25;
|
||||
int frame_shift = 10;
|
||||
int n_mels = 80;
|
||||
int encoder_size = 512;
|
||||
int fsmn_layers = 16;
|
||||
int fsmn_lorder = 10;
|
||||
int fsmn_dims = 512;
|
||||
float cif_threshold = 1.0;
|
||||
float tail_alphas = 0.45;
|
||||
int asr_sample_rate = MODEL_SAMPLE_RATE;
|
||||
int batch_size_ = 1;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
@ -462,15 +462,23 @@ void Paraformer::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
|
||||
asr_feats = out_feats;
|
||||
}
|
||||
|
||||
string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
|
||||
std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
|
||||
{
|
||||
std::vector<std::string> results;
|
||||
string result="";
|
||||
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
|
||||
if(batch_in != 1){
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> asr_feats;
|
||||
FbankKaldi(asr_sample_rate, din, len, asr_feats);
|
||||
FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
|
||||
if(asr_feats.size() == 0){
|
||||
return "";
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
LfrCmvn(asr_feats);
|
||||
int32_t feat_dim = lfr_m*in_feat_dim;
|
||||
@ -509,7 +517,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
|
||||
if (use_hotword) {
|
||||
if(hw_emb.size()<=0){
|
||||
LOG(ERROR) << "hw_emb is null";
|
||||
return "";
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
//PrintMat(hw_emb, "input_clas_emb");
|
||||
const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
|
||||
@ -526,10 +535,10 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
return "";
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
|
||||
string result="";
|
||||
try {
|
||||
auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
|
||||
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
@ -577,7 +586,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
|
||||
return result;
|
||||
results.push_back(result);
|
||||
return results;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -52,13 +52,14 @@ namespace funasr {
|
||||
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
|
||||
void Reset();
|
||||
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
|
||||
string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
|
||||
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums,
|
||||
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
|
||||
|
||||
string Rescoring();
|
||||
string GetLang(){return language;};
|
||||
int GetAsrSampleRate() { return asr_sample_rate; };
|
||||
int GetBatchSize() {return batch_size_;};
|
||||
void StartUtterance();
|
||||
void EndUtterance();
|
||||
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
|
||||
@ -110,6 +111,7 @@ namespace funasr {
|
||||
float cif_threshold = 1.0;
|
||||
float tail_alphas = 0.45;
|
||||
int asr_sample_rate = MODEL_SAMPLE_RATE;
|
||||
int batch_size_ = 1;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -64,6 +64,9 @@ using namespace std;
|
||||
#include "seg_dict.h"
|
||||
#include "resample.h"
|
||||
#include "paraformer.h"
|
||||
#ifdef USE_GPU
|
||||
#include "paraformer-torch.h"
|
||||
#endif
|
||||
#include "paraformer-online.h"
|
||||
#include "offline-stream.h"
|
||||
#include "tpass-stream.h"
|
||||
|
||||
@ -70,13 +70,13 @@ ostream& operator << (ostream& os, const deque<T>& dq) {
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
#ifndef USE_GPU
|
||||
template<class T1, class T2>
|
||||
ostream& operator << (ostream& os, const pair<T1, T2>& pr) {
|
||||
os << pr.first << ":" << pr.second ;
|
||||
return os;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template<class T>
|
||||
string& operator << (string& str, const T& obj) {
|
||||
|
||||
@ -326,6 +326,9 @@ class ContextualParaformer(Paraformer):
|
||||
def __call__(
|
||||
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
|
||||
) -> List:
|
||||
# def __call__(
|
||||
# self, waveform_list:list, hotwords: str, **kwargs
|
||||
# ) -> List:
|
||||
# make hotword list
|
||||
hotwords, hotwords_length = self.proc_hotword(hotwords)
|
||||
# import pdb; pdb.set_trace()
|
||||
@ -345,15 +348,47 @@ class ContextualParaformer(Paraformer):
|
||||
try:
|
||||
outputs = self.bb_infer(feats, feats_len, bias_embed)
|
||||
am_scores, valid_token_lens = outputs[0], outputs[1]
|
||||
|
||||
if len(outputs) == 4:
|
||||
# for BiCifParaformer Inference
|
||||
us_alphas, us_peaks = outputs[2], outputs[3]
|
||||
else:
|
||||
us_alphas, us_peaks = None, None
|
||||
|
||||
except ONNXRuntimeError:
|
||||
# logging.warning(traceback.format_exc())
|
||||
logging.warning("input wav is silence or noise")
|
||||
preds = [""]
|
||||
else:
|
||||
preds = self.decode(am_scores, valid_token_lens)
|
||||
for pred in preds:
|
||||
pred = sentence_postprocess(pred)
|
||||
asr_res.append({"preds": pred})
|
||||
if us_peaks is None:
|
||||
for pred in preds:
|
||||
if self.language == "en-bpe":
|
||||
pred = sentence_postprocess_sentencepiece(pred)
|
||||
else:
|
||||
pred = sentence_postprocess(pred)
|
||||
asr_res.append({"preds": pred})
|
||||
else:
|
||||
for pred, us_peaks_ in zip(preds, us_peaks):
|
||||
raw_tokens = pred
|
||||
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
|
||||
us_peaks_, copy.copy(raw_tokens)
|
||||
)
|
||||
text_proc, timestamp_proc, _ = sentence_postprocess(
|
||||
raw_tokens, timestamp_raw
|
||||
)
|
||||
# logging.warning(timestamp)
|
||||
if len(self.plot_timestamp_to):
|
||||
self.plot_wave_timestamp(
|
||||
waveform_list[0], timestamp, self.plot_timestamp_to
|
||||
)
|
||||
asr_res.append(
|
||||
{
|
||||
"preds": text_proc,
|
||||
"timestamp": timestamp_proc,
|
||||
"raw_tokens": raw_tokens,
|
||||
}
|
||||
)
|
||||
return asr_res
|
||||
|
||||
def proc_hotword(self, hotwords):
|
||||
|
||||
@ -8,6 +8,10 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
|
||||
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
|
||||
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
|
||||
option(ENABLE_GLOG "Whether to build glog" ON)
|
||||
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
|
||||
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
|
||||
option(GPU "Whether to build with GPU" OFF)
|
||||
|
||||
if(WIN32)
|
||||
file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h
|
||||
@ -20,12 +24,16 @@ else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
|
||||
option(ENABLE_GLOG "Whether to build glog" ON)
|
||||
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
|
||||
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
|
||||
if(GPU)
|
||||
add_definitions(-DUSE_GPU)
|
||||
set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
|
||||
set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
|
||||
include_directories(${TORCH_DIR}/include)
|
||||
include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
|
||||
link_directories(${TORCH_DIR}/lib)
|
||||
link_directories(${TORCH_BLADE_DIR})
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
|
||||
endif()
|
||||
|
||||
if(ENABLE_WEBSOCKET)
|
||||
# cmake_policy(SET CMP0135 NEW)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
|
||||
|
||||
if(WIN32)
|
||||
include_directories(${ONNXRUNTIME_DIR}/include)
|
||||
include_directories(${FFMPEG_DIR}/include)
|
||||
@ -12,15 +11,14 @@ if(WIN32)
|
||||
SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp" ${RELATION_SOURCE})
|
||||
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp" ${RELATION_SOURCE})
|
||||
add_executable(funasr-wss-client "funasr-wss-client.cpp" ${RELATION_SOURCE})
|
||||
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp" ${RELATION_SOURCE})
|
||||
|
||||
target_link_options(funasr-wss-server PRIVATE "-Wl,--no-as-needed")
|
||||
target_link_options(funasr-wss-server-2pass PRIVATE "-Wl,--no-as-needed")
|
||||
|
||||
target_link_libraries(funasr-wss-client PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
|
||||
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY} portaudio)
|
||||
target_link_libraries(funasr-wss-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
|
||||
|
||||
@ -56,6 +56,10 @@ int main(int argc, char* argv[]) {
|
||||
"true (Default), load the model of model_quant.onnx in model_dir. If set "
|
||||
"false, load the model of model.onnx in model_dir",
|
||||
false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> bladedisc(
|
||||
"", BLADEDISC,
|
||||
"true (Default), load the model of bladedisc in model_dir.",
|
||||
false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir(
|
||||
"", VAD_DIR,
|
||||
"default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn",
|
||||
@ -121,6 +125,8 @@ int main(int argc, char* argv[]) {
|
||||
false, "/workspace/resources/hotwords.txt", "string");
|
||||
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS,
|
||||
"the fst hotwords incremental bias", false, 20, "int32_t");
|
||||
TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU, default is false", false);
|
||||
TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
|
||||
|
||||
// add file
|
||||
cmd.add(hotword);
|
||||
@ -135,6 +141,7 @@ int main(int argc, char* argv[]) {
|
||||
cmd.add(model_dir);
|
||||
cmd.add(model_revision);
|
||||
cmd.add(quantize);
|
||||
cmd.add(bladedisc);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_revision);
|
||||
cmd.add(vad_quant);
|
||||
@ -151,11 +158,14 @@ int main(int argc, char* argv[]) {
|
||||
cmd.add(io_thread_num);
|
||||
cmd.add(decoder_thread_num);
|
||||
cmd.add(model_thread_num);
|
||||
cmd.add(use_gpu);
|
||||
cmd.add(batch_size);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(bladedisc, BLADEDISC, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
@ -173,6 +183,8 @@ int main(int argc, char* argv[]) {
|
||||
global_beam_ = global_beam.getValue();
|
||||
lattice_beam_ = lattice_beam.getValue();
|
||||
am_scale_ = am_scale.getValue();
|
||||
bool use_gpu_ = use_gpu.getValue();
|
||||
int batch_size_ = batch_size.getValue();
|
||||
|
||||
// Download model form Modelscope
|
||||
try{
|
||||
@ -468,7 +480,7 @@ int main(int argc, char* argv[]) {
|
||||
WebSocketServer websocket_srv(
|
||||
io_decoder, is_ssl, server, wss_server, s_certfile,
|
||||
s_keyfile); // websocket server for asr engine
|
||||
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
|
||||
websocket_srv.initAsr(model_path, s_model_thread_num, use_gpu_, batch_size_); // init asr model
|
||||
|
||||
LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
|
||||
LOG(INFO) << "io-thread-num: " << s_io_thread_num;
|
||||
|
||||
@ -402,11 +402,11 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
|
||||
// init asr model
|
||||
void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
|
||||
int thread_num) {
|
||||
int thread_num, bool use_gpu, int batch_size) {
|
||||
try {
|
||||
// init model with api
|
||||
|
||||
asr_handle = FunOfflineInit(model_path, thread_num);
|
||||
asr_handle = FunOfflineInit(model_path, thread_num, use_gpu, batch_size);
|
||||
LOG(INFO) << "model successfully inited";
|
||||
|
||||
LOG(INFO) << "initAsr run check_and_clean_connection";
|
||||
|
||||
@ -124,7 +124,7 @@ class WebSocketServer {
|
||||
std::string wav_format,
|
||||
FUNASR_DEC_HANDLE& decoder_handle);
|
||||
|
||||
void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
|
||||
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
|
||||
void on_open(websocketpp::connection_hdl hdl);
|
||||
void on_close(websocketpp::connection_hdl hdl);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user