Merge pull request #363 from alibaba-damo-academy/main

update with main
This commit is contained in:
zhifu gao 2023-04-16 22:29:32 +08:00 committed by GitHub
commit 937e507977
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
133 changed files with 4187 additions and 1518 deletions

View File

@ -7,12 +7,11 @@
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
| [**Highlights**](#highlights)
| [**Installation**](#installation)
| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
| [**Model Zoo**](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/modelscope_models.md)
| [**Contact**](#contact)
@ -29,15 +28,37 @@ For the release notes, please ref to [news](https://github.com/alibaba-damo-acad
## Installation
``` sh
pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install --editable ./
Install from pip
```shell
pip install -U funasr
# For the users in China, you could install with the command:
# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
Or install from source code
``` sh
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install -e ./
# For the users in China, you could install with the command:
# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
If you want to use the pretrained models in ModelScope, you should install the modelscope:
```shell
pip install -U modelscope
# For the users in China, you could install with the command:
# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
```
For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
## Usage
For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
[//]: # ()
[//]: # (## Usage)
[//]: # (For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html)))
## Contact

View File

@ -6,29 +6,70 @@
## Model Zoo
Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
| Datasets | Hours | Model | Online/Offline | Language | Framework | Checkpoint |
|:-----:|:-----:|:--------------:|:--------------:| :---: | :---: | --- |
| Alibaba Speech Data | 60000 | Paraformer | Offline | CN | Pytorch |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) |
| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
| Alibaba Speech Data | 50000 | Paraformer | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online](http://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online/summary) |
| Alibaba Speech Data | 50000 | UniASR | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data | 50000 | UniASR | Offline | CN | Tensorflow |[speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data | 50000 | UniASR | Online | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data | 50000 | UniASR | Offline | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data | 20000 | UniASR | Online | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data | 20000 | UniASR | Offline | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online/summary) |
| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/summary) |
| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/summary) |
| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/summary) |
| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
| AISHELL-1 | 178 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) |
| AISHELL-2 | 1000 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell2-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell2-pytorch/summary) |
| AISHELL-1 | 178 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
| AISHELL-2 | 1000 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
| AISHELL-1 | 178 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
| AISHELL-2 | 1000 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
### Speech Recognition Models
#### Paraformer Models
| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav |
| [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
| [Paraformer-online](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |
| [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition |
| [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
| [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
| [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
| [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
#### UniASR Models
| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
|:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
| [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 100M | Online | UniASR streaming offline unifying models |
| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models |
| [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | Burmese | Alibaba Speech Data (? hours) | 696 | 95M | Online | UniASR streaming offline unifying models |
| [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | Hebrew | Alibaba Speech Data (? hours) | 1085 | 95M | Online | UniASR streaming offline unifying models |
| [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | Urdu | Alibaba Speech Data (? hours) | 877 | 95M | Online | UniASR streaming offline unifying models |
#### Conformer Models
#### Paraformer Models
| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s |
| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s |
#### RNN-T Models
### Voice Activity Detection Models
| Model Name | Training Data | Parameters | Sampling Rate | Notes |
|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | |
| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | |
### Punctuation Restoration Models
| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
### Language Models
| Model Name | Training Data | Parameters | Vocab Size | Notes |
|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | |
### Speaker Verification Models
| Model Name | Training Data | Parameters | Vocab Size | Notes |
|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (?hours) | 17.5M | 3465 | |
| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (?hours) | 61M | 6135 | |
### Speaker diarization Models
| Model Name | Training Data | Parameters | Notes |
|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (?hours) | 40.5M | |
| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (?hours) | 12M | |

View File

@ -45,8 +45,8 @@ def compute_wer(ref_file,
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)

View File

@ -74,7 +74,7 @@ def modelscope_infer(params):
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(best_recog_path, "token")
text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))

View File

@ -38,7 +38,7 @@ def modelscope_infer_after_finetune(params):
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))

View File

@ -74,7 +74,7 @@ def modelscope_infer(params):
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(best_recog_path, "token")
text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))

View File

@ -38,7 +38,7 @@ def modelscope_infer_after_finetune(params):
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))

View File

@ -63,8 +63,8 @@ fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
echo "Computing WER ..."
python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
tail -n 3 ${output_dir}/1best_recog/text.cer
fi

View File

@ -34,7 +34,7 @@ def modelscope_infer_after_finetune(params):
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))

View File

@ -63,8 +63,8 @@ fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
echo "Computing WER ..."
python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
tail -n 3 ${output_dir}/1best_recog/text.cer
fi

View File

@ -34,7 +34,7 @@ def modelscope_infer_after_finetune(params):
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))

View File

@ -23,8 +23,7 @@ def modelscope_infer_core(output_dir, split_dir, njob, idx):
batch_size=1
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
inference_pipline(audio_in=audio_in)
def modelscope_infer(params):
# prepare for multi-GPU decoding
@ -75,7 +74,7 @@ def modelscope_infer(params):
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(best_recog_path, "token")
text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))

View File

@ -2,52 +2,103 @@ import json
import os
import shutil
from multiprocessing import Pool
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from funasr.utils.compute_wer import compute_wer
def modelscope_infer_after_finetune_core(model_dir, output_dir, split_dir, njob, idx):
output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
gpu_id = (int(idx) - 1) // njob
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model=model_dir,
output_dir=output_dir_job,
batch_size=1
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
inference_pipeline(audio_in=audio_in)
def modelscope_infer_after_finetune(params):
# prepare for decoding
# prepare for multi-GPU decoding
model_dir = params["model_dir"]
pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
for file_name in params["required_files"]:
if file_name == "configuration.json":
with open(os.path.join(pretrained_model_path, file_name)) as f:
config_dict = json.load(f)
config_dict["model"]["am_model_name"] = params["decoding_model_name"]
with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
with open(os.path.join(model_dir, "configuration.json"), "w") as f:
json.dump(config_dict, f, indent=4, separators=(',', ': '))
else:
shutil.copy(os.path.join(pretrained_model_path, file_name),
os.path.join(params["output_dir"], file_name))
decoding_path = os.path.join(params["output_dir"], "decode_results")
if os.path.exists(decoding_path):
shutil.rmtree(decoding_path)
os.mkdir(decoding_path)
os.path.join(model_dir, file_name))
ngpu = params["ngpu"]
njob = params["njob"]
output_dir = params["output_dir"]
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.mkdir(output_dir)
split_dir = os.path.join(output_dir, "split")
os.mkdir(split_dir)
nj = ngpu * njob
wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
with open(wav_scp_file) as f:
lines = f.readlines()
num_lines = len(lines)
num_job_lines = num_lines // nj
start = 0
for i in range(nj):
end = start + num_job_lines
file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
with open(file, "w") as f:
if i == nj - 1:
f.writelines(lines[start:])
else:
f.writelines(lines[start:end])
start = end
# decoding
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model=params["output_dir"],
output_dir=decoding_path,
batch_size=1
)
audio_in = os.path.join(params["data_dir"], "wav.scp")
inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
p = Pool(nj)
for i in range(nj):
p.apply_async(modelscope_infer_after_finetune_core,
args=(model_dir, output_dir, split_dir, njob, str(i + 1)))
p.close()
p.join()
# computer CER if GT text is set
# combine decoding results
best_recog_path = os.path.join(output_dir, "1best_recog")
os.mkdir(best_recog_path)
files = ["text", "token", "score"]
for file in files:
with open(os.path.join(best_recog_path, file), "w") as f:
for i in range(nj):
job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
with open(job_file) as f_job:
lines = f_job.readlines()
f.writelines(lines)
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
text_proc_file = os.path.join(best_recog_path, "token")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
if __name__ == '__main__':
params = {}
params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline"
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
params["output_dir"] = "./checkpoint"
params["model_dir"] = "./checkpoint"
params["output_dir"] = "./results"
params["data_dir"] = "./data/test"
params["decoding_model_name"] = "20epoch.pb"
params["ngpu"] = 1
params["njob"] = 1
modelscope_infer_after_finetune(params)

View File

@ -75,7 +75,7 @@ def modelscope_infer(params):
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(best_recog_path, "token")
text_proc_file = os.path.join(best_recog_path, "text")
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))

View File

@ -39,7 +39,7 @@ def modelscope_infer_after_finetune(params):
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))

View File

@ -1,3 +1,9 @@
"""
Author: Speech Lab, Alibaba Group, China
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

View File

@ -0,0 +1,32 @@
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
"""
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
# 初始化推理 pipeline
# 当以原始音频作为输入时使用配置文件 sond.yaml并设置 mode 为sond_demo
inference_diar_pipline = pipeline(
mode="sond_demo",
num_workers=0,
task=Tasks.speaker_diarization,
diar_model_config="sond.yaml",
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
sv_model_revision="master",
)
# 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音
audio_list = [
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav",
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav",
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk2.wav",
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk3.wav",
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk4.wav",
]
results = inference_diar_pipline(audio_in=audio_list)
print(results)

View File

@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os

View File

@ -797,7 +797,7 @@ def inference_modelscope(
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text_postprocessed
ibest_writer["text"][key] = " ".join(word_lists)
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))

View File

@ -19,6 +19,7 @@ from typing import List
import numpy as np
import torch
import torchaudio
from typeguard import check_argument_types
from funasr.fileio.datadir_writer import DatadirWriter
@ -607,17 +608,21 @@ def inference_modelscope(
):
# 3. Build data-iterator
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
is_final = False
if param_dict is not None and "cache" in param_dict:
cache = param_dict["cache"]
if param_dict is not None and "is_final" in param_dict:
is_final = param_dict["is_final"]
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
is_final = True
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []

View File

@ -338,7 +338,7 @@ def inference_modelscope(
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["vad"][key] = "{}".format(vadsegments)
ibest_writer["text"][key] = text_postprocessed
ibest_writer["text"][key] = " ".join(word_lists)
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)

View File

@ -670,7 +670,7 @@ def inference_modelscope(
ibest_writer["token"][key] = " ".join(token)
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["vad"][key] = "{}".format(vadsegments)
ibest_writer["text"][key] = text_postprocessed
ibest_writer["text"][key] = " ".join(word_lists)
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
if time_stamp_postprocessed is not None:
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)

View File

@ -738,13 +738,13 @@ def inference_modelscope(
ibest_writer["rtf"][key] = rtf_cur
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text_postprocessed
ibest_writer["text"][key] = " ".join(word_lists)
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))

View File

@ -37,9 +37,6 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2Text:
"""Speech2Text class
@ -507,13 +504,13 @@ def inference_modelscope(
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text_postprocessed
ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
return _forward

View File

@ -507,13 +507,13 @@ def inference_modelscope(
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text_postprocessed
ibest_writer["text"][key] = " ".join(word_lists)
return asr_result_list
return _forward

View File

@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os

View File

@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os

View File

@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os

View File

@ -1,44 +0,0 @@
#!/usr/bin/env python3
import os
from funasr.tasks.punctuation import PunctuationTask
def parse_args():
parser = PunctuationTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
parser.add_argument(
"--punc_list",
type=str,
default=None,
help="Punctuation list",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
"""
punc training.
"""
PunctuationTask.main(args=args, cmd=cmd)
if __name__ == "__main__":
args = parse_args()
# setup local gpu_id
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
main(args=args)

View File

@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:

View File

@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
@ -90,7 +90,7 @@ class Text2Punc:
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
"vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
"vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)

View File

@ -2,6 +2,9 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os

View File

@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os

View File

@ -2,6 +2,9 @@
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
torch.set_num_threads(1)
import argparse
import logging
import os

View File

@ -115,7 +115,7 @@ def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
# NOTE(kamo): SoundScpReader doesn't support pipe-fashion
# like Kaldi e.g. "cat a.wav |".
# NOTE(kamo): The audio signal is normalized to [-1,1] range.
loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
# SoundScpReader.__getitem__() returns Tuple[int, ndarray],
# but ndarray is desired, so Adapter class is inserted here

View File

@ -47,8 +47,8 @@ def tokenize(data,
length = len(text)
for i in range(length):
x = text[i]
if i == length-1 and "punc" in data and text[i].startswith("vad:"):
vad = x[-1][4:]
if i == length-1 and "punc" in data and x.startswith("vad:"):
vad = x[4:]
if len(vad) == 0:
vad = -1
else:

View File

@ -786,6 +786,7 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
) -> Dict[str, np.ndarray]:
for i in range(self.num_tokenizer):
text_name = self.text_name[i]
#import pdb; pdb.set_trace()
if text_name in data and self.tokenizer[i] is not None:
text = data[text_name]
text = self.text_cleaner(text)
@ -800,3 +801,17 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
data[self.vad_name] = np.array([vad], dtype=np.int64)
text_ints = self.token_id_converter[i].tokens2ids(tokens)
data[text_name] = np.array(text_ints, dtype=np.int64)
return data
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences

View File

@ -7,7 +7,7 @@
## Install modelscope and funasr
The installation is the same as [funasr](../../README.md)
The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
## Export model
`Tips`: torch>=1.11.0

View File

@ -167,31 +167,57 @@ class ModelExport:
def export(self,
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
mode: str = 'paraformer',
mode: str = None,
):
model_dir = tag_name
if model_dir.startswith('damo/'):
if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
asr_train_config = os.path.join(model_dir, 'config.yaml')
asr_model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
json_file = os.path.join(model_dir, 'configuration.json')
if mode is None:
import json
json_file = os.path.join(model_dir, 'configuration.json')
with open(json_file, 'r') as f:
config_data = json.load(f)
mode = config_data['model']['model_config']['mode']
if config_data['task'] == "punctuation":
mode = config_data['model']['punc_model_config']['mode']
else:
mode = config_data['model']['model_config']['mode']
if mode.startswith('paraformer'):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
elif mode.startswith('uniasr'):
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
config = os.path.join(model_dir, 'config.yaml')
model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
elif mode.startswith('offline'):
from funasr.tasks.vad import VADTask
config = os.path.join(model_dir, 'vad.yaml')
model_file = os.path.join(model_dir, 'vad.pb')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
model, vad_infer_args = VADTask.build_model_from_file(
config, model_file, cmvn_file=cmvn_file, device='cpu'
)
self.export_config["feats_dim"] = 400
self.frontend = model.frontend
elif mode.startswith('punc'):
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
elif mode.startswith('punc_VadRealtime'):
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
self._export(model, tag_name)

View File

@ -0,0 +1,162 @@
from typing import Tuple
import torch
import torch.nn as nn
from funasr.models.encoder.sanm_encoder import SANMEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.models.encoder.sanm_encoder import SANMVadEncoder
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
class CT_Transformer(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
self.decoder = model.decoder
# self.model = model
self.feats_dim = self.embed.embedding_dim
self.num_embeddings = self.embed.num_embeddings
self.model_name = model_name
if isinstance(model.encoder, SANMEncoder):
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def get_input_names(self):
return ['inputs', 'text_lengths']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'inputs': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
class CT_Transformer_VadRealtime(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
model,
max_seq_len=512,
model_name='punc_model',
**kwargs,
):
super().__init__()
onnx = False
if "onnx" in kwargs:
onnx = kwargs["onnx"]
self.embed = model.embed
if isinstance(model.encoder, SANMVadEncoder):
self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
else:
assert False, "Only support samn encode."
self.decoder = model.decoder
self.model_name = model_name
def forward(self, inputs: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: torch.Tensor,
sub_masks: torch.Tensor,
) -> Tuple[torch.Tensor, None]:
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
y = self.decoder(h)
return y
def with_vad(self):
return True
def get_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)
sub_masks = torch.tril(sub_masks).type(torch.float32)
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
def get_input_names(self):
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
def get_output_names(self):
return ['logits']
def get_dynamic_axes(self):
return {
'inputs': {
1: 'feats_length'
},
'vad_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'sub_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'logits': {
1: 'logits_length'
},
}

View File

@ -1,13 +1,25 @@
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.e2e_vad import E2EVadModel
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
return BiCifParaformer_export(model, **export_config)
elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
elif isinstance(model, E2EVadModel):
return E2EVadModel_export(model, **export_config)
elif isinstance(model, PunctuationModel):
if isinstance(model.punc_model, TargetDelayTransformer):
return CT_Transformer_export(model.punc_model, **export_config)
elif isinstance(model.punc_model, VadRealtimeTransformer):
return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
else:
raise "Funasr does not support the given model type currently."
raise "Funasr does not support the given model type currently."

View File

@ -19,7 +19,7 @@ from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSA
class Paraformer(nn.Module):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@ -112,7 +112,7 @@ class Paraformer(nn.Module):
class BiCifParaformer(nn.Module):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""

View File

@ -0,0 +1,60 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import torch
from torch import nn
import math
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
class E2EVadModel(nn.Module):
def __init__(self, model,
max_seq_len=512,
feats_dim=400,
model_name='model',
**kwargs,):
super(E2EVadModel, self).__init__()
self.feats_dim = feats_dim
self.max_seq_len = max_seq_len
self.model_name = model_name
if isinstance(model.encoder, FSMN):
self.encoder = FSMN_export(model.encoder)
else:
raise "unsupported encoder"
def forward(self, feats: torch.Tensor, *args, ):
scores, out_caches = self.encoder(feats, *args)
return scores, out_caches
def get_dummy_inputs(self, frame=30):
speech = torch.randn(1, frame, self.feats_dim)
in_cache0 = torch.randn(1, 128, 19, 1)
in_cache1 = torch.randn(1, 128, 19, 1)
in_cache2 = torch.randn(1, 128, 19, 1)
in_cache3 = torch.randn(1, 128, 19, 1)
return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
# def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
# import numpy as np
# fbank = np.loadtxt(txt_file)
# fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
# speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
# speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
# return (speech, speech_lengths)
def get_input_names(self):
return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
def get_output_names(self):
return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
def get_dynamic_axes(self):
return {
'speech': {
1: 'feats_length'
},
}

View File

@ -0,0 +1,296 @@
from typing import Tuple, Dict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.encoder.fsmn_encoder import BasicBlock
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, input):
output = self.linear(input)
return output
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, input):
output = self.linear(input)
return output
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, input):
out = self.relu(input)
return out
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
if self.rorder > 0:
self.conv_right = nn.Conv2d(
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
else:
self.conv_right = None
def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
cache = cache.to(x_per.device)
y_left = torch.cat((cache, x_per), dim=2)
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
out = x_per + y_left
if self.conv_right is not None:
# maybe need to check
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
y_right = self.conv_right(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output, cache
class BasicBlock_export(nn.Module):
def __init__(self,
model,
):
super(BasicBlock_export, self).__init__()
self.linear = model.linear
self.fsmn_block = model.fsmn_block
self.affine = model.affine
self.relu = model.relu
def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
x = self.linear(input) # B T D
# cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
# if cache_layer_name not in in_cache:
# in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
x, out_cache = self.fsmn_block(x, in_cache)
x = self.affine(x)
x = self.relu(x)
return x, out_cache
# class FsmnStack(nn.Sequential):
# def __init__(self, *args):
# super(FsmnStack, self).__init__(*args)
#
# def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
# x = input
# for module in self._modules.values():
# x = module(x, in_cache)
# return x
'''
FSMN net for keyword spotting
input_dim: input dimension
linear_dim: fsmn input dimensionll
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
num_syn: output dimension
fsmn_layers: no. of sequential fsmn layers
'''
class FSMN(nn.Module):
def __init__(
self, model,
):
super(FSMN, self).__init__()
# self.input_dim = input_dim
# self.input_affine_dim = input_affine_dim
# self.fsmn_layers = fsmn_layers
# self.linear_dim = linear_dim
# self.proj_dim = proj_dim
# self.output_affine_dim = output_affine_dim
# self.output_dim = output_dim
#
# self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
# self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
# self.relu = RectifiedLinear(linear_dim, linear_dim)
# self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
# range(fsmn_layers)])
# self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
# self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim=-1)
self.in_linear1 = model.in_linear1
self.in_linear2 = model.in_linear2
self.relu = model.relu
# self.fsmn = model.fsmn
self.out_linear1 = model.out_linear1
self.out_linear2 = model.out_linear2
self.softmax = model.softmax
self.fsmn = model.fsmn
for i, d in enumerate(model.fsmn):
if isinstance(d, BasicBlock):
self.fsmn[i] = BasicBlock_export(d)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
*args,
):
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
x = self.in_linear1(input)
x = self.in_linear2(x)
x = self.relu(x)
# x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
out_caches = list()
for i, d in enumerate(self.fsmn):
in_cache = args[i]
x, out_cache = d(x, in_cache)
out_caches.append(out_cache)
x = self.out_linear1(x)
x = self.out_linear2(x)
x = self.softmax(x)
return x, out_caches
'''
one deep fsmn layer
dimproj: projection dimension, input and output dimension of memory blocks
dimlinear: dimension of mapping layer
lorder: left order
rorder: right order
lstride: left stride
rstride: right stride
'''
class DFSMN(nn.Module):
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
super(DFSMN, self).__init__()
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.expand = AffineTransform(dimproj, dimlinear)
self.shrink = LinearTransform(dimlinear, dimproj)
self.conv_left = nn.Conv2d(
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
else:
self.conv_right = None
def forward(self, input):
f1 = F.relu(self.expand(input))
p1 = self.shrink(f1)
x = torch.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
else:
out = x_per + self.conv_left(y_left)
out1 = out.permute(0, 3, 2, 1)
output = input + out1.squeeze(1)
return output
'''
build stacked dfsmn layers
'''
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
repeats = [
nn.Sequential(
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@ -9,6 +9,7 @@ from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as Encod
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
class SANMEncoder(nn.Module):
def __init__(
self,
@ -107,3 +108,106 @@ class SANMEncoder(nn.Module):
}
}
class SANMVadEncoder(nn.Module):
def __init__(
self,
model,
max_seq_len=512,
feats_dim=560,
model_name='encoder',
onnx: bool = True,
):
super().__init__()
self.embed = model.embed
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
if onnx:
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
else:
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
if hasattr(model, 'encoders0'):
for i, d in enumerate(self.model.encoders0):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders0[i] = EncoderLayerSANM_export(d)
for i, d in enumerate(self.model.encoders):
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
if isinstance(d.feed_forward, PositionwiseFeedForward):
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
self.model.encoders[i] = EncoderLayerSANM_export(d)
self.model_name = model_name
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
def prepare_mask(self, mask, sub_masks):
mask_3d_btd = mask[:, :, None]
mask_4d_bhlt = (1 - sub_masks) * -10000.0
return mask_3d_btd, mask_4d_bhlt
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
vad_masks: torch.Tensor,
sub_masks: torch.Tensor,
):
speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
vad_masks = self.prepare_mask(mask, vad_masks)
mask = self.prepare_mask(mask, sub_masks)
if self.embed is None:
xs_pad = speech
else:
xs_pad = self.embed(speech)
encoder_outs = self.model.encoders0(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
# encoder_outs = self.model.encoders(xs_pad, mask)
for layer_idx, encoder_layer in enumerate(self.model.encoders):
if layer_idx == len(self.model.encoders) - 1:
mask = vad_masks
encoder_outs = encoder_layer(xs_pad, mask)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
xs_pad = self.model.after_norm(xs_pad)
return xs_pad, speech_lengths
def get_output_size(self):
return self.model.encoders[0].size
# def get_dummy_inputs(self):
# feats = torch.randn(1, 100, self.feats_dim)
# return (feats)
#
# def get_input_names(self):
# return ['feats']
#
# def get_output_names(self):
# return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
#
# def get_dynamic_axes(self):
# return {
# 'feats': {
# 1: 'feats_length'
# },
# 'encoder_out': {
# 1: 'enc_out_length'
# },
# 'predictor_weight': {
# 1: 'pre_out_length'
# }
#
# }

View File

@ -0,0 +1,18 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(text_length):
return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value))
_run(_get_feed_dict(10))

View File

@ -0,0 +1,22 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(text_length):
return {'inputs': np.ones((1, text_length), dtype=np.int64),
'text_lengths': np.array([text_length,], dtype=np.int32),
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value))
_run(_get_feed_dict(10))

View File

@ -0,0 +1,26 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
sess = onnxruntime.InferenceSession(onnx_path)
input_name = [nd.name for nd in sess.get_inputs()]
output_name = [nd.name for nd in sess.get_outputs()]
def _get_feed_dict(feats_length):
return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
}
def _run(feed_dict):
output = sess.run(output_name, input_feed=feed_dict)
for name, value in zip(output_name, output):
print('{}: {}'.format(name, value.shape))
_run(_get_feed_dict(100))
_run(_get_feed_dict(200))

View File

@ -5,7 +5,17 @@ from typing import Tuple
import torch
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract LM class
@ -27,3 +37,122 @@ class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
self, input: torch.Tensor, hidden: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
class LanguageModel(AbsESPnetModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
assert check_argument_types()
super().__init__()
self.lm = lm
self.sos = 1
self.eos = 2
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
def nll(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
max_length: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
else:
text = text[:, :max_length]
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
x = F.pad(text, [1, 0], "constant", self.sos)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
t[i, l] = self.eos
x_lengths = text_lengths + 1
# 2. Forward Language model
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
y, _ = self.lm(x, None)
# 3. Calc negative log likelihood
# nll: (BxL,)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, x_lengths
def batchify_nll(
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num = text.size(0)
if total_num <= batch_size:
nll, x_lengths = self.nll(text, text_lengths)
else:
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_text = text[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = self.nll(
batch_text, batch_text_lengths, max_length=max_length
)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nlls)
x_lengths = torch.cat(x_lengths)
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def forward(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
nll, y_lengths = self.nll(text, text_lengths)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def collect_feats(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
return {}

View File

@ -1,131 +0,0 @@
from typing import Dict
from typing import Optional
from typing import Tuple
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.modules.nets_utils import make_pad_mask
from funasr.lm.abs_model import AbsLM
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
class ESPnetLanguageModel(AbsESPnetModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
assert check_argument_types()
super().__init__()
self.lm = lm
self.sos = 1
self.eos = 2
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
self.ignore_id = ignore_id
def nll(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
max_length: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll)
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
max_lengths: int
"""
batch_size = text.size(0)
# For data parallel
if max_length is None:
text = text[:, : text_lengths.max()]
else:
text = text[:, :max_length]
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
x = F.pad(text, [1, 0], "constant", self.sos)
t = F.pad(text, [0, 1], "constant", self.ignore_id)
for i, l in enumerate(text_lengths):
t[i, l] = self.eos
x_lengths = text_lengths + 1
# 2. Forward Language model
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
y, _ = self.lm(x, None)
# 3. Calc negative log likelihood
# nll: (BxL,)
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
# nll: (BxL,) -> (BxL,)
if max_length is None:
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
else:
nll.masked_fill_(
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
0.0,
)
# nll: (BxL,) -> (B, L)
nll = nll.view(batch_size, -1)
return nll, x_lengths
def batchify_nll(
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute negative log likelihood(nll) from transformer language model
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
text: (Batch, Length)
text_lengths: (Batch,)
batch_size: int, samples each batch contain when computing nll,
you may change this to avoid OOM or increase
"""
total_num = text.size(0)
if total_num <= batch_size:
nll, x_lengths = self.nll(text, text_lengths)
else:
nlls = []
x_lengths = []
max_length = text_lengths.max()
start_idx = 0
while True:
end_idx = min(start_idx + batch_size, total_num)
batch_text = text[start_idx:end_idx, :]
batch_text_lengths = text_lengths[start_idx:end_idx]
# batch_nll: [B * T]
batch_nll, batch_x_lengths = self.nll(
batch_text, batch_text_lengths, max_length=max_length
)
nlls.append(batch_nll)
x_lengths.append(batch_x_lengths)
start_idx = end_idx
if start_idx == total_num:
break
nll = torch.cat(nlls)
x_lengths = torch.cat(x_lengths)
assert nll.size(0) == total_num
assert x_lengths.size(0) == total_num
return nll, x_lengths
def forward(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
nll, y_lengths = self.nll(text, text_lengths)
ntokens = y_lengths.sum()
loss = nll.sum() / ntokens
stats = dict(loss=loss.detach())
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
return loss, stats, weight
def collect_feats(
self, text: torch.Tensor, text_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
return {}

View File

@ -102,7 +102,7 @@ class ContextualBiasDecoder(nn.Module):
class ContextualParaformerDecoder(ParaformerSANMDecoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""

View File

@ -104,7 +104,6 @@ class DecoderLayerSANM(nn.Module):
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@ -152,7 +151,7 @@ class DecoderLayerSANM(nn.Module):
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@ -400,7 +399,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
@ -410,13 +409,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, memory_mask, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder(
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)
@ -813,7 +812,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
class ParaformerSANMDecoder(BaseTransformerDecoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""
@ -1077,7 +1076,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
for i in range(self.att_layer_num):
decoder = self.decoders[i]
c = cache[i]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
@ -1087,14 +1086,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
j = i + self.att_layer_num
decoder = self.decoders2[i]
c = cache[j]
x, tgt_mask, memory, memory_mask, c_ret = decoder(
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=c
)
new_cache.append(c_ret)
for decoder in self.decoders3:
x, tgt_mask, memory, memory_mask, _ = decoder(
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
x, tgt_mask, memory, None, cache=None
)

View File

@ -405,7 +405,7 @@ class TransformerDecoder(BaseTransformerDecoder):
class ParaformerDecoderSAN(BaseTransformerDecoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2006.01713
"""

View File

@ -36,7 +36,11 @@ import pdb
import random
import math
class MFCCA(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
"""
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
https://arxiv.org/abs/2210.05265
"""
def __init__(
self,

View File

@ -44,7 +44,7 @@ else:
class Paraformer(AbsESPnetModel):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
@ -612,7 +612,7 @@ class Paraformer(AbsESPnetModel):
class ParaformerBert(Paraformer):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
"""

View File

@ -36,8 +36,12 @@ else:
class DiarSondModel(AbsESPnetModel):
"""Speaker overlap-aware neural diarization model
reference: https://arxiv.org/abs/2211.10243
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
def __init__(

View File

@ -1,3 +1,7 @@
"""
Author: Speech Lab, Alibaba Group, China
"""
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion

View File

@ -32,7 +32,7 @@ else:
class TimestampPredictor(AbsESPnetModel):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(

View File

@ -40,7 +40,7 @@ else:
class UniASR(AbsESPnetModel):
"""
Author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(

View File

@ -35,6 +35,11 @@ class VadDetectMode(Enum):
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
sample_rate: int = 16000,
@ -99,6 +104,11 @@ class VADXOptions:
class E2EVadSpeechBufWithDoa(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.start_ms = 0
self.end_ms = 0
@ -117,6 +127,11 @@ class E2EVadSpeechBufWithDoa(object):
class E2EVadFrameProb(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
@ -126,6 +141,11 @@ class E2EVadFrameProb(object):
class WindowDetector(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
@ -192,7 +212,12 @@ class WindowDetector(object):
class E2EVadModel(nn.Module):
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@ -229,6 +254,7 @@ class E2EVadModel(nn.Module):
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
self.frontend = frontend
def AllResetDetection(self):
self.is_final = False
@ -459,8 +485,8 @@ class E2EVadModel(nn.Module):
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point:
if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
@ -477,8 +503,9 @@ class E2EVadModel(nn.Module):
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final:
self.DetectCommonFrames()
else:

View File

@ -67,7 +67,7 @@ class EncoderLayer(nn.Module):
class ConvEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Convolution encoder in OpenNMT framework
"""

View File

@ -117,7 +117,7 @@ class EncoderLayer(nn.Module):
class SelfAttentionEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
Self attention encoder in OpenNMT framework
"""

View File

@ -406,6 +406,12 @@ class ResNet34Diar(ResNet34):
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
https://arxiv.org/abs/2211.10243
"""
super(ResNet34Diar, self).__init__(
input_size,
use_head_conv=use_head_conv,
@ -633,6 +639,12 @@ class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
tf2torch_tensor_name_prefix_torch="encoder",
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
):
"""
Author: Speech Lab, Alibaba Group, China
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
https://arxiv.org/abs/2303.05397
"""
super(ResNet34SpL2RegDiar, self).__init__(
input_size,
use_head_conv=use_head_conv,

View File

@ -10,7 +10,7 @@ from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
@ -27,7 +27,7 @@ from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@ -117,7 +117,7 @@ class EncoderLayerSANM(nn.Module):
class SANMEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@ -549,7 +549,7 @@ class SANMEncoder(AbsEncoder):
class SANMEncoderChunkOpt(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@ -958,3 +958,231 @@ class SANMEncoderChunkOpt(AbsEncoder):
var_dict_tf[name_tf].shape))
return var_dict_torch_update
class SANMVadEncoder(AbsEncoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -38,7 +38,7 @@ def load_cmvn(cmvn_file):
return cmvn
def apply_cmvn(inputs, cmvn_file): # noqa
def apply_cmvn(inputs, cmvn): # noqa
"""
Apply CMVN with mvn data
"""
@ -47,7 +47,6 @@ def apply_cmvn(inputs, cmvn_file): # noqa
dtype = inputs.dtype
frame, dim = inputs.shape
cmvn = load_cmvn(cmvn_file)
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
@ -111,6 +110,7 @@ class WavFrontend(AbsFrontend):
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@ -140,8 +140,8 @@ class WavFrontend(AbsFrontend):
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn_file is not None:
mat = apply_cmvn(mat, self.cmvn_file)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@ -194,8 +194,8 @@ class WavFrontend(AbsFrontend):
mat = input[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn_file is not None:
mat = apply_cmvn(mat, self.cmvn_file)
if self.cmvn is not None:
mat = apply_cmvn(mat, self.cmvn)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)

View File

@ -234,6 +234,7 @@ class CifPredictorV2(nn.Module):
last_fire_place = len_time - 1
last_fire_remainds = 0.0
pre_alphas_length = 0
last_fire = False
mask_chunk_peak_predictor = None
if cache is not None:
@ -251,10 +252,15 @@ class CifPredictorV2(nn.Module):
if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
last_fire_place = len_time - 1 - i
last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
last_fire = True
break
last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
cache["cif_hidden"] = hidden[:, last_fire_place:, :]
cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
if last_fire:
last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
cache["cif_hidden"] = hidden[:, last_fire_place:, :]
cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
else:
cache["cif_hidden"] = hidden
cache["cif_alphas"] = alphas
token_num_int = token_num.floor().type(torch.int32).item()
return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak

View File

@ -5,16 +5,19 @@ from typing import Tuple
import torch
import torch.nn as nn
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.train.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,

View File

@ -6,12 +6,16 @@ import torch
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
from funasr.punctuation.abs_model import AbsPunctuation
from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
from funasr.train.abs_model import AbsPunctuation
class VadRealtimeTransformer(AbsPunctuation):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
def __init__(
self,
vocab_size: int,

View File

@ -11,7 +11,7 @@ from funasr.modules.streaming_utils.utils import sequence_mask
class overlap_chunk():
"""
author: Speech Lab, Alibaba Group, China
Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713

View File

@ -1,31 +0,0 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract class
To share the loss calculation way among different models,
We uses delegate pattern here:
The instance of this class should be passed to "LanguageModel"
>>> from funasr.punctuation.abs_model import AbsPunctuation
>>> punc = AbsPunctuation()
>>> model = ESPnetPunctuationModel(punc=punc)
This "model" is one of mediator objects for "Task" class.
"""
@abstractmethod
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@abstractmethod
def with_vad(self) -> bool:
raise NotImplementedError

View File

@ -1,590 +0,0 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
self,
in_size,
size,
self_attn,
feed_forward,
dropout_rate,
normalize_before=True,
concat_after=False,
stochastic_depth_rate=0.0,
):
"""Construct an EncoderLayer object."""
super(EncoderLayerSANM, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(in_size)
self.norm2 = LayerNorm(size)
self.dropout = nn.Dropout(dropout_rate)
self.in_size = in_size
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
x_input (torch.Tensor): Input tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
skip_layer = False
# with stochastic depth, residual connection `x + f(x)` becomes
# `x <- x + 1 / (1 - p) * f(x)` at training time.
stoch_layer_coeff = 1.0
if self.training and self.stochastic_depth_rate > 0:
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
if skip_layer:
if cache is not None:
x = torch.cat([cache, x], dim=1)
return x, mask
residual = x
if self.normalize_before:
x = self.norm1(x)
if self.concat_after:
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
x = stoch_layer_coeff * self.concat_linear(x_concat)
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
else:
x = stoch_layer_coeff * self.dropout(
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
)
if not self.normalize_before:
x = self.norm1(x)
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm2(x)
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
class SANMEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANM
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
encoder_outs = self.encoders0(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
class SANMVadEncoder(AbsEncoder):
"""
author: Speech Lab, Alibaba Group, China
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: Optional[str] = "conv2d",
pos_enc_class=SinusoidalPositionEncoder,
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 1,
padding_idx: int = -1,
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
torch.nn.ReLU(),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
SinusoidalPositionEncoder(),
)
elif input_layer is None:
if input_size == output_size:
self.embed = None
else:
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "sanm":
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
encoder_selfattn_layer_args0 = (
attention_heads,
input_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
encoder_selfattn_layer_args = (
attention_heads,
output_size,
output_size,
attention_dropout_rate,
kernel_size,
sanm_shfit,
)
self.encoders0 = repeat(
1,
lambda lnum: EncoderLayerSANM(
input_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
self.encoders = repeat(
num_blocks-1,
lambda lnum: EncoderLayerSANM(
output_size,
output_size,
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.interctc_layer_idx = interctc_layer_idx
if len(interctc_layer_idx) > 0:
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
self.dropout = nn.Dropout(dropout_rate)
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
vad_indexes: torch.Tensor,
prev_states: torch.Tensor = None,
ctc: CTC = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
Args:
xs_pad: input tensor (B, L, D)
ilens: input length (B)
prev_states: Not to be used now.
Returns:
position embedded tensor and mask
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
no_future_masks = masks & sub_masks
xs_pad *= self.output_size()**0.5
if self.embed is None:
xs_pad = xs_pad
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
# xs_pad = self.dropout(xs_pad)
mask_tup0 = [masks, no_future_masks]
encoder_outs = self.encoders0(xs_pad, mask_tup0)
xs_pad, _ = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
#if len(self.interctc_layer_idx) == 0:
if False:
# Here, we should not use the repeat operation to do it for all layers.
encoder_outs = self.encoders(xs_pad, masks)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
if layer_idx + 1 == len(self.encoders):
# This is last layer.
coner_mask = torch.ones(masks.size(0),
masks.size(-1),
masks.size(-1),
device=xs_pad.device,
dtype=torch.bool)
for word_index, length in enumerate(ilens):
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
vad_indexes[word_index],
device=xs_pad.device)
layer_mask = masks & coner_mask
else:
layer_mask = no_future_masks
mask_tup1 = [masks, layer_mask]
encoder_outs = encoder_layer(xs_pad, mask_tup1)
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
# intermediate outputs are also normalized
if self.normalize_before:
encoder_out = self.after_norm(encoder_out)
intermediate_outs.append((layer_idx + 1, encoder_out))
if self.interctc_use_conditioning:
ctc_out = ctc.softmax(encoder_out)
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None

View File

@ -1,12 +0,0 @@
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit:(i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit:])
return sentences

View File

@ -74,7 +74,7 @@ foreach(_target
"${_target}.cc")
target_link_libraries(${_target}
rg_grpc_proto
rapidasr
funasr
${EXTRA_LIBS}
${_REFLECTION}
${_GRPC_GRPCPP}

View File

@ -53,6 +53,68 @@ cd ../python/grpc
python grpc_main_client_mic.py --host $server_ip --port 10108
```
The `grpc_main_client_mic.py` follows the [original design] (https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc#workflow-in-desgin) by sending audio_data with chunks. If you want to send audio_data in one request, here is an example:
```
# go to ../python/grpc to find this package
import paraformer_pb2
class RecognizeStub:
def __init__(self, channel):
self.Recognize = channel.stream_stream(
'/paraformer.ASR/Recognize',
request_serializer=paraformer_pb2.Request.SerializeToString,
response_deserializer=paraformer_pb2.Response.FromString,
)
async def send(channel, data, speaking, isEnd):
stub = RecognizeStub(channel)
req = paraformer_pb2.Request()
if data:
req.audio_data = data
req.user = 'zz'
req.language = 'zh-CN'
req.speaking = speaking
req.isEnd = isEnd
q = queue.SimpleQueue()
q.put(req)
return stub.Recognize(iter(q.get, None))
# send the audio data once
async def grpc_rec(data, grpc_uri):
with grpc.insecure_channel(grpc_uri) as channel:
b = time.time()
response = await send(channel, data, False, False)
resp = response.next()
text = ''
if 'decoding' == resp.action:
resp = response.next()
if 'finish' == resp.action:
text = json.loads(resp.sentence)['text']
response = await send(channel, None, False, True)
return {
'text': text,
'time': time.time() - b,
}
async def test():
# fc = FunAsrGrpcClient('127.0.0.1', 9900)
# t = await fc.rec(wav.tobytes())
# print(t)
wav, _ = sf.read('z-10s.wav', dtype='int16')
uri = '127.0.0.1:9900'
res = await grpc_rec(wav.tobytes(), uri)
print(res)
if __name__ == '__main__':
asyncio.run(test())
```
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.

View File

@ -15,7 +15,6 @@
#include "paraformer.grpc.pb.h"
#include "paraformer_server.h"
using grpc::Server;
using grpc::ServerBuilder;
using grpc::ServerContext;
@ -24,48 +23,35 @@ using grpc::ServerReaderWriter;
using grpc::ServerWriter;
using grpc::Status;
using paraformer::Request;
using paraformer::Response;
using paraformer::ASR;
ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
AsrHanlde=FunASRInit(model_path, thread_num, quantize);
std::cout << "ASRServicer init" << std::endl;
init_flag = 0;
}
void ASRServicer::clear_states(const std::string& user) {
clear_buffers(user);
clear_transcriptions(user);
}
void ASRServicer::clear_buffers(const std::string& user) {
if (client_buffers.count(user)) {
client_buffers.erase(user);
}
}
void ASRServicer::clear_transcriptions(const std::string& user) {
if (client_transcription.count(user)) {
client_transcription.erase(user);
}
}
void ASRServicer::disconnect(const std::string& user) {
clear_states(user);
std::cout << "Disconnecting user: " << user << std::endl;
}
grpc::Status ASRServicer::Recognize(
grpc::ServerContext* context,
grpc::ServerReaderWriter<Response, Request>* stream) {
Request req;
std::unordered_map<std::string, std::string> client_buffers;
std::unordered_map<std::string, std::string> client_transcription;
while (stream->Read(&req)) {
if (req.isend()) {
std::cout << "asr end" << std::endl;
disconnect(req.user());
// disconnect
if (client_buffers.count(req.user())) {
client_buffers.erase(req.user());
}
if (client_transcription.count(req.user())) {
client_transcription.erase(req.user());
}
Response res;
res.set_sentence(
R"({"success": true, "detail": "asr end"})"
@ -88,7 +74,7 @@ grpc::Status ASRServicer::Recognize(
res.set_language(req.language());
stream->Write(res);
} else if (!req.speaking()) {
if (client_buffers.count(req.user()) == 0) {
if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) {
Response res;
res.set_sentence(
R"({"success": true, "detail": "waiting_for_voice"})"
@ -99,14 +85,24 @@ grpc::Status ASRServicer::Recognize(
stream->Write(res);
}else {
auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
std::string tmp_data = this->client_buffers[req.user()];
this->clear_states(req.user());
if (req.audio_data().size() > 0) {
auto& buf = client_buffers[req.user()];
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
}
std::string tmp_data = client_buffers[req.user()];
// clear_states
if (client_buffers.count(req.user())) {
client_buffers.erase(req.user());
}
if (client_transcription.count(req.user())) {
client_transcription.erase(req.user());
}
Response res;
res.set_sentence(
R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})"
);
int data_len_int = tmp_data.length();
int data_len_int = tmp_data.length();
std::string data_len = std::to_string(data_len_int);
std::stringstream ss;
ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})";
@ -129,18 +125,15 @@ grpc::Status ASRServicer::Recognize(
res.set_user(req.user());
res.set_action("finish");
res.set_language(req.language());
stream->Write(res);
}
else {
RPASR_RESULT Result= RapidAsrRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL);
std::string asr_result = ((RPASR_RECOG_RESULT*)Result)->msg;
FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL);
std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
std::string delay_str = std::to_string(end_time - begin_time);
std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
Response res;
std::stringstream ss;
@ -150,8 +143,7 @@ grpc::Status ASRServicer::Recognize(
res.set_user(req.user());
res.set_action("finish");
res.set_language(req.language());
stream->Write(res);
}
}
@ -165,11 +157,10 @@ grpc::Status ASRServicer::Recognize(
res.set_language(req.language());
stream->Write(res);
}
}
}
return Status::OK;
}
void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
std::string server_address;
server_address = "0.0.0.0:" + port;

View File

@ -15,7 +15,7 @@
#include <chrono>
#include "paraformer.grpc.pb.h"
#include "librapidasrapi.h"
#include "libfunasrapi.h"
using grpc::Server;
@ -35,22 +35,16 @@ typedef struct
{
std::string msg;
float snippet_time;
}RPASR_RECOG_RESULT;
}FUNASR_RECOG_RESULT;
class ASRServicer final : public ASR::Service {
private:
int init_flag;
std::unordered_map<std::string, std::string> client_buffers;
std::unordered_map<std::string, std::string> client_transcription;
public:
ASRServicer(const char* model_path, int thread_num, bool quantize);
void clear_states(const std::string& user);
void clear_buffers(const std::string& user);
void clear_transcriptions(const std::string& user);
void disconnect(const std::string& user);
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
RPASR_HANDLE AsrHanlde;
FUNASR_HANDLE AsrHanlde;
};

View File

@ -2,24 +2,27 @@ cmake_minimum_required(VERSION 3.10)
project(FunASRonnx)
set(CMAKE_CXX_STANDARD 11)
# set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
include(TestBigEndian)
test_big_endian(BIG_ENDIAN)
if(BIG_ENDIAN)
message("Big endian system")
else()
message("Little endian system")
endif()
# for onnxruntime
IF(WIN32)
if(CMAKE_CL_64)
link_directories(${ONNXRUNTIME_DIR}\\lib)
else()
add_definitions(-D_WIN_X86)
endif()
ELSE()
link_directories(${ONNXRUNTIME_DIR}/lib)
link_directories(${ONNXRUNTIME_DIR}/lib)
endif()
add_subdirectory("./third_party/yaml-cpp")

View File

@ -6,6 +6,13 @@
#include <queue>
#include <stdint.h>
#ifndef model_sample_rate
#define model_sample_rate 16000
#endif
#ifndef WAV_HEADER_SIZE
#define WAV_HEADER_SIZE 44
#endif
using namespace std;
class AudioFrame {
@ -32,7 +39,6 @@ class Audio {
int16_t *speech_buff;
int speech_len;
int speech_align_len;
int16_t sample_rate;
int offset;
float align_size;
int data_type;
@ -43,10 +49,11 @@ class Audio {
Audio(int data_type, int size);
~Audio();
void disp();
bool loadwav(const char* filename);
bool loadwav(const char* buf, int nLen);
bool loadpcmwav(const char* buf, int nFileLen);
bool loadpcmwav(const char* filename);
bool loadwav(const char* filename, int32_t* sampling_rate);
void wavResample(int32_t sampling_rate, const float *waveform, int32_t n);
bool loadwav(const char* buf, int nLen, int32_t* sampling_rate);
bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate);
bool loadpcmwav(const char* filename, int32_t* sampling_rate);
int fetch_chunck(float *&dout, int len);
int fetch(float *&dout, int &len, int &flag);
void padding();

View File

@ -0,0 +1,77 @@
#pragma once
#ifdef WIN32
#ifdef _FUNASR_API_EXPORT
#define _FUNASRAPI __declspec(dllexport)
#else
#define _FUNASRAPI __declspec(dllimport)
#endif
#else
#define _FUNASRAPI
#endif
#ifndef _WIN32
#define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__))
#else
#define FUNASR_CALLBCK_PREFIX __stdcall
#endif
#ifdef __cplusplus
extern "C" {
#endif
typedef void* FUNASR_HANDLE;
typedef void* FUNASR_RESULT;
typedef unsigned char FUNASR_BOOL;
#define FUNASR_TRUE 1
#define FUNASR_FALSE 0
#define QM_DEFAULT_THREAD_NUM 4
typedef enum
{
RASR_NONE=-1,
RASRM_CTC_GREEDY_SEARCH=0,
RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
RASRM_ATTENSION_RESCORING = 2,
}FUNASR_MODE;
typedef enum {
FUNASR_MODEL_PADDLE = 0,
FUNASR_MODEL_PADDLE_2 = 1,
FUNASR_MODEL_K2 = 2,
FUNASR_MODEL_PARAFORMER = 3,
}FUNASR_MODEL_TYPE;
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
// APIs for qmasr
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize);
// if not give a fnCallback ,it should be NULL
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result);
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE Handle);
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result);
#ifdef __cplusplus
}
#endif

