Merge branch 'dev_gzf_deepspeed' of github.com:alibaba-damo-academy/FunASR into dev_gzf_deepspeed

merge
This commit is contained in:
游雁 2024-06-06 09:56:55 +08:00
commit 783a051f65
69 changed files with 4092 additions and 146 deletions

View File

@ -2,8 +2,9 @@
([简体中文](./README_zh.md)|English)
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
[//]: # (# FunASR: A Fundamental End-to-End Speech Recognition Toolkit)
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=FunASR🤠&text2=💖%20A%20Fundamental%20End-to-End%20Speech%20Recognition%20Toolkit&width=800&height=210)](https://github.com/Akshay090/svg-banners)
[![PyPI](https://img.shields.io/pypi/v/funasr)](https://pypi.org/project/funasr/)
@ -34,6 +35,9 @@
- 2024/03/05Added support for the Whisper-large-v3 model, a multitasking model that can perform multilingual speech recognition, speech translation, and language identification. It can be downloaded from the[modelscope](examples/industrial_data_pretraining/whisper/demo.py), and [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py).
- 2024/03/05: Offline File Transcription Service 4.4, Offline File Transcription Service of English 1.5Real-time Transcription Service 1.9 releaseddocker image supports ARM64 platform, update modelscope([docs](runtime/readme.md))
- 2024/01/30funasr-1.0 has been released ([docs](https://github.com/alibaba-damo-academy/FunASR/discussions/1319))
<details><summary>Full Changelog</summary>
- 2024/01/30emotion recognition models are new supported. [model link](https://www.modelscope.cn/models/iic/emotion2vec_base_finetuned/summary), modified from [repo](https://github.com/ddlBoJack/emotion2vec).
- 2024/01/25: Offline File Transcription Service 4.2, Offline File Transcription Service of English 1.3 releasedoptimized the VAD (Voice Activity Detection) data processing method, significantly reducing peak memory usage, memory leak optimization; Real-time Transcription Service 1.7 releasedoptimizatized the client-side([docs](runtime/readme.md))
- 2024/01/09: The Funasr SDK for Windows version 2.0 has been released, featuring support for The offline file transcription service (CPU) of Mandarin 4.1, The offline file transcription service (CPU) of English 1.2, The real-time transcription service (CPU) of Mandarin 1.6. For more details, please refer to the official documentation or release notes([FunASR-Runtime-Windows](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
@ -51,22 +55,31 @@
- 2023/07/17: BAT is released, which is a low-latency and low-memory-consumption RNN-T model. For more details, please refer to ([BAT](egs/aishell/bat)).
- 2023/06/26: ASRU2023 Multi-Channel Multi-Party Meeting Transcription Challenge 2.0 completed the competition and announced the results. For more details, please refer to ([M2MeT2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2/index.html)).
</details>
<a name="Installation"></a>
## Installation
- Requirements
```text
python>=3.8
torch>=1.13
torchaudio
```
- Install for pypi
```shell
pip3 install -U funasr
```
Or install from source code
- Or install from source code
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip3 install -e ./
```
Install modelscope for the pretrained models (Optional)
- Install modelscope or huggingface_hub for the pretrained models (Optional)
```shell
pip3 install -U modelscope
pip3 install -U modelscope huggingface_hub
```
## Model Zoo
@ -77,19 +90,19 @@ FunASR has open-sourced a large number of pre-trained models on industrial data.
| Model Name | Task Details | Training Data | Parameters |
|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------------------------------------------:|:--------------------------------:|:----------:|
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M |
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-zh) ) | speech recognition, with timestamps, non-streaming | 60000 hours, Mandarin | 220M |
| <nobr>paraformer-zh-streaming <br> ( [](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) )</nobr> | speech recognition, streaming | 60000 hours, Mandarin | 220M |
| paraformer-en <br> ( [](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | speech recognition, without timestamps, non-streaming | 50000 hours, English | 220M |
| conformer-en <br> ( [](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | speech recognition, non-streaming | 50000 hours, English | 220M |
| ct-punc <br> ( [](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | punctuation restoration | 100M, Mandarin and English | 1.1G |
| ct-punc <br> ( [](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | punctuation restoration | 100M, Mandarin and English | 290M |
| fsmn-vad <br> ( [](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | voice activity detection | 5000 hours, Mandarin and English | 0.4M |
| fa-zh <br> ( [](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | timestamp prediction | 5000 hours, Mandarin | 38M |
| cam++ <br> ( [](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | speaker verification/diarization | 5000 hours | 7.2M |
| Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
| Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
| Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B |
| Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B |
| emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M |
| Whisper-large-v2 <br> ([⭐](https://www.modelscope.cn/models/iic/speech_whisper-large_asr_multilingual/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
| Whisper-large-v3 <br> ([⭐](https://www.modelscope.cn/models/iic/Whisper-large-v3/summary) [🍀](https://github.com/openai/whisper) ) | speech recognition, with timestamps, non-streaming | multilingual | 1550 M |
| Qwen-Audio <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio) ) | audio-text multimodal models (pretraining) | multilingual | 8B |
| Qwen-Audio-Chat <br> ([⭐](examples/industrial_data_pretraining/qwen_audio/demo_chat.py) [🤗](https://huggingface.co/Qwen/Qwen-Audio-Chat) ) | audio-text multimodal models (chat) | multilingual | 8B |
| emotion2vec+large <br> ([⭐](https://modelscope.cn/models/iic/emotion2vec_plus_large/summary) [🤗](https://huggingface.co/emotion2vec/emotion2vec_plus_large) ) | speech emotion recongintion | 40000 hours | 300M |
@ -153,6 +166,8 @@ for i in range(total_chunk_num):
```
Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word.
<details><summary>More Examples</summary>
### Voice Activity Detection (Non-Streaming)
```python
from funasr import AutoModel
@ -211,9 +226,24 @@ text_file = f"{model.model_path}/example/text.txt"
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
### Speech Emotion Recognition
```python
from funasr import AutoModel
model = AutoModel(model="emotion2vec_plus_large")
wav_file = f"{model.model_path}/example/test.wav"
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
print(res)
```
More usages ref to [docs](docs/tutorial/README_zh.md),
more examples ref to [demo](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
</details>
## Export ONNX

View File

@ -2,7 +2,11 @@
(简体中文|[English](./README.md))
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
[![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=FunASR🤠&text2=💖%20A%20Fundamental%20End-to-End%20Speech%20Recognition%20Toolkit&width=800&height=210)](https://github.com/Akshay090/svg-banners)
[//]: # (# FunASR: A Fundamental End-to-End Speech Recognition Toolkit)
[![PyPI](https://img.shields.io/pypi/v/funasr)](https://pypi.org/project/funasr/)
@ -35,6 +39,9 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
- 2024/03/05新增加Whisper-large-v3模型支持多语言语音识别/翻译/语种识别,支持从 [modelscope](examples/industrial_data_pretraining/whisper/demo.py)仓库下载,也支持从 [openai](examples/industrial_data_pretraining/whisper/demo_from_openai.py)仓库下载模型。
- 2024/03/05: 中文离线文件转写服务 4.4、英文离线文件转写服务 1.5、中文实时语音听写服务 1.9 发布docker镜像支持arm64平台升级modelscope版本详细信息参阅([部署文档](runtime/readme_cn.md))
- 2024/01/30funasr-1.0发布,更新说明[文档](https://github.com/alibaba-damo-academy/FunASR/discussions/1319)
<details><summary>展开日志</summary>
- 2024/01/30新增加情感识别 [模型链接](https://www.modelscope.cn/models/iic/emotion2vec_base_finetuned/summary),原始模型 [repo](https://github.com/ddlBoJack/emotion2vec).
- 2024/01/25: 中文离线文件转写服务 4.2、英文离线文件转写服务 1.3优化vad数据处理方式大幅降低峰值内存占用内存泄漏优化中文实时语音听写服务 1.7 发布,客户端优化;详细信息参阅([部署文档](runtime/readme_cn.md))
- 2024/01/09: funasr社区软件包windows 2.0版本发布支持软件包中文离线文件转写4.1、英文离线文件转写1.2、中文实时听写服务1.6的最新功能,详细信息参阅([FunASR社区软件包windows版本](https://www.modelscope.cn/models/damo/funasr-runtime-win-cpu-x64/summary))
@ -52,21 +59,33 @@ FunASR希望在语音识别的学术研究和工业应用之间架起一座桥
- 2023.07.17: BAT一种低延迟低内存消耗的RNN-T模型发布详细信息参阅[BAT](egs/aishell/bat)
- 2023.06.26: ASRU2023 多通道多方会议转录挑战赛2.0完成竞赛结果公布,详细信息参阅([M2MeT2.0](https://alibaba-damo-academy.github.io/FunASR/m2met2_cn/index.html)
</details>
<a name="安装教程"></a>
## 安装教程
- 安装funasr之前确保已经安装了下面依赖环境:
```text
python>=3.8
torch>=1.13
torchaudio
```
- pip安装
```shell
pip3 install -U funasr
```
或者从源代码安装
- 或者从源代码安装
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip3 install -e ./
```
如果需要使用工业预训练模型安装modelscope可选
如果需要使用工业预训练模型安装modelscope与huggingface_hub可选
```shell
pip3 install -U modelscope
pip3 install -U modelscope huggingface huggingface_hub
```
## 模型仓库
@ -78,11 +97,11 @@ FunASR开源了大量在工业数据上预训练模型您可以在[模型许
| 模型名字 | 任务详情 | 训练数据 | 参数量 |
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------:|:--------------:|:------:|
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-tp) ) | 语音识别,带时间戳输出,非实时 | 60000小时中文 | 220M |
| paraformer-zh <br> ([⭐](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) [🤗](https://huggingface.co/funasr/paraformer-zh) ) | 语音识别,带时间戳输出,非实时 | 60000小时中文 | 220M |
| paraformer-zh-streaming <br> ( [](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/summary) [🤗](https://huggingface.co/funasr/paraformer-zh-streaming) ) | 语音识别,实时 | 60000小时中文 | 220M |
| paraformer-en <br> ( [](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020/summary) [🤗](https://huggingface.co/funasr/paraformer-en) ) | 语音识别,非实时 | 50000小时英文 | 220M |
| conformer-en <br> ( [](https://modelscope.cn/models/damo/speech_conformer_asr-en-16k-vocab4199-pytorch/summary) [🤗](https://huggingface.co/funasr/conformer-en) ) | 语音识别,非实时 | 50000小时英文 | 220M |
| ct-punc <br> ( [](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | 标点恢复 | 100M中文与英文 | 1.1B |
| ct-punc <br> ( [](https://modelscope.cn/models/damo/punc_ct-transformer_cn-en-common-vocab471067-large/summary) [🤗](https://huggingface.co/funasr/ct-punc) ) | 标点恢复 | 100M中文与英文 | 290M |
| fsmn-vad <br> ( [](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) [🤗](https://huggingface.co/funasr/fsmn-vad) ) | 语音端点检测,实时 | 5000小时中文与英文 | 0.4M |
| fa-zh <br> ( [](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary) [🤗](https://huggingface.co/funasr/fa-zh) ) | 字级别时间戳预测 | 50000小时中文 | 38M |
| cam++ <br> ( [](https://modelscope.cn/models/iic/speech_campplus_sv_zh-cn_16k-common/summary) [🤗](https://huggingface.co/funasr/campplus) ) | 说话人确认/分割 | 5000小时 | 7.2M |
@ -148,6 +167,8 @@ for i in range(total_chunk_num):
注:`chunk_size`为流式延时配置,`[0,10,5]`表示上屏实时出字粒度为`10*60=600ms`,未来信息为`5*60=300ms`。每次推理输入为`600ms`(采样点数为`16000*0.6=960`),输出为对应文字,最后一个语音片段输入需要设置`is_final=True`来强制输出最后一个字。
<details><summary>更多例子</summary>
### 语音端点检测(非实时)
```python
from funasr import AutoModel
@ -211,9 +232,24 @@ text_file = f"{model.model_path}/example/text.txt"
res = model.generate(input=(wav_file, text_file), data_type=("sound", "text"))
print(res)
```
### 情感识别
```python
from funasr import AutoModel
model = AutoModel(model="emotion2vec_plus_large")
wav_file = f"{model.model_path}/example/test.wav"
res = model.generate(wav_file, output_dir="./outputs", granularity="utterance", extract_embedding=False)
print(res)
```
更详细([教程文档](docs/tutorial/README_zh.md)
更多([模型示例](https://github.com/alibaba-damo-academy/FunASR/tree/main/examples/industrial_data_pretraining)
</details>
## 导出ONNX
### 从命令行导出
```shell

Binary file not shown.

Before

Width:  |  Height:  |  Size: 181 KiB

After

Width:  |  Height:  |  Size: 183 KiB

View File

@ -52,7 +52,7 @@ class AutoFrontend:
key_list, data_list = prepare_data_iterator(input, input_len=input_len)
batch_size = kwargs.get("batch_size", 1)
device = kwargs.get("device", "cpu")
device = kwargs.get("device", "cuda")
if device == "cpu":
batch_size = 1
@ -60,7 +60,7 @@ class AutoFrontend:
result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
# pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
time0 = time.perf_counter()
for beg_idx in range(0, num_samples, batch_size):
@ -87,15 +87,23 @@ class AutoFrontend:
speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
)
speech.to(device=device), speech_lengths.to(device=device)
batch = {"input": speech, "input_len": speech_lengths, "key": key_batch}
if kwargs.get("return_pt", True):
speech, speech_lengths = speech.to(device=device), speech_lengths.to(device=device)
else:
speech, speech_lengths = speech.numpy(), speech_lengths.numpy()
batch = {
"input": speech,
"input_len": speech_lengths,
"key": key_batch,
"data_type": "fbank",
}
result_list.append(batch)
pbar.update(1)
description = f"{meta_data}, "
pbar.set_description(description)
# pbar.update(1)
# description = f"{meta_data}, "
# pbar.set_description(description)
time_end = time.perf_counter()
pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
# pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
return result_list

View File

@ -42,8 +42,9 @@ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith("http"): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str):
if data_in.startswith("http://") or data_in.startswith("https://"): # url
data_in = download_from_url(data_in)
if isinstance(data_in, str) and os.path.exists(
data_in
@ -284,7 +285,7 @@ class AutoModel:
with torch.no_grad():
res = model.inference(**batch, **kwargs)
if isinstance(res, (list, tuple)):
results = res[0]
results = res[0] if len(res) > 0 else [{"text": ""}]
meta_data = res[1] if len(res) > 1 else {}
time2 = time.perf_counter()
@ -358,6 +359,7 @@ class AutoModel:
results_sorted = []
if not len(sorted_data):
results_ret_list.append({"key": key, "text": "", "timestamp": []})
logging.info("decoding, utt: {}, empty speech".format(key))
continue

View File

@ -147,7 +147,9 @@ class EspnetStyleBatchSampler(DistributedSampler):
start_idx = self.rank * batches_per_rank
end_idx = start_idx + batches_per_rank
rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
self.batch_num = len(rank_batches)
logging.info(
f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
)

View File

@ -12,10 +12,30 @@ name_maps_ms = {
"Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
"Whisper-large-v3": "iic/Whisper-large-v3",
"Qwen-Audio": "Qwen/Qwen-Audio",
"emotion2vec_plus_large": "iic/emotion2vec_plus_large",
"emotion2vec_plus_base": "iic/emotion2vec_plus_base",
"emotion2vec_plus_seed": "iic/emotion2vec_plus_seed",
}
name_maps_hf = {
"": "",
"paraformer": "funasr/paraformer-zh",
"paraformer-zh": "funasr/paraformer-zh",
"paraformer-en": "funasr/paraformer-zh",
"paraformer-zh-streaming": "funasr/paraformer-zh-streaming",
"fsmn-vad": "funasr/fsmn-vad",
"ct-punc": "funasr/ct-punc",
"ct-punc-c": "iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
"fa-zh": "funasr/fa-zh",
"cam++": "funasr/campplus",
"Whisper-large-v2": "iic/speech_whisper-large_asr_multilingual",
"Whisper-large-v3": "iic/Whisper-large-v3",
"Qwen-Audio": "Qwen/Qwen-Audio",
"emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
"iic/emotion2vec_plus_large": "emotion2vec/emotion2vec_plus_large",
"emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
"iic/emotion2vec_plus_base": "emotion2vec/emotion2vec_plus_base",
"emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
"iic/emotion2vec_plus_seed": "emotion2vec/emotion2vec_plus_seed",
}
name_maps_openai = {

View File

@ -1,5 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@ -63,3 +65,64 @@ class EncoderProjectorQFormer(nn.Module):
query_proj = self.norm(self.linear(query_output.last_hidden_state))
return query_proj
@tables.register("adaptor_classes", "Transformer")
class Transformer(nn.Module):
def __init__(
self, downsample_rate=2, encoder_dim=1280, llm_dim=4096, ffn_dim: int = 2048, **kwargs
):
super().__init__()
self.k = downsample_rate
self.encoder_dim = encoder_dim
self.llm_dim = llm_dim
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
from funasr.models.transformer.encoder import EncoderLayer
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
self.blocks = nn.ModuleList(
[
EncoderLayer(
llm_dim,
MultiHeadedAttention(
kwargs.get("attention_heads", 8),
llm_dim,
kwargs.get("attention_dropout_rate", 0.0),
),
PositionwiseFeedForward(
llm_dim,
llm_dim // 4,
kwargs.get("dropout_rate", 0.0),
),
kwargs.get("dropout_rate", 0.0),
)
for i in range(kwargs.get("n_layer", 2))
]
)
def forward(self, x, ilens=None):
batch_size, seq_len, dim = x.size()
# num_frames_to_discard = seq_len % self.k
chunk_num = (seq_len - 1) // self.k + 1
pad_num = chunk_num * self.k - seq_len
x = F.pad(x, (0, 0, 0, pad_num, 0, 0), value=0.0)
# if num_frames_to_discard > 0:
# x = x[:, :-num_frames_to_discard, :]
seq_len = x.size(1)
x = x.contiguous()
x = x.view(batch_size, chunk_num, dim * self.k)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
for layer, block in enumerate(self.blocks):
x, masks = block(x, masks)
return x, olens

View File

@ -163,7 +163,11 @@ def export_backbone_forward(
dha_ids = dha_pred.max(-1)[-1]
dha_mask = (dha_ids == self.NOBIAS).int().unsqueeze(-1)
decoder_out = decoder_out * dha_mask + dha_pred * (1 - dha_mask)
return decoder_out, pre_token_length, alphas
# get predicted timestamps
us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
return decoder_out, pre_token_length, us_alphas, us_cif_peak
def export_backbone_dummy_inputs(self):
@ -178,7 +182,7 @@ def export_backbone_input_names(self):
def export_backbone_output_names(self):
return ["logits", "token_num", "alphas"]
return ["logits", "token_num", "us_alphas", "us_cif_peak"]
def export_backbone_dynamic_axes(self):
@ -190,6 +194,8 @@ def export_backbone_dynamic_axes(self):
"bias_embed": {0: "batch_size", 1: "num_hotwords"},
"logits": {0: "batch_size", 1: "logits_length"},
"pre_acoustic_embeds": {1: "feats_length1"},
"us_alphas": {0: "batch_size", 1: "alphas_length"},
"us_cif_peak": {0: "batch_size", 1: "alphas_length"},
}

View File

@ -360,6 +360,7 @@ class SenseVoiceDecoder(nn.Module):
"""Score."""
ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
logp = torch.log_softmax(logp, dim=-1)
return logp.squeeze(0)[-1, :], state

View File

@ -1264,15 +1264,29 @@ class SenseVoiceSANM(nn.Module):
if isinstance(task, str):
task = [task]
task = "".join([f"<|{x}|>" for x in task])
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
sos = kwargs.get("model_conf").get("sos")
if isinstance(sos, str):
initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
sos_int = tokenizer.encode(sos, allowed_special="all")
sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
sos_int = tokenizer.encode(sos, allowed_special="all")
else:
language = DecodingOptions.get("language", None)
language = None if language == "auto" else language
initial_prompt = kwargs.get("initial_prompt", f"{task}")
initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
sos_int = [sos] + initial_prompt_lid_int
eos = kwargs.get("model_conf").get("eos")
eos_int = tokenizer.encode(eos, allowed_special="all")
if isinstance(eos, str):
eos_int = tokenizer.encode(eos, allowed_special="all")
else:
eos_int = [eos]
self.beam_search.sos = sos_int
self.beam_search.eos = eos_int[0]
@ -1298,7 +1312,7 @@ class SenseVoiceSANM(nn.Module):
self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
encoder_out, encoder_out_lens = self.encode(
speech[None, :, :].permute(0, 2, 1), speech_lengths
speech[None, :, :], speech_lengths
)
if text_token_int is not None:

View File

@ -27,9 +27,24 @@ class ModelDimensions:
n_text_layer: int
# class LayerNorm(nn.LayerNorm):
# def forward(self, x: Tensor) -> Tensor:
# return super().forward(x.float()).type(x.dtype)
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, input):
output = F.layer_norm(
input.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class Linear(nn.Linear):

View File

@ -64,7 +64,7 @@ class EncoderLayer(nn.Module):
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
super().__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)

View File

@ -621,7 +621,6 @@ class Trainer:
self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
def forward_step(self, model, batch, loss_dict={}):
dtype = torch.bfloat16
with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
retval = model(**batch)

View File

@ -0,0 +1,72 @@
# python funasr_api
This is the api for python to use funasr engine, only support 2pass server.
## For install
### Install websocket-client and ffmpeg
```shell
pip install websocket-client
apt install ffmpeg -y
```
#### recognizer examples
support many audio type as ffmpeg support, detail see FunASR/runtime/funasr_api/example.py
```shell
# create an recognizer
rcg = FunasrApi(
uri="wss://www.funasr.com:10096/"
)
# recognizer by filepath
text=rcg.rec_file("asr_example.mp3")
print("recognizer by filepath result=",text)
# recognizer by buffer
# rec_buf(audio_buf,ffmpeg_decode=False),set ffmpeg_decode=True if audio is not PCM or WAV type
with open("asr_example.wav", "rb") as f:
audio_bytes = f.read()
text=rcg.rec_buf(audio_bytes)
print("recognizer by buffer result=",text)
```
#### streaming recognizer examples,use FunasrApi.audio2wav to covert to WAV type if need
```shell
rcg = FunasrApi(
uri="wss://www.funasr.com:10096/"
)
#define call_back function for msg
def on_msg(msg):
print("stream msg=",msg)
stream=rcg.create_stream(msg_callback=on_msg)
wav_path = "asr_example.wav"
with open(wav_path, "rb") as f:
audio_bytes = f.read()
# use FunasrApi's audio2wav to covert other audio to PCM if needed
#import os
#from funasr_tools import FunasrTools
#file_ext=os.path.splitext(wav_path)[-1].upper()
#if not file_ext =="PCM" and not file_ext =="WAV":
# audio_bytes=FunasrTools.audio2wav(audio_bytes)
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
chunk_num = (len(audio_bytes) - 1) // stride + 1
for i in range(chunk_num):
beg = i * stride
data = audio_bytes[beg : beg + stride]
stream.feed_chunk(data)
final_result=stream.wait_for_end()
print("asr_example.wav stream_result=",final_result)
```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [zhaoming](https://github.com/zhaomingwork/FunASR/tree/fix_bug_for_python_websocket) for contributing the websocket service.
3. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service of offline model.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,70 @@
"""
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
Reserved. MIT License (https://opensource.org/licenses/MIT)
2023-2024 by zhaomingwork@qq.com
"""
from funasr_api import FunasrApi
import wave
def recognizer_example():
# create an recognizer
rcg = FunasrApi(
uri="wss://www.funasr.com:10096/"
)
# recognizer by filepath
text=rcg.rec_file("asr_example.mp3")
print("recognizer by filepath result=",text)
# recognizer by buffer
# rec_buf(audio_buf,ffmpeg_decode=False),set ffmpeg_decode=True if audio is not PCM or WAV type
with open("asr_example.wav", "rb") as f:
audio_bytes = f.read()
text=rcg.rec_buf(audio_bytes)
print("recognizer by buffer result=",text)
def recognizer_stream_example():
rcg = FunasrApi(
uri="wss://www.funasr.com:10096/"
)
#define call_back function for msg
def on_msg(msg):
print("stream msg=",msg)
stream=rcg.create_stream(msg_callback=on_msg)
wav_path = "asr_example.wav"
with open(wav_path, "rb") as f:
audio_bytes = f.read()
# use FunasrApi's audio2wav to covert other audio to PCM if needed
#import os
#from funasr_tools import FunasrTools
#file_ext=os.path.splitext(wav_path)[-1].upper()
#if not file_ext =="PCM" and not file_ext =="WAV":
# audio_bytes=FunasrTools.audio2wav(audio_bytes)
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
chunk_num = (len(audio_bytes) - 1) // stride + 1
for i in range(chunk_num):
beg = i * stride
data = audio_bytes[beg : beg + stride]
stream.feed_chunk(data)
final_result=stream.wait_for_end()
print("asr_example.wav stream_result=",final_result)
if __name__ == "__main__":
print("example for Funasr_websocket_recognizer")
recognizer_stream_example()
recognizer_example()

View File

@ -0,0 +1,96 @@
"""
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
Reserved. MIT License (https://opensource.org/licenses/MIT)
2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import threading
import traceback
import json
import time
import numpy as np
from funasr_stream import FunasrStream
from funasr_tools import FunasrTools
from funasr_core import FunasrCore
# class for recognizer in websocket
class FunasrApi:
"""
python asr recognizer lib
"""
def __init__(
self,
uri="wss://www.funasr.com:10096/",
timeout=1000,
msg_callback=None,
):
"""
uri: ws or wss server uri
msg_callback: for message received
timeout: timeout for get result
"""
try:
self.uri=uri
self.timeout=timeout
self.msg_callback=msg_callback
self.funasr_core=None
except Exception as e:
print("Exception:", e)
traceback.print_exc()
def create_stream(self,msg_callback=None):
if self.funasr_core is not None:
self.funasr_core.close()
funasr_core=self.new_core(msg_callback=msg_callback)
return FunasrStream(funasr_core)
def new_core(self,msg_callback=None):
try:
if self.funasr_core is not None:
self.funasr_core.close()
if msg_callback==None:
msg_callback=self.msg_callback
funasr_core=FunasrCore(self.uri,msg_callback=msg_callback,timeout=self.timeout)
funasr_core.new_connection()
self.funasr_core=funasr_core
return funasr_core
except Exception as e:
print("init_core",e)
exit(0)
# rec buffer, set ffmpeg_decode=True if audio is not PCM or WAV type
def rec_buf(self,audio_buf,ffmpeg_decode=False):
try:
funasr_core=self.new_core()
funasr_core.rec_buf(audio_buf,ffmpeg_decode=ffmpeg_decode)
return funasr_core.get_result()
except Exception as e:
print("rec_file",e)
return
# rec file
def rec_file(self,file_path):
try:
funasr_core=self.new_core()
funasr_core.rec_file(file_path)
return funasr_core.get_result()
except Exception as e:
print("rec_file",e)
return

View File

@ -0,0 +1,230 @@
"""
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
Reserved. MIT License (https://opensource.org/licenses/MIT)
2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import ssl
from websocket import ABNF
from websocket import create_connection
from queue import Queue
import threading
import traceback
import json
import time
import numpy as np
from funasr_tools import FunasrTools
# class for recognizer in websocket
class FunasrCore:
"""
python asr recognizer lib
"""
def __init__(
self,
uri="wss://www.funasr.com:10096/",
msg_callback=None,
timeout=1000,
):
"""
uri: ws or wss server uri
msg_callback: for message received
timeout: timeout for get result
"""
try:
if uri.find("wss://"):
is_ssl=True
elif uri.find("ws://"):
is_ssl=False
else:
print("not support uri",uri)
exit(0)
if is_ssl == True:
ssl_context = ssl.SSLContext()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
uri = uri
ssl_opt = {"cert_reqs": ssl.CERT_NONE}
else:
uri = uri
ssl_context = None
ssl_opt = None
self.ssl_opt=ssl_opt
self.ssl_context=ssl_context
self.uri = uri
print("connect to url", uri)
self.msg_callback=msg_callback
self.is_final=False
self.rec_text=""
self.timeout=timeout
self.rec_file_len=0
self.connect_state=0
except Exception as e:
print("Exception:", e)
traceback.print_exc()
def new_connection(self):
try:
self.websocket = create_connection(self.uri, ssl=self.ssl_context, sslopt=self.ssl_opt)
self.is_final=False
self.rec_text=""
self.rec_file_len=0
self.connect_state=0
message = json.dumps(
{
"mode": "2pass",
"chunk_size": [int(x) for x in "0,10,5".split(",")],
"encoder_chunk_look_back": 4,
"decoder_chunk_look_back": 1,
"chunk_interval": 10,
"wav_name": "funasr_api",
"is_speaking": True,
}
)
self.websocket.send(message)
self.connect_state=1
# thread for receive message
self.thread_msg = threading.Thread(
target=FunasrCore.thread_rec_msg, args=(self,)
)
self.thread_msg.start()
print("new_connection: ",message)
except Exception as e:
print("new_connection",e)
# threads for rev msg
def thread_rec_msg(self):
try:
while True:
if self.connect_state==0:
time.sleep(0.1)
continue
if self.connect_state==2:
break
msg = self.websocket.recv()
if msg is None or len(msg) == 0:
continue
msg = json.loads(msg)
if msg['is_final']==True:
self.is_final=True
if msg['mode']=='2pass-offline':
self.rec_text=self.rec_text+msg['text']
if not self.msg_callback is None:
self.msg_callback(msg)
except Exception as e:
#print("client closed")
return
# feed data to asr engine in stream way
def feed_chunk(self, chunk):
try:
self.websocket.send(chunk, ABNF.OPCODE_BINARY)
return
except:
print("feed chunk error")
return
def close(self):
self.connect_state==2
self.websocket.close()
def rec_buf(self,audio_bytes,ffmpeg_decode=False):
try:
if ffmpeg_decode:
audio_bytes=FunasrTools.audio2wav(audio_bytes)
self.rec_file_len=len(audio_bytes)
stride = int(60 * 10 / 10 / 1000 * 16000 * 2)
chunk_num = (len(audio_bytes) - 1) // stride + 1
for i in range(chunk_num):
beg = i * stride
data = audio_bytes[beg : beg + stride]
self.feed_chunk(data)
return self.get_result()
except Exception as e:
print("rec_file",e)
return
# rec file
def rec_file(self,file_path):
try:
#self.new_connection()
import os
file_ext=os.path.splitext(file_path)[-1].upper()
with open(file_path, "rb") as f:
audio_bytes = f.read()
if not file_ext =="PCM" and not file_ext =="WAV":
audio_bytes=FunasrTools.audio2wav(audio_bytes)
if audio_bytes==None:
print("error, ffmpeg can not decode such file!")
exit(0)
return self.rec_buf(audio_bytes)
except Exception as e:
print("rec_file",e)
return
def wait_for_result(self):
try:
timeout=self.timeout
file_dur=self.rec_file_len/16000/2*100
if file_dur>timeout:
timeout=file_dur
self.timeout=timeout
#print("wait_for_result timeout=",timeout)
# if file_dur==0 means in stream way and no timeout
while(self.is_final==False and (timeout>0 or file_dur==0 )):
time.sleep(0.01)
timeout=timeout-1
if timeout<=0 and not file_dur==0:
print("time out!",self.timeout)
except Exception as e:
print("wait_for_result",e)
return
def get_result(self):
try:
message = json.dumps({"is_speaking": False})
self.websocket.send(message)
self.wait_for_result()
self.close()
# return the msg
return self.rec_text
except Exception as e:
#print("get_result ",e)
return self.rec_text

View File

@ -0,0 +1,72 @@
"""
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
Reserved. MIT License (https://opensource.org/licenses/MIT)
2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import threading
import traceback
import json
import time
# class for recognizer in websocket
class FunasrStream:
"""
python asr recognizer lib
"""
def __init__(
self,
funasr_core
):
"""
uri: ws or wss server uri
msg_callback: for message received
timeout: timeout for get result
"""
try:
self.funasr_core=funasr_core
except Exception as e:
print("FunasrStream init Exception:", e)
traceback.print_exc()
# feed data to asr engine in stream way
def feed_chunk(self, chunk):
try:
if self.funasr_core is None:
print("error in stream, funasr_core is None")
exit(0)
self.funasr_core.feed_chunk(chunk)
return
except:
print("feed chunk error")
return
# return all result for this stream
def wait_for_end(self):
try:
message = json.dumps({"is_speaking": False})
self.funasr_core.websocket.send(message)
self.funasr_core.wait_for_result()
self.funasr_core.close()
# return the msg
return self.funasr_core.rec_text
except Exception as e:
print("error get_final_result ",e)
return ""

View File

@ -0,0 +1,84 @@
"""
Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
Reserved. MIT License (https://opensource.org/licenses/MIT)
2023-2024 by zhaomingwork@qq.com
"""
# pip install websocket-client
# apt install ffmpeg
import threading
import traceback
import time
# class for recognizer in websocket
class FunasrTools:
"""
python asr recognizer lib
"""
def __init__(
self
):
"""
"""
try:
if FunasrTools.check_ffmpeg()==False:
print("pls instal ffmpeg firest, in ubuntu, you can type apt install -y ffmpeg")
exit(0)
except Exception as e:
print("Exception:", e)
traceback.print_exc()
# check ffmpeg installed
@staticmethod
def check_ffmpeg():
import subprocess
try:
subprocess.run(['ffmpeg', '-version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return True
except FileNotFoundError:
return False
# use ffmpeg to convert audio to wav
@staticmethod
def audio2wav(audiobuf):
try:
import os
import subprocess
if FunasrTools.check_ffmpeg()==False:
print("pls instal ffmpeg firest, in ubuntu, you can type apt install -y ffmpeg")
exit(0)
return
ffmpeg_target_to_outwav = ["ffmpeg", "-i", '-', "-ac", "1", "-ar", "16000", "-f", "wav", "pipe:1"]
pipe_to = subprocess.Popen(ffmpeg_target_to_outwav,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
wavbuf, err = pipe_to.communicate(audiobuf)
if str(err).find("Error")>=0 or str(err).find("Unknown")>=0 or str(err).find("Invalid")>=0:
print("ffmpeg err",err)
return None
return wavbuf
except Exception as e:
print("audio2wav",e)
return None

142
runtime/http/CMakeLists.txt Normal file
View File

@ -0,0 +1,142 @@
cmake_minimum_required(VERSION 3.16)
project(FunASRWebscoket)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
option(ENABLE_HTTP "Whether to build http server" ON)
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
if(WIN32)
file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/export.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/logging.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/raw_logging.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/stl_logging.h
${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/glog/vlog_is_on.h)
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
endif()
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
if(ENABLE_HTTP)
# cmake_policy(SET CMP0135 NEW)
include(FetchContent)
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/asio/asio )
FetchContent_Declare(asio
URL https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/asio
)
FetchContent_MakeAvailable(asio)
endif()
include_directories(${PROJECT_SOURCE_DIR}/third_party/asio/asio/include)
if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
FetchContent_Declare(json
URL https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
)
FetchContent_MakeAvailable(json)
endif()
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
endif()
if(ENABLE_PORTAUDIO)
include(FetchContent)
set(portaudio_URL "http://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz")
set(portaudio_URL2 "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/pa_stable_v190700_20210406.tgz")
set(portaudio_HASH "SHA256=47efbf42c77c19a05d22e627d42873e991ec0c1357219c0d74ce6a2948cb2def")
FetchContent_Declare(portaudio
URL
${portaudio_URL}
${portaudio_URL2}
URL_HASH ${portaudio_HASH}
)
FetchContent_GetProperties(portaudio)
if(NOT portaudio_POPULATED)
message(STATUS "Downloading portaudio from ${portaudio_URL}")
FetchContent_Populate(portaudio)
endif()
message(STATUS "portaudio is downloaded to ${portaudio_SOURCE_DIR}")
message(STATUS "portaudio's binary dir is ${portaudio_BINARY_DIR}")
add_subdirectory(${portaudio_SOURCE_DIR} ${portaudio_BINARY_DIR} EXCLUDE_FROM_ALL)
if(NOT WIN32)
target_compile_options(portaudio PRIVATE "-Wno-deprecated-declarations")
else()
install(TARGETS portaudio DESTINATION ..)
endif()
endif()
# Include generated *.pb.h files
link_directories(${ONNXRUNTIME_DIR}/lib)
link_directories(${FFMPEG_DIR}/lib)
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src)
set(BUILD_TESTING OFF)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
include_directories(${glog_BINARY_DIR})
endif()
if(ENABLE_FST)
# fst depend on glog and gflags
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/gflags gflags)
include_directories(${gflags_BINARY_DIR}/include)
# the following openfst if cloned from https://github.com/kkm000/openfst.git
# with some patch to fix the make errors.
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/openfst openfst)
include_directories(${openfst_SOURCE_DIR}/src/include)
if(WIN32)
include_directories(${openfst_SOURCE_DIR}/src/lib)
endif()
endif()
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/src)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/jieba/include/limonp/include)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi kaldi)
# install openssl first apt-get install libssl-dev
find_package(OpenSSL REQUIRED)
message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
#
get_property(includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES)
#
foreach(include ${includes})
message("Include directory: ${include}")
endforeach()
add_subdirectory(bin)

View File

@ -0,0 +1,23 @@
if(WIN32)
include_directories(${ONNXRUNTIME_DIR}/include)
include_directories(${FFMPEG_DIR}/include)
include_directories(${OPENSSL_ROOT_DIR}//include)
link_directories(${OPENSSL_ROOT_DIR}/lib)
add_definitions(-D_WEBSOCKETPP_CPP11_RANDOM_DEVICE_)
add_definitions(-D_WEBSOCKETPP_CPP11_TYPE_TRAITS_)
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/bigobj>")
add_compile_options("$<$<CXX_COMPILER_ID:MSVC>:/utf-8>")
SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
endif()
find_package(ZLIB REQUIRED)
file(GLOB SRC_FILES "*.cpp")
add_executable(funasr-http-server ${SRC_FILES} ${RELATION_SOURCE})
target_link_libraries(funasr-http-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})

View File

@ -0,0 +1,20 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// FUNASR_MESSAGE define the needed message between funasr engine and http server
#ifndef HTTP_SERVER2_SESSIONS_HPP
#define HTTP_SERVER2_SESSIONS_HPP
#include "funasrruntime.h"
#include "nlohmann/json.hpp"
#include <atomic>
typedef struct {
nlohmann::json msg;
std::shared_ptr<std::vector<char>> samples;
std::shared_ptr<std::vector<std::vector<float>>> hotwords_embedding=nullptr;
FUNASR_DEC_HANDLE decoder_handle=nullptr;
std::atomic<int> status;
} FUNASR_MESSAGE;
#endif // HTTP_SERVER2_REQUEST_PARSER_HPP

View File

@ -0,0 +1,196 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// connection.cpp
// copy some codes from http://www.boost.org/
#include "connection.hpp"
#include <thread>
#include <utility>
namespace http {
namespace server2 {
//std::ofstream fwout("out.data", std::ios::binary);
std::shared_ptr<FUNASR_MESSAGE> &connection::get_data_msg() { return data_msg; }
connection::connection(asio::ip::tcp::socket socket,
asio::io_context &io_decoder, int connection_id,
std::shared_ptr<ModelDecoder> model_decoder)
: socket_(std::move(socket)),
io_decoder(io_decoder),
connection_id(connection_id),
model_decoder(model_decoder)
{
s_timer = std::make_shared<asio::steady_timer>(io_decoder);
}
void connection::setup_timer() {
if (data_msg->status == 1) return;
s_timer->expires_after(std::chrono::seconds(3));
s_timer->async_wait([=](const asio::error_code &ec) {
if (!ec) {
std::cout << "time is out!" << std::endl;
if (data_msg->status == 1) return;
data_msg->status = 1;
s_timer->cancel();
auto wf = std::bind(&connection::write_back, std::ref(*this), "");
// close the connection
strand_->post(wf);
}
});
}
void connection::start() {
std::lock_guard<std::mutex> lock(m_lock); // for threads safty
try {
data_msg = std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for
// new connection
data_msg->samples = std::make_shared<std::vector<char>>();
//data_msg->samples->reserve(16000*20);
data_msg->msg = nlohmann::json::parse("{}");
data_msg->msg["wav_format"] = "pcm";
data_msg->msg["wav_name"] = "wav-default-id";
data_msg->msg["itn"] = true;
data_msg->msg["audio_fs"] = 16000; // default is 16k
data_msg->msg["access_num"] = 0; // the number of access for this object,
// when it is 0, we can free it saftly
data_msg->msg["is_eof"] = false;
data_msg->status = 0;
strand_ = std::make_shared<asio::io_context::strand>(io_decoder);
FUNASR_DEC_HANDLE decoder_handle = FunASRWfstDecoderInit(
model_decoder->get_asr_handle(), ASR_OFFLINE, global_beam_, lattice_beam_, am_scale_);
data_msg->decoder_handle = decoder_handle;
if (data_msg->hotwords_embedding == nullptr) {
std::unordered_map<std::string, int> merged_hws_map;
std::string nn_hotwords = "";
if (true) {
std::string json_string = "{}";
if (!json_string.empty()) {
nlohmann::json json_fst_hws;
try {
json_fst_hws = nlohmann::json::parse(json_string);
if (json_fst_hws.type() == nlohmann::json::value_t::object) {
// fst
try {
std::unordered_map<std::string, int> client_hws_map =
json_fst_hws;
merged_hws_map.insert(client_hws_map.begin(),
client_hws_map.end());
} catch (const std::exception &e) {
std::cout << e.what();
}
}
} catch (std::exception const &e) {
std::cout << e.what();
// nn
std::string client_nn_hws = "{}";
nn_hotwords += " " + client_nn_hws;
std::cout << "nn hotwords: " << client_nn_hws;
}
}
}
merged_hws_map.insert(hws_map_.begin(), hws_map_.end());
// fst
std::cout << "hotwords: ";
for (const auto &pair : merged_hws_map) {
nn_hotwords += " " + pair.first;
std::cout << pair.first << " : " << pair.second;
}
FunWfstDecoderLoadHwsRes(data_msg->decoder_handle, fst_inc_wts_,
merged_hws_map);
// nn
std::vector<std::vector<float>> new_hotwords_embedding =
CompileHotwordEmbedding(model_decoder->get_asr_handle(), nn_hotwords);
data_msg->hotwords_embedding =
std::make_shared<std::vector<std::vector<float>>>(
new_hotwords_embedding);
}
file_parse = std::make_shared<http::server2::file_parser>(data_msg);
do_read();
} catch (const std::exception &e) {
std::cout << "error:" << e.what();
}
}
void connection::write_back(std::string str) {
s_timer->cancel();
std::cout << "jsonresult=" << data_msg->msg["asr_result"].dump() << std::endl;
reply_ = reply::stock_reply(
data_msg->msg["asr_result"].dump()); // reply::stock_reply();
do_write();
}
void connection::do_read() {
// status==1 means time out
if (data_msg->status == 1) return;
s_timer->cancel();
setup_timer();
auto self(shared_from_this());
socket_.async_read_some(
asio::buffer(buffer_),
[this, self](asio::error_code ec, std::size_t bytes_transferred) {
if (!ec) {
auto is = std::begin(buffer_);
auto ie = std::next(is, bytes_transferred);
http::server2::file_parser::result_type rtype =
file_parse->parse_file(is, ie);
if (rtype == http::server2::file_parser::result_type::ok) {
//fwout.write(data_msg->samples->data(),data_msg->samples->size());
//fwout.flush();
auto wf = std::bind(&connection::write_back, std::ref(*this), "aa");
auto f = std::bind(&ModelDecoder::do_decoder,
std::ref(*model_decoder), std::ref(data_msg));
// for decode task
strand_->post(f);
// for close task
strand_->post(wf);
// std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
}
do_read();
}
});
}
void connection::do_write() {
auto self(shared_from_this());
asio::async_write(socket_, reply_.to_buffers(),
[this, self](asio::error_code ec, std::size_t) {
if (!ec) {
// Initiate graceful connection closure.
asio::error_code ignored_ec;
socket_.shutdown(asio::ip::tcp::socket::shutdown_both,
ignored_ec);
}
// No new asynchronous operations are started. This means
// that all shared_ptr references to the connection object
// will disappear and the object will be destroyed
// automatically after this handler returns. The
// connection class's destructor closes the socket.
});
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,104 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// copy some codes from http://www.boost.org/
//
#ifndef HTTP_SERVER2_CONNECTION_HPP
#define HTTP_SERVER2_CONNECTION_HPP
#include <array>
#include <asio.hpp>
#include <atomic>
#include <iostream>
#include <memory>
#include "reply.hpp"
#include <fstream>
#include "file_parse.hpp"
#include "model-decoder.h"
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
namespace http {
namespace server2 {
/// Represents a single connection from a client.
class connection : public std::enable_shared_from_this<connection> {
public:
connection(const connection &) = delete;
connection &operator=(const connection &) = delete;
~connection() { std::cout << "one connection is close()" << std::endl; };
/// Construct a connection with the given socket.
explicit connection(asio::ip::tcp::socket socket,
asio::io_context &io_decoder, int connection_id,
std::shared_ptr<ModelDecoder> model_decoder);
/// Start the first asynchronous operation for the connection.
void start();
std::shared_ptr<FUNASR_MESSAGE> &get_data_msg();
void write_back(std::string str);
private:
/// Perform an asynchronous read operation.
void do_read();
/// Perform an asynchronous write operation.
void do_write();
void do_decoder();
void setup_timer();
/// Socket for the connection.
asio::ip::tcp::socket socket_;
/// Buffer for incoming data.
std::array<char, 8192> buffer_;
/// for time out
std::shared_ptr<asio::steady_timer> s_timer;
std::shared_ptr<ModelDecoder> model_decoder;
int connection_id = 0;
/// The reply to be sent back to the client.
reply reply_;
asio::io_context &io_decoder;
std::shared_ptr<FUNASR_MESSAGE> data_msg;
std::mutex m_lock;
std::shared_ptr<asio::io_context::strand> strand_;
std::shared_ptr<http::server2::file_parser> file_parse;
};
typedef std::shared_ptr<connection> connection_ptr;
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_CONNECTION_HPP

View File

@ -0,0 +1,29 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
#include "file_parse.hpp"
namespace http {
namespace server2 {
file_parser::file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg)
:data_msg(data_msg)
{
now_state=start;
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,234 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// ~~~~~~~~~~~~~~~~~~
#ifndef HTTP_SERVER2_REQUEST_FILEPARSER_HPP
#define HTTP_SERVER2_REQUEST_FILEPARSER_HPP
#include <iostream>
#include <memory>
#include <tuple>
#include "asr_sessions.h"
namespace http {
namespace server2 {
/// Parser for incoming requests.
class file_parser {
public:
/// Construct ready to parse the request method.
explicit file_parser(std::shared_ptr<FUNASR_MESSAGE> data_msg);
/// Result of parse.
enum result_type { start, in_boundary, data, ok };
template <typename InputIterator>
void parse_one_line(InputIterator &is, InputIterator &ie, InputIterator &it) {
if (is != it) {
is = it;
}
if (*it == '\n') {
is = std::next(is);
}
it = std::find(is, ie, '\n');
std::string str(is, it);
}
std::string trim_name(std::string raw_string) {
int pos = raw_string.find('\"');
if (pos != std::string::npos) {
raw_string = raw_string.substr(pos + 1);
pos = raw_string.find('\"');
raw_string = raw_string.substr(0, pos);
}
return raw_string;
}
std::string parese_file_ext(std::string file_name) {
int pos = file_name.rfind('.');
std::string ext = "";
if (pos != std::string::npos) ext = file_name.substr(pos + 1);
return ext;
}
template <typename InputIterator>
int parse_data_content(InputIterator is, InputIterator ie, InputIterator it) {
int len = std::distance(it + 1, ie);
if (len <= 0) {
return 0;
}
std::string str(it + 1, ie);
// check if at the end, "--boundary--" need +4 for "--"
if (len == boundary.length() + 4)
{
std::string str(it + 1, ie);
// std::cout << "len good=" << str << std::endl;
if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\n') {
// remove '\n' in boundary
boundary = boundary.substr(0, boundary.length() - 2);
}
if (boundary.length() > 1 && boundary[boundary.length() - 1] == '\r') {
// remove '\r' in boundary
boundary = boundary.substr(0, boundary.length() - 2);
}
auto found_boundary = str.find(boundary);
if (found_boundary == std::string::npos) {
std::cout << "not found end boundary!=" << found_boundary << std::endl;
return 0;
}
// remove the end of data that contains '\n' or '\r'
int last_sub = 0;
if (*(it) == '\n') {
last_sub++;
}
int lasts_len = std::distance(it, ie);
data_msg->samples->erase(data_msg->samples->end() - last_sub - lasts_len,
data_msg->samples->end());
std::cout << "one file finished, file size=" << data_msg->samples->size()
<< std::endl;
return 1;
}
}
template <typename InputIterator>
void parse_boundary_content(InputIterator is, InputIterator ie,
InputIterator it) {
parse_one_line(is, ie, it);
std::string str;
while (it != ie) {
str = std::string(is, it);
auto found_content = str.find("Content-Disposition:");
auto found_filename = str.find("filename=");
if (found_content != std::string::npos &&
found_filename != std::string::npos) {
std::string file_name =
str.substr(found_filename + 9, std::string::npos);
file_name = trim_name(file_name);
std::string ext = parese_file_ext(file_name);
if (file_name.find(".wav") != std::string::npos) {
std::cout << "set wav_format=pcm, file_name=" << file_name
<< std::endl;
data_msg->msg["wav_format"] = "pcm";
} else {
std::cout << "set wav_format=" << ext << ", file_name=" << file_name
<< std::endl;
data_msg->msg["wav_format"] = ext;
}
data_msg->msg["wav_name"] = file_name;
now_state = data;
} else {
auto found_content = str.find("Content-Disposition:");
auto found_name = str.find("name=");
if (found_content != std::string::npos &&
found_name != std::string::npos) {
std::string name = str.substr(found_name + 5, std::string::npos);
name = trim_name(name);
parse_one_line(is, ie, it);
if (*it == '\n') it++;
parse_one_line(is, ie, it);
str = std::string(is, it);
std::cout << "para: name=" << name << ",value=" << str << std::endl;
}
}
parse_one_line(is, ie, it);
if (now_state == data && std::distance(is, it) <= 2) {
break;
}
}
if (now_state == data) {
if (*it == '\n') it++;
data_msg->samples->insert(data_msg->samples->end(), it,
it + std::distance(it, ie));
// it=ie;
}
}
template <typename InputIterator>
result_type parse_file(InputIterator is, InputIterator ie) {
if (now_state == data) {
data_msg->samples->insert(data_msg->samples->end(), is, ie);
}
auto it = is;
while (it != ie) {
std::string str(is, it);
parse_one_line(is, ie, it);
if (now_state == data) {
// for data end search
int ret = parse_data_content(is, ie, it);
if (ret == 0) continue;
return ok;
} else {
std::string str(is, it + 1);
if (now_state == start) {
auto found_boundary = str.find("Content-Length:");
if (found_boundary != std::string::npos) {
std::string file_len =
str.substr(found_boundary + 15, std::string::npos);
data_msg->samples->reserve(std::stoi(file_len));
}
found_boundary = str.find("boundary=");
if (found_boundary != std::string::npos) {
boundary = str.substr(found_boundary + 9, std::string::npos);
now_state = in_boundary;
}
} else if (now_state == in_boundary) {
// for file header
auto found_boundary = str.find(boundary);
if (found_boundary != std::string::npos) {
parse_boundary_content(is, ie, it);
}
}
}
}
return now_state;
}
private:
std::shared_ptr<FUNASR_MESSAGE> data_msg;
result_type now_state;
std::string boundary = "";
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_REQUEST_FILEPARSER_HPP

View File

@ -0,0 +1,523 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
#include "funasr-http-main.hpp"
#ifdef _WIN32
#include "win_func.h"
#else
#include <unistd.h>
#endif
#include <fstream>
#include "util.h"
// hotwords
std::unordered_map<std::string, int> hws_map_;
int fst_inc_wts_ = 20;
float global_beam_, lattice_beam_, am_scale_;
using namespace std;
void GetValue(TCLAP::ValueArg<std::string> &value_arg, string key,
std::map<std::string, std::string> &model_path) {
model_path.insert({key, value_arg.getValue()});
LOG(INFO) << key << " : " << value_arg.getValue();
}
FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path,
int thread_num) {
try {
// init model with api
FUNASR_HANDLE asr_handle = FunOfflineInit(model_path, thread_num);
LOG(INFO) << "model successfully inited";
LOG(INFO) << "initAsr run check_and_clean_connection";
// std::thread
// clean_thread(&ModelDecoderSrv::check_and_clean_connection,this);
// clean_thread.detach();
LOG(INFO) << "initAsr run check_and_clean_connection finished";
return asr_handle;
} catch (const std::exception &e) {
LOG(INFO) << e.what();
// return nullptr;
}
}
int main(int argc, char *argv[]) {
#ifdef _WIN32
#include <windows.h>
SetConsoleOutputCP(65001);
#endif
try {
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
std::string offline_version = "";
#ifdef _WIN32
offline_version = "0.1.0";
#endif
TCLAP::CmdLine cmd("funasr-wss-server", ' ', offline_version);
TCLAP::ValueArg<std::string> download_model_dir(
"", "download-model-dir",
"Download model from Modelscope to download_model_dir", false,
"/workspace/models", "string");
TCLAP::ValueArg<std::string> model_dir(
"", OFFLINE_MODEL_DIR,
"default: "
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx, "
"the asr model path, which "
"contains model_quant.onnx, config.yaml, am.mvn",
false,
"damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx",
"string");
TCLAP::ValueArg<std::string> model_revision("", "offline-model-revision",
"ASR offline model revision",
false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> quantize(
"", QUANTIZE,
"true (Default), load the model of model_quant.onnx in model_dir. If "
"set "
"false, load the model of model.onnx in model_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir(
"", VAD_DIR,
"default: damo/speech_fsmn_vad_zh-cn-16k-common-onnx, the vad model "
"path, which contains "
"model_quant.onnx, vad.yaml, vad.mvn",
false, "damo/speech_fsmn_vad_zh-cn-16k-common-onnx", "string");
TCLAP::ValueArg<std::string> vad_revision(
"", "vad-revision", "VAD model revision", false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> vad_quant(
"", VAD_QUANT,
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
"false, load the model of model.onnx in vad_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir(
"", PUNC_DIR,
"default: "
"damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx, "
"the punc model path, which contains "
"model_quant.onnx, punc.yaml",
false,
"damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx",
"string");
TCLAP::ValueArg<std::string> punc_revision(
"", "punc-revision", "PUNC model revision", false, "v2.0.4", "string");
TCLAP::ValueArg<std::string> punc_quant(
"", PUNC_QUANT,
"true (Default), load the model of model_quant.onnx in punc_dir. If "
"set "
"false, load the model of model.onnx in punc_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> itn_dir(
"", ITN_DIR,
"default: thuduj12/fst_itn_zh, the itn model path, which contains "
"zh_itn_tagger.fst, zh_itn_verbalizer.fst",
false, "", "string");
TCLAP::ValueArg<std::string> itn_revision(
"", "itn-revision", "ITN model revision", false, "v1.0.1", "string");
TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
"0.0.0.0", "string");
TCLAP::ValueArg<int> port("", "port", "port", false, 80, "int");
TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
false, 8, "int");
TCLAP::ValueArg<int> decoder_thread_num(
"", "decoder-thread-num", "decoder thread num", false, 32, "int");
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
"model thread num", false, 1, "int");
TCLAP::ValueArg<std::string> certfile(
"", "certfile",
"default: ../../../ssl_key/server.crt, path of certficate for WSS "
"connection. if it is empty, it will be in WS mode.",
false, "../../../ssl_key/server.crt", "string");
TCLAP::ValueArg<std::string> keyfile(
"", "keyfile",
"default: ../../../ssl_key/server.key, path of keyfile for WSS "
"connection",
false, "../../../ssl_key/server.key", "string");
TCLAP::ValueArg<float> global_beam("", GLOB_BEAM,
"the decoding beam for beam searching ",
false, 3.0, "float");
TCLAP::ValueArg<float> lattice_beam(
"", LAT_BEAM, "the lattice generation beam for beam searching ", false,
3.0, "float");
TCLAP::ValueArg<float> am_scale("", AM_SCALE,
"the acoustic scale for beam searching ",
false, 10.0, "float");
TCLAP::ValueArg<std::string> lm_dir(
"", LM_DIR,
"the LM model path, which contains compiled models: TLG.fst, "
"config.yaml ",
false, "", "string");
TCLAP::ValueArg<std::string> lm_revision(
"", "lm-revision", "LM model revision", false, "v1.0.2", "string");
TCLAP::ValueArg<std::string> hotword(
"", HOTWORD,
"the hotword file, one hotword perline, Format: Hotword Weight (could "
"be: 阿里巴巴 20)",
false, "/workspace/resources/hotwords.txt", "string");
TCLAP::ValueArg<std::int32_t> fst_inc_wts(
"", FST_INC_WTS, "the fst hotwords incremental bias", false, 20,
"int32_t");
// add file
cmd.add(hotword);
cmd.add(fst_inc_wts);
cmd.add(global_beam);
cmd.add(lattice_beam);
cmd.add(am_scale);
cmd.add(certfile);
cmd.add(keyfile);
cmd.add(download_model_dir);
cmd.add(model_dir);
cmd.add(model_revision);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_revision);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_revision);
cmd.add(punc_quant);
cmd.add(itn_dir);
cmd.add(itn_revision);
cmd.add(lm_dir);
cmd.add(lm_revision);
cmd.add(listen_ip);
cmd.add(port);
cmd.add(io_thread_num);
cmd.add(decoder_thread_num);
cmd.add(model_thread_num);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(itn_dir, ITN_DIR, model_path);
GetValue(lm_dir, LM_DIR, model_path);
GetValue(hotword, HOTWORD, model_path);
GetValue(model_revision, "model-revision", model_path);
GetValue(vad_revision, "vad-revision", model_path);
GetValue(punc_revision, "punc-revision", model_path);
GetValue(itn_revision, "itn-revision", model_path);
GetValue(lm_revision, "lm-revision", model_path);
global_beam_ = global_beam.getValue();
lattice_beam_ = lattice_beam.getValue();
am_scale_ = am_scale.getValue();
// Download model form Modelscope
try {
std::string s_download_model_dir = download_model_dir.getValue();
std::string s_vad_path = model_path[VAD_DIR];
std::string s_vad_quant = model_path[VAD_QUANT];
std::string s_asr_path = model_path[MODEL_DIR];
std::string s_asr_quant = model_path[QUANTIZE];
std::string s_punc_path = model_path[PUNC_DIR];
std::string s_punc_quant = model_path[PUNC_QUANT];
std::string s_itn_path = model_path[ITN_DIR];
std::string s_lm_path = model_path[LM_DIR];
std::string python_cmd =
"python -m funasr.download.runtime_sdk_download_tool --type onnx "
"--quantize True ";
if (vad_dir.isSet() && !s_vad_path.empty()) {
std::string python_cmd_vad;
std::string down_vad_path;
std::string down_vad_model;
if (access(s_vad_path.c_str(), F_OK) == 0) {
// local
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
" --export-dir ./ " + " --model_revision " +
model_path["vad-revision"];
down_vad_path = s_vad_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: ";
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["vad-revision"];
down_vad_path = s_download_model_dir + "/" + s_vad_path;
}
int ret = system(python_cmd_vad.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local vad model path, you can ignore the errors.";
}
down_vad_model = down_vad_path + "/model_quant.onnx";
if (s_vad_quant == "false" || s_vad_quant == "False" ||
s_vad_quant == "FALSE") {
down_vad_model = down_vad_path + "/model.onnx";
}
if (access(down_vad_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_vad_model << " do not exists.";
exit(-1);
} else {
model_path[VAD_DIR] = down_vad_path;
LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
}
} else {
LOG(INFO) << "VAD model is not set, use default.";
}
if (model_dir.isSet() && !s_asr_path.empty()) {
std::string python_cmd_asr;
std::string down_asr_path;
std::string down_asr_model;
// modify model-revision by model name
size_t found = s_asr_path.find(
"speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-"
"vocab8404");
if (found != std::string::npos) {
model_path["model-revision"] = "v1.2.4";
}
found = s_asr_path.find(
"speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-"
"vocab8404");
if (found != std::string::npos) {
model_path["model-revision"] = "v1.0.5";
}
found = s_asr_path.find(
"speech_paraformer-large_asr_nat-en-16k-common-vocab10020");
if (found != std::string::npos) {
model_path["model-revision"] = "v1.0.0";
s_itn_path = "";
s_lm_path = "";
}
if (access(s_asr_path.c_str(), F_OK) == 0) {
// local
python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
" --export-dir ./ " + " --model_revision " +
model_path["model-revision"];
down_asr_path = s_asr_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: ";
python_cmd_asr = python_cmd + " --model-name " + s_asr_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["model-revision"];
down_asr_path = s_download_model_dir + "/" + s_asr_path;
}
int ret = system(python_cmd_asr.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local asr model path, you can ignore the errors.";
}
down_asr_model = down_asr_path + "/model_quant.onnx";
if (s_asr_quant == "false" || s_asr_quant == "False" ||
s_asr_quant == "FALSE") {
down_asr_model = down_asr_path + "/model.onnx";
}
if (access(down_asr_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_asr_model << " do not exists.";
exit(-1);
} else {
model_path[MODEL_DIR] = down_asr_path;
LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
}
} else {
LOG(INFO) << "ASR model is not set, use default.";
}
if (!s_itn_path.empty()) {
std::string python_cmd_itn;
std::string down_itn_path;
std::string down_itn_model;
if (access(s_itn_path.c_str(), F_OK) == 0) {
// local
python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
" --export-dir ./ " + " --model_revision " +
model_path["itn-revision"] + " --export False ";
down_itn_path = s_itn_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_itn_path
<< " from modelscope : ";
python_cmd_itn = python_cmd + " --model-name " + s_itn_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["itn-revision"] +
" --export False ";
down_itn_path = s_download_model_dir + "/" + s_itn_path;
}
int ret = system(python_cmd_itn.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local itn model path, you can ignore the errors.";
}
down_itn_model = down_itn_path + "/zh_itn_tagger.fst";
if (access(down_itn_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_itn_model << " do not exists.";
exit(-1);
} else {
model_path[ITN_DIR] = down_itn_path;
LOG(INFO) << "Set " << ITN_DIR << " : " << model_path[ITN_DIR];
}
} else {
LOG(INFO) << "ITN model is not set, not executed.";
}
if (!s_lm_path.empty() && s_lm_path != "NONE" && s_lm_path != "none") {
std::string python_cmd_lm;
std::string down_lm_path;
std::string down_lm_model;
if (access(s_lm_path.c_str(), F_OK) == 0) {
// local
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
" --export-dir ./ " + " --model_revision " +
model_path["lm-revision"] + " --export False ";
down_lm_path = s_lm_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_lm_path << " from modelscope : ";
python_cmd_lm = python_cmd + " --model-name " + s_lm_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["lm-revision"] +
" --export False ";
down_lm_path = s_download_model_dir + "/" + s_lm_path;
}
int ret = system(python_cmd_lm.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local lm model path, you can ignore the errors.";
}
down_lm_model = down_lm_path + "/TLG.fst";
if (access(down_lm_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_lm_model << " do not exists.";
exit(-1);
} else {
model_path[LM_DIR] = down_lm_path;
LOG(INFO) << "Set " << LM_DIR << " : " << model_path[LM_DIR];
}
} else {
LOG(INFO) << "LM model is not set, not executed.";
model_path[LM_DIR] = "";
}
if (punc_dir.isSet() && !s_punc_path.empty()) {
std::string python_cmd_punc;
std::string down_punc_path;
std::string down_punc_model;
if (access(s_punc_path.c_str(), F_OK) == 0) {
// local
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
" --export-dir ./ " + " --model_revision " +
model_path["punc-revision"];
down_punc_path = s_punc_path;
} else {
// modelscope
LOG(INFO) << "Download model: " << s_punc_path
<< " from modelscope: ";
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
" --export-dir " + s_download_model_dir +
" --model_revision " + model_path["punc-revision"];
down_punc_path = s_download_model_dir + "/" + s_punc_path;
}
int ret = system(python_cmd_punc.c_str());
if (ret != 0) {
LOG(INFO) << "Failed to download model from modelscope. If you set "
"local punc model path, you can ignore the errors.";
}
down_punc_model = down_punc_path + "/model_quant.onnx";
if (s_punc_quant == "false" || s_punc_quant == "False" ||
s_punc_quant == "FALSE") {
down_punc_model = down_punc_path + "/model.onnx";
}
if (access(down_punc_model.c_str(), F_OK) != 0) {
LOG(ERROR) << down_punc_model << " do not exists.";
exit(-1);
} else {
model_path[PUNC_DIR] = down_punc_path;
LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
}
} else {
LOG(INFO) << "PUNC model is not set, use default.";
}
} catch (std::exception const &e) {
LOG(ERROR) << "Error: " << e.what();
}
std::string s_listen_ip = listen_ip.getValue();
int s_port = port.getValue();
int s_io_thread_num = io_thread_num.getValue();
int s_decoder_thread_num = decoder_thread_num.getValue();
int s_model_thread_num = model_thread_num.getValue();
asio::io_context io_decoder; // context for decoding
std::vector<std::thread> decoder_threads;
// hotword file
std::string hotword_path;
hotword_path = model_path.at(HOTWORD);
fst_inc_wts_ = fst_inc_wts.getValue();
LOG(INFO) << "hotword path: " << hotword_path;
funasr::ExtractHws(hotword_path, hws_map_);
auto conn_guard = asio::make_work_guard(
io_decoder); // make sure threads can wait in the queue
// create threads pool
for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
}
// ModelDecoderSrv modelSrv(
// io_decoder); // websocket server for asr engine
// modelSrv.initAsr(model_path, s_model_thread_num); // init asr model
// FUNASR_HANDLE asr_handle= initAsr();
LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
LOG(INFO) << "io-thread-num: " << s_io_thread_num;
LOG(INFO) << "model-thread-num: " << s_model_thread_num;
http::server2::server s(s_listen_ip, std::to_string(s_port), "./",
s_io_thread_num, io_decoder, model_path,
s_model_thread_num);
s.run();
LOG(INFO) << "http model loop " << s_port;
// wait for theads
for (auto &t : decoder_threads) {
t.join();
}
} catch (std::exception const &e) {
LOG(ERROR) << "Error: " << e.what();
}
return 0;
}

View File

@ -0,0 +1,20 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
#ifndef HTTP_SERVER2_MAIN_HPP
#define HTTP_SERVER2_MAIN_HPP
#include "model-decoder.h"
#include "server.hpp"
namespace http {
namespace server2 {
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_MAIN_HPP

View File

@ -0,0 +1,27 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// header.hpp
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_HEADER_HPP
#define HTTP_SERVER2_HEADER_HPP
#include <string>
namespace http {
namespace server2 {
struct header
{
std::string name;
std::string value;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_HEADER_HPP

View File

@ -0,0 +1,66 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// io_context_pool.cpp
// ~~~~~~~~~~~~~~~~~~~
// copy some codes from http://www.boost.org/
#include "io_context_pool.hpp"
#include <stdexcept>
#include <thread>
namespace http {
namespace server2 {
io_context_pool::io_context_pool(std::size_t pool_size)
: next_io_context_(0)
{
if (pool_size == 0)
throw std::runtime_error("io_context_pool size is 0");
// Give all the io_contexts work to do so that their run() functions will not
// exit until they are explicitly stopped.
for (std::size_t i = 0; i < pool_size; ++i)
{
io_context_ptr io_context(new asio::io_context);
io_contexts_.push_back(io_context);
work_.push_back(asio::make_work_guard(*io_context));
}
}
void io_context_pool::run()
{
// Create a pool of threads to run all of the io_contexts.
std::vector<std::thread> threads;
for (std::size_t i = 0; i < io_contexts_.size(); ++i)
threads.emplace_back([this, i]{ io_contexts_[i]->run(); });
// Wait for all threads in the pool to exit.
for (std::size_t i = 0; i < threads.size(); ++i)
threads[i].join();
}
void io_context_pool::stop()
{
// Explicitly stop all io_contexts.
for (std::size_t i = 0; i < io_contexts_.size(); ++i)
io_contexts_[i]->stop();
}
asio::io_context& io_context_pool::get_io_context()
{
// Use a round-robin scheme to choose the next io_context to use.
asio::io_context& io_context = *io_contexts_[next_io_context_];
++next_io_context_;
if (next_io_context_ == io_contexts_.size())
next_io_context_ = 0;
return io_context;
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,59 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// io_context_pool.hpp
// ~~~~~~~~~~~~~~~~~~~
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_IO_SERVICE_POOL_HPP
#define HTTP_SERVER2_IO_SERVICE_POOL_HPP
#include <asio.hpp>
#include <list>
#include <memory>
#include <vector>
namespace http {
namespace server2 {
/// A pool of io_context objects.
class io_context_pool
{
public:
/// Construct the io_context pool.
explicit io_context_pool(std::size_t pool_size);
/// Run all io_context objects in the pool.
void run();
/// Stop all io_context objects in the pool.
void stop();
/// Get an io_context to use.
asio::io_context& get_io_context();
private:
io_context_pool(const io_context_pool&) = delete;
io_context_pool& operator=(const io_context_pool&) = delete;
typedef std::shared_ptr<::asio::io_context> io_context_ptr;
typedef asio::executor_work_guard<
asio::io_context::executor_type> io_context_work;
/// The pool of io_contexts.
std::vector<io_context_ptr> io_contexts_;
/// The work that keeps the io_contexts running.
std::list<io_context_work> work_;
/// The next io_context to use for a connection.
std::size_t next_io_context_;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_IO_SERVICE_POOL_HPP

View File

@ -0,0 +1,119 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// funasr asr engine
#include "model-decoder.h"
#include <thread>
#include <utility>
#include <vector>
extern std::unordered_map<std::string, int> hws_map_;
extern int fst_inc_wts_;
extern float global_beam_, lattice_beam_, am_scale_;
// feed msg to asr engine for decoder
void ModelDecoder::do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg) {
try {
// std::this_thread::sleep_for(std::chrono::milliseconds(1000*10));
if (session_msg->status == 1) return;
//std::cout << "in do_decoder" << std::endl;
std::shared_ptr<std::vector<char>> buffer = session_msg->samples;
int num_samples = buffer->size(); // the size of the buf
std::string wav_name =session_msg->msg["wav_name"];
bool itn = session_msg->msg["itn"];
int audio_fs = session_msg->msg["audio_fs"];;
std::string wav_format = session_msg->msg["wav_format"];
if (num_samples > 0 && session_msg->hotwords_embedding->size() > 0) {
std::string asr_result = "";
std::string stamp_res = "";
std::string stamp_sents = "";
try {
std::vector<std::vector<float>> hotwords_embedding_(
*(session_msg->hotwords_embedding));
FUNASR_RESULT Result = FunOfflineInferBuffer(
asr_handle, buffer->data(), buffer->size(), RASR_NONE, nullptr,
std::move(hotwords_embedding_), audio_fs, wav_format, itn,
session_msg->decoder_handle);
if (Result != nullptr) {
asr_result = FunASRGetResult(Result, 0); // get decode result
stamp_res = FunASRGetStamp(Result);
stamp_sents = FunASRGetStampSents(Result);
FunASRFreeResult(Result);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
}
} catch (std::exception const &e) {
std::cout << "error in decoder!!! "<<e.what() <<std::endl;
}
nlohmann::json jsonresult; // result json
jsonresult["text"] = asr_result; // put result in 'text'
jsonresult["mode"] = "offline";
jsonresult["is_final"] = false;
if (stamp_res != "") {
jsonresult["timestamp"] = stamp_res;
}
if (stamp_sents != "") {
try {
nlohmann::json json_stamp = nlohmann::json::parse(stamp_sents);
jsonresult["stamp_sents"] = json_stamp;
} catch (std::exception const &e) {
std::cout << "error:" << e.what();
jsonresult["stamp_sents"] = "";
}
}
jsonresult["wav_name"] = wav_name;
std::cout << "buffer.size=" << buffer->size()
<< ",result json=" << jsonresult.dump() << std::endl;
FunWfstDecoderUnloadHwsRes(session_msg->decoder_handle);
FunASRWfstDecoderUninit(session_msg->decoder_handle);
session_msg->status = 1;
session_msg->msg["asr_result"] = jsonresult;
return;
} else {
std::cout << "Sent empty msg";
nlohmann::json jsonresult; // result json
jsonresult["text"] = ""; // put result in 'text'
jsonresult["mode"] = "offline";
jsonresult["is_final"] = false;
jsonresult["wav_name"] = wav_name;
}
} catch (std::exception const &e) {
std::cerr << "Error: " << e.what() << std::endl;
}
}
// init asr model
FUNASR_HANDLE ModelDecoder::initAsr(std::map<std::string, std::string> &model_path,
int thread_num) {
try {
// init model with api
asr_handle = FunOfflineInit(model_path, thread_num);
LOG(INFO) << "model successfully inited";
return asr_handle;
} catch (const std::exception &e) {
LOG(INFO) << e.what();
return nullptr;
}
}

View File

@ -0,0 +1,60 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// funasr asr engine
#ifndef MODEL_DECODER_SERVER_H_
#define MODEL_DECODER_SERVER_H_
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <thread>
#include <unordered_map>
#include <utility>
#define ASIO_STANDALONE 1 // not boost
#include <glog/logging.h>
#include <fstream>
#include <functional>
#include "asio.hpp"
#include "asr_sessions.h"
#include "com-define.h"
#include "funasrruntime.h"
#include "nlohmann/json.hpp"
#include "tclap/CmdLine.h"
#include "util/text-utils.h"
class ModelDecoder {
public:
ModelDecoder(asio::io_context &io_decoder,
std::map<std::string, std::string> &model_path, int thread_num)
: io_decoder_(io_decoder) {
asr_handle = initAsr(model_path, thread_num);
}
void do_decoder(std::shared_ptr<FUNASR_MESSAGE> session_msg);
FUNASR_HANDLE initAsr(std::map<std::string, std::string> &model_path, int thread_num);
asio::io_context &io_decoder_; // threads for asr decoder
FUNASR_HANDLE get_asr_handle()
{
return asr_handle;
}
private:
FUNASR_HANDLE asr_handle; // asr engine handle
bool isonline = false; // online or offline engine, now only support offline
};
#endif // MODEL_DECODER_SERVER_H_

245
runtime/http/bin/reply.cpp Normal file
View File

@ -0,0 +1,245 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// reply.cpp
// ~~~~~~~~~
//
// copy some codes from http://www.boost.org/
#include "reply.hpp"
#include <iostream>
#include <string>
namespace http {
namespace server2 {
namespace status_strings {
const std::string ok = "HTTP/1.0 200 OK\r\n";
const std::string created = "HTTP/1.0 201 Created\r\n";
const std::string accepted = "HTTP/1.0 202 Accepted\r\n";
const std::string no_content = "HTTP/1.0 204 No Content\r\n";
const std::string multiple_choices = "HTTP/1.0 300 Multiple Choices\r\n";
const std::string moved_permanently = "HTTP/1.0 301 Moved Permanently\r\n";
const std::string moved_temporarily = "HTTP/1.0 302 Moved Temporarily\r\n";
const std::string not_modified = "HTTP/1.0 304 Not Modified\r\n";
const std::string bad_request = "HTTP/1.0 400 Bad Request\r\n";
const std::string unauthorized = "HTTP/1.0 401 Unauthorized\r\n";
const std::string forbidden = "HTTP/1.0 403 Forbidden\r\n";
const std::string not_found = "HTTP/1.0 404 Not Found\r\n";
const std::string internal_server_error =
"HTTP/1.0 500 Internal Server Error\r\n";
const std::string not_implemented = "HTTP/1.0 501 Not Implemented\r\n";
const std::string bad_gateway = "HTTP/1.0 502 Bad Gateway\r\n";
const std::string service_unavailable = "HTTP/1.0 503 Service Unavailable\r\n";
asio::const_buffer to_buffer(reply::status_type status) {
switch (status) {
case reply::ok:
return asio::buffer(ok);
case reply::created:
return asio::buffer(created);
case reply::accepted:
return asio::buffer(accepted);
case reply::no_content:
return asio::buffer(no_content);
case reply::multiple_choices:
return asio::buffer(multiple_choices);
case reply::moved_permanently:
return asio::buffer(moved_permanently);
case reply::moved_temporarily:
return asio::buffer(moved_temporarily);
case reply::not_modified:
return asio::buffer(not_modified);
case reply::bad_request:
return asio::buffer(bad_request);
case reply::unauthorized:
return asio::buffer(unauthorized);
case reply::forbidden:
return asio::buffer(forbidden);
case reply::not_found:
return asio::buffer(not_found);
case reply::internal_server_error:
return asio::buffer(internal_server_error);
case reply::not_implemented:
return asio::buffer(not_implemented);
case reply::bad_gateway:
return asio::buffer(bad_gateway);
case reply::service_unavailable:
return asio::buffer(service_unavailable);
default:
return asio::buffer(internal_server_error);
}
}
} // namespace status_strings
namespace misc_strings {
const char name_value_separator[] = {':', ' '};
const char crlf[] = {'\r', '\n'};
} // namespace misc_strings
std::vector<::asio::const_buffer> reply::to_buffers() {
std::vector<::asio::const_buffer> buffers;
buffers.push_back(status_strings::to_buffer(status));
for (std::size_t i = 0; i < headers.size(); ++i) {
header &h = headers[i];
buffers.push_back(asio::buffer(h.name));
buffers.push_back(asio::buffer(misc_strings::name_value_separator));
buffers.push_back(asio::buffer(h.value));
buffers.push_back(asio::buffer(misc_strings::crlf));
}
buffers.push_back(asio::buffer(misc_strings::crlf));
buffers.push_back(asio::buffer(content));
return buffers;
}
namespace stock_replies {
const char ok[] = "";
const char created[] =
"<html>"
"<head><title>Created</title></head>"
"<body><h1>201 Created</h1></body>"
"</html>";
const char accepted[] =
"<html>"
"<head><title>Accepted</title></head>"
"<body><h1>202 Accepted</h1></body>"
"</html>";
const char no_content[] =
"<html>"
"<head><title>No Content</title></head>"
"<body><h1>204 Content</h1></body>"
"</html>";
const char multiple_choices[] =
"<html>"
"<head><title>Multiple Choices</title></head>"
"<body><h1>300 Multiple Choices</h1></body>"
"</html>";
const char moved_permanently[] =
"<html>"
"<head><title>Moved Permanently</title></head>"
"<body><h1>301 Moved Permanently</h1></body>"
"</html>";
const char moved_temporarily[] =
"<html>"
"<head><title>Moved Temporarily</title></head>"
"<body><h1>302 Moved Temporarily</h1></body>"
"</html>";
const char not_modified[] =
"<html>"
"<head><title>Not Modified</title></head>"
"<body><h1>304 Not Modified</h1></body>"
"</html>";
const char bad_request[] =
"<html>"
"<head><title>Bad Request</title></head>"
"<body><h1>400 Bad Request</h1></body>"
"</html>";
const char unauthorized[] =
"<html>"
"<head><title>Unauthorized</title></head>"
"<body><h1>401 Unauthorized</h1></body>"
"</html>";
const char forbidden[] =
"<html>"
"<head><title>Forbidden</title></head>"
"<body><h1>403 Forbidden</h1></body>"
"</html>";
const char not_found[] =
"<html>"
"<head><title>Not Found</title></head>"
"<body><h1>404 Not Found</h1></body>"
"</html>";
const char internal_server_error[] =
"<html>"
"<head><title>Internal Server Error</title></head>"
"<body><h1>500 Internal Server Error</h1></body>"
"</html>";
const char not_implemented[] =
"<html>"
"<head><title>Not Implemented</title></head>"
"<body><h1>501 Not Implemented</h1></body>"
"</html>";
const char bad_gateway[] =
"<html>"
"<head><title>Bad Gateway</title></head>"
"<body><h1>502 Bad Gateway</h1></body>"
"</html>";
const char service_unavailable[] =
"<html>"
"<head><title>Service Unavailable</title></head>"
"<body><h1>503 Service Unavailable</h1></body>"
"</html>";
std::string to_string(reply::status_type status) {
switch (status) {
case reply::ok:
return ok;
case reply::created:
return created;
case reply::accepted:
return accepted;
case reply::no_content:
return no_content;
case reply::multiple_choices:
return multiple_choices;
case reply::moved_permanently:
return moved_permanently;
case reply::moved_temporarily:
return moved_temporarily;
case reply::not_modified:
return not_modified;
case reply::bad_request:
return bad_request;
case reply::unauthorized:
return unauthorized;
case reply::forbidden:
return forbidden;
case reply::not_found:
return not_found;
case reply::internal_server_error:
return internal_server_error;
case reply::not_implemented:
return not_implemented;
case reply::bad_gateway:
return bad_gateway;
case reply::service_unavailable:
return service_unavailable;
default:
return internal_server_error;
}
}
} // namespace stock_replies
reply reply::stock_reply(std::string jsonresult) {
reply rep;
rep.status = reply::ok;
rep.content = jsonresult+"\n";
rep.headers.resize(2);
rep.headers[0].name = "Content-Length";
rep.headers[0].value = std::to_string(rep.content.size());
rep.headers[1].name = "Content-Type";
rep.headers[1].value = "text/html;charset=utf-8";
return rep;
}
reply reply::stock_reply(reply::status_type status) {
reply rep;
rep.status = status;
rep.content = stock_replies::to_string(status);
rep.headers.resize(2);
rep.headers[0].name = "Content-Length";
rep.headers[0].value = std::to_string(rep.content.size());
rep.headers[1].name = "Content-Type";
rep.headers[1].value = "text/html";
return rep;
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,64 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
// reply.hpp
// ~~~~~~~~~
//
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_REPLY_HPP
#define HTTP_SERVER2_REPLY_HPP
#include <asio.hpp>
#include <string>
#include <vector>
#include "header.hpp"
namespace http {
namespace server2 {
/// A reply to be sent to a client.
struct reply {
/// The status of the reply.
enum status_type {
ok = 200,
created = 201,
accepted = 202,
no_content = 204,
multiple_choices = 300,
moved_permanently = 301,
moved_temporarily = 302,
not_modified = 304,
bad_request = 400,
unauthorized = 401,
forbidden = 403,
not_found = 404,
internal_server_error = 500,
not_implemented = 501,
bad_gateway = 502,
service_unavailable = 503
} status;
/// The headers to be included in the reply.
std::vector<header> headers;
/// The content to be sent in the reply.
std::string content;
/// Convert the reply into a vector of buffers. The buffers do not own the
/// underlying memory blocks, therefore the reply object must remain valid and
/// not be changed until the write operation has completed.
std::vector<::asio::const_buffer> to_buffers();
/// Get a stock reply.
static reply stock_reply(status_type status);
static reply stock_reply(std::string jsonresult);
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_REPLY_HPP

113
runtime/http/bin/server.cpp Normal file
View File

@ -0,0 +1,113 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// server.cpp
// copy some codes from http://www.boost.org/
#include "server.hpp"
#include <signal.h>
#include <fstream>
#include <iostream>
#include <utility>
#include "util.h"
namespace http {
namespace server2 {
server::server(const std::string &address, const std::string &port,
const std::string &doc_root, std::size_t io_context_pool_size,
asio::io_context &decoder_context,
std::map<std::string, std::string> &model_path, int thread_num)
: io_context_pool_(io_context_pool_size),
signals_(io_context_pool_.get_io_context()),
acceptor_(io_context_pool_.get_io_context()),
decoder_context(decoder_context) {
// Register to handle the signals that indicate when the server should exit.
// It is safe to register for the same signal multiple times in a program,
// provided all registration for the specified signal is made through Asio.
try {
model_decoder =
std::make_shared<ModelDecoder>(decoder_context, model_path, thread_num);
LOG(INFO) << "try to listen on port:" << port << std::endl;
LOG(INFO) << "still not work, pls wait... " << std::endl;
LOG(INFO) << "if always waiting here, may be port in used, pls change the "
"port or kill pre-process!"
<< std::endl;
atom_id = 0;
// init model with api
signals_.add(SIGINT);
signals_.add(SIGTERM);
#if defined(SIGQUIT)
signals_.add(SIGQUIT);
#endif // defined(SIGQUIT)
do_await_stop();
// Open the acceptor with the option to reuse the address (i.e.
// SO_REUSEADDR).
asio::ip::tcp::resolver resolver(acceptor_.get_executor());
asio::ip::tcp::endpoint endpoint = *resolver.resolve(address, port).begin();
acceptor_.open(endpoint.protocol());
acceptor_.set_option(asio::ip::tcp::acceptor::reuse_address(true));
acceptor_.bind(endpoint);
acceptor_.listen();
do_accept();
std::cout << "use curl to test,just as " << std::endl;
std::cout << "curl -F \"file=@example.wav\" 127.0.0.1:80" << std::endl;
std::cout << "http post only support offline mode, if you want online "
"mode, pls try websocket!"
<< std::endl;
std::cout << "now succeed listen on port " << address << ":" << port
<< ", can accept data now!!!" << std::endl;
} catch (const std::exception &e) {
std::cout << "error:" << e.what();
}
}
void server::run() { io_context_pool_.run(); }
void server::do_accept() {
acceptor_.async_accept(
io_context_pool_.get_io_context(),
[this](asio::error_code ec, asio::ip::tcp::socket socket) {
// Check whether the server was stopped by a signal before this
// completion handler had a chance to run.
if (!acceptor_.is_open()) {
return;
}
if (!ec) {
std::lock_guard<std::mutex> lk(m_lock);
atom_id = atom_id + 1;
std::make_shared<connection>(std::move(socket), decoder_context,
(atom_id).load(), model_decoder)
->start();
}
do_accept();
});
}
void server::do_await_stop() {
signals_.async_wait([this](asio::error_code /*ec*/, int /*signo*/) {
io_context_pool_.stop();
});
}
} // namespace server2
} // namespace http

View File

@ -0,0 +1,71 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
* Reserved. MIT License (https://opensource.org/licenses/MIT)
*/
/* 2023-2024 by zhaomingwork@qq.com */
//
// server.hpp
// ~~~~~~~~~~
// copy some codes from http://www.boost.org/
#ifndef HTTP_SERVER2_SERVER_HPP
#define HTTP_SERVER2_SERVER_HPP
#include <asio.hpp>
#include <atomic>
#include <string>
#include "connection.hpp"
#include "funasrruntime.h"
#include "io_context_pool.hpp"
#include "model-decoder.h"
#include "util.h"
namespace http {
namespace server2 {
/// The top-level class of the HTTP server.
class server {
public:
server(const server &) = delete;
server &operator=(const server &) = delete;
/// Construct the server to listen on the specified TCP address and port, and
/// serve up files from the given directory.
explicit server(const std::string &address, const std::string &port,
const std::string &doc_root, std::size_t io_context_pool_size,
asio::io_context &decoder_context,
std::map<std::string, std::string> &model_path,
int thread_num);
/// Run the server's io_context loop.
void run();
private:
/// Perform an asynchronous accept operation.
void do_accept();
/// Wait for a request to stop the server.
void do_await_stop();
/// The pool of io_context objects used to perform asynchronous operations.
io_context_pool io_context_pool_;
asio::io_context &decoder_context;
/// The signal_set is used to register for process termination notifications.
asio::signal_set signals_;
/// Acceptor used to listen for incoming connections.
asio::ip::tcp::acceptor acceptor_;
std::shared_ptr<ModelDecoder> model_decoder;
std::atomic<int> atom_id;
std::mutex m_lock;
};
} // namespace server2
} // namespace http
#endif // HTTP_SERVER2_SERVER_HPP

58
runtime/http/readme.md Normal file
View File

@ -0,0 +1,58 @@
# Advanced Development Guide (File transcription service) ([click](../docs/SDK_advanced_guide_offline.md))
# Real-time Speech Transcription Service Development Guide ([click](../docs/SDK_advanced_guide_online.md))
# If you want to compile the file yourself, you can follow the steps below.
## Building for Linux/Unix
### Download onnxruntime
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
```
### Download ffmpeg
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
```
### Install deps
```shell
# openblas
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
# openssl
apt-get install libssl-dev #ubuntu
# yum install openssl-devel #centos
```
### Build runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
make -j 4
```
### test
```shell
curl -F \"file=@example.wav\" 127.0.0.1:80
```
### run
```shell
./funasr-http-server \
--lm-dir '' \
--itn-dir '' \
--download-model-dir ${download_model_dir} \
--model-dir ${model_dir} \
--vad-dir ${vad_dir} \
--punc-dir ${punc_dir} \
--decoder-thread-num ${decoder_thread_num} \
--io-thread-num ${io_thread_num} \
--port ${port} \
```

61
runtime/http/readme_zh.md Normal file
View File

@ -0,0 +1,61 @@
# FunASR离线文件转写服务开发指南([点击此处](../docs/SDK_advanced_guide_offline_zh.md))
# FunASR实时语音听写服务开发指南([点击此处](../docs/SDK_advanced_guide_online_zh.md))
# 如果您想自己编译文件,可以参考下述步骤
## Linux/Unix 平台编译
### 下载 onnxruntime
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
```
### 下载 ffmpeg
```shell
wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-master-latest-linux64-gpl-shared.tar.xz
tar -xvf ffmpeg-master-latest-linux64-gpl-shared.tar.xz
```
### 安装依赖
```shell
# openblas
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
# openssl
apt-get install libssl-dev #ubuntu
# yum install openssl-devel #centos
```
### 编译 runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/runtime/http
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-master-latest-linux64-gpl-shared
make -j 4
```
### 测试
```shell
curl -F \"file=@example.wav\" 127.0.0.1:80
```
### 运行
```shell
./funasr-http-server \
--lm-dir '' \
--itn-dir '' \
--download-model-dir ${download_model_dir} \
--model-dir ${model_dir} \
--vad-dir ${vad_dir} \
--punc-dir ${punc_dir} \
--decoder-thread-num ${decoder_thread_num} \
--io-thread-num ${io_thread_num} \
--port ${port} \
```

View File

@ -0,0 +1,15 @@
#### Download onnxruntime
```shell
bash third_party/download_onnxruntime.sh
```
#### Download ffmpeg
```shell
bash third_party/download_ffmpeg.sh
```
#### Install openblas and openssl
```shell
sudo apt-get install libopenblas-dev libssl-dev #ubuntu
# sudo yum -y install openblas-devel openssl-devel #centos
```

View File

@ -4,6 +4,7 @@ project(FunASROnnx)
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(GPU "Whether to build with GPU" OFF)
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
@ -49,6 +50,17 @@ include_directories(${PROJECT_SOURCE_DIR}/third_party/jieba/include/limonp/inclu
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi)
include_directories(${PROJECT_SOURCE_DIR}/third_party/json/include)
if(GPU)
add_definitions(-DUSE_GPU)
set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
include_directories(${TORCH_DIR}/include)
include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
link_directories(${TORCH_DIR}/lib)
link_directories(${TORCH_BLADE_DIR})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
if(ENABLE_GLOG)
include_directories(${PROJECT_SOURCE_DIR}/third_party/glog/src)
set(BUILD_TESTING OFF)

View File

@ -10,33 +10,43 @@ SET(RELATION_SOURCE "../src/resample.cpp" "../src/util.cpp" "../src/alignedmem.c
endif()
add_executable(funasr-onnx-offline "funasr-onnx-offline.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline PUBLIC funasr)
add_executable(funasr-onnx-offline-vad "funasr-onnx-offline-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-vad PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-asr PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-punc PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-offline-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-2pass-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp" ${RELATION_SOURCE})
target_link_options(funasr-onnx-online-rtf PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
# include_directories(${FFMPEG_DIR}/include)

View File

@ -52,7 +52,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, nn_hotwords_);
// warm up
for (size_t i = 0; i < 1; i++)
for (size_t i = 0; i < 10; i++)
{
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs, true, decoder_handle);
if(result){
@ -127,6 +127,7 @@ int main(int argc, char *argv[])
TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
TCLAP::ValueArg<std::string> bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@ -140,11 +141,14 @@ int main(int argc, char *argv[])
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
cmd.add(model_dir);
cmd.add(quantize);
cmd.add(bladedisc);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
@ -159,11 +163,14 @@ int main(int argc, char *argv[])
cmd.add(wav_path);
cmd.add(audio_fs);
cmd.add(thread_num);
cmd.add(use_gpu);
cmd.add(batch_size);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(bladedisc, BLADEDISC, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
@ -175,7 +182,9 @@ int main(int argc, char *argv[])
struct timeval start, end;
gettimeofday(&start, nullptr);
FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1);
bool use_gpu_ = use_gpu.getValue();
int batch_size_ = batch_size.getValue();
FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1, use_gpu_, batch_size_);
if (!asr_handle)
{

View File

@ -19,6 +19,7 @@
#include "com-define.h"
#include <unordered_map>
#include "util.h"
#include "audio.h"
using namespace std;
bool is_target_file(const std::string& filename, const std::string target) {
@ -44,6 +45,7 @@ int main(int argc, char** argv)
TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
TCLAP::ValueArg<std::string> bladedisc("", BLADEDISC, "true (Default), load the model of bladedisc in model_dir.", false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
@ -57,9 +59,12 @@ int main(int argc, char** argv)
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
TCLAP::ValueArg<std::int32_t> audio_fs("", AUDIO_FS, "the sample rate of audio", false, 16000, "int32_t");
TCLAP::ValueArg<std::string> hotword("", HOTWORD, "the hotword file, one hotword perline, Format: Hotword Weight (could be: 阿里巴巴 20)", false, "", "string");
TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU for inference, default is false", false);
TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
cmd.add(model_dir);
cmd.add(quantize);
cmd.add(bladedisc);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
@ -73,11 +78,14 @@ int main(int argc, char** argv)
cmd.add(wav_path);
cmd.add(audio_fs);
cmd.add(hotword);
cmd.add(use_gpu);
cmd.add(batch_size);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(bladedisc, BLADEDISC, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
@ -89,7 +97,9 @@ int main(int argc, char** argv)
struct timeval start, end;
gettimeofday(&start, nullptr);
int thread_num = 1;
FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num);
bool use_gpu_ = use_gpu.getValue();
int batch_size_ = batch_size.getValue();
FUNASR_HANDLE asr_hanlde=FunOfflineInit(model_path, thread_num, use_gpu_, batch_size_);
if (!asr_hanlde)
{
@ -156,7 +166,33 @@ int main(int argc, char** argv)
for (int i = 0; i < wav_list.size(); i++) {
auto& wav_file = wav_list[i];
auto& wav_id = wav_ids[i];
gettimeofday(&start, nullptr);
// For debug:begin
// int32_t sampling_rate_ = audio_fs.getValue();
// funasr::Audio audio(1);
// if(is_target_file(wav_file.c_str(), "wav")){
// if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
// LOG(ERROR)<<"Failed to load "<< wav_file;
// exit(-1);
// }
// }else if(is_target_file(wav_file.c_str(), "pcm")){
// if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
// LOG(ERROR)<<"Failed to load "<< wav_file;
// exit(-1);
// }
// }else{
// if (!audio.FfmpegLoad(wav_file.c_str(), true)){
// LOG(ERROR)<<"Failed to load "<< wav_file;
// exit(-1);
// }
// }
// char* speech_buff = audio.GetSpeechChar();
// int buff_len = audio.GetSpeechLen()*2;
// gettimeofday(&start, nullptr);
// FUNASR_RESULT result=FunOfflineInferBuffer(asr_hanlde, speech_buff, buff_len, RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), "pcm", true, decoder_handle);
// For debug:end
FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, nullptr, hotwords_embedding, audio_fs.getValue(), true, decoder_handle);
gettimeofday(&end, nullptr);
seconds = (end.tv_sec - start.tv_sec);

View File

@ -83,9 +83,11 @@ class DLLAPI Audio {
int FetchTpass(AudioFrame *&frame);
int Fetch(float *&dout, int &len, int &flag);
int Fetch(float *&dout, int &len, int &flag, float &start_time);
int Fetch(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
int FetchDynamic(float **&dout, int *&len, int *&flag, float*& start_time, int batch_size, int &batch_in);
void Padding();
void Split(OfflineStream* offline_streamj);
void CutSplit(OfflineStream* offline_streamj);
void CutSplit(OfflineStream* offline_streamj, std::vector<int> &index_vector);
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
float GetTimeLen();

View File

@ -51,6 +51,15 @@ namespace funasr {
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "am.mvn"
#define VAD_CONFIG_NAME "config.yaml"
// gpu models
#define INFER_GPU "gpu"
#define BATCHSIZE "batch-size"
#define TORCH_MODEL_NAME "model.torchscripts"
#define TORCH_QUANT_MODEL_NAME "model_quant.torchscripts"
#define BLADE_MODEL_NAME "model.blade.fp16.pt"
#define BLADEDISC "bladedisc"
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define LM_CONFIG_NAME "config.yaml"

View File

@ -96,7 +96,7 @@ _FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result);
_FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle);
//OfflineStream
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
_FUNASRAPI void FunOfflineReset(FUNASR_HANDLE handle, FUNASR_DEC_HANDLE dec_handle=nullptr);
// buffer
_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len,
@ -106,9 +106,9 @@ _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char*
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode,
QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb,
int sampling_rate=16000, bool itn=true, FUNASR_DEC_HANDLE dec_handle=nullptr);
#if !defined(__APPLE__)
//#if !defined(__APPLE__)
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode=ASR_OFFLINE);
#endif
//#endif
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);

View File

@ -5,6 +5,10 @@
#include <string>
#include <map>
#include "funasrruntime.h"
#include "vocab.h"
#include "phone-set.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
namespace funasr {
class Model {
public:
@ -18,13 +22,19 @@ class Model {
virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
virtual void InitFstDecoder(){};
virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1)
{return std::vector<string>();};
virtual std::string Rescoring() = 0;
virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
virtual void InitSegDict(const std::string &seg_dict_model){};
virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
virtual std::string GetLang(){return "";};
virtual int GetAsrSampleRate() = 0;
virtual void SetBatchSize(int batch_size) {};
virtual int GetBatchSize() {return 0;};
virtual Vocab* GetVocab() {return nullptr;};
virtual Vocab* GetLmVocab() {return nullptr;};
virtual PhoneSet* GetPhoneSet() {return nullptr;};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);

View File

@ -14,7 +14,7 @@
namespace funasr {
class OfflineStream {
public:
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
~OfflineStream(){};
std::unique_ptr<VadModel> vad_handle= nullptr;
@ -33,6 +33,6 @@ class OfflineStream {
bool use_itn=false;
};
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1);
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num=1, bool use_gpu=false, int batch_size=1);
} // namespace funasr
#endif

View File

@ -1,7 +1,16 @@
file(GLOB files1 "*.cpp")
if(APPLE)
file(GLOB itn_files "itn-*.cpp")
list(REMOVE_ITEM files1 ${itn_files})
endif(APPLE)
list(REMOVE_ITEM files1 "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
set(files ${files1})
if(GPU)
set(files ${files} "${CMAKE_CURRENT_SOURCE_DIR}/paraformer-torch.cpp")
endif()
message("files: "${files})
if(WIN32)
@ -23,9 +32,17 @@ else()
set(EXTRA_LIBS pthread yaml-cpp csrc kaldi-decoder fst glog gflags avutil avcodec avformat swresample)
include_directories(${ONNXRUNTIME_DIR}/include)
include_directories(${FFMPEG_DIR}/include)
if(APPLE)
target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib)
target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib)
endif(APPLE)
endif()
if(GPU)
set(TORCH_DEPS torch torch_cuda torch_cpu c10 c10_cuda torch_blade ral_base_context)
endif()
#message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
include_directories(${CMAKE_SOURCE_DIR}/include)
include_directories(${CMAKE_SOURCE_DIR}/third_party)
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS} ${TORCH_DEPS})

View File

@ -1023,6 +1023,90 @@ int Audio::Fetch(float *&dout, int &len, int &flag, float &start_time)
}
}
int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
batch_in = std::min((int)frame_queue.size(), batch_size);
if (batch_in == 0){
return 0;
} else{
// init
dout = new float*[batch_in];
len = new int[batch_in];
flag = new int[batch_in];
start_time = new float[batch_in];
for(int idx=0; idx < batch_in; idx++){
AudioFrame *frame = frame_queue.front();
frame_queue.pop();
start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
dout[idx] = speech_data + frame->GetStart();
len[idx] = frame->GetLen();
delete frame;
flag[idx] = S_END;
}
return 1;
}
}
int Audio::FetchDynamic(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
{
//compute batch size
queue<AudioFrame *> frame_batch;
int max_acc = 300*1000*seg_sample;
int max_sent = 60*1000*seg_sample;
int bs_acc = 0;
int max_len = 0;
int max_batch = 1;
#ifdef USE_GPU
max_batch = batch_size;
#endif
max_batch = std::min(max_batch, (int)frame_queue.size());
for(int idx=0; idx < max_batch; idx++){
AudioFrame *frame = frame_queue.front();
int length = frame->GetLen();
if(length >= max_sent){
if(bs_acc == 0){
bs_acc++;
frame_batch.push(frame);
frame_queue.pop();
}
break;
}
max_len = std::max(max_len, frame->GetLen());
if(max_len*(bs_acc+1) > max_acc){
break;
}
bs_acc++;
frame_batch.push(frame);
frame_queue.pop();
}
batch_in = (int)frame_batch.size();
if (batch_in == 0){
return 0;
} else{
// init
dout = new float*[batch_in];
len = new int[batch_in];
flag = new int[batch_in];
start_time = new float[batch_in];
for(int idx=0; idx < batch_in; idx++){
AudioFrame *frame = frame_batch.front();
frame_batch.pop();
start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
dout[idx] = speech_data + frame->GetStart();
len[idx] = frame->GetLen();
delete frame;
flag[idx] = S_END;
}
return 1;
}
}
void Audio::Padding()
{
float num_samples = speech_len;
@ -1085,7 +1169,7 @@ void Audio::Split(OfflineStream* offline_stream)
}
}
void Audio::CutSplit(OfflineStream* offline_stream)
void Audio::CutSplit(OfflineStream* offline_stream, std::vector<int> &index_vector)
{
std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
AudioFrame *frame;
@ -1112,6 +1196,7 @@ void Audio::CutSplit(OfflineStream* offline_stream)
}
int speech_start_i = -1, speech_end_i =-1;
std::vector<AudioFrame*> vad_frames;
for(vector<int> vad_segment:vad_segments)
{
if(vad_segment.size() != 2){
@ -1126,16 +1211,31 @@ void Audio::CutSplit(OfflineStream* offline_stream)
}
if(speech_start_i!=-1 && speech_end_i!=-1){
frame = new AudioFrame();
int start = speech_start_i*seg_sample;
int end = speech_end_i*seg_sample;
frame = new AudioFrame(end-start);
frame->SetStart(start);
frame->SetEnd(end);
frame_queue.push(frame);
vad_frames.push_back(frame);
frame = nullptr;
speech_start_i=-1;
speech_end_i=-1;
}
}
// sort
{
index_vector.clear();
index_vector.resize(vad_frames.size());
for (int i = 0; i < index_vector.size(); ++i) {
index_vector[i] = i;
}
std::sort(index_vector.begin(), index_vector.end(), [&vad_frames](const int a, const int b) {
return vad_frames[a]->len < vad_frames[b]->len;
});
for (int idx : index_vector) {
frame_queue.push(vad_frames[idx]);
}
}
}

View File

@ -33,9 +33,9 @@
return mm;
}
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
return mm;
}
@ -74,16 +74,11 @@
if(p_result->snippet_time == 0){
return p_result;
}
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, input_finished);
p_result->msg += msg;
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
return p_result;
}
@ -109,8 +104,6 @@
float* buff;
int len;
int flag = 0;
int n_step = 0;
int n_total = audio.GetQueueSize();
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
if(p_result->snippet_time == 0){
@ -119,11 +112,7 @@
while (audio.Fetch(buff, len, flag) > 0) {
string msg = recog_obj->Forward(buff, len, true);
p_result->msg += msg;
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
return p_result;
}
@ -244,26 +233,53 @@
if(p_result->snippet_time == 0){
return p_result;
}
std::vector<int> index_vector={0};
int msg_idx = 0;
if(offline_stream->UseVad()){
audio.CutSplit(offline_stream);
audio.CutSplit(offline_stream, index_vector);
}
std::vector<string> msgs(index_vector.size());
std::vector<float> msg_stimes(index_vector.size());
float* buff;
int len;
int flag = 0;
float** buff;
int* len;
int* flag;
float* start_time;
int batch_size = offline_stream->asr_handle->GetBatchSize();
int batch_in = 0;
int n_step = 0;
int n_total = audio.GetQueueSize();
float start_time = 0.0;
std::string cur_stamp = "[";
std::string lang = (offline_stream->asr_handle)->GetLang();
while (audio.Fetch(buff, len, flag, start_time) > 0) {
while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
// dec reset
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
for(int idx=0; idx<batch_in; idx++){
string msg = msg_batch[idx];
if(msg_idx < index_vector.size()){
msgs[index_vector[msg_idx]] = msg;
msg_stimes[index_vector[msg_idx]] = start_time[idx];
msg_idx++;
}else{
LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
}
}
// release
delete[] buff;
buff = nullptr;
delete[] len;
len = nullptr;
delete[] flag;
flag = nullptr;
delete[] start_time;
start_time = nullptr;
}
for(int idx=0; idx<msgs.size(); idx++){
string msg = msgs[idx];
std::vector<std::string> msg_vec = funasr::split(msg, '|');
if(msg_vec.size()==0){
continue;
@ -276,14 +292,11 @@
if(msg_vec.size() > 1){
std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
for(int i=0; i<msg_stamp.size()-1; i+=2){
float begin = std::stof(msg_stamp[i])+start_time;
float end = std::stof(msg_stamp[i+1])+start_time;
float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
}
}
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
if(cur_stamp != "["){
cur_stamp.erase(cur_stamp.length() - 1);
@ -342,25 +355,53 @@
if(p_result->snippet_time == 0){
return p_result;
}
std::vector<int> index_vector={0};
int msg_idx = 0;
if(offline_stream->UseVad()){
audio.CutSplit(offline_stream);
audio.CutSplit(offline_stream, index_vector);
}
std::vector<string> msgs(index_vector.size());
std::vector<float> msg_stimes(index_vector.size());
float** buff;
int* len;
int* flag;
float* start_time;
int batch_size = offline_stream->asr_handle->GetBatchSize();
int batch_in = 0;
float* buff;
int len;
int flag = 0;
int n_step = 0;
int n_total = audio.GetQueueSize();
float start_time = 0.0;
std::string cur_stamp = "[";
std::string lang = (offline_stream->asr_handle)->GetLang();
while (audio.Fetch(buff, len, flag, start_time) > 0) {
while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
// dec reset
funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
for(int idx=0; idx<batch_in; idx++){
string msg = msg_batch[idx];
if(msg_idx < index_vector.size()){
msgs[index_vector[msg_idx]] = msg;
msg_stimes[index_vector[msg_idx]] = start_time[idx];
msg_idx++;
}else{
LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
}
}
// release
delete[] buff;
buff = nullptr;
delete[] len;
len = nullptr;
delete[] flag;
flag = nullptr;
delete[] start_time;
start_time = nullptr;
}
for(int idx=0; idx<msgs.size(); idx++){
string msg = msgs[idx];
std::vector<std::string> msg_vec = funasr::split(msg, '|');
if(msg_vec.size()==0){
continue;
@ -373,15 +414,11 @@
if(msg_vec.size() > 1){
std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
for(int i=0; i<msg_stamp.size()-1; i+=2){
float begin = std::stof(msg_stamp[i])+start_time;
float end = std::stof(msg_stamp[i+1])+start_time;
float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
}
}
n_step++;
if (fn_callback)
fn_callback(n_step, n_total);
}
if(cur_stamp != "["){
cur_stamp.erase(cur_stamp.length() - 1);
@ -409,7 +446,7 @@
return p_result;
}
#if !defined(__APPLE__)
//#if !defined(__APPLE__)
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode)
{
if (mode == ASR_OFFLINE){
@ -433,7 +470,7 @@
}
}
#endif
//#endif
// APIs for 2pass-stream Infer
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf,
@ -518,8 +555,14 @@
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
float** buff;
int* len;
buff = new float*[1];
len = new int[1];
buff[0] = frame->data;
len[0] = frame->len;
vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
string msg = msgs.size()>0?msgs[0]:"";
std::vector<std::string> msg_vec = funasr::split(msg, '|'); // split with timestamp
if(msg_vec.size()==0){
continue;
@ -767,16 +810,45 @@
funasr::WfstDecoder* mm = nullptr;
if (asr_type == ASR_OFFLINE) {
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
if(paraformer !=nullptr){
if (paraformer->lm_){
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#ifdef USE_GPU
auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
if(paraformer_torch !=nullptr){
if (paraformer_torch->lm_){
mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#endif
} else if (asr_type == ASR_TWO_PASS){
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
funasr::Paraformer* paraformer = (funasr::Paraformer*)tpass_stream->asr_handle.get();
if (paraformer->lm_)
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
if(paraformer !=nullptr){
if (paraformer->lm_){
mm = new funasr::WfstDecoder(paraformer->lm_.get(),
paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#ifdef USE_GPU
auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
if(paraformer_torch !=nullptr){
if (paraformer_torch->lm_){
mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
}
return mm;
}
#endif
}
return mm;
}

View File

@ -1,7 +1,7 @@
#include "precomp.h"
namespace funasr {
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
// VAD model
if(model_path.find(VAD_DIR) != model_path.end()){
@ -36,7 +36,19 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
string hw_compile_model_path;
string seg_dict_path;
asr_handle = make_unique<Paraformer>();
if(use_gpu){
#ifdef USE_GPU
asr_handle = make_unique<ParaformerTorch>();
asr_handle->SetBatchSize(batch_size);
#else
LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
asr_handle = make_unique<Paraformer>();
use_gpu = false;
#endif
}else{
asr_handle = make_unique<Paraformer>();
}
bool enable_hotword = false;
hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
@ -55,6 +67,15 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
}
if(use_gpu){
am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_MODEL_NAME);
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), TORCH_QUANT_MODEL_NAME);
}
if(model_path.find(BLADEDISC) != model_path.end() && model_path.at(BLADEDISC) == "true"){
am_model_path = PathAppend(model_path.at(MODEL_DIR), BLADE_MODEL_NAME);
}
}
}
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
@ -120,10 +141,10 @@ OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int
#endif
}
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
{
OfflineStream *mm;
mm = new OfflineStream(model_path, thread_num);
mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
return mm;
}

View File

@ -0,0 +1,415 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
#include "paraformer-torch.h"
#include "encode_converter.h"
#include <cstddef>
using namespace std;
namespace funasr {
ParaformerTorch::ParaformerTorch()
:use_hotword(false){
}
// offline
void ParaformerTorch::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){
LoadConfigFromYaml(am_config.c_str());
// knf options
fbank_opts_.frame_opts.dither = 0;
fbank_opts_.mel_opts.num_bins = n_mels;
fbank_opts_.frame_opts.samp_freq = asr_sample_rate;
fbank_opts_.frame_opts.window_type = window_type;
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
fbank_opts_.frame_opts.frame_length_ms = frame_length;
fbank_opts_.energy_floor = 0;
fbank_opts_.mel_opts.debug_mel = false;
vocab = new Vocab(token_file.c_str());
phone_set_ = new PhoneSet(token_file.c_str());
LoadCmvn(am_cmvn.c_str());
torch::DeviceType device = at::kCPU;
#ifdef USE_GPU
if (!torch::cuda::is_available()) {
LOG(ERROR) << "CUDA is not available! Please check your GPU settings";
exit(-1);
} else {
LOG(INFO) << "CUDA is available, running on GPU";
device = at::kCUDA;
}
#endif
#ifdef USE_IPEX
torch::jit::setTensorExprFuserEnabled(false);
#endif
try {
torch::jit::script::Module model = torch::jit::load(am_model, device);
model_ = std::make_shared<TorchModule>(std::move(model));
LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am model: " << am_model << e.what();
exit(-1);
}
}
void ParaformerTorch::InitLm(const std::string &lm_file,
const std::string &lm_cfg_file,
const std::string &lex_file) {
try {
lm_ = std::shared_ptr<fst::Fst<fst::StdArc>>(
fst::Fst<fst::StdArc>::Read(lm_file));
if (lm_){
lm_vocab = new Vocab(lm_cfg_file.c_str(), lex_file.c_str());
LOG(INFO) << "Successfully load lm file " << lm_file;
}else{
LOG(ERROR) << "Failed to load lm file " << lm_file;
}
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load lm file: " << e.what();
exit(0);
}
}
void ParaformerTorch::LoadConfigFromYaml(const char* filename){
YAML::Node config;
try{
config = YAML::LoadFile(filename);
}catch(exception const &e){
LOG(ERROR) << "Error loading file, yaml file error or not exist.";
exit(-1);
}
try{
YAML::Node frontend_conf = config["frontend_conf"];
this->asr_sample_rate = frontend_conf["fs"].as<int>();
YAML::Node lang_conf = config["lang"];
if (lang_conf.IsDefined()){
language = lang_conf.as<string>();
}
}catch(exception const &e){
LOG(ERROR) << "Error when load argument from vad config YAML.";
exit(-1);
}
}
void ParaformerTorch::InitHwCompiler(const std::string &hw_model, int thread_num) {
// TODO
use_hotword = true;
}
void ParaformerTorch::InitSegDict(const std::string &seg_dict_model) {
seg_dict = new SegDict(seg_dict_model.c_str());
}
ParaformerTorch::~ParaformerTorch()
{
if(vocab){
delete vocab;
}
if(lm_vocab){
delete lm_vocab;
}
if(seg_dict){
delete seg_dict;
}
if(phone_set_){
delete phone_set_;
}
}
void ParaformerTorch::StartUtterance()
{
}
void ParaformerTorch::EndUtterance()
{
}
void ParaformerTorch::Reset()
{
}
void ParaformerTorch::FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats) {
knf::OnlineFbank fbank_(fbank_opts_);
std::vector<float> buf(len);
for (int32_t i = 0; i != len; ++i) {
buf[i] = waves[i] * 32768;
}
fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
int32_t frames = fbank_.NumFramesReady();
for (int32_t i = 0; i != frames; ++i) {
const float *frame = fbank_.GetFrame(i);
std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
asr_feats.emplace_back(frame_vector);
}
}
void ParaformerTorch::LoadCmvn(const char *filename)
{
ifstream cmvn_stream(filename);
if (!cmvn_stream.is_open()) {
LOG(ERROR) << "Failed to open file: " << filename;
exit(-1);
}
string line;
while (getline(cmvn_stream, line)) {
istringstream iss(line);
vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
if (line_item[0] == "<AddShift>") {
getline(cmvn_stream, line);
istringstream means_lines_stream(line);
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
if (means_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < means_lines.size() - 1; j++) {
means_list_.push_back(stof(means_lines[j]));
}
continue;
}
}
else if (line_item[0] == "<Rescale>") {
getline(cmvn_stream, line);
istringstream vars_lines_stream(line);
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
if (vars_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < vars_lines.size() - 1; j++) {
vars_list_.push_back(stof(vars_lines[j])*scale);
}
continue;
}
}
}
}
string ParaformerTorch::GreedySearch(float * in, int n_len, int64_t token_nums, bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
{
vector<int> hyps;
int Tmax = n_len;
for (int i = 0; i < Tmax; i++) {
int max_idx;
float max_val;
FindMax(in + i * token_nums, token_nums, max_val, max_idx);
hyps.push_back(max_idx);
}
if(!is_stamp){
return vocab->Vector2StringV2(hyps, language);
}else{
std::vector<string> char_list;
std::vector<std::vector<float>> timestamp_list;
std::string res_str;
vocab->Vector2String(hyps, char_list);
std::vector<string> raw_char(char_list);
TimestampOnnx(us_alphas, us_cif_peak, char_list, res_str, timestamp_list);
return PostProcess(raw_char, timestamp_list);
}
}
string ParaformerTorch::BeamSearch(WfstDecoder* &wfst_decoder, float *in, int len, int64_t token_nums)
{
return wfst_decoder->Search(in, len, token_nums);
}
string ParaformerTorch::FinalizeDecode(WfstDecoder* &wfst_decoder,
bool is_stamp, std::vector<float> us_alphas, std::vector<float> us_cif_peak)
{
return wfst_decoder->FinalizeDecode(is_stamp, us_alphas, us_cif_peak);
}
void ParaformerTorch::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
std::vector<std::vector<float>> out_feats;
int T = asr_feats.size();
int T_lrf = ceil(1.0 * T / lfr_n);
// Pad frames at start(copy first frame)
for (int i = 0; i < (lfr_m - 1) / 2; i++) {
asr_feats.insert(asr_feats.begin(), asr_feats[0]);
}
// Merge lfr_m frames as one,lfr_n frames per window
T = T + (lfr_m - 1) / 2;
std::vector<float> p;
for (int i = 0; i < T_lrf; i++) {
if (lfr_m <= T - i * lfr_n) {
for (int j = 0; j < lfr_m; j++) {
p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
}
out_feats.emplace_back(p);
p.clear();
} else {
// Fill to lfr_m frames at last window if less than lfr_m frames (copy last frame)
int num_padding = lfr_m - (T - i * lfr_n);
for (int j = 0; j < (asr_feats.size() - i * lfr_n); j++) {
p.insert(p.end(), asr_feats[i * lfr_n + j].begin(), asr_feats[i * lfr_n + j].end());
}
for (int j = 0; j < num_padding; j++) {
p.insert(p.end(), asr_feats[asr_feats.size() - 1].begin(), asr_feats[asr_feats.size() - 1].end());
}
out_feats.emplace_back(p);
p.clear();
}
}
// Apply cmvn
for (auto &out_feat: out_feats) {
for (int j = 0; j < means_list_.size(); j++) {
out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
}
}
asr_feats = out_feats;
}
std::vector<std::string> ParaformerTorch::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
int32_t feature_dim = lfr_m*in_feat_dim;
std::vector<vector<float>> feats_batch;
std::vector<int32_t> paraformer_length;
int max_size = 0;
int max_frames = 0;
for(int index=0; index<batch_in; index++){
std::vector<std::vector<float>> asr_feats;
FbankKaldi(asr_sample_rate, din[index], len[index], asr_feats);
if(asr_feats.size() != 0){
LfrCmvn(asr_feats);
}
int32_t num_frames = asr_feats.size();
paraformer_length.emplace_back(num_frames);
if(max_size < asr_feats.size()*feature_dim){
max_size = asr_feats.size()*feature_dim;
max_frames = num_frames;
}
std::vector<float> flattened;
for (const auto& sub_vector : asr_feats) {
flattened.insert(flattened.end(), sub_vector.begin(), sub_vector.end());
}
feats_batch.emplace_back(flattened);
}
torch::NoGradGuard no_grad;
model_->eval();
// padding
std::vector<float> all_feats(batch_in * max_frames * feature_dim);
for(int index=0; index<batch_in; index++){
feats_batch[index].resize(max_size);
std::memcpy(&all_feats[index * max_frames * feature_dim], feats_batch[index].data(),
max_frames * feature_dim * sizeof(float));
}
torch::Tensor feats =
torch::from_blob(all_feats.data(),
{batch_in, max_frames, feature_dim}, torch::kFloat).contiguous();
torch::Tensor feat_lens = torch::from_blob(paraformer_length.data(),
{batch_in}, torch::kInt32);
// 2. forward
#ifdef USE_GPU
feats = feats.to(at::kCUDA);
feat_lens = feat_lens.to(at::kCUDA);
#endif
std::vector<torch::jit::IValue> inputs = {feats, feat_lens};
vector<std::string> results;
try {
auto outputs = model_->forward(inputs).toTuple()->elements();
torch::Tensor am_scores;
torch::Tensor valid_token_lens;
#ifdef USE_GPU
am_scores = outputs[0].toTensor().to(at::kCPU);
valid_token_lens = outputs[1].toTensor().to(at::kCPU);
#else
am_scores = outputs[0].toTensor();
valid_token_lens = outputs[1].toTensor();
#endif
// timestamp
for(int index=0; index<batch_in; index++){
string result="";
if(outputs.size() == 4){
torch::Tensor us_alphas_tensor;
torch::Tensor us_peaks_tensor;
#ifdef USE_GPU
us_alphas_tensor = outputs[2].toTensor().to(at::kCPU);
us_peaks_tensor = outputs[3].toTensor().to(at::kCPU);
#else
us_alphas_tensor = outputs[2].toTensor();
us_peaks_tensor = outputs[3].toTensor();
#endif
float* us_alphas_data = us_alphas_tensor[index].data_ptr<float>();
std::vector<float> us_alphas(paraformer_length[index]);
for (int i = 0; i < us_alphas.size(); i++) {
us_alphas[i] = us_alphas_data[i];
}
float* us_peaks_data = us_peaks_tensor[index].data_ptr<float>();
std::vector<float> us_peaks(paraformer_length[index]);
for (int i = 0; i < us_peaks.size(); i++) {
us_peaks[i] = us_peaks_data[i];
}
if (lm_ == nullptr) {
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2), true, us_alphas, us_peaks);
} else {
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
if (input_finished) {
result = FinalizeDecode(wfst_decoder, true, us_alphas, us_peaks);
}
}
}else{
if (lm_ == nullptr) {
result = GreedySearch(am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
} else {
result = BeamSearch(wfst_decoder, am_scores[index].data_ptr<float>(), valid_token_lens[index].item<int>(), am_scores.size(2));
if (input_finished) {
result = FinalizeDecode(wfst_decoder);
}
}
}
results.push_back(result);
if (wfst_decoder){
wfst_decoder->StartUtterance();
}
}
}
catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
}
return results;
}
std::vector<std::vector<float>> ParaformerTorch::CompileHotwordEmbedding(std::string &hotwords) {
// TODO
std::vector<std::vector<float>> result(1, std::vector<float>(10, 0.0f));
return result;
}
Vocab* ParaformerTorch::GetVocab()
{
return vocab;
}
Vocab* ParaformerTorch::GetLmVocab()
{
return lm_vocab;
}
PhoneSet* ParaformerTorch::GetPhoneSet()
{
return phone_set_;
}
string ParaformerTorch::Rescoring()
{
LOG(ERROR)<<"Not Imp!!!!!!";
return "";
}
} // namespace funasr

View File

@ -0,0 +1,96 @@
/**
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
* MIT License (https://opensource.org/licenses/MIT)
*/
#pragma once
#define C10_USE_GLOG
#include <torch/serialize.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include "precomp.h"
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "bias-lm.h"
#include "phone-set.h"
namespace funasr {
class ParaformerTorch : public Model {
/**
* Author: Speech Lab of DAMO Academy, Alibaba Group
* Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
* https://arxiv.org/pdf/2206.08317.pdf
*/
private:
Vocab* vocab = nullptr;
Vocab* lm_vocab = nullptr;
SegDict* seg_dict = nullptr;
PhoneSet* phone_set_ = nullptr;
//const float scale = 22.6274169979695;
const float scale = 1.0;
void LoadConfigFromYaml(const char* filename);
void LoadCmvn(const char *filename);
void LfrCmvn(std::vector<std::vector<float>> &asr_feats);
using TorchModule = torch::jit::script::Module;
std::shared_ptr<TorchModule> model_ = nullptr;
std::vector<torch::Tensor> encoder_outs_;
bool use_hotword;
public:
ParaformerTorch();
~ParaformerTorch();
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num);
void InitHwCompiler(const std::string &hw_model, int thread_num);
void InitSegDict(const std::string &seg_dict_model);
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
void Reset();
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
string GreedySearch( float* in, int n_len, int64_t token_nums,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
void SetBatchSize(int batch_size) {batch_size_ = batch_size;};
int GetBatchSize() {return batch_size_;};
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
string BeamSearch(WfstDecoder* &wfst_decoder, float* in, int n_len, int64_t token_nums);
string FinalizeDecode(WfstDecoder* &wfst_decoder,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
Vocab* GetVocab();
Vocab* GetLmVocab();
PhoneSet* GetPhoneSet();
knf::FbankOptions fbank_opts_;
vector<float> means_list_;
vector<float> vars_list_;
int lfr_m = PARA_LFR_M;
int lfr_n = PARA_LFR_N;
// paraformer-offline
std::string language="zh-cn";
// lm
std::shared_ptr<fst::Fst<fst::StdArc>> lm_ = nullptr;
string window_type = "hamming";
int frame_length = 25;
int frame_shift = 10;
int n_mels = 80;
int encoder_size = 512;
int fsmn_layers = 16;
int fsmn_lorder = 10;
int fsmn_dims = 512;
float cif_threshold = 1.0;
float tail_alphas = 0.45;
int asr_sample_rate = MODEL_SAMPLE_RATE;
int batch_size_ = 1;
};
} // namespace funasr

View File

@ -462,15 +462,23 @@ void Paraformer::LfrCmvn(std::vector<std::vector<float>> &asr_feats) {
asr_feats = out_feats;
}
string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle)
std::vector<std::string> Paraformer::Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb, void* decoder_handle, int batch_in)
{
std::vector<std::string> results;
string result="";
WfstDecoder* wfst_decoder = (WfstDecoder*)decoder_handle;
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
if(batch_in != 1){
results.push_back(result);
return results;
}
std::vector<std::vector<float>> asr_feats;
FbankKaldi(asr_sample_rate, din, len, asr_feats);
FbankKaldi(asr_sample_rate, din[0], len[0], asr_feats);
if(asr_feats.size() == 0){
return "";
results.push_back(result);
return results;
}
LfrCmvn(asr_feats);
int32_t feat_dim = lfr_m*in_feat_dim;
@ -509,7 +517,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
if (use_hotword) {
if(hw_emb.size()<=0){
LOG(ERROR) << "hw_emb is null";
return "";
results.push_back(result);
return results;
}
//PrintMat(hw_emb, "input_clas_emb");
const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
@ -526,10 +535,10 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
}catch (std::exception const &e)
{
LOG(ERROR)<<e.what();
return "";
results.push_back(result);
return results;
}
string result="";
try {
auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
@ -577,7 +586,8 @@ string Paraformer::Forward(float* din, int len, bool input_finished, const std::
LOG(ERROR)<<e.what();
}
return result;
results.push_back(result);
return results;
}

View File

@ -52,13 +52,14 @@ namespace funasr {
std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
void Reset();
void FbankKaldi(float sample_rate, const float* waves, int len, std::vector<std::vector<float>> &asr_feats);
string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr);
std::vector<std::string> Forward(float** din, int* len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1);
string GreedySearch( float* in, int n_len, int64_t token_nums,
bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
string Rescoring();
string GetLang(){return language;};
int GetAsrSampleRate() { return asr_sample_rate; };
int GetBatchSize() {return batch_size_;};
void StartUtterance();
void EndUtterance();
void InitLm(const std::string &lm_file, const std::string &lm_cfg_file, const std::string &lex_file);
@ -110,6 +111,7 @@ namespace funasr {
float cif_threshold = 1.0;
float tail_alphas = 0.45;
int asr_sample_rate = MODEL_SAMPLE_RATE;
int batch_size_ = 1;
};
} // namespace funasr

View File

@ -64,6 +64,9 @@ using namespace std;
#include "seg_dict.h"
#include "resample.h"
#include "paraformer.h"
#ifdef USE_GPU
#include "paraformer-torch.h"
#endif
#include "paraformer-online.h"
#include "offline-stream.h"
#include "tpass-stream.h"

View File

@ -70,13 +70,13 @@ ostream& operator << (ostream& os, const deque<T>& dq) {
return os;
}
#ifndef USE_GPU
template<class T1, class T2>
ostream& operator << (ostream& os, const pair<T1, T2>& pr) {
os << pr.first << ":" << pr.second ;
return os;
}
#endif
template<class T>
string& operator << (string& str, const T& obj) {

View File

@ -326,6 +326,9 @@ class ContextualParaformer(Paraformer):
def __call__(
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
) -> List:
# def __call__(
# self, waveform_list:list, hotwords: str, **kwargs
# ) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
# import pdb; pdb.set_trace()
@ -345,15 +348,47 @@ class ContextualParaformer(Paraformer):
try:
outputs = self.bb_infer(feats, feats_len, bias_embed)
am_scores, valid_token_lens = outputs[0], outputs[1]
if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_peaks = outputs[2], outputs[3]
else:
us_alphas, us_peaks = None, None
except ONNXRuntimeError:
# logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
for pred in preds:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
if us_peaks is None:
for pred in preds:
if self.language == "en-bpe":
pred = sentence_postprocess_sentencepiece(pred)
else:
pred = sentence_postprocess(pred)
asr_res.append({"preds": pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
raw_tokens = pred
timestamp, timestamp_raw = time_stamp_lfr6_onnx(
us_peaks_, copy.copy(raw_tokens)
)
text_proc, timestamp_proc, _ = sentence_postprocess(
raw_tokens, timestamp_raw
)
# logging.warning(timestamp)
if len(self.plot_timestamp_to):
self.plot_wave_timestamp(
waveform_list[0], timestamp, self.plot_timestamp_to
)
asr_res.append(
{
"preds": text_proc,
"timestamp": timestamp_proc,
"raw_tokens": raw_tokens,
}
)
return asr_res
def proc_hotword(self, hotwords):

View File

@ -8,6 +8,10 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
option(ENABLE_WEBSOCKET "Whether to build websocket server" ON)
option(ENABLE_PORTAUDIO "Whether to build portaudio" ON)
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(GPU "Whether to build with GPU" OFF)
if(WIN32)
file(REMOVE ${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog/src/config.h
@ -20,12 +24,16 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -fPIC")
endif()
option(ENABLE_GLOG "Whether to build glog" ON)
option(ENABLE_FST "Whether to build openfst" ON) # ITN need openfst compiled
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
if(GPU)
add_definitions(-DUSE_GPU)
set(TORCH_DIR "/usr/local/lib/python3.8/dist-packages/torch")
set(TORCH_BLADE_DIR "/usr/local/lib/python3.8/dist-packages/torch_blade")
include_directories(${TORCH_DIR}/include)
include_directories(${TORCH_DIR}/include/torch/csrc/api/include)
link_directories(${TORCH_DIR}/lib)
link_directories(${TORCH_BLADE_DIR})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -D_GLIBCXX_USE_CXX11_ABI=0")
endif()
if(ENABLE_WEBSOCKET)
# cmake_policy(SET CMP0135 NEW)

View File

@ -1,5 +1,4 @@
if(WIN32)
include_directories(${ONNXRUNTIME_DIR}/include)
include_directories(${FFMPEG_DIR}/include)
@ -12,15 +11,14 @@ if(WIN32)
SET(RELATION_SOURCE "../../onnxruntime/src/resample.cpp" "../../onnxruntime/src/util.cpp" "../../onnxruntime/src/alignedmem.cpp" "../../onnxruntime/src/encode_converter.cpp")
endif()
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client "funasr-wss-client.cpp" ${RELATION_SOURCE})
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp" "microphone.cpp" ${RELATION_SOURCE})
target_link_options(funasr-wss-server PRIVATE "-Wl,--no-as-needed")
target_link_options(funasr-wss-server-2pass PRIVATE "-Wl,--no-as-needed")
target_link_libraries(funasr-wss-client PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY} portaudio)
target_link_libraries(funasr-wss-server PUBLIC funasr ${OPENSSL_CRYPTO_LIBRARY} ${OPENSSL_SSL_LIBRARY})

View File

@ -56,6 +56,10 @@ int main(int argc, char* argv[]) {
"true (Default), load the model of model_quant.onnx in model_dir. If set "
"false, load the model of model.onnx in model_dir",
false, "true", "string");
TCLAP::ValueArg<std::string> bladedisc(
"", BLADEDISC,
"true (Default), load the model of bladedisc in model_dir.",
false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir(
"", VAD_DIR,
"default: /workspace/models/vad, the vad model path, which contains model_quant.onnx, vad.yaml, vad.mvn",
@ -121,6 +125,8 @@ int main(int argc, char* argv[]) {
false, "/workspace/resources/hotwords.txt", "string");
TCLAP::ValueArg<std::int32_t> fst_inc_wts("", FST_INC_WTS,
"the fst hotwords incremental bias", false, 20, "int32_t");
TCLAP::SwitchArg use_gpu("", INFER_GPU, "Whether to use GPU, default is false", false);
TCLAP::ValueArg<std::int32_t> batch_size("", BATCHSIZE, "batch_size for ASR model when using GPU", false, 4, "int32_t");
// add file
cmd.add(hotword);
@ -135,6 +141,7 @@ int main(int argc, char* argv[]) {
cmd.add(model_dir);
cmd.add(model_revision);
cmd.add(quantize);
cmd.add(bladedisc);
cmd.add(vad_dir);
cmd.add(vad_revision);
cmd.add(vad_quant);
@ -151,11 +158,14 @@ int main(int argc, char* argv[]) {
cmd.add(io_thread_num);
cmd.add(decoder_thread_num);
cmd.add(model_thread_num);
cmd.add(use_gpu);
cmd.add(batch_size);
cmd.parse(argc, argv);
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(bladedisc, BLADEDISC, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
@ -173,6 +183,8 @@ int main(int argc, char* argv[]) {
global_beam_ = global_beam.getValue();
lattice_beam_ = lattice_beam.getValue();
am_scale_ = am_scale.getValue();
bool use_gpu_ = use_gpu.getValue();
int batch_size_ = batch_size.getValue();
// Download model form Modelscope
try{
@ -468,7 +480,7 @@ int main(int argc, char* argv[]) {
WebSocketServer websocket_srv(
io_decoder, is_ssl, server, wss_server, s_certfile,
s_keyfile); // websocket server for asr engine
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
websocket_srv.initAsr(model_path, s_model_thread_num, use_gpu_, batch_size_); // init asr model
LOG(INFO) << "decoder-thread-num: " << s_decoder_thread_num;
LOG(INFO) << "io-thread-num: " << s_io_thread_num;

View File

@ -402,11 +402,11 @@ void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
// init asr model
void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
int thread_num) {
int thread_num, bool use_gpu, int batch_size) {
try {
// init model with api
asr_handle = FunOfflineInit(model_path, thread_num);
asr_handle = FunOfflineInit(model_path, thread_num, use_gpu, batch_size);
LOG(INFO) << "model successfully inited";
LOG(INFO) << "initAsr run check_and_clean_connection";

View File

@ -124,7 +124,7 @@ class WebSocketServer {
std::string wav_format,
FUNASR_DEC_HANDLE& decoder_handle);
void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
void initAsr(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu=false, int batch_size=1);
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
void on_open(websocketpp::connection_hdl hdl);
void on_close(websocketpp::connection_hdl hdl);

View File

@ -39,7 +39,7 @@ requirements = {
"jaconv",
"hydra-core>=1.3.2",
"tensorboardX",
"rotary_embedding_torch",
# "rotary_embedding_torch",
"openai-whisper",
],
# train: The modules invoked when training only.