View File

@ -1,77 +0,0 @@
#pragma once
#ifdef WIN32
#ifdef _RPASR_API_EXPORT
#define _RAPIDASRAPI __declspec(dllexport)
#else
#define _RAPIDASRAPI __declspec(dllimport)
#endif
#else
#define _RAPIDASRAPI
#endif
#ifndef _WIN32
#define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__))
#else
#define RPASR_CALLBCK_PREFIX __stdcall
#endif
#ifdef __cplusplus
extern "C" {
#endif
typedef void* RPASR_HANDLE;
typedef void* RPASR_RESULT;
typedef unsigned char RPASR_BOOL;
#define RPASR_TRUE 1
#define RPASR_FALSE 0
#define QM_DEFAULT_THREAD_NUM 4
typedef enum
{
RASR_NONE=-1,
RASRM_CTC_GREEDY_SEARCH=0,
RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
RASRM_ATTENSION_RESCORING = 2,
}RPASR_MODE;
typedef enum {
RPASR_MODEL_PADDLE = 0,
RPASR_MODEL_PADDLE_2 = 1,
RPASR_MODEL_K2 = 2,
RPASR_MODEL_PARAFORMER = 3,
}RPASR_MODEL_TYPE;
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
// APIs for qmasr
_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThread, bool quantize);
// if not give a fnCallback ,it should be NULL
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback);
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback);
_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex);
_RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result);
_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result);
_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE Handle);
_RAPIDASRAPI const float RapidAsrGetRetSnippetTime(RPASR_RESULT Result);
#ifdef __cplusplus
}
#endif

View File

@ -1,83 +1,70 @@
## 快速使用
### Windows
安装Vs2022 打开cpp_onnx目录下的cmake工程直接 build即可。 本仓库已经准备好所有相关依赖库。
Windows下已经预置fftw3及onnxruntime库
### Linux
See the bottom of this page: Building Guidance
### 运行程序
tester /path/to/models_dir /path/to/wave_file quantize(true or false)
例如: tester /data/models /data/test.wav false
/data/models 需要包括如下三个文件: config.yaml, am.mvn, model.onnx(or model_quant.onnx)
## 支持平台
- Windows
- Linux/Unix
## 依赖
- fftw3
- openblas
- onnxruntime
## 导出onnx格式模型文件
安装 modelscope与FunASR依赖torchtorchaudio安装过程[详细参考文档](https://github.com/alibaba-damo-academy/FunASR/wiki)
## Demo
```shell
pip install "modelscope[audio_asr]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
git clone https://github.com/alibaba/FunASR.git && cd FunASR
pip install --editable ./
tester /path/models_dir /path/wave_file quantize(true or false)
```
导出onnx模型[详见](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)参考示例从modelscope中模型导出
The structure of /path/models_dir
```
config.yaml, am.mvn, model.onnx(or model_quant.onnx)
```
## Steps
### Export onnx
#### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
```shell
pip3 install torch torchaudio
pip install -U modelscope
pip install -U funasr
```
#### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
```shell
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
```
## Building Guidance for Linux/Unix
### Building for Linux/Unix
```
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
mkdir build
cd build
#### Download onnxruntime
```shell
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
# here we get a copy of onnxruntime for linux 64
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
# ls
# onnxruntime-linux-x64-1.14.0 onnxruntime-linux-x64-1.14.0.tgz
```
#install fftw3-dev
ubuntu: apt install libfftw3-dev
centos: yum install fftw fftw-devel
#### Install fftw3
```shell
sudo apt install libfftw3-dev #ubuntu
# sudo yum install fftw fftw-devel #centos
```
#install openblas
bash ./third_party/install_openblas.sh
#### Install openblas
```shell
sudo apt-get install libopenblas-dev #ubuntu
# sudo yum -y install openblas-devel #centos
```
# build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
make
#### Build runtime
```shell
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
mkdir build && cd build
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
make
```
# then in the subfolder tester of current direcotry, you will see a program, tester
````
### The structure of a qualified onnxruntime package.
#### The structure of a qualified onnxruntime package.
```
onnxruntime_xxx
├───include
└───lib
```
## 注意
本程序只支持 采样率16000hz, 位深16bit的 **单声道** 音频。
### Building for Windows
Ref to win/
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).

View File

@ -3,11 +3,96 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <fstream>
#include <assert.h>
#include "Audio.h"
#include "precomp.h"
using namespace std;
// see http://soundfile.sapp.org/doc/WaveFormat/
// Note: We assume little endian here
struct WaveHeader {
bool Validate() const {
// F F I R
if (chunk_id != 0x46464952) {
printf("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
return false;
}
// E V A W
if (format != 0x45564157) {
printf("Expected format WAVE. Given: 0x%08x\n", format);
return false;
}
if (subchunk1_id != 0x20746d66) {
printf("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
subchunk1_id);
return false;
}
if (subchunk1_size != 16) { // 16 for PCM
printf("Expected subchunk1_size 16. Given: %d\n",
subchunk1_size);
return false;
}
if (audio_format != 1) { // 1 for PCM
printf("Expected audio_format 1. Given: %d\n", audio_format);
return false;
}
if (num_channels != 1) { // we support only single channel for now
printf("Expected single channel. Given: %d\n", num_channels);
return false;
}
if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
return false;
}
if (block_align != (num_channels * bits_per_sample / 8)) {
return false;
}
if (bits_per_sample != 16) { // we support only 16 bits per sample
printf("Expected bits_per_sample 16. Given: %d\n",
bits_per_sample);
return false;
}
return true;
}
// See https://en.wikipedia.org/wiki/WAV#Metadata and
// https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
void SeekToDataChunk(std::istream &is) {
// a t a d
while (is && subchunk2_id != 0x61746164) {
// const char *p = reinterpret_cast<const char *>(&subchunk2_id);
// printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
// p[1], p[2], p[3], subchunk2_size);
is.seekg(subchunk2_size, std::istream::cur);
is.read(reinterpret_cast<char *>(&subchunk2_id), sizeof(int32_t));
is.read(reinterpret_cast<char *>(&subchunk2_size), sizeof(int32_t));
}
}
int32_t chunk_id;
int32_t chunk_size;
int32_t format;
int32_t subchunk1_id;
int32_t subchunk1_size;
int16_t audio_format;
int16_t num_channels;
int32_t sample_rate;
int32_t byte_rate;
int16_t block_align;
int16_t bits_per_sample;
int32_t subchunk2_id; // a tag of this chunk
int32_t subchunk2_size; // size of subchunk2
};
static_assert(sizeof(WaveHeader) == WAV_HEADER_SIZE, "");
class AudioWindow {
private:
int *window;
@ -56,7 +141,7 @@ int AudioFrame::set_end(int val, int max_len)
float frame_length = 400;
float frame_shift = 160;
float num_new_samples =
ceil((num_samples - 400) / frame_shift) * frame_shift + frame_length;
ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length;
end = start + num_new_samples;
len = (int)num_new_samples;
@ -111,120 +196,150 @@ Audio::~Audio()
void Audio::disp()
{
printf("Audio time is %f s. len is %d\n", (float)speech_len / 16000,
printf("Audio time is %f s. len is %d\n", (float)speech_len / model_sample_rate,
speech_len);
}
float Audio::get_time_len()
{
return (float)speech_len / 16000;
//speech_len);
return (float)speech_len / model_sample_rate;
}
bool Audio::loadwav(const char *filename)
void Audio::wavResample(int32_t sampling_rate, const float *waveform,
int32_t n)
{
printf(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(model_sample_rate));
float min_freq =
std::min<int32_t>(sampling_rate, model_sample_rate);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
//FIXME
//auto resampler = new LinearResample(
// sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
auto resampler = std::make_unique<LinearResample>(
sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
std::vector<float> samples;
resampler->Resample(waveform, n, true, &samples);
//reset speech_data
speech_len = samples.size();
if (speech_data != NULL) {
free(speech_data);
}
speech_data = (float*)malloc(sizeof(float) * speech_len);
memset(speech_data, 0, sizeof(float) * speech_len);
copy(samples.begin(), samples.end(), speech_data);
}
bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
{
WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
}
if (speech_buff != NULL) {
free(speech_buff);
}
offset = 0;
FILE *fp;
fp = fopen(filename, "rb");
if (fp == nullptr)
std::ifstream is(filename, std::ifstream::binary);
is.read(reinterpret_cast<char *>(&header), sizeof(header));
if(!is){
fprintf(stderr, "Failed to read %s\n", filename);
return false;
fseek(fp, 0, SEEK_END); /*定位到文件末尾*/
uint32_t nFileLen = ftell(fp); /*得到文件大小*/
fseek(fp, 44, SEEK_SET); /*跳过wav文件头*/
speech_len = (nFileLen - 44) / 2;
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_align_len);
}
*sampling_rate = header.sample_rate;
// header.subchunk2_size contains the number of bytes in the data.
// As we assume each sample contains two bytes, so it is divided by 2 here
speech_len = header.subchunk2_size / 2;
speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
fclose(fp);
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
is.read(reinterpret_cast<char *>(speech_buff), header.subchunk2_size);
if (!is) {
fprintf(stderr, "Failed to read %s\n", filename);
return false;
}
speech_data = (float*)malloc(sizeof(float) * speech_len);
memset(speech_data, 0, sizeof(float) * speech_len);
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
memset(speech_data, 0, sizeof(float) * speech_align_len);
int i;
float scale = 1;
if (data_type == 1) {
scale = 32768;
}
for (i = 0; i < speech_len; i++) {
for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
//resample
if(*sampling_rate != model_sample_rate){
wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
return true;
}
else
return false;
}
bool Audio::loadwav(const char* buf, int nFileLen)
bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
{
WaveHeader header;
if (speech_data != NULL) {
free(speech_data);
}
if (speech_buff != NULL) {
free(speech_buff);
}
offset = 0;
size_t nOffset = 0;
std::memcpy(&header, buf, sizeof(header));
#define WAV_HEADER_SIZE 44
speech_len = (nFileLen - WAV_HEADER_SIZE) / 2;
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
*sampling_rate = header.sample_rate;
speech_len = header.subchunk2_size / 2;
speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
memcpy((void*)speech_buff, (const void*)(buf + WAV_HEADER_SIZE), speech_len * sizeof(int16_t));
speech_data = (float*)malloc(sizeof(float) * speech_len);
memset(speech_data, 0, sizeof(float) * speech_len);
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
memset(speech_data, 0, sizeof(float) * speech_align_len);
int i;
float scale = 1;
if (data_type == 1) {
scale = 32768;
}
for (i = 0; i < speech_len; i++) {
for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
//resample
if(*sampling_rate != model_sample_rate){
wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
return true;
}
else
return false;
}
bool Audio::loadpcmwav(const char* buf, int nBufLen)
bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
{
if (speech_data != NULL) {
free(speech_data);
@ -234,33 +349,29 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen)
}
offset = 0;
size_t nOffset = 0;
speech_len = nBufLen / 2;
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
speech_data = (float*)malloc(sizeof(float) * speech_len);
memset(speech_data, 0, sizeof(float) * speech_len);
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
memset(speech_data, 0, sizeof(float) * speech_align_len);
int i;
float scale = 1;
if (data_type == 1) {
scale = 32768;
}
for (i = 0; i < speech_len; i++) {
for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
//resample
if(*sampling_rate != model_sample_rate){
wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
@ -269,13 +380,10 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen)
}
else
return false;
}
bool Audio::loadpcmwav(const char* filename)
bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
{
if (speech_data != NULL) {
free(speech_data);
}
@ -293,34 +401,31 @@ bool Audio::loadpcmwav(const char* filename)
fseek(fp, 0, SEEK_SET);
speech_len = (nFileLen) / 2;
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
{
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
fclose(fp);
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
memset(speech_data, 0, sizeof(float) * speech_align_len);
speech_data = (float*)malloc(sizeof(float) * speech_len);
memset(speech_data, 0, sizeof(float) * speech_len);
int i;
float scale = 1;
if (data_type == 1) {
scale = 32768;
}
for (i = 0; i < speech_len; i++) {
for (int32_t i = 0; i != speech_len; ++i) {
speech_data[i] = (float)speech_buff[i] / scale;
}
//resample
if(*sampling_rate != model_sample_rate){
wavResample(*sampling_rate, speech_data, speech_len);
}
AudioFrame* frame = new AudioFrame(speech_len);
frame_queue.push(frame);
return true;
}
@ -329,7 +434,6 @@ bool Audio::loadpcmwav(const char* filename)
}
int Audio::fetch_chunck(float *&dout, int len)
{
if (offset >= speech_align_len) {

View File

@ -1,43 +1,44 @@
file(GLOB files1 "*.cpp")
file(GLOB files2 "*.cc")
file(GLOB files4 "paraformer/*.cpp")
set(files ${files1} ${files2} ${files3} ${files4})
# message("${files}")
add_library(rapidasr ${files})
add_library(funasr ${files})
if(WIN32)
set(EXTRA_LIBS libfftw3f-3 yaml-cpp)
if(CMAKE_CL_64)
target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
else()
target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
endif()
target_include_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
target_include_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
target_compile_definitions(rapidasr PUBLIC -D_RPASR_API_EXPORT)
target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT)
else()
set(EXTRA_LIBS fftw3f pthread yaml-cpp)
target_include_directories(rapidasr PUBLIC "/usr/local/opt/fftw/include")
target_link_directories(rapidasr PUBLIC "/usr/local/opt/fftw/lib")
target_include_directories(funasr PUBLIC "/usr/local/opt/fftw/include")
target_link_directories(funasr PUBLIC "/usr/local/opt/fftw/lib")
target_include_directories(rapidasr PUBLIC "/usr/local/opt/openblas/include")
target_link_directories(rapidasr PUBLIC "/usr/local/opt/openblas/lib")
target_include_directories(funasr PUBLIC "/usr/local/opt/openblas/include")
target_link_directories(funasr PUBLIC "/usr/local/opt/openblas/lib")
target_include_directories(rapidasr PUBLIC "/usr/include")
target_link_directories(rapidasr PUBLIC "/usr/lib64")
target_include_directories(funasr PUBLIC "/usr/include")
target_link_directories(funasr PUBLIC "/usr/lib64")
target_include_directories(rapidasr PUBLIC ${FFTW3F_INCLUDE_DIR})
target_link_directories(rapidasr PUBLIC ${FFTW3F_LIBRARY_DIR})
target_include_directories(funasr PUBLIC ${FFTW3F_INCLUDE_DIR})
target_link_directories(funasr PUBLIC ${FFTW3F_LIBRARY_DIR})
include_directories(${ONNXRUNTIME_DIR}/include)
endif()
include_directories(${CMAKE_SOURCE_DIR}/include)
target_link_libraries(rapidasr PUBLIC onnxruntime ${EXTRA_LIBS})
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})

View File

@ -5,14 +5,10 @@ using namespace std;
FeatureExtract::FeatureExtract(int mode) : mode(mode)
{
fftw_init();
}
FeatureExtract::~FeatureExtract()
{
fftwf_free(fft_input);
fftwf_free(fft_out);
fftwf_destroy_plan(p);
}
void FeatureExtract::reset()
@ -26,34 +22,25 @@ int FeatureExtract::size()
return fqueue.size();
}
void FeatureExtract::fftw_init()
void FeatureExtract::insert(fftwf_plan plan, float *din, int len, int flag)
{
int fft_size = 512;
fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
float* fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
fftwf_complex* fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
memset(fft_input, 0, sizeof(float) * fft_size);
p = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
}
void FeatureExtract::insert(float *din, int len, int flag)
{
const float *window = (const float *)&window_hex;
if (mode == 3)
window = (const float *)&window_hamm_hex;
int window_size = 400;
int fft_size = 512;
int window_shift = 160;
speech.load(din, len);
int i, j;
float tmp_feature[80];
if (mode == 0 || mode == 2 || mode == 3) {
int ll = (speech.size() - 400) / 160 + 1;
int ll = (speech.size() - window_size) / window_shift + 1;
fqueue.reinit(ll);
}
for (i = 0; i <= speech.size() - 400; i = i + window_shift) {
for (i = 0; i <= speech.size() - window_size; i = i + window_shift) {
float tmp_mean = 0;
for (j = 0; j < window_size; j++) {
tmp_mean += speech[i + j];
@ -70,7 +57,7 @@ void FeatureExtract::insert(float *din, int len, int flag)
pre_val = cur_val;
}
fftwf_execute(p);
fftwf_execute_dft_r2c(plan, fft_input, fft_out);
melspect((float *)fft_out, tmp_feature);
int tmp_flag = S_MIDDLE;
@ -80,6 +67,8 @@ void FeatureExtract::insert(float *din, int len, int flag)
fqueue.push(tmp_feature, tmp_flag);
}
speech.update(i);
fftwf_free(fft_input);
fftwf_free(fft_out);
}
bool FeatureExtract::fetch(Tensor<float> *&dout)
@ -128,7 +117,6 @@ void FeatureExtract::global_cmvn(float *din)
void FeatureExtract::melspect(float *din, float *dout)
{
float fftmag[256];
// float tmp;
const float *melcoe = (const float *)melcoe_hex;
int i;
for (i = 0; i < 256; i++) {

View File

@ -14,12 +14,11 @@ class FeatureExtract {
SpeechWrap speech;
FeatureQueue fqueue;
int mode;
int fft_size = 512;
int window_size = 400;
int window_shift = 160;
float *fft_input;
fftwf_complex *fft_out;
fftwf_plan p;
void fftw_init();
//void fftw_init();
void melspect(float *din, float *dout);
void global_cmvn(float *din);
@ -27,9 +26,9 @@ class FeatureExtract {
FeatureExtract(int mode);
~FeatureExtract();
int size();
int status();
//int status();
void reset();
void insert(float *din, int len, int flag);
void insert(fftwf_plan plan, float *din, int len, int flag);
bool fetch(Tensor<float> *&dout);
};

View File

@ -13,21 +13,6 @@ Vocab::Vocab(const char *filename)
{
ifstream in(filename);
loadVocabFromYaml(filename);
/*
string line;
if (in) // 有该文件
{
while (getline(in, line)) // line中不包括每行的换行符
{
vocab.push_back(line);
}
}
else{
printf("Cannot load vocab from: %s, there must be file vocab.txt", filename);
exit(-1);
}
*/
}
Vocab::~Vocab()
{

View File

@ -5,7 +5,7 @@ typedef struct
{
std::string msg;
float snippet_time;
}RPASR_RECOG_RESULT;
}FUNASR_RECOG_RESULT;
#ifdef _WIN32
@ -53,4 +53,4 @@ inline void getOutputName(Ort::Session* session, string& outputName, int nIndex
}
}
}
}

View File

@ -5,32 +5,33 @@ extern "C" {
#endif
// APIs for qmasr
_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThreadNum, bool quantize)
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize)
{
Model* mm = create_model(szModelDir, nThreadNum, quantize);
return mm;
}
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
return nullptr;
int32_t sampling_rate = -1;
Audio audio(1);
if (!audio.loadwav(szBuf, nLen))
if (!audio.loadwav(szBuf, nLen, &sampling_rate))
return nullptr;
//audio.split();
float* buff;
int len;
int flag=0;
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
int nStep = 0;
int nTotal = audio.get_queue_size();
while (audio.fetch(buff, len, flag) > 0) {
pRecogObj->reset();
//pRecogObj->reset();
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg += msg;
nStep++;
@ -41,26 +42,26 @@ extern "C" {
return pResult;
}
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
return nullptr;
Audio audio(1);
if (!audio.loadpcmwav(szBuf, nLen))
if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
return nullptr;
//audio.split();
float* buff;
int len;
int flag = 0;
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
int nStep = 0;
int nTotal = audio.get_queue_size();
while (audio.fetch(buff, len, flag) > 0) {
pRecogObj->reset();
//pRecogObj->reset();
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg += msg;
nStep++;
@ -71,26 +72,26 @@ extern "C" {
return pResult;
}
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback)
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
return nullptr;
Audio audio(1);
if (!audio.loadpcmwav(szFileName))
if (!audio.loadpcmwav(szFileName, &sampling_rate))
return nullptr;
//audio.split();
float* buff;
int len;
int flag = 0;
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
int nStep = 0;
int nTotal = audio.get_queue_size();
while (audio.fetch(buff, len, flag) > 0) {
pRecogObj->reset();
//pRecogObj->reset();
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg += msg;
nStep++;
@ -101,14 +102,15 @@ extern "C" {
return pResult;
}
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback)
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
{
Model* pRecogObj = (Model*)handle;
if (!pRecogObj)
return nullptr;
int32_t sampling_rate = -1;
Audio audio(1);
if(!audio.loadwav(szWavfile))
if(!audio.loadwav(szWavfile, &sampling_rate))
return nullptr;
//audio.split();
@ -117,10 +119,10 @@ extern "C" {
int flag = 0;
int nStep = 0;
int nTotal = audio.get_queue_size();
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
pResult->snippet_time = audio.get_time_len();
while (audio.fetch(buff, len, flag) > 0) {
pRecogObj->reset();
//pRecogObj->reset();
string msg = pRecogObj->forward(buff, len, flag);
pResult->msg+= msg;
nStep++;
@ -131,7 +133,7 @@ extern "C" {
return pResult;
}
_RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result)
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result)
{
if (!Result)
return 0;
@ -140,32 +142,32 @@ extern "C" {
}
_RAPIDASRAPI const float RapidAsrGetRetSnippetTime(RPASR_RESULT Result)
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result)
{
if (!Result)
return 0.0f;
return ((RPASR_RECOG_RESULT*)Result)->snippet_time;
return ((FUNASR_RECOG_RESULT*)Result)->snippet_time;
}
_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex)
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex)
{
RPASR_RECOG_RESULT * pResult = (RPASR_RECOG_RESULT*)Result;
FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result;
if(!pResult)
return nullptr;
return pResult->msg.c_str();
}
_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result)
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result)
{
if (Result)
{
delete (RPASR_RECOG_RESULT*)Result;
delete (FUNASR_RECOG_RESULT*)Result;
}
}
_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle)
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
{
Model* pRecogObj = (Model*)handle;

View File

@ -4,7 +4,7 @@ using namespace std;
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
{
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
string model_path;
string cmvn_path;
string config_path;
@ -18,7 +18,10 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
cmvn_path = pathAppend(path, "am.mvn");
config_path = pathAppend(path, "config.yaml");
fe = new FeatureExtract(3);
fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
memset(fft_input, 0, sizeof(float) * fft_size);
plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
//sessionOptions.SetInterOpNumThreads(1);
sessionOptions.SetIntraOpNumThreads(nNumThread);
@ -26,20 +29,20 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
#ifdef _WIN32
wstring wstrPath = strToWstr(model_path);
m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#else
m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
#endif
string strName;
getInputName(m_session, strName);
getInputName(m_session.get(), strName);
m_strInputNames.push_back(strName.c_str());
getInputName(m_session, strName,1);
getInputName(m_session.get(), strName,1);
m_strInputNames.push_back(strName);
getOutputName(m_session, strName);
getOutputName(m_session.get(), strName);
m_strOutputNames.push_back(strName);
getOutputName(m_session, strName,1);
getOutputName(m_session.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@ -52,20 +55,16 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
ModelImp::~ModelImp()
{
if(fe)
delete fe;
if (m_session)
{
delete m_session;
m_session = nullptr;
}
if(vocab)
delete vocab;
fftwf_free(fft_input);
fftwf_free(fft_out);
fftwf_destroy_plan(plan);
fftwf_cleanup();
}
void ModelImp::reset()
{
fe->reset();
}
void ModelImp::apply_lfr(Tensor<float>*& din)
@ -159,14 +158,21 @@ string ModelImp::greedy_search(float * in, int nLen )
string ModelImp::forward(float* din, int len, int flag)
{
Tensor<float>* in;
fe->insert(din, len, flag);
FeatureExtract* fe = new FeatureExtract(3);
fe->reset();
fe->insert(plan, din, len, flag);
fe->fetch(in);
apply_lfr(in);
apply_cmvn(in);
Ort::RunOptions run_option;
#ifdef _WIN_X86
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
in->buff,
@ -192,7 +198,6 @@ string ModelImp::forward(float* din, int len, int flag)
auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
@ -203,9 +208,14 @@ string ModelImp::forward(float* din, int len, int flag)
result = "";
}
if(in)
if(in){
delete in;
in = nullptr;
}
if(fe){
delete fe;
fe = nullptr;
}
return result;
}

View File

@ -8,7 +8,10 @@ namespace paraformer {
class ModelImp : public Model {
private:
FeatureExtract* fe;
int fft_size=512;
float *fft_input;
fftwf_complex *fft_out;
fftwf_plan plan;
Vocab* vocab;
vector<float> means_list;
@ -21,21 +24,13 @@ namespace paraformer {
string greedy_search( float* in, int nLen);
#ifdef _WIN_X86
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
Ort::Session* m_session = nullptr;
Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
Ort::SessionOptions sessionOptions = Ort::SessionOptions();
std::unique_ptr<Ort::Session> m_session;
Ort::Env env_;
Ort::SessionOptions sessionOptions;
vector<string> m_strInputNames, m_strOutputNames;
vector<const char*> m_szInputNames;
vector<const char*> m_szOutputNames;
//string m_strInputName, m_strInputNameLen;
//string m_strOutputName, m_strOutputNameLen;
public:
ModelImp(const char* path, int nNumThread=0, bool quantize=false);

View File

@ -44,9 +44,10 @@ using namespace std;
#include "FeatureQueue.h"
#include "SpeechWrap.h"
#include <Audio.h>
#include "resample.h"
#include "Model.h"
#include "paraformer_onnx.h"
#include "librapidasrapi.h"
#include "libfunasrapi.h"
using namespace paraformer;

View File

@ -0,0 +1,305 @@
/**
* Copyright 2013 Pegah Ghahremani
* 2014 IMSL, PKU-HKUST (author: Wei Shi)
* 2014 Yanqing Sun, Junjie Wang
* 2014 Johns Hopkins University (author: Daniel Povey)
* Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// this file is copied and modified from
// kaldi/src/feat/resample.cc
#include "resample.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <cstdlib>
#include <type_traits>
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
#ifndef M_PI
#define M_PI 3.1415926535897932384626433832795
#endif
template <class I>
I Gcd(I m, I n) {
// this function is copied from kaldi/src/base/kaldi-math.h
if (m == 0 || n == 0) {
if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
exit(-1);
}
return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
// return absolute value of whichever is nonzero
}
// could use compile-time assertion
// but involves messing with complex template stuff.
static_assert(std::is_integral<I>::value, "");
while (1) {
m %= n;
if (m == 0) return (n > 0 ? n : -n);
n %= m;
if (n == 0) return (m > 0 ? m : -m);
}
}
/// Returns the least common multiple of two integers. Will
/// crash unless the inputs are positive.
template <class I>
I Lcm(I m, I n) {
// This function is copied from kaldi/src/base/kaldi-math.h
assert(m > 0 && n > 0);
I gcd = Gcd(m, n);
return gcd * (m / gcd) * (n / gcd);
}
static float DotProduct(const float *a, const float *b, int32_t n) {
float sum = 0;
for (int32_t i = 0; i != n; ++i) {
sum += a[i] * b[i];
}
return sum;
}
LinearResample::LinearResample(int32_t samp_rate_in_hz,
int32_t samp_rate_out_hz, float filter_cutoff_hz,
int32_t num_zeros)
: samp_rate_in_(samp_rate_in_hz),
samp_rate_out_(samp_rate_out_hz),
filter_cutoff_(filter_cutoff_hz),
num_zeros_(num_zeros) {
assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 &&
filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz &&
filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0);
// base_freq is the frequency of the repeating unit, which is the gcd
// of the input frequencies.
int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_);
input_samples_in_unit_ = samp_rate_in_ / base_freq;
output_samples_in_unit_ = samp_rate_out_ / base_freq;
SetIndexesAndWeights();
Reset();
}
void LinearResample::SetIndexesAndWeights() {
first_index_.resize(output_samples_in_unit_);
weights_.resize(output_samples_in_unit_);
double window_width = num_zeros_ / (2.0 * filter_cutoff_);
for (int32_t i = 0; i < output_samples_in_unit_; i++) {
double output_t = i / static_cast<double>(samp_rate_out_);
double min_t = output_t - window_width, max_t = output_t + window_width;
// we do ceil on the min and floor on the max, because if we did it
// the other way around we would unnecessarily include indexes just
// outside the window, with zero coefficients. It's possible
// if the arguments to the ceil and floor expressions are integers
// (e.g. if filter_cutoff_ has an exact ratio with the sample rates),
// that we unnecessarily include something with a zero coefficient,
// but this is only a slight efficiency issue.
int32_t min_input_index = ceil(min_t * samp_rate_in_),
max_input_index = floor(max_t * samp_rate_in_),
num_indices = max_input_index - min_input_index + 1;
first_index_[i] = min_input_index;
weights_[i].resize(num_indices);
for (int32_t j = 0; j < num_indices; j++) {
int32_t input_index = min_input_index + j;
double input_t = input_index / static_cast<double>(samp_rate_in_),
delta_t = input_t - output_t;
// sign of delta_t doesn't matter.
weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_;
}
}
}
/** Here, t is a time in seconds representing an offset from
the center of the windowed filter function, and FilterFunction(t)
returns the windowed filter function, described
in the header as h(t) = f(t)g(t), evaluated at t.
*/
float LinearResample::FilterFunc(float t) const {
float window, // raised-cosine (Hanning) window of width
// num_zeros_/2*filter_cutoff_
filter; // sinc filter function
if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_))
window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t));
else
window = 0.0; // outside support of window function
if (t != 0)
filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t);
else
filter = 2 * filter_cutoff_; // limit of the function at t = 0
return filter * window;
}
void LinearResample::Reset() {
input_sample_offset_ = 0;
output_sample_offset_ = 0;
input_remainder_.resize(0);
}
void LinearResample::Resample(const float *input, int32_t input_dim, bool flush,
std::vector<float> *output) {
int64_t tot_input_samp = input_sample_offset_ + input_dim,
tot_output_samp = GetNumOutputSamples(tot_input_samp, flush);
assert(tot_output_samp >= output_sample_offset_);
output->resize(tot_output_samp - output_sample_offset_);
// samp_out is the index into the total output signal, not just the part
// of it we are producing here.
for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp;
samp_out++) {
int64_t first_samp_in;
int32_t samp_out_wrapped;
GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped);
const std::vector<float> &weights = weights_[samp_out_wrapped];
// first_input_index is the first index into "input" that we have a weight
// for.
int32_t first_input_index =
static_cast<int32_t>(first_samp_in - input_sample_offset_);
float this_output;
if (first_input_index >= 0 &&
first_input_index + static_cast<int32_t>(weights.size()) <= input_dim) {
this_output =
DotProduct(input + first_input_index, weights.data(), weights.size());
} else { // Handle edge cases.
this_output = 0.0;
for (int32_t i = 0; i < static_cast<int32_t>(weights.size()); i++) {
float weight = weights[i];
int32_t input_index = first_input_index + i;
if (input_index < 0 &&
static_cast<int32_t>(input_remainder_.size()) + input_index >= 0) {
this_output +=
weight * input_remainder_[input_remainder_.size() + input_index];
} else if (input_index >= 0 && input_index < input_dim) {
this_output += weight * input[input_index];
} else if (input_index >= input_dim) {
// We're past the end of the input and are adding zero; should only
// happen if the user specified flush == true, or else we would not
// be trying to output this sample.
assert(flush);
}
}
}
int32_t output_index =
static_cast<int32_t>(samp_out - output_sample_offset_);
(*output)[output_index] = this_output;
}
if (flush) {
Reset(); // Reset the internal state.
} else {
SetRemainder(input, input_dim);
input_sample_offset_ = tot_input_samp;
output_sample_offset_ = tot_output_samp;
}
}
int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp,
bool flush) const {
// For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
// where tick_freq is the least common multiple of samp_rate_in_ and
// samp_rate_out_.
int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_);
int32_t ticks_per_input_period = tick_freq / samp_rate_in_;
// work out the number of ticks in the time interval
// [ 0, input_num_samp/samp_rate_in_ ).
int64_t interval_length_in_ticks = input_num_samp * ticks_per_input_period;
if (!flush) {
float window_width = num_zeros_ / (2.0 * filter_cutoff_);
// To count the window-width in ticks we take the floor. This
// is because since we're looking for the largest integer num-out-samp
// that fits in the interval, which is open on the right, a reduction
// in interval length of less than a tick will never make a difference.
// For example, the largest integer in the interval [ 0, 2 ) and the
// largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one).
// So when we're subtracting the window-width we can ignore the fractional
// part.
int32_t window_width_ticks = floor(window_width * tick_freq);
// The time-period of the output that we can sample gets reduced
// by the window-width (which is actually the distance from the
// center to the edge of the windowing function) if we're not
// "flushing the output".
interval_length_in_ticks -= window_width_ticks;
}
if (interval_length_in_ticks <= 0) return 0;
int32_t ticks_per_output_period = tick_freq / samp_rate_out_;
// Get the last output-sample in the closed interval, i.e. replacing [ ) with
// [ ]. Note: integer division rounds down. See
// http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
// the notation.
int64_t last_output_samp = interval_length_in_ticks / ticks_per_output_period;
// We need the last output-sample in the open interval, so if it takes us to
// the end of the interval exactly, subtract one.
if (last_output_samp * ticks_per_output_period == interval_length_in_ticks)
last_output_samp--;
// First output-sample index is zero, so the number of output samples
// is the last output-sample plus one.
int64_t num_output_samp = last_output_samp + 1;
return num_output_samp;
}
// inline
void LinearResample::GetIndexes(int64_t samp_out, int64_t *first_samp_in,
int32_t *samp_out_wrapped) const {
// A unit is the smallest nonzero amount of time that is an exact
// multiple of the input and output sample periods. The unit index
// is the answer to "which numbered unit we are in".
int64_t unit_index = samp_out / output_samples_in_unit_;
// samp_out_wrapped is equal to samp_out % output_samples_in_unit_
*samp_out_wrapped =
static_cast<int32_t>(samp_out - unit_index * output_samples_in_unit_);
*first_samp_in =
first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_;
}
void LinearResample::SetRemainder(const float *input, int32_t input_dim) {
std::vector<float> old_remainder(input_remainder_);
// max_remainder_needed is the width of the filter from side to side,
// measured in input samples. you might think it should be half that,
// but you have to consider that you might be wanting to output samples
// that are "in the past" relative to the beginning of the latest
// input... anyway, storing more remainder than needed is not harmful.
int32_t max_remainder_needed =
ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_);
input_remainder_.resize(max_remainder_needed);
for (int32_t index = -static_cast<int32_t>(input_remainder_.size());
index < 0; index++) {
// we interpret "index" as an offset from the end of "input" and
// from the end of input_remainder_.
int32_t input_index = index + input_dim;
if (input_index >= 0) {
input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
input[input_index];
} else if (input_index + static_cast<int32_t>(old_remainder.size()) >= 0) {
input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
old_remainder[input_index +
static_cast<int32_t>(old_remainder.size())];
// else leave it at zero.
}
}
}

View File

@ -0,0 +1,137 @@
/**
* Copyright 2013 Pegah Ghahremani
* 2014 IMSL, PKU-HKUST (author: Wei Shi)
* 2014 Yanqing Sun, Junjie Wang
* 2014 Johns Hopkins University (author: Daniel Povey)
* Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// this file is copied and modified from
// kaldi/src/feat/resample.h
#include <cstdint>
#include <vector>
/*
We require that the input and output sampling rate be specified as
integers, as this is an easy way to specify that their ratio be rational.
*/
class LinearResample {
public:
/// Constructor. We make the input and output sample rates integers, because
/// we are going to need to find a common divisor. This should just remind
/// you that they need to be integers. The filter cutoff needs to be less
/// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros
/// controls the sharpness of the filter, more == sharper but less efficient.
/// We suggest around 4 to 10 for normal use.
LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz,
float filter_cutoff_hz, int32_t num_zeros);
/// Calling the function Reset() resets the state of the object prior to
/// processing a new signal; it is only necessary if you have called
/// Resample(x, x_size, false, y) for some signal, leading to a remainder of
/// the signal being called, but then abandon processing the signal before
/// calling Resample(x, x_size, true, y) for the last piece. Call it
/// unnecessarily between signals will not do any harm.
void Reset();
/// This function does the resampling. If you call it with flush == true and
/// you have never called it with flush == false, it just resamples the input
/// signal (it resizes the output to a suitable number of samples).
///
/// You can also use this function to process a signal a piece at a time.
/// suppose you break it into piece1, piece2, ... pieceN. You can call
/// \code{.cc}
/// Resample(piece1, piece1_size, false, &output1);
/// Resample(piece2, piece2_size, false, &output2);
/// Resample(piece3, piece3_size, true, &output3);
/// \endcode
/// If you call it with flush == false, it won't output the last few samples
/// but will remember them, so that if you later give it a second piece of
/// the input signal it can process it correctly.
/// If your most recent call to the object was with flush == false, it will
/// have internal state; you can remove this by calling Reset().
/// Empty input is acceptable.
void Resample(const float *input, int32_t input_dim, bool flush,
std::vector<float> *output);
//// Return the input and output sampling rates (for checks, for example)
int32_t GetInputSamplingRate() const { return samp_rate_in_; }
int32_t GetOutputSamplingRate() const { return samp_rate_out_; }
private:
void SetIndexesAndWeights();
float FilterFunc(float) const;
/// This function outputs the number of output samples we will output
/// for a signal with "input_num_samp" input samples. If flush == true,
/// we return the largest n such that
/// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ),
/// and note that the interval is half-open. If flush == false,
/// define window_width as num_zeros / (2.0 * filter_cutoff_);
/// we return the largest n such that (n/samp_rate_out_) is in the interval
/// [ 0, input_num_samp/samp_rate_in_ - window_width ).
int64_t GetNumOutputSamples(int64_t input_num_samp, bool flush) const;
/// Given an output-sample index, this function outputs to *first_samp_in the
/// first input-sample index that we have a weight on (may be negative),
/// and to *samp_out_wrapped the index into weights_ where we can get the
/// corresponding weights on the input.
inline void GetIndexes(int64_t samp_out, int64_t *first_samp_in,
int32_t *samp_out_wrapped) const;
void SetRemainder(const float *input, int32_t input_dim);
private:
// The following variables are provided by the user.
int32_t samp_rate_in_;
int32_t samp_rate_out_;
float filter_cutoff_;
int32_t num_zeros_;
int32_t input_samples_in_unit_; ///< The number of input samples in the
///< smallest repeating unit: num_samp_in_ =
///< samp_rate_in_hz / Gcd(samp_rate_in_hz,
///< samp_rate_out_hz)
int32_t output_samples_in_unit_; ///< The number of output samples in the
///< smallest repeating unit: num_samp_out_
///< = samp_rate_out_hz /
///< Gcd(samp_rate_in_hz, samp_rate_out_hz)
/// The first input-sample index that we sum over, for this output-sample
/// index. May be negative; any truncation at the beginning is handled
/// separately. This is just for the first few output samples, but we can
/// extrapolate the correct input-sample index for arbitrary output samples.
std::vector<int32_t> first_index_;
/// Weights on the input samples, for this output-sample index.
std::vector<std::vector<float>> weights_;
// the following variables keep track of where we are in a particular signal,
// if it is being provided over multiple calls to Resample().
int64_t input_sample_offset_; ///< The number of input samples we have
///< already received for this signal
///< (including anything in remainder_)
int64_t output_sample_offset_; ///< The number of samples we have already
///< output for this signal.
std::vector<float> input_remainder_; ///< A small trailing part of the
///< previously seen input signal.
};

View File

@ -8,7 +8,7 @@ if(WIN32)
endif()
endif()
set(EXTRA_LIBS rapidasr)
set(EXTRA_LIBS funasr)
include_directories(${CMAKE_SOURCE_DIR}/include)

View File

@ -5,7 +5,7 @@
#include <win_func.h>
#endif
#include "librapidasrapi.h"
#include "libfunasrapi.h"
#include <iostream>
#include <fstream>
@ -26,7 +26,7 @@ int main(int argc, char *argv[])
// is quantize
bool quantize = false;
istringstream(argv[3]) >> boolalpha >> quantize;
RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
if (!AsrHanlde)
{
@ -42,72 +42,32 @@ int main(int argc, char *argv[])
gettimeofday(&start, NULL);
float snippet_time = 0.0f;
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
gettimeofday(&end, NULL);
if (Result)
{
string msg = RapidAsrGetResult(Result, 0);
string msg = FunASRGetResult(Result, 0);
setbuf(stdout, NULL);
cout << "Result: \"";
cout << msg << "\"." << endl;
snippet_time = RapidAsrGetRetSnippetTime(Result);
RapidAsrFreeResult(Result);
printf("Result: %s \n", msg.c_str());
snippet_time = FunASRGetRetSnippetTime(Result);
FunASRFreeResult(Result);
}
else
{
cout <<"no return data!";
}
//char* buff = nullptr;
//int len = 0;
//ifstream ifs(argv[2], std::ios::binary | std::ios::in);
//if (ifs.is_open())
//{
// ifs.seekg(0, std::ios::end);
// len = ifs.tellg();
// ifs.seekg(0, std::ios::beg);
// buff = new char[len];
// ifs.read(buff, len);
// //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
// RPASR_RESULT Result=RapidAsrRecogPCMBuffer(AsrHanlde, buff,len, RASR_NONE, NULL);
// //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
// gettimeofday(&end, NULL);
//
// if (Result)
// {
// string msg = RapidAsrGetResult(Result, 0);
// setbuf(stdout, NULL);
// cout << "Result: \"";
// cout << msg << endl;
// cout << "\"." << endl;
// snippet_time = RapidAsrGetRetSnippetTime(Result);
// RapidAsrFreeResult(Result);
// }
// else
// {
// cout <<"no return data!";
// }
//
//delete[]buff;
//}
printf("Audio length %lfs.\n", (double)snippet_time);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
RapidAsrUninit(AsrHanlde);
FunASRUninit(AsrHanlde);
return 0;
}

View File

@ -5,7 +5,7 @@
#include <win_func.h>
#endif
#include "librapidasrapi.h"
#include "libfunasrapi.h"
#include <iostream>
#include <fstream>
@ -47,7 +47,7 @@ int main(int argc, char *argv[])
bool quantize = false;
istringstream(argv[3]) >> boolalpha >> quantize;
RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
if (!AsrHanlde)
{
printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
@ -61,7 +61,7 @@ int main(int argc, char *argv[])
// warm up
for (size_t i = 0; i < 30; i++)
{
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
}
// forward
@ -72,19 +72,19 @@ int main(int argc, char *argv[])
for (size_t i = 0; i < wav_list.size(); i++)
{
gettimeofday(&start, NULL);
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
total_time += taking_micros;
if(Result){
string msg = RapidAsrGetResult(Result, 0);
printf("Result: %s \n", msg);
string msg = FunASRGetResult(Result, 0);
printf("Result: %s \n", msg.c_str());
snippet_time = RapidAsrGetRetSnippetTime(Result);
snippet_time = FunASRGetRetSnippetTime(Result);
total_length += snippet_time;
RapidAsrFreeResult(Result);
FunASRFreeResult(Result);
}else{
cout <<"No return data!";
}
@ -94,6 +94,6 @@ int main(int argc, char *argv[])
printf("total_time_comput %ld ms.\n", total_time / 1000);
printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
RapidAsrUninit(AsrHanlde);
FunASRUninit(AsrHanlde);
return 0;
}

View File

@ -0,0 +1,62 @@
import grpc
import json
import time
import asyncio
import soundfile as sf
import argparse
from grpc_client import transcribe_audio_bytes
from paraformer_pb2_grpc import ASRStub
# send the audio data once
async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
with grpc.insecure_channel(grpc_uri) as channel:
stub = ASRStub(channel)
for line in wav_scp:
wav_file = line.split()[1]
wav, _ = sf.read(wav_file, dtype='int16')
b = time.time()
response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
resp = response.next()
text = ''
if 'decoding' == resp.action:
resp = response.next()
if 'finish' == resp.action:
text = json.loads(resp.sentence)['text']
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
res= {'text': text, 'time': time.time() - b}
print(res)
async def test(args):
wav_scp = open(args.wav_scp, "r").readlines()
uri = '{}:{}'.format(args.host, args.port)
res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--host",
type=str,
default="127.0.0.1",
required=False,
help="grpc server host ip")
parser.add_argument("--port",
type=int,
default=10108,
required=False,
help="grpc server port")
parser.add_argument("--user_allowed",
type=str,
default="project1_user1",
help="allowed user for grpc client")
parser.add_argument("--sample_rate",
type=int,
default=16000,
help="audio sample_rate from client")
parser.add_argument("--wav_scp",
type=str,
required=True,
help="audio wav scp")
args = parser.parse_args()
asyncio.run(test(args))

Some files were not shown because too many files have changed in this diff Show More