mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #363 from alibaba-damo-academy/main
update with main
This commit is contained in:
commit
937e507977
37
README.md
37
README.md
@ -7,12 +7,11 @@
|
||||
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
|
||||
| [**Highlights**](#highlights)
|
||||
| [**Installation**](#installation)
|
||||
| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
|
||||
| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
|
||||
| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
|
||||
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
|
||||
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
|
||||
| [**Model Zoo**](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||
| [**Model Zoo**](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/modelscope_models.md)
|
||||
| [**Contact**](#contact)
|
||||
|
||||
|
||||
@ -29,15 +28,37 @@ For the release notes, please ref to [news](https://github.com/alibaba-damo-acad
|
||||
|
||||
## Installation
|
||||
|
||||
``` sh
|
||||
pip install "modelscope[audio_asr]" --upgrade -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
pip install --editable ./
|
||||
Install from pip
|
||||
```shell
|
||||
pip install -U funasr
|
||||
# For the users in China, you could install with the command:
|
||||
# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
|
||||
```
|
||||
|
||||
Or install from source code
|
||||
|
||||
|
||||
``` sh
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
pip install -e ./
|
||||
# For the users in China, you could install with the command:
|
||||
# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
|
||||
|
||||
```
|
||||
If you want to use the pretrained models in ModelScope, you should install the modelscope:
|
||||
|
||||
```shell
|
||||
pip install -U modelscope
|
||||
# For the users in China, you could install with the command:
|
||||
# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
|
||||
```
|
||||
|
||||
For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
|
||||
|
||||
## Usage
|
||||
For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
|
||||
[//]: # ()
|
||||
[//]: # (## Usage)
|
||||
|
||||
[//]: # (For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html)))
|
||||
|
||||
## Contact
|
||||
|
||||
|
||||
@ -6,29 +6,70 @@
|
||||
## Model Zoo
|
||||
Here we provided several pretrained models on different datasets. The details of models and datasets can be found on [ModelScope](https://www.modelscope.cn/models?page=1&tasks=auto-speech-recognition).
|
||||
|
||||
| Datasets | Hours | Model | Online/Offline | Language | Framework | Checkpoint |
|
||||
|:-----:|:-----:|:--------------:|:--------------:| :---: | :---: | --- |
|
||||
| Alibaba Speech Data | 60000 | Paraformer | Offline | CN | Pytorch |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) |
|
||||
| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
|
||||
| Alibaba Speech Data | 50000 | Paraformer | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) |
|
||||
| Alibaba Speech Data | 50000 | Paraformer | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online](http://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online/summary) |
|
||||
| Alibaba Speech Data | 50000 | UniASR | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) |
|
||||
| Alibaba Speech Data | 50000 | UniASR | Offline | CN | Tensorflow |[speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) |
|
||||
| Alibaba Speech Data | 50000 | UniASR | Online | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online/summary) |
|
||||
| Alibaba Speech Data | 50000 | UniASR | Offline | CN&EN | Tensorflow |[speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline/summary) |
|
||||
| Alibaba Speech Data | 20000 | UniASR | Online | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online/summary) |
|
||||
| Alibaba Speech Data | 20000 | UniASR | Offline | CN-Accent | Tensorflow |[speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline/summary) |
|
||||
| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online/summary) |
|
||||
| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Tensorflow |[speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1/summary) |
|
||||
| Alibaba Speech Data | 30000 | Paraformer-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
|
||||
| Alibaba Speech Data | 30000 | Paraformer-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
|
||||
| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online/summary) |
|
||||
| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Tensorflow |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline/summary) |
|
||||
| Alibaba Speech Data | 30000 | UniASR-8K | Online | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-online/summary) |
|
||||
| Alibaba Speech Data | 30000 | UniASR-8K | Offline | CN | Pytorch |[speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline](https://www.modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline/summary) |
|
||||
| AISHELL-1 | 178 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell1-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) |
|
||||
| AISHELL-2 | 1000 | Paraformer | Offline | CN | Pytorch | [speech_paraformer_asr_nat-aishell2-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell2-pytorch/summary) |
|
||||
| AISHELL-1 | 178 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
|
||||
| AISHELL-2 | 1000 | ParaformerBert | Offline | CN | Pytorch | [speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
|
||||
| AISHELL-1 | 178 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) |
|
||||
| AISHELL-2 | 1000 | Conformer | Offline | CN | Pytorch | [speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) |
|
||||
### Speech Recognition Models
|
||||
#### Paraformer Models
|
||||
| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Duration of input wav <= 20s |
|
||||
| [Paraformer-large-long](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which ould deal with arbitrary length input wav |
|
||||
| [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8404 | 220M | Offline | Which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords. |
|
||||
| [Paraformer](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8358 | 68M | Offline | Duration of input wav <= 20s |
|
||||
| [Paraformer-online](https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary) | CN & EN | Alibaba Speech Data (50000hours) | 8404 | 68M | Online | Which could deal with streaming input |
|
||||
| [Paraformer-tiny](https://www.modelscope.cn/models/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/summary) | CN | Alibaba Speech Data (200hours) | 544 | 5.2M | Offline | Lightweight Paraformer model which supports Mandarin command words recognition |
|
||||
| [Paraformer-aishell](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
|
||||
| [ParaformerBert-aishell](https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 43M | Offline | |
|
||||
| [Paraformer-aishell2](https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
|
||||
| [ParaformerBert-aishell2](https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 64M | Offline | |
|
||||
|
||||
|
||||
#### UniASR Models
|
||||
| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------:|:--------:|:--------------------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [UniASR](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 100M | Online | UniASR streaming offline unifying models |
|
||||
| [UniASR-large](https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary) | CN & EN | Alibaba Speech Data (60000hours) | 8358 | 220M | Offline | UniASR streaming offline unifying models |
|
||||
| [UniASR Burmese](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-my-16k-common-vocab696-pytorch/summary) | Burmese | Alibaba Speech Data (? hours) | 696 | 95M | Online | UniASR streaming offline unifying models |
|
||||
| [UniASR Hebrew](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-he-16k-common-vocab1085-pytorch/summary) | Hebrew | Alibaba Speech Data (? hours) | 1085 | 95M | Online | UniASR streaming offline unifying models |
|
||||
| [UniASR Urdu](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-ur-16k-common-vocab877-pytorch/summary) | Urdu | Alibaba Speech Data (? hours) | 877 | 95M | Online | UniASR streaming offline unifying models |
|
||||
|
||||
#### Conformer Models
|
||||
#### Paraformer Models
|
||||
| Model Name | Language | Training Data | Vocab Size | Parameter | Offline/Online | Notes |
|
||||
|:----------------------------------------------------------------------------------------------------------------------:|:--------:|:---------------------:|:----------:|:---------:|:--------------:|:--------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Conformer](https://modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary) | CN | AISHELL (178hours) | 4234 | 44M | Offline | Duration of input wav <= 20s |
|
||||
| [Conformer](https://www.modelscope.cn/models/damo/speech_conformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary) | CN | AISHELL-2 (1000hours) | 5212 | 44M | Offline | Duration of input wav <= 20s |
|
||||
|
||||
#### RNN-T Models
|
||||
|
||||
### Voice Activity Detection Models
|
||||
|
||||
| Model Name | Training Data | Parameters | Sampling Rate | Notes |
|
||||
|:----------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:-------------:|:------|
|
||||
| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) | Alibaba Speech Data (5000hours) | 0.4M | 16000 | |
|
||||
| [FSMN-VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-8k-common/summary) | Alibaba Speech Data (5000hours) | 0.4M | 8000 | |
|
||||
|
||||
### Punctuation Restoration Models
|
||||
|
||||
| Model Name | Training Data | Parameters | Vocab Size| Offline/Online | Notes |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:--------------:|:------|
|
||||
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) | Alibaba Text Data | 70M | 272727 | Offline | offline punctuation model |
|
||||
| [CT-Transformer](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/summary) | Alibaba Text Data | 70M | 272727 | Online | online punctuation model |
|
||||
|
||||
### Language Models
|
||||
|
||||
| Model Name | Training Data | Parameters | Vocab Size | Notes |
|
||||
|:----------------------------------------------------------------------------------------------------------------------:|:----------------------------:|:----------:|:----------:|:------|
|
||||
| [Transformer](https://www.modelscope.cn/models/damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/summary) | Alibaba Speech Data (?hours) | 57M | 8404 | |
|
||||
|
||||
### Speaker Verification Models
|
||||
|
||||
| Model Name | Training Data | Parameters | Vocab Size | Notes |
|
||||
|:-------------------------------------------------------------------------------------------------------------:|:-----------------:|:----------:|:----------:|:------|
|
||||
| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch/summary) | CNCeleb (?hours) | 17.5M | 3465 | |
|
||||
| [Xvector](https://www.modelscope.cn/models/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/summary) | CallHome (?hours) | 61M | 6135 | |
|
||||
|
||||
### Speaker diarization Models
|
||||
|
||||
| Model Name | Training Data | Parameters | Notes |
|
||||
|:----------------------------------------------------------------------------------------------------------------:|:-------------------:|:----------:|:------|
|
||||
| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/summary) | AliMeeting (?hours) | 40.5M | |
|
||||
| [SOND](https://www.modelscope.cn/models/damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch/summary) | CallHome (?hours) | 12M | |
|
||||
|
||||
@ -45,8 +45,8 @@ def compute_wer(ref_file,
|
||||
if out_item['wrong'] > 0:
|
||||
rst['wrong_sentences'] += 1
|
||||
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
|
||||
cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
|
||||
cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
|
||||
cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
|
||||
cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
|
||||
|
||||
if rst['Wrd'] > 0:
|
||||
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
|
||||
|
||||
@ -74,7 +74,7 @@ def modelscope_infer(params):
|
||||
# If text exists, compute CER
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(best_recog_path, "token")
|
||||
text_proc_file = os.path.join(best_recog_path, "text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ def modelscope_infer_after_finetune(params):
|
||||
# computer CER if GT text is set
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -74,7 +74,7 @@ def modelscope_infer(params):
|
||||
# If text exists, compute CER
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(best_recog_path, "token")
|
||||
text_proc_file = os.path.join(best_recog_path, "text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ def modelscope_infer_after_finetune(params):
|
||||
# computer CER if GT text is set
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -63,8 +63,8 @@ fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
|
||||
echo "Computing WER ..."
|
||||
python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
|
||||
python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
|
||||
cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
|
||||
cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
|
||||
python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
|
||||
tail -n 3 ${output_dir}/1best_recog/text.cer
|
||||
fi
|
||||
|
||||
@ -34,7 +34,7 @@ def modelscope_infer_after_finetune(params):
|
||||
# computer CER if GT text is set
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -63,8 +63,8 @@ fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
|
||||
echo "Computing WER ..."
|
||||
python utils/proce_text.py ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
|
||||
python utils/proce_text.py ${data_dir}/text ${output_dir}/1best_recog/text.ref
|
||||
cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
|
||||
cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
|
||||
python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
|
||||
tail -n 3 ${output_dir}/1best_recog/text.cer
|
||||
fi
|
||||
|
||||
@ -34,7 +34,7 @@ def modelscope_infer_after_finetune(params):
|
||||
# computer CER if GT text is set
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -23,8 +23,7 @@ def modelscope_infer_core(output_dir, split_dir, njob, idx):
|
||||
batch_size=1
|
||||
)
|
||||
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
|
||||
inference_pipline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
|
||||
|
||||
inference_pipline(audio_in=audio_in)
|
||||
|
||||
def modelscope_infer(params):
|
||||
# prepare for multi-GPU decoding
|
||||
@ -75,7 +74,7 @@ def modelscope_infer(params):
|
||||
# If text exists, compute CER
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(best_recog_path, "token")
|
||||
text_proc_file = os.path.join(best_recog_path, "text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -2,52 +2,103 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from multiprocessing import Pool
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from funasr.utils.compute_wer import compute_wer
|
||||
|
||||
|
||||
def modelscope_infer_after_finetune_core(model_dir, output_dir, split_dir, njob, idx):
|
||||
output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
|
||||
gpu_id = (int(idx) - 1) // njob
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
|
||||
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model=model_dir,
|
||||
output_dir=output_dir_job,
|
||||
batch_size=1
|
||||
)
|
||||
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
|
||||
inference_pipeline(audio_in=audio_in)
|
||||
|
||||
def modelscope_infer_after_finetune(params):
|
||||
# prepare for decoding
|
||||
# prepare for multi-GPU decoding
|
||||
model_dir = params["model_dir"]
|
||||
pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
|
||||
for file_name in params["required_files"]:
|
||||
if file_name == "configuration.json":
|
||||
with open(os.path.join(pretrained_model_path, file_name)) as f:
|
||||
config_dict = json.load(f)
|
||||
config_dict["model"]["am_model_name"] = params["decoding_model_name"]
|
||||
with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
|
||||
with open(os.path.join(model_dir, "configuration.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=4, separators=(',', ': '))
|
||||
else:
|
||||
shutil.copy(os.path.join(pretrained_model_path, file_name),
|
||||
os.path.join(params["output_dir"], file_name))
|
||||
decoding_path = os.path.join(params["output_dir"], "decode_results")
|
||||
if os.path.exists(decoding_path):
|
||||
shutil.rmtree(decoding_path)
|
||||
os.mkdir(decoding_path)
|
||||
os.path.join(model_dir, file_name))
|
||||
ngpu = params["ngpu"]
|
||||
njob = params["njob"]
|
||||
output_dir = params["output_dir"]
|
||||
if os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
os.mkdir(output_dir)
|
||||
split_dir = os.path.join(output_dir, "split")
|
||||
os.mkdir(split_dir)
|
||||
nj = ngpu * njob
|
||||
wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
|
||||
with open(wav_scp_file) as f:
|
||||
lines = f.readlines()
|
||||
num_lines = len(lines)
|
||||
num_job_lines = num_lines // nj
|
||||
start = 0
|
||||
for i in range(nj):
|
||||
end = start + num_job_lines
|
||||
file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
|
||||
with open(file, "w") as f:
|
||||
if i == nj - 1:
|
||||
f.writelines(lines[start:])
|
||||
else:
|
||||
f.writelines(lines[start:end])
|
||||
start = end
|
||||
|
||||
# decoding
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model=params["output_dir"],
|
||||
output_dir=decoding_path,
|
||||
batch_size=1
|
||||
)
|
||||
audio_in = os.path.join(params["data_dir"], "wav.scp")
|
||||
inference_pipeline(audio_in=audio_in, param_dict={"decoding_model": "offline"})
|
||||
p = Pool(nj)
|
||||
for i in range(nj):
|
||||
p.apply_async(modelscope_infer_after_finetune_core,
|
||||
args=(model_dir, output_dir, split_dir, njob, str(i + 1)))
|
||||
p.close()
|
||||
p.join()
|
||||
|
||||
# computer CER if GT text is set
|
||||
# combine decoding results
|
||||
best_recog_path = os.path.join(output_dir, "1best_recog")
|
||||
os.mkdir(best_recog_path)
|
||||
files = ["text", "token", "score"]
|
||||
for file in files:
|
||||
with open(os.path.join(best_recog_path, file), "w") as f:
|
||||
for i in range(nj):
|
||||
job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
|
||||
with open(job_file) as f_job:
|
||||
lines = f_job.readlines()
|
||||
f.writelines(lines)
|
||||
|
||||
# If text exists, compute CER
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
|
||||
|
||||
text_proc_file = os.path.join(best_recog_path, "token")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {}
|
||||
params["modelscope_model_name"] = "damo/speech_UniASR_asr_2pass-zh-cn-8k-common-vocab3445-pytorch-offline"
|
||||
params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["model_dir"] = "./checkpoint"
|
||||
params["output_dir"] = "./results"
|
||||
params["data_dir"] = "./data/test"
|
||||
params["decoding_model_name"] = "20epoch.pb"
|
||||
params["ngpu"] = 1
|
||||
params["njob"] = 1
|
||||
modelscope_infer_after_finetune(params)
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ def modelscope_infer(params):
|
||||
# If text exists, compute CER
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(best_recog_path, "token")
|
||||
text_proc_file = os.path.join(best_recog_path, "text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -39,7 +39,7 @@ def modelscope_infer_after_finetune(params):
|
||||
# computer CER if GT text is set
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/text")
|
||||
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,9 @@
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
@ -0,0 +1,32 @@
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
# 初始化推理 pipeline
|
||||
# 当以原始音频作为输入时使用配置文件 sond.yaml,并设置 mode 为sond_demo
|
||||
inference_diar_pipline = pipeline(
|
||||
mode="sond_demo",
|
||||
num_workers=0,
|
||||
task=Tasks.speaker_diarization,
|
||||
diar_model_config="sond.yaml",
|
||||
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
|
||||
sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
|
||||
sv_model_revision="master",
|
||||
)
|
||||
|
||||
# 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音
|
||||
audio_list = [
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/record.wav",
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk1.wav",
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk2.wav",
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk3.wav",
|
||||
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/speaker_diarization/spk4.wav",
|
||||
]
|
||||
|
||||
results = inference_diar_pipline(audio_in=audio_list)
|
||||
print(results)
|
||||
@ -2,6 +2,9 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -797,7 +797,7 @@ def inference_modelscope(
|
||||
finish_count += 1
|
||||
# asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
ibest_writer["text"][key] = " ".join(word_lists)
|
||||
|
||||
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
|
||||
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
|
||||
|
||||
@ -19,6 +19,7 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
@ -607,17 +608,21 @@ def inference_modelscope(
|
||||
):
|
||||
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
|
||||
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
|
||||
raw_inputs = torch.tensor(raw_inputs)
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, np.ndarray):
|
||||
raw_inputs = torch.tensor(raw_inputs)
|
||||
is_final = False
|
||||
if param_dict is not None and "cache" in param_dict:
|
||||
cache = param_dict["cache"]
|
||||
if param_dict is not None and "is_final" in param_dict:
|
||||
is_final = param_dict["is_final"]
|
||||
|
||||
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes":
|
||||
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
|
||||
raw_inputs = torch.tensor(raw_inputs)
|
||||
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
|
||||
raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
|
||||
is_final = True
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, np.ndarray):
|
||||
raw_inputs = torch.tensor(raw_inputs)
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
asr_result_list = []
|
||||
|
||||
@ -338,7 +338,7 @@ def inference_modelscope(
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["vad"][key] = "{}".format(vadsegments)
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
ibest_writer["text"][key] = " ".join(word_lists)
|
||||
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
|
||||
if time_stamp_postprocessed is not None:
|
||||
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
|
||||
|
||||
@ -670,7 +670,7 @@ def inference_modelscope(
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["vad"][key] = "{}".format(vadsegments)
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
ibest_writer["text"][key] = " ".join(word_lists)
|
||||
ibest_writer["text_with_punc"][key] = text_postprocessed_punc
|
||||
if time_stamp_postprocessed is not None:
|
||||
ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
|
||||
|
||||
@ -738,13 +738,13 @@ def inference_modelscope(
|
||||
ibest_writer["rtf"][key] = rtf_cur
|
||||
|
||||
if text is not None:
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
# asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
ibest_writer["text"][key] = " ".join(word_lists)
|
||||
|
||||
logging.info("decoding, utt: {}, predictions: {}".format(key, text))
|
||||
rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
|
||||
|
||||
@ -37,9 +37,6 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class
|
||||
@ -507,13 +504,13 @@ def inference_modelscope(
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
|
||||
if text is not None:
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
ibest_writer["text"][key] = " ".join(word_lists)
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
|
||||
@ -507,13 +507,13 @@ def inference_modelscope(
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
|
||||
if text is not None:
|
||||
text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
|
||||
text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text_postprocessed
|
||||
ibest_writer["text"][key] = " ".join(word_lists)
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -1,44 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from funasr.tasks.punctuation import PunctuationTask
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = PunctuationTask.get_parser()
|
||||
parser.add_argument(
|
||||
"--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="local gpu id.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--punc_list",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Punctuation list",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args=None, cmd=None):
|
||||
"""
|
||||
punc training.
|
||||
"""
|
||||
PunctuationTask.main(args=args, cmd=cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
args.distributed = True
|
||||
else:
|
||||
args.distributed = False
|
||||
assert args.num_worker_count == 1
|
||||
|
||||
main(args=args)
|
||||
@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
|
||||
from funasr.datasets.preprocessor import split_to_mini_sentence
|
||||
|
||||
|
||||
class Text2Punc:
|
||||
|
||||
@ -23,7 +23,7 @@ from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.punctuation.text_preprocessor import split_to_mini_sentence
|
||||
from funasr.datasets.preprocessor import split_to_mini_sentence
|
||||
|
||||
|
||||
class Text2Punc:
|
||||
@ -90,7 +90,7 @@ class Text2Punc:
|
||||
data = {
|
||||
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
|
||||
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
|
||||
"vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
|
||||
"vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
|
||||
}
|
||||
data = to_device(data, self.device)
|
||||
y, _ = self.wrapped_model(**data)
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -2,6 +2,9 @@
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
@ -115,7 +115,7 @@ def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
|
||||
# NOTE(kamo): SoundScpReader doesn't support pipe-fashion
|
||||
# like Kaldi e.g. "cat a.wav |".
|
||||
# NOTE(kamo): The audio signal is normalized to [-1,1] range.
|
||||
loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
|
||||
loader = SoundScpReader(path, normalize=True, always_2d=False, dest_sample_rate = dest_sample_rate)
|
||||
|
||||
# SoundScpReader.__getitem__() returns Tuple[int, ndarray],
|
||||
# but ndarray is desired, so Adapter class is inserted here
|
||||
|
||||
@ -47,8 +47,8 @@ def tokenize(data,
|
||||
length = len(text)
|
||||
for i in range(length):
|
||||
x = text[i]
|
||||
if i == length-1 and "punc" in data and text[i].startswith("vad:"):
|
||||
vad = x[-1][4:]
|
||||
if i == length-1 and "punc" in data and x.startswith("vad:"):
|
||||
vad = x[4:]
|
||||
if len(vad) == 0:
|
||||
vad = -1
|
||||
else:
|
||||
|
||||
@ -786,6 +786,7 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
) -> Dict[str, np.ndarray]:
|
||||
for i in range(self.num_tokenizer):
|
||||
text_name = self.text_name[i]
|
||||
#import pdb; pdb.set_trace()
|
||||
if text_name in data and self.tokenizer[i] is not None:
|
||||
text = data[text_name]
|
||||
text = self.text_cleaner(text)
|
||||
@ -800,3 +801,17 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
data[self.vad_name] = np.array([vad], dtype=np.int64)
|
||||
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
return data
|
||||
|
||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||
assert word_limit > 1
|
||||
if len(words) <= word_limit:
|
||||
return [words]
|
||||
sentences = []
|
||||
length = len(words)
|
||||
sentence_len = length // word_limit
|
||||
for i in range(sentence_len):
|
||||
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
||||
if length % word_limit > 0:
|
||||
sentences.append(words[sentence_len * word_limit:])
|
||||
return sentences
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
## Install modelscope and funasr
|
||||
|
||||
The installation is the same as [funasr](../../README.md)
|
||||
The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
|
||||
|
||||
## Export model
|
||||
`Tips`: torch>=1.11.0
|
||||
|
||||
@ -167,31 +167,57 @@ class ModelExport:
|
||||
|
||||
def export(self,
|
||||
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
|
||||
mode: str = 'paraformer',
|
||||
mode: str = None,
|
||||
):
|
||||
|
||||
model_dir = tag_name
|
||||
if model_dir.startswith('damo/'):
|
||||
if model_dir.startswith('damo'):
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
|
||||
asr_train_config = os.path.join(model_dir, 'config.yaml')
|
||||
asr_model_file = os.path.join(model_dir, 'model.pb')
|
||||
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
||||
json_file = os.path.join(model_dir, 'configuration.json')
|
||||
|
||||
if mode is None:
|
||||
import json
|
||||
json_file = os.path.join(model_dir, 'configuration.json')
|
||||
with open(json_file, 'r') as f:
|
||||
config_data = json.load(f)
|
||||
mode = config_data['model']['model_config']['mode']
|
||||
if config_data['task'] == "punctuation":
|
||||
mode = config_data['model']['punc_model_config']['mode']
|
||||
else:
|
||||
mode = config_data['model']['model_config']['mode']
|
||||
if mode.startswith('paraformer'):
|
||||
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
||||
elif mode.startswith('uniasr'):
|
||||
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
||||
config = os.path.join(model_dir, 'config.yaml')
|
||||
model_file = os.path.join(model_dir, 'model.pb')
|
||||
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
||||
model, asr_train_args = ASRTask.build_model_from_file(
|
||||
config, model_file, cmvn_file, 'cpu'
|
||||
)
|
||||
self.frontend = model.frontend
|
||||
elif mode.startswith('offline'):
|
||||
from funasr.tasks.vad import VADTask
|
||||
config = os.path.join(model_dir, 'vad.yaml')
|
||||
model_file = os.path.join(model_dir, 'vad.pb')
|
||||
cmvn_file = os.path.join(model_dir, 'vad.mvn')
|
||||
|
||||
model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, cmvn_file, 'cpu'
|
||||
)
|
||||
self.frontend = model.frontend
|
||||
model, vad_infer_args = VADTask.build_model_from_file(
|
||||
config, model_file, cmvn_file=cmvn_file, device='cpu'
|
||||
)
|
||||
self.export_config["feats_dim"] = 400
|
||||
self.frontend = model.frontend
|
||||
elif mode.startswith('punc'):
|
||||
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
|
||||
punc_train_config = os.path.join(model_dir, 'config.yaml')
|
||||
punc_model_file = os.path.join(model_dir, 'punc.pb')
|
||||
model, punc_train_args = PUNCTask.build_model_from_file(
|
||||
punc_train_config, punc_model_file, 'cpu'
|
||||
)
|
||||
elif mode.startswith('punc_VadRealtime'):
|
||||
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
|
||||
punc_train_config = os.path.join(model_dir, 'config.yaml')
|
||||
punc_model_file = os.path.join(model_dir, 'punc.pb')
|
||||
model, punc_train_args = PUNCTask.build_model_from_file(
|
||||
punc_train_config, punc_model_file, 'cpu'
|
||||
)
|
||||
self._export(model, tag_name)
|
||||
|
||||
|
||||
|
||||
162
funasr/export/models/CT_Transformer.py
Normal file
162
funasr/export/models/CT_Transformer.py
Normal file
@ -0,0 +1,162 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder
|
||||
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
|
||||
from funasr.models.encoder.sanm_encoder import SANMVadEncoder
|
||||
from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
|
||||
|
||||
class CT_Transformer(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
model_name='punc_model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
self.embed = model.embed
|
||||
self.decoder = model.decoder
|
||||
# self.model = model
|
||||
self.feats_dim = self.embed.embedding_dim
|
||||
self.num_embeddings = self.embed.num_embeddings
|
||||
self.model_name = model_name
|
||||
|
||||
if isinstance(model.encoder, SANMEncoder):
|
||||
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
||||
else:
|
||||
assert False, "Only support samn encode."
|
||||
|
||||
def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
# mask = self._target_mask(input)
|
||||
h, _ = self.encoder(x, text_lengths)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
|
||||
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
|
||||
return (text_indexes, text_lengths)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['inputs', 'text_lengths']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'text_lengths': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'logits': {
|
||||
0: 'batch_size',
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CT_Transformer_VadRealtime(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
model_name='punc_model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
|
||||
self.embed = model.embed
|
||||
if isinstance(model.encoder, SANMVadEncoder):
|
||||
self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
|
||||
else:
|
||||
assert False, "Only support samn encode."
|
||||
self.decoder = model.decoder
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
|
||||
def forward(self, inputs: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
sub_masks: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, None]:
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
# mask = self._target_mask(input)
|
||||
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def with_vad(self):
|
||||
return True
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
|
||||
text_lengths = torch.tensor([length], dtype=torch.int32)
|
||||
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
|
||||
sub_masks = torch.ones(length, length, dtype=torch.float32)
|
||||
sub_masks = torch.tril(sub_masks).type(torch.float32)
|
||||
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
|
||||
|
||||
def get_input_names(self):
|
||||
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'vad_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'sub_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'logits': {
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
@ -1,13 +1,25 @@
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
|
||||
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
|
||||
from funasr.models.e2e_uni_asr import UniASR
|
||||
|
||||
from funasr.models.e2e_vad import E2EVadModel
|
||||
from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
|
||||
from funasr.models.target_delay_transformer import TargetDelayTransformer
|
||||
from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
|
||||
from funasr.train.abs_model import PunctuationModel
|
||||
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
|
||||
|
||||
def get_model(model, export_config=None):
|
||||
if isinstance(model, BiCifParaformer):
|
||||
return BiCifParaformer_export(model, **export_config)
|
||||
elif isinstance(model, Paraformer):
|
||||
return Paraformer_export(model, **export_config)
|
||||
elif isinstance(model, E2EVadModel):
|
||||
return E2EVadModel_export(model, **export_config)
|
||||
elif isinstance(model, PunctuationModel):
|
||||
if isinstance(model.punc_model, TargetDelayTransformer):
|
||||
return CT_Transformer_export(model.punc_model, **export_config)
|
||||
elif isinstance(model.punc_model, VadRealtimeTransformer):
|
||||
return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
|
||||
else:
|
||||
raise "Funasr does not support the given model type currently."
|
||||
raise "Funasr does not support the given model type currently."
|
||||
|
||||
@ -19,7 +19,7 @@ from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSA
|
||||
|
||||
class Paraformer(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
@ -112,7 +112,7 @@ class Paraformer(nn.Module):
|
||||
|
||||
class BiCifParaformer(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
60
funasr/export/models/e2e_vad.py
Normal file
60
funasr/export/models/e2e_vad.py
Normal file
@ -0,0 +1,60 @@
|
||||
from enum import Enum
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
from funasr.models.encoder.fsmn_encoder import FSMN
|
||||
from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
|
||||
|
||||
class E2EVadModel(nn.Module):
|
||||
def __init__(self, model,
|
||||
max_seq_len=512,
|
||||
feats_dim=400,
|
||||
model_name='model',
|
||||
**kwargs,):
|
||||
super(E2EVadModel, self).__init__()
|
||||
self.feats_dim = feats_dim
|
||||
self.max_seq_len = max_seq_len
|
||||
self.model_name = model_name
|
||||
if isinstance(model.encoder, FSMN):
|
||||
self.encoder = FSMN_export(model.encoder)
|
||||
else:
|
||||
raise "unsupported encoder"
|
||||
|
||||
|
||||
def forward(self, feats: torch.Tensor, *args, ):
|
||||
|
||||
scores, out_caches = self.encoder(feats, *args)
|
||||
return scores, out_caches
|
||||
|
||||
def get_dummy_inputs(self, frame=30):
|
||||
speech = torch.randn(1, frame, self.feats_dim)
|
||||
in_cache0 = torch.randn(1, 128, 19, 1)
|
||||
in_cache1 = torch.randn(1, 128, 19, 1)
|
||||
in_cache2 = torch.randn(1, 128, 19, 1)
|
||||
in_cache3 = torch.randn(1, 128, 19, 1)
|
||||
|
||||
return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
|
||||
|
||||
# def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
|
||||
# import numpy as np
|
||||
# fbank = np.loadtxt(txt_file)
|
||||
# fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
|
||||
# speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
|
||||
# speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
|
||||
# return (speech, speech_lengths)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'speech': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
}
|
||||
296
funasr/export/models/encoder/fsmn_encoder.py
Executable file
296
funasr/export/models/encoder/fsmn_encoder.py
Executable file
@ -0,0 +1,296 @@
|
||||
from typing import Tuple, Dict
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from funasr.models.encoder.fsmn_encoder import BasicBlock
|
||||
|
||||
class LinearTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LinearTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim, bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AffineTransform(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(AffineTransform, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.linear(input)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RectifiedLinear(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(RectifiedLinear, self).__init__()
|
||||
self.dim = input_dim
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
|
||||
def forward(self, input):
|
||||
out = self.relu(input)
|
||||
return out
|
||||
|
||||
|
||||
class FSMNBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
lorder=None,
|
||||
rorder=None,
|
||||
lstride=1,
|
||||
rstride=1,
|
||||
):
|
||||
super(FSMNBlock, self).__init__()
|
||||
|
||||
self.dim = input_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
|
||||
|
||||
if self.rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
def forward(self, input: torch.Tensor, cache: torch.Tensor):
|
||||
x = torch.unsqueeze(input, 1)
|
||||
x_per = x.permute(0, 3, 2, 1) # B D T C
|
||||
|
||||
cache = cache.to(x_per.device)
|
||||
y_left = torch.cat((cache, x_per), dim=2)
|
||||
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
|
||||
y_left = self.conv_left(y_left)
|
||||
out = x_per + y_left
|
||||
|
||||
if self.conv_right is not None:
|
||||
# maybe need to check
|
||||
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
y_right = self.conv_right(y_right)
|
||||
out += y_right
|
||||
|
||||
out_per = out.permute(0, 3, 2, 1)
|
||||
output = out_per.squeeze(1)
|
||||
|
||||
return output, cache
|
||||
|
||||
|
||||
class BasicBlock_export(nn.Module):
|
||||
def __init__(self,
|
||||
model,
|
||||
):
|
||||
super(BasicBlock_export, self).__init__()
|
||||
self.linear = model.linear
|
||||
self.fsmn_block = model.fsmn_block
|
||||
self.affine = model.affine
|
||||
self.relu = model.relu
|
||||
|
||||
def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
|
||||
x = self.linear(input) # B T D
|
||||
# cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
|
||||
# if cache_layer_name not in in_cache:
|
||||
# in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
|
||||
x, out_cache = self.fsmn_block(x, in_cache)
|
||||
x = self.affine(x)
|
||||
x = self.relu(x)
|
||||
return x, out_cache
|
||||
|
||||
|
||||
# class FsmnStack(nn.Sequential):
|
||||
# def __init__(self, *args):
|
||||
# super(FsmnStack, self).__init__(*args)
|
||||
#
|
||||
# def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
|
||||
# x = input
|
||||
# for module in self._modules.values():
|
||||
# x = module(x, in_cache)
|
||||
# return x
|
||||
|
||||
|
||||
'''
|
||||
FSMN net for keyword spotting
|
||||
input_dim: input dimension
|
||||
linear_dim: fsmn input dimensionll
|
||||
proj_dim: fsmn projection dimension
|
||||
lorder: fsmn left order
|
||||
rorder: fsmn right order
|
||||
num_syn: output dimension
|
||||
fsmn_layers: no. of sequential fsmn layers
|
||||
'''
|
||||
|
||||
|
||||
class FSMN(nn.Module):
|
||||
def __init__(
|
||||
self, model,
|
||||
):
|
||||
super(FSMN, self).__init__()
|
||||
|
||||
# self.input_dim = input_dim
|
||||
# self.input_affine_dim = input_affine_dim
|
||||
# self.fsmn_layers = fsmn_layers
|
||||
# self.linear_dim = linear_dim
|
||||
# self.proj_dim = proj_dim
|
||||
# self.output_affine_dim = output_affine_dim
|
||||
# self.output_dim = output_dim
|
||||
#
|
||||
# self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
|
||||
# self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
|
||||
# self.relu = RectifiedLinear(linear_dim, linear_dim)
|
||||
# self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
|
||||
# range(fsmn_layers)])
|
||||
# self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
|
||||
# self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
|
||||
# self.softmax = nn.Softmax(dim=-1)
|
||||
self.in_linear1 = model.in_linear1
|
||||
self.in_linear2 = model.in_linear2
|
||||
self.relu = model.relu
|
||||
# self.fsmn = model.fsmn
|
||||
self.out_linear1 = model.out_linear1
|
||||
self.out_linear2 = model.out_linear2
|
||||
self.softmax = model.softmax
|
||||
self.fsmn = model.fsmn
|
||||
for i, d in enumerate(model.fsmn):
|
||||
if isinstance(d, BasicBlock):
|
||||
self.fsmn[i] = BasicBlock_export(d)
|
||||
|
||||
def fuse_modules(self):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input (torch.Tensor): Input tensor (B, T, D)
|
||||
in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
|
||||
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
|
||||
"""
|
||||
|
||||
x = self.in_linear1(input)
|
||||
x = self.in_linear2(x)
|
||||
x = self.relu(x)
|
||||
# x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
|
||||
out_caches = list()
|
||||
for i, d in enumerate(self.fsmn):
|
||||
in_cache = args[i]
|
||||
x, out_cache = d(x, in_cache)
|
||||
out_caches.append(out_cache)
|
||||
x = self.out_linear1(x)
|
||||
x = self.out_linear2(x)
|
||||
x = self.softmax(x)
|
||||
|
||||
return x, out_caches
|
||||
|
||||
|
||||
'''
|
||||
one deep fsmn layer
|
||||
dimproj: projection dimension, input and output dimension of memory blocks
|
||||
dimlinear: dimension of mapping layer
|
||||
lorder: left order
|
||||
rorder: right order
|
||||
lstride: left stride
|
||||
rstride: right stride
|
||||
'''
|
||||
|
||||
|
||||
class DFSMN(nn.Module):
|
||||
|
||||
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
|
||||
super(DFSMN, self).__init__()
|
||||
|
||||
self.lorder = lorder
|
||||
self.rorder = rorder
|
||||
self.lstride = lstride
|
||||
self.rstride = rstride
|
||||
|
||||
self.expand = AffineTransform(dimproj, dimlinear)
|
||||
self.shrink = LinearTransform(dimlinear, dimproj)
|
||||
|
||||
self.conv_left = nn.Conv2d(
|
||||
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
|
||||
|
||||
if rorder > 0:
|
||||
self.conv_right = nn.Conv2d(
|
||||
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
|
||||
else:
|
||||
self.conv_right = None
|
||||
|
||||
def forward(self, input):
|
||||
f1 = F.relu(self.expand(input))
|
||||
p1 = self.shrink(f1)
|
||||
|
||||
x = torch.unsqueeze(p1, 1)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
|
||||
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
|
||||
|
||||
if self.conv_right is not None:
|
||||
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
|
||||
y_right = y_right[:, :, self.rstride:, :]
|
||||
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
|
||||
else:
|
||||
out = x_per + self.conv_left(y_left)
|
||||
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
output = input + out1.squeeze(1)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
'''
|
||||
build stacked dfsmn layers
|
||||
'''
|
||||
|
||||
|
||||
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
|
||||
repeats = [
|
||||
nn.Sequential(
|
||||
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
|
||||
for i in range(fsmn_layers)
|
||||
]
|
||||
|
||||
return nn.Sequential(*repeats)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
|
||||
print(fsmn)
|
||||
|
||||
num_params = sum(p.numel() for p in fsmn.parameters())
|
||||
print('the number of model params: {}'.format(num_params))
|
||||
x = torch.zeros(128, 200, 400) # batch-size * time * dim
|
||||
y, _ = fsmn(x) # batch-size * time * dim
|
||||
print('input shape: {}'.format(x.shape))
|
||||
print('output shape: {}'.format(y.shape))
|
||||
|
||||
print(fsmn.to_kaldi_net())
|
||||
@ -9,6 +9,7 @@ from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as Encod
|
||||
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
|
||||
|
||||
|
||||
class SANMEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -107,3 +108,106 @@ class SANMEncoder(nn.Module):
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
class SANMVadEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='encoder',
|
||||
onnx: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = model.embed
|
||||
self.model = model
|
||||
self.feats_dim = feats_dim
|
||||
self._output_size = model._output_size
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
if hasattr(model, 'encoders0'):
|
||||
for i, d in enumerate(self.model.encoders0):
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
|
||||
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForward):
|
||||
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
|
||||
self.model.encoders0[i] = EncoderLayerSANM_export(d)
|
||||
|
||||
for i, d in enumerate(self.model.encoders):
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANM):
|
||||
d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForward):
|
||||
d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
|
||||
self.model.encoders[i] = EncoderLayerSANM_export(d)
|
||||
|
||||
self.model_name = model_name
|
||||
self.num_heads = model.encoders[0].self_attn.h
|
||||
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
|
||||
|
||||
def prepare_mask(self, mask, sub_masks):
|
||||
mask_3d_btd = mask[:, :, None]
|
||||
mask_4d_bhlt = (1 - sub_masks) * -10000.0
|
||||
|
||||
return mask_3d_btd, mask_4d_bhlt
|
||||
|
||||
def forward(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
vad_masks: torch.Tensor,
|
||||
sub_masks: torch.Tensor,
|
||||
):
|
||||
speech = speech * self._output_size ** 0.5
|
||||
mask = self.make_pad_mask(speech_lengths)
|
||||
vad_masks = self.prepare_mask(mask, vad_masks)
|
||||
mask = self.prepare_mask(mask, sub_masks)
|
||||
|
||||
if self.embed is None:
|
||||
xs_pad = speech
|
||||
else:
|
||||
xs_pad = self.embed(speech)
|
||||
|
||||
encoder_outs = self.model.encoders0(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
# encoder_outs = self.model.encoders(xs_pad, mask)
|
||||
for layer_idx, encoder_layer in enumerate(self.model.encoders):
|
||||
if layer_idx == len(self.model.encoders) - 1:
|
||||
mask = vad_masks
|
||||
encoder_outs = encoder_layer(xs_pad, mask)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
xs_pad = self.model.after_norm(xs_pad)
|
||||
|
||||
return xs_pad, speech_lengths
|
||||
|
||||
def get_output_size(self):
|
||||
return self.model.encoders[0].size
|
||||
|
||||
# def get_dummy_inputs(self):
|
||||
# feats = torch.randn(1, 100, self.feats_dim)
|
||||
# return (feats)
|
||||
#
|
||||
# def get_input_names(self):
|
||||
# return ['feats']
|
||||
#
|
||||
# def get_output_names(self):
|
||||
# return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
|
||||
#
|
||||
# def get_dynamic_axes(self):
|
||||
# return {
|
||||
# 'feats': {
|
||||
# 1: 'feats_length'
|
||||
# },
|
||||
# 'encoder_out': {
|
||||
# 1: 'enc_out_length'
|
||||
# },
|
||||
# 'predictor_weight': {
|
||||
# 1: 'pre_out_length'
|
||||
# }
|
||||
#
|
||||
# }
|
||||
|
||||
18
funasr/export/test/test_onnx_punc.py
Normal file
18
funasr/export/test/test_onnx_punc.py
Normal file
@ -0,0 +1,18 @@
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
|
||||
sess = onnxruntime.InferenceSession(onnx_path)
|
||||
input_name = [nd.name for nd in sess.get_inputs()]
|
||||
output_name = [nd.name for nd in sess.get_outputs()]
|
||||
|
||||
def _get_feed_dict(text_length):
|
||||
return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
|
||||
|
||||
def _run(feed_dict):
|
||||
output = sess.run(output_name, input_feed=feed_dict)
|
||||
for name, value in zip(output_name, output):
|
||||
print('{}: {}'.format(name, value))
|
||||
_run(_get_feed_dict(10))
|
||||
22
funasr/export/test/test_onnx_punc_vadrealtime.py
Normal file
22
funasr/export/test/test_onnx_punc_vadrealtime.py
Normal file
@ -0,0 +1,22 @@
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
|
||||
sess = onnxruntime.InferenceSession(onnx_path)
|
||||
input_name = [nd.name for nd in sess.get_inputs()]
|
||||
output_name = [nd.name for nd in sess.get_outputs()]
|
||||
|
||||
def _get_feed_dict(text_length):
|
||||
return {'inputs': np.ones((1, text_length), dtype=np.int64),
|
||||
'text_lengths': np.array([text_length,], dtype=np.int32),
|
||||
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
|
||||
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
|
||||
}
|
||||
|
||||
def _run(feed_dict):
|
||||
output = sess.run(output_name, input_feed=feed_dict)
|
||||
for name, value in zip(output_name, output):
|
||||
print('{}: {}'.format(name, value))
|
||||
_run(_get_feed_dict(10))
|
||||
26
funasr/export/test/test_onnx_vad.py
Normal file
26
funasr/export/test/test_onnx_vad.py
Normal file
@ -0,0 +1,26 @@
|
||||
import onnxruntime
|
||||
import numpy as np
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
|
||||
sess = onnxruntime.InferenceSession(onnx_path)
|
||||
input_name = [nd.name for nd in sess.get_inputs()]
|
||||
output_name = [nd.name for nd in sess.get_outputs()]
|
||||
|
||||
def _get_feed_dict(feats_length):
|
||||
|
||||
return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
|
||||
'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
|
||||
'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
|
||||
'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
|
||||
'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
|
||||
}
|
||||
|
||||
def _run(feed_dict):
|
||||
output = sess.run(output_name, input_feed=feed_dict)
|
||||
for name, value in zip(output_name, output):
|
||||
print('{}: {}'.format(name, value.shape))
|
||||
|
||||
_run(_get_feed_dict(100))
|
||||
_run(_get_feed_dict(200))
|
||||
@ -5,7 +5,17 @@ from typing import Tuple
|
||||
import torch
|
||||
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
"""The abstract LM class
|
||||
@ -27,3 +37,122 @@ class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
self, input: torch.Tensor, hidden: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanguageModel(AbsESPnetModel):
|
||||
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.lm = lm
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
def nll(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll)
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
max_lengths: int
|
||||
"""
|
||||
batch_size = text.size(0)
|
||||
# For data parallel
|
||||
if max_length is None:
|
||||
text = text[:, : text_lengths.max()]
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
|
||||
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
|
||||
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
|
||||
x = F.pad(text, [1, 0], "constant", self.sos)
|
||||
t = F.pad(text, [0, 1], "constant", self.ignore_id)
|
||||
for i, l in enumerate(text_lengths):
|
||||
t[i, l] = self.eos
|
||||
x_lengths = text_lengths + 1
|
||||
|
||||
# 2. Forward Language model
|
||||
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
|
||||
y, _ = self.lm(x, None)
|
||||
|
||||
# 3. Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
|
||||
else:
|
||||
nll.masked_fill_(
|
||||
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
||||
0.0,
|
||||
)
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll, x_lengths
|
||||
|
||||
def batchify_nll(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll) from transformer language model
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
batch_size: int, samples each batch contain when computing nll,
|
||||
you may change this to avoid OOM or increase
|
||||
|
||||
"""
|
||||
total_num = text.size(0)
|
||||
if total_num <= batch_size:
|
||||
nll, x_lengths = self.nll(text, text_lengths)
|
||||
else:
|
||||
nlls = []
|
||||
x_lengths = []
|
||||
max_length = text_lengths.max()
|
||||
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
batch_text = text[start_idx:end_idx, :]
|
||||
batch_text_lengths = text_lengths[start_idx:end_idx]
|
||||
# batch_nll: [B * T]
|
||||
batch_nll, batch_x_lengths = self.nll(
|
||||
batch_text, batch_text_lengths, max_length=max_length
|
||||
)
|
||||
nlls.append(batch_nll)
|
||||
x_lengths.append(batch_x_lengths)
|
||||
start_idx = end_idx
|
||||
if start_idx == total_num:
|
||||
break
|
||||
nll = torch.cat(nlls)
|
||||
x_lengths = torch.cat(x_lengths)
|
||||
assert nll.size(0) == total_num
|
||||
assert x_lengths.size(0) == total_num
|
||||
return nll, x_lengths
|
||||
|
||||
def forward(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, text_lengths)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
|
||||
@ -1,131 +0,0 @@
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
|
||||
class ESPnetLanguageModel(AbsESPnetModel):
|
||||
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.lm = lm
|
||||
self.sos = 1
|
||||
self.eos = 2
|
||||
|
||||
# ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
|
||||
self.ignore_id = ignore_id
|
||||
|
||||
def nll(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
max_length: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll)
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
max_lengths: int
|
||||
"""
|
||||
batch_size = text.size(0)
|
||||
# For data parallel
|
||||
if max_length is None:
|
||||
text = text[:, : text_lengths.max()]
|
||||
else:
|
||||
text = text[:, :max_length]
|
||||
|
||||
# 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
|
||||
# text: (Batch, Length) -> x, y: (Batch, Length + 1)
|
||||
x = F.pad(text, [1, 0], "constant", self.sos)
|
||||
t = F.pad(text, [0, 1], "constant", self.ignore_id)
|
||||
for i, l in enumerate(text_lengths):
|
||||
t[i, l] = self.eos
|
||||
x_lengths = text_lengths + 1
|
||||
|
||||
# 2. Forward Language model
|
||||
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
|
||||
y, _ = self.lm(x, None)
|
||||
|
||||
# 3. Calc negative log likelihood
|
||||
# nll: (BxL,)
|
||||
nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
if max_length is None:
|
||||
nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
|
||||
else:
|
||||
nll.masked_fill_(
|
||||
make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
|
||||
0.0,
|
||||
)
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.view(batch_size, -1)
|
||||
return nll, x_lengths
|
||||
|
||||
def batchify_nll(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute negative log likelihood(nll) from transformer language model
|
||||
|
||||
To avoid OOM, this fuction seperate the input into batches.
|
||||
Then call nll for each batch and combine and return results.
|
||||
Args:
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
batch_size: int, samples each batch contain when computing nll,
|
||||
you may change this to avoid OOM or increase
|
||||
|
||||
"""
|
||||
total_num = text.size(0)
|
||||
if total_num <= batch_size:
|
||||
nll, x_lengths = self.nll(text, text_lengths)
|
||||
else:
|
||||
nlls = []
|
||||
x_lengths = []
|
||||
max_length = text_lengths.max()
|
||||
|
||||
start_idx = 0
|
||||
while True:
|
||||
end_idx = min(start_idx + batch_size, total_num)
|
||||
batch_text = text[start_idx:end_idx, :]
|
||||
batch_text_lengths = text_lengths[start_idx:end_idx]
|
||||
# batch_nll: [B * T]
|
||||
batch_nll, batch_x_lengths = self.nll(
|
||||
batch_text, batch_text_lengths, max_length=max_length
|
||||
)
|
||||
nlls.append(batch_nll)
|
||||
x_lengths.append(batch_x_lengths)
|
||||
start_idx = end_idx
|
||||
if start_idx == total_num:
|
||||
break
|
||||
nll = torch.cat(nlls)
|
||||
x_lengths = torch.cat(x_lengths)
|
||||
assert nll.size(0) == total_num
|
||||
assert x_lengths.size(0) == total_num
|
||||
return nll, x_lengths
|
||||
|
||||
def forward(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
nll, y_lengths = self.nll(text, text_lengths)
|
||||
ntokens = y_lengths.sum()
|
||||
loss = nll.sum() / ntokens
|
||||
stats = dict(loss=loss.detach())
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return {}
|
||||
@ -102,7 +102,7 @@ class ContextualBiasDecoder(nn.Module):
|
||||
|
||||
class ContextualParaformerDecoder(ParaformerSANMDecoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
|
||||
@ -104,7 +104,6 @@ class DecoderLayerSANM(nn.Module):
|
||||
|
||||
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
|
||||
|
||||
|
||||
return x, tgt_mask, memory, memory_mask, cache
|
||||
|
||||
def forward_chunk(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
|
||||
@ -152,7 +151,7 @@ class DecoderLayerSANM(nn.Module):
|
||||
|
||||
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
|
||||
@ -400,7 +399,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
for i in range(self.att_layer_num):
|
||||
decoder = self.decoders[i]
|
||||
c = cache[i]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
@ -410,13 +409,13 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
j = i + self.att_layer_num
|
||||
decoder = self.decoders2[i]
|
||||
c = cache[j]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, memory_mask, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
|
||||
for decoder in self.decoders3:
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder(
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, None, cache=None
|
||||
)
|
||||
|
||||
@ -813,7 +812,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
|
||||
|
||||
class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
@ -1077,7 +1076,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
for i in range(self.att_layer_num):
|
||||
decoder = self.decoders[i]
|
||||
c = cache[i]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, None, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
@ -1087,14 +1086,14 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
|
||||
j = i + self.att_layer_num
|
||||
decoder = self.decoders2[i]
|
||||
c = cache[j]
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder(
|
||||
x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, None, cache=c
|
||||
)
|
||||
new_cache.append(c_ret)
|
||||
|
||||
for decoder in self.decoders3:
|
||||
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder(
|
||||
x, tgt_mask, memory, memory_mask, _ = decoder.forward_chunk(
|
||||
x, tgt_mask, memory, None, cache=None
|
||||
)
|
||||
|
||||
|
||||
@ -405,7 +405,7 @@ class TransformerDecoder(BaseTransformerDecoder):
|
||||
|
||||
class ParaformerDecoderSAN(BaseTransformerDecoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
|
||||
@ -36,7 +36,11 @@ import pdb
|
||||
import random
|
||||
import math
|
||||
class MFCCA(AbsESPnetModel):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
"""
|
||||
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
|
||||
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
|
||||
https://arxiv.org/abs/2210.05265
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -44,7 +44,7 @@ else:
|
||||
|
||||
class Paraformer(AbsESPnetModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
@ -612,7 +612,7 @@ class Paraformer(AbsESPnetModel):
|
||||
|
||||
class ParaformerBert(Paraformer):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition
|
||||
"""
|
||||
|
||||
|
||||
@ -36,8 +36,12 @@ else:
|
||||
|
||||
|
||||
class DiarSondModel(AbsESPnetModel):
|
||||
"""Speaker overlap-aware neural diarization model
|
||||
reference: https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
@ -32,7 +32,7 @@ else:
|
||||
|
||||
class TimestampPredictor(AbsESPnetModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -40,7 +40,7 @@ else:
|
||||
|
||||
class UniASR(AbsESPnetModel):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -35,6 +35,11 @@ class VadDetectMode(Enum):
|
||||
|
||||
|
||||
class VADXOptions:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate: int = 16000,
|
||||
@ -99,6 +104,11 @@ class VADXOptions:
|
||||
|
||||
|
||||
class E2EVadSpeechBufWithDoa(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.start_ms = 0
|
||||
self.end_ms = 0
|
||||
@ -117,6 +127,11 @@ class E2EVadSpeechBufWithDoa(object):
|
||||
|
||||
|
||||
class E2EVadFrameProb(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self):
|
||||
self.noise_prob = 0.0
|
||||
self.speech_prob = 0.0
|
||||
@ -126,6 +141,11 @@ class E2EVadFrameProb(object):
|
||||
|
||||
|
||||
class WindowDetector(object):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
|
||||
speech_to_sil_time: int, frame_size_ms: int):
|
||||
self.window_size_ms = window_size_ms
|
||||
@ -192,7 +212,12 @@ class WindowDetector(object):
|
||||
|
||||
|
||||
class E2EVadModel(nn.Module):
|
||||
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
|
||||
https://arxiv.org/abs/1803.05030
|
||||
"""
|
||||
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
|
||||
super(E2EVadModel, self).__init__()
|
||||
self.vad_opts = VADXOptions(**vad_post_args)
|
||||
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
|
||||
@ -229,6 +254,7 @@ class E2EVadModel(nn.Module):
|
||||
self.data_buf_all = None
|
||||
self.waveform = None
|
||||
self.ResetDetection()
|
||||
self.frontend = frontend
|
||||
|
||||
def AllResetDetection(self):
|
||||
self.is_final = False
|
||||
@ -459,8 +485,8 @@ class E2EVadModel(nn.Module):
|
||||
segment_batch = []
|
||||
if len(self.output_data_buf) > 0:
|
||||
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
||||
if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
|
||||
i].contain_seg_end_point:
|
||||
if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
|
||||
i].contain_seg_end_point):
|
||||
continue
|
||||
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
|
||||
segment_batch.append(segment)
|
||||
@ -477,8 +503,9 @@ class E2EVadModel(nn.Module):
|
||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
|
||||
self.waveform = waveform # compute decibel for each frame
|
||||
self.ComputeDecibel()
|
||||
|
||||
self.ComputeScores(feats, in_cache)
|
||||
self.ComputeDecibel()
|
||||
if not is_final:
|
||||
self.DetectCommonFrames()
|
||||
else:
|
||||
|
||||
@ -67,7 +67,7 @@ class EncoderLayer(nn.Module):
|
||||
|
||||
class ConvEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Convolution encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
|
||||
@ -117,7 +117,7 @@ class EncoderLayer(nn.Module):
|
||||
|
||||
class SelfAttentionEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Self attention encoder in OpenNMT framework
|
||||
"""
|
||||
|
||||
|
||||
@ -406,6 +406,12 @@ class ResNet34Diar(ResNet34):
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
|
||||
):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
|
||||
https://arxiv.org/abs/2211.10243
|
||||
"""
|
||||
|
||||
super(ResNet34Diar, self).__init__(
|
||||
input_size,
|
||||
use_head_conv=use_head_conv,
|
||||
@ -633,6 +639,12 @@ class ResNet34SpL2RegDiar(ResNet34_SP_L2Reg):
|
||||
tf2torch_tensor_name_prefix_torch="encoder",
|
||||
tf2torch_tensor_name_prefix_tf="seq2seq/speech_encoder"
|
||||
):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
TOLD: A Novel Two-Stage Overlap-Aware Framework for Speaker Diarization
|
||||
https://arxiv.org/abs/2303.05397
|
||||
"""
|
||||
|
||||
super(ResNet34SpL2RegDiar, self).__init__(
|
||||
input_size,
|
||||
use_head_conv=use_head_conv,
|
||||
|
||||
@ -10,7 +10,7 @@ from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
@ -27,7 +27,7 @@ from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
from funasr.modules.mask import subsequent_mask, vad_mask
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
@ -117,7 +117,7 @@ class EncoderLayerSANM(nn.Module):
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
|
||||
@ -549,7 +549,7 @@ class SANMEncoder(AbsEncoder):
|
||||
|
||||
class SANMEncoderChunkOpt(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
|
||||
@ -958,3 +958,231 @@ class SANMEncoderChunkOpt(AbsEncoder):
|
||||
var_dict_tf[name_tf].shape))
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
class SANMVadEncoder(AbsEncoder):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
|
||||
no_future_masks = masks & sub_masks
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
|
||||
f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
mask_tup0 = [masks, no_future_masks]
|
||||
encoder_outs = self.encoders0(xs_pad, mask_tup0)
|
||||
xs_pad, _ = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
|
||||
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
if layer_idx + 1 == len(self.encoders):
|
||||
# This is last layer.
|
||||
coner_mask = torch.ones(masks.size(0),
|
||||
masks.size(-1),
|
||||
masks.size(-1),
|
||||
device=xs_pad.device,
|
||||
dtype=torch.bool)
|
||||
for word_index, length in enumerate(ilens):
|
||||
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
|
||||
vad_indexes[word_index],
|
||||
device=xs_pad.device)
|
||||
layer_mask = masks & coner_mask
|
||||
else:
|
||||
layer_mask = no_future_masks
|
||||
mask_tup1 = [masks, layer_mask]
|
||||
encoder_outs = encoder_layer(xs_pad, mask_tup1)
|
||||
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
@ -38,7 +38,7 @@ def load_cmvn(cmvn_file):
|
||||
return cmvn
|
||||
|
||||
|
||||
def apply_cmvn(inputs, cmvn_file): # noqa
|
||||
def apply_cmvn(inputs, cmvn): # noqa
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
@ -47,7 +47,6 @@ def apply_cmvn(inputs, cmvn_file): # noqa
|
||||
dtype = inputs.dtype
|
||||
frame, dim = inputs.shape
|
||||
|
||||
cmvn = load_cmvn(cmvn_file)
|
||||
means = np.tile(cmvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
|
||||
inputs += torch.from_numpy(means).type(dtype).to(device)
|
||||
@ -111,6 +110,7 @@ class WavFrontend(AbsFrontend):
|
||||
self.dither = dither
|
||||
self.snip_edges = snip_edges
|
||||
self.upsacle_samples = upsacle_samples
|
||||
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * self.lfr_m
|
||||
@ -140,8 +140,8 @@ class WavFrontend(AbsFrontend):
|
||||
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.cmvn_file is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn_file)
|
||||
if self.cmvn is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
@ -194,8 +194,8 @@ class WavFrontend(AbsFrontend):
|
||||
mat = input[i, :input_lengths[i], :]
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.cmvn_file is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn_file)
|
||||
if self.cmvn is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
@ -234,6 +234,7 @@ class CifPredictorV2(nn.Module):
|
||||
last_fire_place = len_time - 1
|
||||
last_fire_remainds = 0.0
|
||||
pre_alphas_length = 0
|
||||
last_fire = False
|
||||
|
||||
mask_chunk_peak_predictor = None
|
||||
if cache is not None:
|
||||
@ -251,10 +252,15 @@ class CifPredictorV2(nn.Module):
|
||||
if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold:
|
||||
last_fire_place = len_time - 1 - i
|
||||
last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold
|
||||
last_fire = True
|
||||
break
|
||||
last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
|
||||
cache["cif_hidden"] = hidden[:, last_fire_place:, :]
|
||||
cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
|
||||
if last_fire:
|
||||
last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device)
|
||||
cache["cif_hidden"] = hidden[:, last_fire_place:, :]
|
||||
cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1)
|
||||
else:
|
||||
cache["cif_hidden"] = hidden
|
||||
cache["cif_alphas"] = alphas
|
||||
token_num_int = token_num.floor().type(torch.int32).item()
|
||||
return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak
|
||||
|
||||
|
||||
@ -5,16 +5,19 @@ from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.modules.embedding import PositionalEncoding
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
|
||||
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
|
||||
#from funasr.modules.mask import subsequent_n_mask
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.train.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class TargetDelayTransformer(AbsPunctuation):
|
||||
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
@ -6,12 +6,16 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
|
||||
from funasr.punctuation.abs_model import AbsPunctuation
|
||||
from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
|
||||
from funasr.train.abs_model import AbsPunctuation
|
||||
|
||||
|
||||
class VadRealtimeTransformer(AbsPunctuation):
|
||||
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
|
||||
https://arxiv.org/pdf/2003.01309.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
@ -11,7 +11,7 @@ from funasr.modules.streaming_utils.utils import sequence_mask
|
||||
|
||||
class overlap_chunk():
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
San-m: Memory equipped self-attention for end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
|
||||
|
||||
class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
|
||||
"""The abstract class
|
||||
|
||||
To share the loss calculation way among different models,
|
||||
We uses delegate pattern here:
|
||||
The instance of this class should be passed to "LanguageModel"
|
||||
|
||||
>>> from funasr.punctuation.abs_model import AbsPunctuation
|
||||
>>> punc = AbsPunctuation()
|
||||
>>> model = ESPnetPunctuationModel(punc=punc)
|
||||
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def with_vad(self) -> bool:
|
||||
raise NotImplementedError
|
||||
@ -1,590 +0,0 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
|
||||
from typeguard import check_argument_types
|
||||
import numpy as np
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
|
||||
from funasr.modules.embedding import SinusoidalPositionEncoder
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.subsampling import Conv2dSubsampling
|
||||
from funasr.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.mask import subsequent_mask, vad_mask
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayerSANM, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
class SANMEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
encoder_outs = self.encoders0(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
if len(self.interctc_layer_idx) == 0:
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
class SANMVadEncoder(AbsEncoder):
|
||||
"""
|
||||
author: Speech Lab, Alibaba Group, China
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
interctc_layer_idx: List[int] = [],
|
||||
interctc_use_conditioning: bool = False,
|
||||
kernel_size : int = 11,
|
||||
sanm_shfit : int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d2":
|
||||
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
SinusoidalPositionEncoder(),
|
||||
)
|
||||
elif input_layer is None:
|
||||
if input_size == output_size:
|
||||
self.embed = None
|
||||
else:
|
||||
self.embed = torch.nn.Linear(input_size, output_size)
|
||||
elif input_layer == "pe":
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
|
||||
elif selfattention_layer_type == "sanm":
|
||||
self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = repeat(
|
||||
1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks-1,
|
||||
lambda lnum: EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.interctc_layer_idx = interctc_layer_idx
|
||||
if len(interctc_layer_idx) > 0:
|
||||
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
|
||||
self.interctc_use_conditioning = interctc_use_conditioning
|
||||
self.conditioning_layer = None
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
ctc: CTC = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input length (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
|
||||
no_future_masks = masks & sub_masks
|
||||
xs_pad *= self.output_size()**0.5
|
||||
if self.embed is None:
|
||||
xs_pad = xs_pad
|
||||
elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
|
||||
or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling " +
|
||||
f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# xs_pad = self.dropout(xs_pad)
|
||||
mask_tup0 = [masks, no_future_masks]
|
||||
encoder_outs = self.encoders0(xs_pad, mask_tup0)
|
||||
xs_pad, _ = encoder_outs[0], encoder_outs[1]
|
||||
intermediate_outs = []
|
||||
#if len(self.interctc_layer_idx) == 0:
|
||||
if False:
|
||||
# Here, we should not use the repeat operation to do it for all layers.
|
||||
encoder_outs = self.encoders(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
else:
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
if layer_idx + 1 == len(self.encoders):
|
||||
# This is last layer.
|
||||
coner_mask = torch.ones(masks.size(0),
|
||||
masks.size(-1),
|
||||
masks.size(-1),
|
||||
device=xs_pad.device,
|
||||
dtype=torch.bool)
|
||||
for word_index, length in enumerate(ilens):
|
||||
coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
|
||||
vad_indexes[word_index],
|
||||
device=xs_pad.device)
|
||||
layer_mask = masks & coner_mask
|
||||
else:
|
||||
layer_mask = no_future_masks
|
||||
mask_tup1 = [masks, layer_mask]
|
||||
encoder_outs = encoder_layer(xs_pad, mask_tup1)
|
||||
xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
if layer_idx + 1 in self.interctc_layer_idx:
|
||||
encoder_out = xs_pad
|
||||
|
||||
# intermediate outputs are also normalized
|
||||
if self.normalize_before:
|
||||
encoder_out = self.after_norm(encoder_out)
|
||||
|
||||
intermediate_outs.append((layer_idx + 1, encoder_out))
|
||||
|
||||
if self.interctc_use_conditioning:
|
||||
ctc_out = ctc.softmax(encoder_out)
|
||||
xs_pad = xs_pad + self.conditioning_layer(ctc_out)
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
if len(intermediate_outs) > 0:
|
||||
return (xs_pad, intermediate_outs), olens, None
|
||||
return xs_pad, olens, None
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||
assert word_limit > 1
|
||||
if len(words) <= word_limit:
|
||||
return [words]
|
||||
sentences = []
|
||||
length = len(words)
|
||||
sentence_len = length // word_limit
|
||||
for i in range(sentence_len):
|
||||
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
||||
if length % word_limit > 0:
|
||||
sentences.append(words[sentence_len * word_limit:])
|
||||
return sentences
|
||||
@ -74,7 +74,7 @@ foreach(_target
|
||||
"${_target}.cc")
|
||||
target_link_libraries(${_target}
|
||||
rg_grpc_proto
|
||||
rapidasr
|
||||
funasr
|
||||
${EXTRA_LIBS}
|
||||
${_REFLECTION}
|
||||
${_GRPC_GRPCPP}
|
||||
|
||||
@ -53,6 +53,68 @@ cd ../python/grpc
|
||||
python grpc_main_client_mic.py --host $server_ip --port 10108
|
||||
```
|
||||
|
||||
The `grpc_main_client_mic.py` follows the [original design] (https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc#workflow-in-desgin) by sending audio_data with chunks. If you want to send audio_data in one request, here is an example:
|
||||
|
||||
```
|
||||
# go to ../python/grpc to find this package
|
||||
import paraformer_pb2
|
||||
|
||||
|
||||
class RecognizeStub:
|
||||
def __init__(self, channel):
|
||||
self.Recognize = channel.stream_stream(
|
||||
'/paraformer.ASR/Recognize',
|
||||
request_serializer=paraformer_pb2.Request.SerializeToString,
|
||||
response_deserializer=paraformer_pb2.Response.FromString,
|
||||
)
|
||||
|
||||
|
||||
async def send(channel, data, speaking, isEnd):
|
||||
stub = RecognizeStub(channel)
|
||||
req = paraformer_pb2.Request()
|
||||
if data:
|
||||
req.audio_data = data
|
||||
req.user = 'zz'
|
||||
req.language = 'zh-CN'
|
||||
req.speaking = speaking
|
||||
req.isEnd = isEnd
|
||||
q = queue.SimpleQueue()
|
||||
q.put(req)
|
||||
return stub.Recognize(iter(q.get, None))
|
||||
|
||||
# send the audio data once
|
||||
async def grpc_rec(data, grpc_uri):
|
||||
with grpc.insecure_channel(grpc_uri) as channel:
|
||||
b = time.time()
|
||||
response = await send(channel, data, False, False)
|
||||
resp = response.next()
|
||||
text = ''
|
||||
if 'decoding' == resp.action:
|
||||
resp = response.next()
|
||||
if 'finish' == resp.action:
|
||||
text = json.loads(resp.sentence)['text']
|
||||
response = await send(channel, None, False, True)
|
||||
return {
|
||||
'text': text,
|
||||
'time': time.time() - b,
|
||||
}
|
||||
|
||||
async def test():
|
||||
# fc = FunAsrGrpcClient('127.0.0.1', 9900)
|
||||
# t = await fc.rec(wav.tobytes())
|
||||
# print(t)
|
||||
wav, _ = sf.read('z-10s.wav', dtype='int16')
|
||||
uri = '127.0.0.1:9900'
|
||||
res = await grpc_rec(wav.tobytes(), uri)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(test())
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Acknowledge
|
||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||
2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
#include "paraformer.grpc.pb.h"
|
||||
#include "paraformer_server.h"
|
||||
|
||||
|
||||
using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
@ -24,48 +23,35 @@ using grpc::ServerReaderWriter;
|
||||
using grpc::ServerWriter;
|
||||
using grpc::Status;
|
||||
|
||||
|
||||
using paraformer::Request;
|
||||
using paraformer::Response;
|
||||
using paraformer::ASR;
|
||||
|
||||
ASRServicer::ASRServicer(const char* model_path, int thread_num, bool quantize) {
|
||||
AsrHanlde=RapidAsrInit(model_path, thread_num, quantize);
|
||||
AsrHanlde=FunASRInit(model_path, thread_num, quantize);
|
||||
std::cout << "ASRServicer init" << std::endl;
|
||||
init_flag = 0;
|
||||
}
|
||||
|
||||
void ASRServicer::clear_states(const std::string& user) {
|
||||
clear_buffers(user);
|
||||
clear_transcriptions(user);
|
||||
}
|
||||
|
||||
void ASRServicer::clear_buffers(const std::string& user) {
|
||||
if (client_buffers.count(user)) {
|
||||
client_buffers.erase(user);
|
||||
}
|
||||
}
|
||||
|
||||
void ASRServicer::clear_transcriptions(const std::string& user) {
|
||||
if (client_transcription.count(user)) {
|
||||
client_transcription.erase(user);
|
||||
}
|
||||
}
|
||||
|
||||
void ASRServicer::disconnect(const std::string& user) {
|
||||
clear_states(user);
|
||||
std::cout << "Disconnecting user: " << user << std::endl;
|
||||
}
|
||||
|
||||
grpc::Status ASRServicer::Recognize(
|
||||
grpc::ServerContext* context,
|
||||
grpc::ServerReaderWriter<Response, Request>* stream) {
|
||||
|
||||
Request req;
|
||||
std::unordered_map<std::string, std::string> client_buffers;
|
||||
std::unordered_map<std::string, std::string> client_transcription;
|
||||
|
||||
while (stream->Read(&req)) {
|
||||
if (req.isend()) {
|
||||
std::cout << "asr end" << std::endl;
|
||||
disconnect(req.user());
|
||||
// disconnect
|
||||
if (client_buffers.count(req.user())) {
|
||||
client_buffers.erase(req.user());
|
||||
}
|
||||
if (client_transcription.count(req.user())) {
|
||||
client_transcription.erase(req.user());
|
||||
}
|
||||
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "asr end"})"
|
||||
@ -88,7 +74,7 @@ grpc::Status ASRServicer::Recognize(
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
} else if (!req.speaking()) {
|
||||
if (client_buffers.count(req.user()) == 0) {
|
||||
if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) {
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "waiting_for_voice"})"
|
||||
@ -99,14 +85,24 @@ grpc::Status ASRServicer::Recognize(
|
||||
stream->Write(res);
|
||||
}else {
|
||||
auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
std::string tmp_data = this->client_buffers[req.user()];
|
||||
this->clear_states(req.user());
|
||||
|
||||
if (req.audio_data().size() > 0) {
|
||||
auto& buf = client_buffers[req.user()];
|
||||
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
|
||||
}
|
||||
std::string tmp_data = client_buffers[req.user()];
|
||||
// clear_states
|
||||
if (client_buffers.count(req.user())) {
|
||||
client_buffers.erase(req.user());
|
||||
}
|
||||
if (client_transcription.count(req.user())) {
|
||||
client_transcription.erase(req.user());
|
||||
}
|
||||
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})"
|
||||
);
|
||||
int data_len_int = tmp_data.length();
|
||||
int data_len_int = tmp_data.length();
|
||||
std::string data_len = std::to_string(data_len_int);
|
||||
std::stringstream ss;
|
||||
ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})";
|
||||
@ -129,18 +125,15 @@ grpc::Status ASRServicer::Recognize(
|
||||
res.set_user(req.user());
|
||||
res.set_action("finish");
|
||||
res.set_language(req.language());
|
||||
|
||||
|
||||
|
||||
stream->Write(res);
|
||||
}
|
||||
else {
|
||||
RPASR_RESULT Result= RapidAsrRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL);
|
||||
std::string asr_result = ((RPASR_RECOG_RESULT*)Result)->msg;
|
||||
FUNASR_RESULT Result= FunASRRecogPCMBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, 16000, RASR_NONE, NULL);
|
||||
std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
|
||||
|
||||
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
std::string delay_str = std::to_string(end_time - begin_time);
|
||||
|
||||
|
||||
std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
|
||||
Response res;
|
||||
std::stringstream ss;
|
||||
@ -150,8 +143,7 @@ grpc::Status ASRServicer::Recognize(
|
||||
res.set_user(req.user());
|
||||
res.set_action("finish");
|
||||
res.set_language(req.language());
|
||||
|
||||
|
||||
|
||||
stream->Write(res);
|
||||
}
|
||||
}
|
||||
@ -165,11 +157,10 @@ grpc::Status ASRServicer::Recognize(
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
|
||||
void RunServer(const std::string& port, int thread_num, const char* model_path, bool quantize) {
|
||||
std::string server_address;
|
||||
server_address = "0.0.0.0:" + port;
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#include <chrono>
|
||||
|
||||
#include "paraformer.grpc.pb.h"
|
||||
#include "librapidasrapi.h"
|
||||
#include "libfunasrapi.h"
|
||||
|
||||
|
||||
using grpc::Server;
|
||||
@ -35,22 +35,16 @@ typedef struct
|
||||
{
|
||||
std::string msg;
|
||||
float snippet_time;
|
||||
}RPASR_RECOG_RESULT;
|
||||
}FUNASR_RECOG_RESULT;
|
||||
|
||||
|
||||
class ASRServicer final : public ASR::Service {
|
||||
private:
|
||||
int init_flag;
|
||||
std::unordered_map<std::string, std::string> client_buffers;
|
||||
std::unordered_map<std::string, std::string> client_transcription;
|
||||
|
||||
public:
|
||||
ASRServicer(const char* model_path, int thread_num, bool quantize);
|
||||
void clear_states(const std::string& user);
|
||||
void clear_buffers(const std::string& user);
|
||||
void clear_transcriptions(const std::string& user);
|
||||
void disconnect(const std::string& user);
|
||||
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
|
||||
RPASR_HANDLE AsrHanlde;
|
||||
FUNASR_HANDLE AsrHanlde;
|
||||
|
||||
};
|
||||
|
||||
@ -2,24 +2,27 @@ cmake_minimum_required(VERSION 3.10)
|
||||
|
||||
project(FunASRonnx)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 11)
|
||||
# set(CMAKE_CXX_STANDARD 11)
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
include(TestBigEndian)
|
||||
test_big_endian(BIG_ENDIAN)
|
||||
if(BIG_ENDIAN)
|
||||
message("Big endian system")
|
||||
else()
|
||||
message("Little endian system")
|
||||
endif()
|
||||
|
||||
# for onnxruntime
|
||||
|
||||
IF(WIN32)
|
||||
|
||||
|
||||
if(CMAKE_CL_64)
|
||||
link_directories(${ONNXRUNTIME_DIR}\\lib)
|
||||
else()
|
||||
add_definitions(-D_WIN_X86)
|
||||
endif()
|
||||
ELSE()
|
||||
|
||||
|
||||
link_directories(${ONNXRUNTIME_DIR}/lib)
|
||||
|
||||
link_directories(${ONNXRUNTIME_DIR}/lib)
|
||||
endif()
|
||||
|
||||
add_subdirectory("./third_party/yaml-cpp")
|
||||
|
||||
@ -6,6 +6,13 @@
|
||||
#include <queue>
|
||||
#include <stdint.h>
|
||||
|
||||
#ifndef model_sample_rate
|
||||
#define model_sample_rate 16000
|
||||
#endif
|
||||
#ifndef WAV_HEADER_SIZE
|
||||
#define WAV_HEADER_SIZE 44
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
|
||||
class AudioFrame {
|
||||
@ -32,7 +39,6 @@ class Audio {
|
||||
int16_t *speech_buff;
|
||||
int speech_len;
|
||||
int speech_align_len;
|
||||
int16_t sample_rate;
|
||||
int offset;
|
||||
float align_size;
|
||||
int data_type;
|
||||
@ -43,10 +49,11 @@ class Audio {
|
||||
Audio(int data_type, int size);
|
||||
~Audio();
|
||||
void disp();
|
||||
bool loadwav(const char* filename);
|
||||
bool loadwav(const char* buf, int nLen);
|
||||
bool loadpcmwav(const char* buf, int nFileLen);
|
||||
bool loadpcmwav(const char* filename);
|
||||
bool loadwav(const char* filename, int32_t* sampling_rate);
|
||||
void wavResample(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||
bool loadwav(const char* buf, int nLen, int32_t* sampling_rate);
|
||||
bool loadpcmwav(const char* buf, int nFileLen, int32_t* sampling_rate);
|
||||
bool loadpcmwav(const char* filename, int32_t* sampling_rate);
|
||||
int fetch_chunck(float *&dout, int len);
|
||||
int fetch(float *&dout, int &len, int &flag);
|
||||
void padding();
|
||||
|
||||
77
funasr/runtime/onnxruntime/include/libfunasrapi.h
Normal file
77
funasr/runtime/onnxruntime/include/libfunasrapi.h
Normal file
@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef WIN32
|
||||
#ifdef _FUNASR_API_EXPORT
|
||||
#define _FUNASRAPI __declspec(dllexport)
|
||||
#else
|
||||
#define _FUNASRAPI __declspec(dllimport)
|
||||
#endif
|
||||
#else
|
||||
#define _FUNASRAPI
|
||||
#endif
|
||||
|
||||
#ifndef _WIN32
|
||||
#define FUNASR_CALLBCK_PREFIX __attribute__((__stdcall__))
|
||||
#else
|
||||
#define FUNASR_CALLBCK_PREFIX __stdcall
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef void* FUNASR_HANDLE;
|
||||
typedef void* FUNASR_RESULT;
|
||||
typedef unsigned char FUNASR_BOOL;
|
||||
|
||||
#define FUNASR_TRUE 1
|
||||
#define FUNASR_FALSE 0
|
||||
#define QM_DEFAULT_THREAD_NUM 4
|
||||
|
||||
typedef enum
|
||||
{
|
||||
RASR_NONE=-1,
|
||||
RASRM_CTC_GREEDY_SEARCH=0,
|
||||
RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
|
||||
RASRM_ATTENSION_RESCORING = 2,
|
||||
|
||||
}FUNASR_MODE;
|
||||
|
||||
typedef enum {
|
||||
FUNASR_MODEL_PADDLE = 0,
|
||||
FUNASR_MODEL_PADDLE_2 = 1,
|
||||
FUNASR_MODEL_K2 = 2,
|
||||
FUNASR_MODEL_PARAFORMER = 3,
|
||||
|
||||
}FUNASR_MODEL_TYPE;
|
||||
|
||||
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
|
||||
|
||||
// APIs for qmasr
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThread, bool quantize);
|
||||
|
||||
|
||||
// if not give a fnCallback ,it should be NULL
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex);
|
||||
|
||||
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result);
|
||||
|
||||
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result);
|
||||
|
||||
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE Handle);
|
||||
|
||||
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
}
|
||||
#endif
|
||||
@ -1,77 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#ifdef WIN32
|
||||
#ifdef _RPASR_API_EXPORT
|
||||
#define _RAPIDASRAPI __declspec(dllexport)
|
||||
#else
|
||||
#define _RAPIDASRAPI __declspec(dllimport)
|
||||
#endif
|
||||
#else
|
||||
#define _RAPIDASRAPI
|
||||
#endif
|
||||
|
||||
#ifndef _WIN32
|
||||
#define RPASR_CALLBCK_PREFIX __attribute__((__stdcall__))
|
||||
#else
|
||||
#define RPASR_CALLBCK_PREFIX __stdcall
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef void* RPASR_HANDLE;
|
||||
typedef void* RPASR_RESULT;
|
||||
typedef unsigned char RPASR_BOOL;
|
||||
|
||||
#define RPASR_TRUE 1
|
||||
#define RPASR_FALSE 0
|
||||
#define QM_DEFAULT_THREAD_NUM 4
|
||||
|
||||
typedef enum
|
||||
{
|
||||
RASR_NONE=-1,
|
||||
RASRM_CTC_GREEDY_SEARCH=0,
|
||||
RASRM_CTC_RPEFIX_BEAM_SEARCH = 1,
|
||||
RASRM_ATTENSION_RESCORING = 2,
|
||||
|
||||
}RPASR_MODE;
|
||||
|
||||
typedef enum {
|
||||
RPASR_MODEL_PADDLE = 0,
|
||||
RPASR_MODEL_PADDLE_2 = 1,
|
||||
RPASR_MODEL_K2 = 2,
|
||||
RPASR_MODEL_PARAFORMER = 3,
|
||||
|
||||
}RPASR_MODEL_TYPE;
|
||||
|
||||
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
|
||||
|
||||
// APIs for qmasr
|
||||
_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThread, bool quantize);
|
||||
|
||||
|
||||
// if not give a fnCallback ,it should be NULL
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback);
|
||||
|
||||
_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex);
|
||||
|
||||
_RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result);
|
||||
|
||||
_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result);
|
||||
|
||||
_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE Handle);
|
||||
|
||||
_RAPIDASRAPI const float RapidAsrGetRetSnippetTime(RPASR_RESULT Result);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
}
|
||||
#endif
|
||||
@ -1,83 +1,70 @@
|
||||
|
||||
## 快速使用
|
||||
|
||||
### Windows
|
||||
|
||||
安装Vs2022 打开cpp_onnx目录下的cmake工程,直接 build即可。 本仓库已经准备好所有相关依赖库。
|
||||
|
||||
Windows下已经预置fftw3及onnxruntime库
|
||||
|
||||
### Linux
|
||||
See the bottom of this page: Building Guidance
|
||||
|
||||
### 运行程序
|
||||
|
||||
tester /path/to/models_dir /path/to/wave_file quantize(true or false)
|
||||
|
||||
例如: tester /data/models /data/test.wav false
|
||||
|
||||
/data/models 需要包括如下三个文件: config.yaml, am.mvn, model.onnx(or model_quant.onnx)
|
||||
|
||||
## 支持平台
|
||||
- Windows
|
||||
- Linux/Unix
|
||||
|
||||
## 依赖
|
||||
- fftw3
|
||||
- openblas
|
||||
- onnxruntime
|
||||
|
||||
## 导出onnx格式模型文件
|
||||
安装 modelscope与FunASR,依赖:torch,torchaudio,安装过程[详细参考文档](https://github.com/alibaba-damo-academy/FunASR/wiki)
|
||||
## Demo
|
||||
```shell
|
||||
pip install "modelscope[audio_asr]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
pip install --editable ./
|
||||
tester /path/models_dir /path/wave_file quantize(true or false)
|
||||
```
|
||||
导出onnx模型,[详见](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export),参考示例,从modelscope中模型导出:
|
||||
|
||||
The structure of /path/models_dir
|
||||
```
|
||||
config.yaml, am.mvn, model.onnx(or model_quant.onnx)
|
||||
```
|
||||
|
||||
## Steps
|
||||
|
||||
### Export onnx
|
||||
#### Install [modelscope and funasr](https://github.com/alibaba-damo-academy/FunASR#installation)
|
||||
|
||||
```shell
|
||||
pip3 install torch torchaudio
|
||||
pip install -U modelscope
|
||||
pip install -U funasr
|
||||
```
|
||||
#### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
|
||||
|
||||
```shell
|
||||
python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
|
||||
```
|
||||
|
||||
## Building Guidance for Linux/Unix
|
||||
### Building for Linux/Unix
|
||||
|
||||
```
|
||||
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
|
||||
mkdir build
|
||||
cd build
|
||||
#### Download onnxruntime
|
||||
```shell
|
||||
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
|
||||
# here we get a copy of onnxruntime for linux 64
|
||||
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
|
||||
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
|
||||
# ls
|
||||
# onnxruntime-linux-x64-1.14.0 onnxruntime-linux-x64-1.14.0.tgz
|
||||
```
|
||||
|
||||
#install fftw3-dev
|
||||
ubuntu: apt install libfftw3-dev
|
||||
centos: yum install fftw fftw-devel
|
||||
#### Install fftw3
|
||||
```shell
|
||||
sudo apt install libfftw3-dev #ubuntu
|
||||
# sudo yum install fftw fftw-devel #centos
|
||||
```
|
||||
|
||||
#install openblas
|
||||
bash ./third_party/install_openblas.sh
|
||||
#### Install openblas
|
||||
```shell
|
||||
sudo apt-get install libopenblas-dev #ubuntu
|
||||
# sudo yum -y install openblas-devel #centos
|
||||
```
|
||||
|
||||
# build
|
||||
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
|
||||
make
|
||||
#### Build runtime
|
||||
```shell
|
||||
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
|
||||
mkdir build && cd build
|
||||
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
|
||||
make
|
||||
```
|
||||
|
||||
# then in the subfolder tester of current direcotry, you will see a program, tester
|
||||
|
||||
````
|
||||
|
||||
### The structure of a qualified onnxruntime package.
|
||||
#### The structure of a qualified onnxruntime package.
|
||||
```
|
||||
onnxruntime_xxx
|
||||
├───include
|
||||
└───lib
|
||||
```
|
||||
|
||||
## 注意
|
||||
本程序只支持 采样率16000hz, 位深16bit的 **单声道** 音频。
|
||||
### Building for Windows
|
||||
|
||||
Ref to win/
|
||||
|
||||
## Acknowledge
|
||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||
|
||||
@ -3,11 +3,96 @@
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <fstream>
|
||||
#include <assert.h>
|
||||
|
||||
#include "Audio.h"
|
||||
#include "precomp.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
// see http://soundfile.sapp.org/doc/WaveFormat/
|
||||
// Note: We assume little endian here
|
||||
struct WaveHeader {
|
||||
bool Validate() const {
|
||||
// F F I R
|
||||
if (chunk_id != 0x46464952) {
|
||||
printf("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
|
||||
return false;
|
||||
}
|
||||
// E V A W
|
||||
if (format != 0x45564157) {
|
||||
printf("Expected format WAVE. Given: 0x%08x\n", format);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (subchunk1_id != 0x20746d66) {
|
||||
printf("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
|
||||
subchunk1_id);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (subchunk1_size != 16) { // 16 for PCM
|
||||
printf("Expected subchunk1_size 16. Given: %d\n",
|
||||
subchunk1_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (audio_format != 1) { // 1 for PCM
|
||||
printf("Expected audio_format 1. Given: %d\n", audio_format);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_channels != 1) { // we support only single channel for now
|
||||
printf("Expected single channel. Given: %d\n", num_channels);
|
||||
return false;
|
||||
}
|
||||
if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (block_align != (num_channels * bits_per_sample / 8)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (bits_per_sample != 16) { // we support only 16 bits per sample
|
||||
printf("Expected bits_per_sample 16. Given: %d\n",
|
||||
bits_per_sample);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// See https://en.wikipedia.org/wiki/WAV#Metadata and
|
||||
// https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
|
||||
void SeekToDataChunk(std::istream &is) {
|
||||
// a t a d
|
||||
while (is && subchunk2_id != 0x61746164) {
|
||||
// const char *p = reinterpret_cast<const char *>(&subchunk2_id);
|
||||
// printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
|
||||
// p[1], p[2], p[3], subchunk2_size);
|
||||
is.seekg(subchunk2_size, std::istream::cur);
|
||||
is.read(reinterpret_cast<char *>(&subchunk2_id), sizeof(int32_t));
|
||||
is.read(reinterpret_cast<char *>(&subchunk2_size), sizeof(int32_t));
|
||||
}
|
||||
}
|
||||
|
||||
int32_t chunk_id;
|
||||
int32_t chunk_size;
|
||||
int32_t format;
|
||||
int32_t subchunk1_id;
|
||||
int32_t subchunk1_size;
|
||||
int16_t audio_format;
|
||||
int16_t num_channels;
|
||||
int32_t sample_rate;
|
||||
int32_t byte_rate;
|
||||
int16_t block_align;
|
||||
int16_t bits_per_sample;
|
||||
int32_t subchunk2_id; // a tag of this chunk
|
||||
int32_t subchunk2_size; // size of subchunk2
|
||||
};
|
||||
static_assert(sizeof(WaveHeader) == WAV_HEADER_SIZE, "");
|
||||
|
||||
class AudioWindow {
|
||||
private:
|
||||
int *window;
|
||||
@ -56,7 +141,7 @@ int AudioFrame::set_end(int val, int max_len)
|
||||
float frame_length = 400;
|
||||
float frame_shift = 160;
|
||||
float num_new_samples =
|
||||
ceil((num_samples - 400) / frame_shift) * frame_shift + frame_length;
|
||||
ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length;
|
||||
|
||||
end = start + num_new_samples;
|
||||
len = (int)num_new_samples;
|
||||
@ -111,120 +196,150 @@ Audio::~Audio()
|
||||
|
||||
void Audio::disp()
|
||||
{
|
||||
printf("Audio time is %f s. len is %d\n", (float)speech_len / 16000,
|
||||
printf("Audio time is %f s. len is %d\n", (float)speech_len / model_sample_rate,
|
||||
speech_len);
|
||||
}
|
||||
|
||||
float Audio::get_time_len()
|
||||
{
|
||||
return (float)speech_len / 16000;
|
||||
//speech_len);
|
||||
return (float)speech_len / model_sample_rate;
|
||||
}
|
||||
|
||||
bool Audio::loadwav(const char *filename)
|
||||
void Audio::wavResample(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n)
|
||||
{
|
||||
printf(
|
||||
"Creating a resampler:\n"
|
||||
" in_sample_rate: %d\n"
|
||||
" output_sample_rate: %d\n",
|
||||
sampling_rate, static_cast<int32_t>(model_sample_rate));
|
||||
float min_freq =
|
||||
std::min<int32_t>(sampling_rate, model_sample_rate);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
int32_t lowpass_filter_width = 6;
|
||||
//FIXME
|
||||
//auto resampler = new LinearResample(
|
||||
// sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
|
||||
auto resampler = std::make_unique<LinearResample>(
|
||||
sampling_rate, model_sample_rate, lowpass_cutoff, lowpass_filter_width);
|
||||
std::vector<float> samples;
|
||||
resampler->Resample(waveform, n, true, &samples);
|
||||
//reset speech_data
|
||||
speech_len = samples.size();
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
}
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_len);
|
||||
copy(samples.begin(), samples.end(), speech_data);
|
||||
}
|
||||
|
||||
bool Audio::loadwav(const char *filename, int32_t* sampling_rate)
|
||||
{
|
||||
WaveHeader header;
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
}
|
||||
if (speech_buff != NULL) {
|
||||
free(speech_buff);
|
||||
}
|
||||
|
||||
|
||||
offset = 0;
|
||||
|
||||
FILE *fp;
|
||||
fp = fopen(filename, "rb");
|
||||
if (fp == nullptr)
|
||||
std::ifstream is(filename, std::ifstream::binary);
|
||||
is.read(reinterpret_cast<char *>(&header), sizeof(header));
|
||||
if(!is){
|
||||
fprintf(stderr, "Failed to read %s\n", filename);
|
||||
return false;
|
||||
fseek(fp, 0, SEEK_END); /*定位到文件末尾*/
|
||||
uint32_t nFileLen = ftell(fp); /*得到文件大小*/
|
||||
fseek(fp, 44, SEEK_SET); /*跳过wav文件头*/
|
||||
|
||||
speech_len = (nFileLen - 44) / 2;
|
||||
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
|
||||
speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_align_len);
|
||||
}
|
||||
|
||||
*sampling_rate = header.sample_rate;
|
||||
// header.subchunk2_size contains the number of bytes in the data.
|
||||
// As we assume each sample contains two bytes, so it is divided by 2 here
|
||||
speech_len = header.subchunk2_size / 2;
|
||||
speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
|
||||
|
||||
if (speech_buff)
|
||||
{
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
|
||||
int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
|
||||
fclose(fp);
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
|
||||
is.read(reinterpret_cast<char *>(speech_buff), header.subchunk2_size);
|
||||
if (!is) {
|
||||
fprintf(stderr, "Failed to read %s\n", filename);
|
||||
return false;
|
||||
}
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_len);
|
||||
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_align_len);
|
||||
int i;
|
||||
float scale = 1;
|
||||
|
||||
if (data_type == 1) {
|
||||
scale = 32768;
|
||||
}
|
||||
|
||||
for (i = 0; i < speech_len; i++) {
|
||||
for (int32_t i = 0; i != speech_len; ++i) {
|
||||
speech_data[i] = (float)speech_buff[i] / scale;
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != model_sample_rate){
|
||||
wavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
AudioFrame* frame = new AudioFrame(speech_len);
|
||||
frame_queue.push(frame);
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool Audio::loadwav(const char* buf, int nFileLen)
|
||||
bool Audio::loadwav(const char* buf, int nFileLen, int32_t* sampling_rate)
|
||||
{
|
||||
|
||||
|
||||
|
||||
WaveHeader header;
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
}
|
||||
if (speech_buff != NULL) {
|
||||
free(speech_buff);
|
||||
}
|
||||
|
||||
offset = 0;
|
||||
|
||||
size_t nOffset = 0;
|
||||
std::memcpy(&header, buf, sizeof(header));
|
||||
|
||||
#define WAV_HEADER_SIZE 44
|
||||
|
||||
speech_len = (nFileLen - WAV_HEADER_SIZE) / 2;
|
||||
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
|
||||
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
|
||||
*sampling_rate = header.sample_rate;
|
||||
speech_len = header.subchunk2_size / 2;
|
||||
speech_buff = (int16_t *)malloc(sizeof(int16_t) * speech_len);
|
||||
if (speech_buff)
|
||||
{
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
|
||||
memcpy((void*)speech_buff, (const void*)(buf + WAV_HEADER_SIZE), speech_len * sizeof(int16_t));
|
||||
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_len);
|
||||
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_align_len);
|
||||
int i;
|
||||
float scale = 1;
|
||||
|
||||
if (data_type == 1) {
|
||||
scale = 32768;
|
||||
}
|
||||
|
||||
for (i = 0; i < speech_len; i++) {
|
||||
for (int32_t i = 0; i != speech_len; ++i) {
|
||||
speech_data[i] = (float)speech_buff[i] / scale;
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != model_sample_rate){
|
||||
wavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
AudioFrame* frame = new AudioFrame(speech_len);
|
||||
frame_queue.push(frame);
|
||||
|
||||
return true;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
|
||||
bool Audio::loadpcmwav(const char* buf, int nBufLen)
|
||||
bool Audio::loadpcmwav(const char* buf, int nBufLen, int32_t* sampling_rate)
|
||||
{
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
@ -234,33 +349,29 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen)
|
||||
}
|
||||
offset = 0;
|
||||
|
||||
size_t nOffset = 0;
|
||||
|
||||
|
||||
|
||||
speech_len = nBufLen / 2;
|
||||
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
|
||||
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
|
||||
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
|
||||
if (speech_buff)
|
||||
{
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
|
||||
memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
|
||||
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_len);
|
||||
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_align_len);
|
||||
|
||||
|
||||
int i;
|
||||
float scale = 1;
|
||||
|
||||
if (data_type == 1) {
|
||||
scale = 32768;
|
||||
}
|
||||
|
||||
for (i = 0; i < speech_len; i++) {
|
||||
for (int32_t i = 0; i != speech_len; ++i) {
|
||||
speech_data[i] = (float)speech_buff[i] / scale;
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != model_sample_rate){
|
||||
wavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
AudioFrame* frame = new AudioFrame(speech_len);
|
||||
frame_queue.push(frame);
|
||||
@ -269,13 +380,10 @@ bool Audio::loadpcmwav(const char* buf, int nBufLen)
|
||||
}
|
||||
else
|
||||
return false;
|
||||
|
||||
|
||||
}
|
||||
|
||||
bool Audio::loadpcmwav(const char* filename)
|
||||
bool Audio::loadpcmwav(const char* filename, int32_t* sampling_rate)
|
||||
{
|
||||
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
}
|
||||
@ -293,34 +401,31 @@ bool Audio::loadpcmwav(const char* filename)
|
||||
fseek(fp, 0, SEEK_SET);
|
||||
|
||||
speech_len = (nFileLen) / 2;
|
||||
speech_align_len = (int)(ceil((float)speech_len / align_size) * align_size);
|
||||
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_align_len);
|
||||
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
|
||||
if (speech_buff)
|
||||
{
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_align_len);
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
|
||||
int ret = fread(speech_buff, sizeof(int16_t), speech_len, fp);
|
||||
fclose(fp);
|
||||
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_align_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_align_len);
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_len);
|
||||
|
||||
|
||||
|
||||
int i;
|
||||
float scale = 1;
|
||||
|
||||
if (data_type == 1) {
|
||||
scale = 32768;
|
||||
}
|
||||
|
||||
for (i = 0; i < speech_len; i++) {
|
||||
for (int32_t i = 0; i != speech_len; ++i) {
|
||||
speech_data[i] = (float)speech_buff[i] / scale;
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != model_sample_rate){
|
||||
wavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
AudioFrame* frame = new AudioFrame(speech_len);
|
||||
frame_queue.push(frame);
|
||||
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -329,7 +434,6 @@ bool Audio::loadpcmwav(const char* filename)
|
||||
|
||||
}
|
||||
|
||||
|
||||
int Audio::fetch_chunck(float *&dout, int len)
|
||||
{
|
||||
if (offset >= speech_align_len) {
|
||||
|
||||
@ -1,43 +1,44 @@
|
||||
|
||||
file(GLOB files1 "*.cpp")
|
||||
file(GLOB files2 "*.cc")
|
||||
file(GLOB files4 "paraformer/*.cpp")
|
||||
|
||||
set(files ${files1} ${files2} ${files3} ${files4})
|
||||
|
||||
# message("${files}")
|
||||
|
||||
add_library(rapidasr ${files})
|
||||
add_library(funasr ${files})
|
||||
|
||||
if(WIN32)
|
||||
|
||||
set(EXTRA_LIBS libfftw3f-3 yaml-cpp)
|
||||
if(CMAKE_CL_64)
|
||||
target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
|
||||
target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x64)
|
||||
else()
|
||||
target_link_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
|
||||
target_link_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/lib/x86)
|
||||
endif()
|
||||
target_include_directories(rapidasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
|
||||
target_include_directories(funasr PUBLIC ${CMAKE_SOURCE_DIR}/win/include )
|
||||
|
||||
target_compile_definitions(rapidasr PUBLIC -D_RPASR_API_EXPORT)
|
||||
target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT)
|
||||
else()
|
||||
|
||||
set(EXTRA_LIBS fftw3f pthread yaml-cpp)
|
||||
target_include_directories(rapidasr PUBLIC "/usr/local/opt/fftw/include")
|
||||
target_link_directories(rapidasr PUBLIC "/usr/local/opt/fftw/lib")
|
||||
target_include_directories(funasr PUBLIC "/usr/local/opt/fftw/include")
|
||||
target_link_directories(funasr PUBLIC "/usr/local/opt/fftw/lib")
|
||||
|
||||
target_include_directories(rapidasr PUBLIC "/usr/local/opt/openblas/include")
|
||||
target_link_directories(rapidasr PUBLIC "/usr/local/opt/openblas/lib")
|
||||
target_include_directories(funasr PUBLIC "/usr/local/opt/openblas/include")
|
||||
target_link_directories(funasr PUBLIC "/usr/local/opt/openblas/lib")
|
||||
|
||||
target_include_directories(rapidasr PUBLIC "/usr/include")
|
||||
target_link_directories(rapidasr PUBLIC "/usr/lib64")
|
||||
target_include_directories(funasr PUBLIC "/usr/include")
|
||||
target_link_directories(funasr PUBLIC "/usr/lib64")
|
||||
|
||||
target_include_directories(rapidasr PUBLIC ${FFTW3F_INCLUDE_DIR})
|
||||
target_link_directories(rapidasr PUBLIC ${FFTW3F_LIBRARY_DIR})
|
||||
target_include_directories(funasr PUBLIC ${FFTW3F_INCLUDE_DIR})
|
||||
target_link_directories(funasr PUBLIC ${FFTW3F_LIBRARY_DIR})
|
||||
include_directories(${ONNXRUNTIME_DIR}/include)
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_SOURCE_DIR}/include)
|
||||
target_link_libraries(rapidasr PUBLIC onnxruntime ${EXTRA_LIBS})
|
||||
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
|
||||
|
||||
|
||||
|
||||
|
||||
@ -5,14 +5,10 @@ using namespace std;
|
||||
|
||||
FeatureExtract::FeatureExtract(int mode) : mode(mode)
|
||||
{
|
||||
fftw_init();
|
||||
}
|
||||
|
||||
FeatureExtract::~FeatureExtract()
|
||||
{
|
||||
fftwf_free(fft_input);
|
||||
fftwf_free(fft_out);
|
||||
fftwf_destroy_plan(p);
|
||||
}
|
||||
|
||||
void FeatureExtract::reset()
|
||||
@ -26,34 +22,25 @@ int FeatureExtract::size()
|
||||
return fqueue.size();
|
||||
}
|
||||
|
||||
void FeatureExtract::fftw_init()
|
||||
void FeatureExtract::insert(fftwf_plan plan, float *din, int len, int flag)
|
||||
{
|
||||
int fft_size = 512;
|
||||
fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
|
||||
fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
|
||||
float* fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
|
||||
fftwf_complex* fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
|
||||
memset(fft_input, 0, sizeof(float) * fft_size);
|
||||
p = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
|
||||
}
|
||||
|
||||
void FeatureExtract::insert(float *din, int len, int flag)
|
||||
{
|
||||
const float *window = (const float *)&window_hex;
|
||||
if (mode == 3)
|
||||
window = (const float *)&window_hamm_hex;
|
||||
|
||||
int window_size = 400;
|
||||
int fft_size = 512;
|
||||
int window_shift = 160;
|
||||
|
||||
speech.load(din, len);
|
||||
int i, j;
|
||||
float tmp_feature[80];
|
||||
if (mode == 0 || mode == 2 || mode == 3) {
|
||||
int ll = (speech.size() - 400) / 160 + 1;
|
||||
int ll = (speech.size() - window_size) / window_shift + 1;
|
||||
fqueue.reinit(ll);
|
||||
}
|
||||
|
||||
for (i = 0; i <= speech.size() - 400; i = i + window_shift) {
|
||||
for (i = 0; i <= speech.size() - window_size; i = i + window_shift) {
|
||||
float tmp_mean = 0;
|
||||
for (j = 0; j < window_size; j++) {
|
||||
tmp_mean += speech[i + j];
|
||||
@ -70,7 +57,7 @@ void FeatureExtract::insert(float *din, int len, int flag)
|
||||
pre_val = cur_val;
|
||||
}
|
||||
|
||||
fftwf_execute(p);
|
||||
fftwf_execute_dft_r2c(plan, fft_input, fft_out);
|
||||
|
||||
melspect((float *)fft_out, tmp_feature);
|
||||
int tmp_flag = S_MIDDLE;
|
||||
@ -80,6 +67,8 @@ void FeatureExtract::insert(float *din, int len, int flag)
|
||||
fqueue.push(tmp_feature, tmp_flag);
|
||||
}
|
||||
speech.update(i);
|
||||
fftwf_free(fft_input);
|
||||
fftwf_free(fft_out);
|
||||
}
|
||||
|
||||
bool FeatureExtract::fetch(Tensor<float> *&dout)
|
||||
@ -128,7 +117,6 @@ void FeatureExtract::global_cmvn(float *din)
|
||||
void FeatureExtract::melspect(float *din, float *dout)
|
||||
{
|
||||
float fftmag[256];
|
||||
// float tmp;
|
||||
const float *melcoe = (const float *)melcoe_hex;
|
||||
int i;
|
||||
for (i = 0; i < 256; i++) {
|
||||
|
||||
@ -14,12 +14,11 @@ class FeatureExtract {
|
||||
SpeechWrap speech;
|
||||
FeatureQueue fqueue;
|
||||
int mode;
|
||||
int fft_size = 512;
|
||||
int window_size = 400;
|
||||
int window_shift = 160;
|
||||
|
||||
float *fft_input;
|
||||
fftwf_complex *fft_out;
|
||||
fftwf_plan p;
|
||||
|
||||
void fftw_init();
|
||||
//void fftw_init();
|
||||
void melspect(float *din, float *dout);
|
||||
void global_cmvn(float *din);
|
||||
|
||||
@ -27,9 +26,9 @@ class FeatureExtract {
|
||||
FeatureExtract(int mode);
|
||||
~FeatureExtract();
|
||||
int size();
|
||||
int status();
|
||||
//int status();
|
||||
void reset();
|
||||
void insert(float *din, int len, int flag);
|
||||
void insert(fftwf_plan plan, float *din, int len, int flag);
|
||||
bool fetch(Tensor<float> *&dout);
|
||||
};
|
||||
|
||||
|
||||
@ -13,21 +13,6 @@ Vocab::Vocab(const char *filename)
|
||||
{
|
||||
ifstream in(filename);
|
||||
loadVocabFromYaml(filename);
|
||||
|
||||
/*
|
||||
string line;
|
||||
if (in) // 有该文件
|
||||
{
|
||||
while (getline(in, line)) // line中不包括每行的换行符
|
||||
{
|
||||
vocab.push_back(line);
|
||||
}
|
||||
}
|
||||
else{
|
||||
printf("Cannot load vocab from: %s, there must be file vocab.txt", filename);
|
||||
exit(-1);
|
||||
}
|
||||
*/
|
||||
}
|
||||
Vocab::~Vocab()
|
||||
{
|
||||
|
||||
@ -5,7 +5,7 @@ typedef struct
|
||||
{
|
||||
std::string msg;
|
||||
float snippet_time;
|
||||
}RPASR_RECOG_RESULT;
|
||||
}FUNASR_RECOG_RESULT;
|
||||
|
||||
|
||||
#ifdef _WIN32
|
||||
@ -53,4 +53,4 @@ inline void getOutputName(Ort::Session* session, string& outputName, int nIndex
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,32 +5,33 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
// APIs for qmasr
|
||||
_RAPIDASRAPI RPASR_HANDLE RapidAsrInit(const char* szModelDir, int nThreadNum, bool quantize)
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(const char* szModelDir, int nThreadNum, bool quantize)
|
||||
{
|
||||
Model* mm = create_model(szModelDir, nThreadNum, quantize);
|
||||
return mm;
|
||||
}
|
||||
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
{
|
||||
Model* pRecogObj = (Model*)handle;
|
||||
if (!pRecogObj)
|
||||
return nullptr;
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
Audio audio(1);
|
||||
if (!audio.loadwav(szBuf, nLen))
|
||||
if (!audio.loadwav(szBuf, nLen, &sampling_rate))
|
||||
return nullptr;
|
||||
//audio.split();
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
int flag=0;
|
||||
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
|
||||
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
|
||||
pResult->snippet_time = audio.get_time_len();
|
||||
int nStep = 0;
|
||||
int nTotal = audio.get_queue_size();
|
||||
while (audio.fetch(buff, len, flag) > 0) {
|
||||
pRecogObj->reset();
|
||||
//pRecogObj->reset();
|
||||
string msg = pRecogObj->forward(buff, len, flag);
|
||||
pResult->msg += msg;
|
||||
nStep++;
|
||||
@ -41,26 +42,26 @@ extern "C" {
|
||||
return pResult;
|
||||
}
|
||||
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMBuffer(RPASR_HANDLE handle, const char* szBuf, int nLen, RPASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
{
|
||||
Model* pRecogObj = (Model*)handle;
|
||||
if (!pRecogObj)
|
||||
return nullptr;
|
||||
|
||||
Audio audio(1);
|
||||
if (!audio.loadpcmwav(szBuf, nLen))
|
||||
if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
|
||||
return nullptr;
|
||||
//audio.split();
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
int flag = 0;
|
||||
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
|
||||
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
|
||||
pResult->snippet_time = audio.get_time_len();
|
||||
int nStep = 0;
|
||||
int nTotal = audio.get_queue_size();
|
||||
while (audio.fetch(buff, len, flag) > 0) {
|
||||
pRecogObj->reset();
|
||||
//pRecogObj->reset();
|
||||
string msg = pRecogObj->forward(buff, len, flag);
|
||||
pResult->msg += msg;
|
||||
nStep++;
|
||||
@ -71,26 +72,26 @@ extern "C" {
|
||||
return pResult;
|
||||
}
|
||||
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogPCMFile(RPASR_HANDLE handle, const char* szFileName, RPASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
{
|
||||
Model* pRecogObj = (Model*)handle;
|
||||
if (!pRecogObj)
|
||||
return nullptr;
|
||||
|
||||
Audio audio(1);
|
||||
if (!audio.loadpcmwav(szFileName))
|
||||
if (!audio.loadpcmwav(szFileName, &sampling_rate))
|
||||
return nullptr;
|
||||
//audio.split();
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
int flag = 0;
|
||||
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
|
||||
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
|
||||
pResult->snippet_time = audio.get_time_len();
|
||||
int nStep = 0;
|
||||
int nTotal = audio.get_queue_size();
|
||||
while (audio.fetch(buff, len, flag) > 0) {
|
||||
pRecogObj->reset();
|
||||
//pRecogObj->reset();
|
||||
string msg = pRecogObj->forward(buff, len, flag);
|
||||
pResult->msg += msg;
|
||||
nStep++;
|
||||
@ -101,14 +102,15 @@ extern "C" {
|
||||
return pResult;
|
||||
}
|
||||
|
||||
_RAPIDASRAPI RPASR_RESULT RapidAsrRecogFile(RPASR_HANDLE handle, const char* szWavfile, RPASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
|
||||
{
|
||||
Model* pRecogObj = (Model*)handle;
|
||||
if (!pRecogObj)
|
||||
return nullptr;
|
||||
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
Audio audio(1);
|
||||
if(!audio.loadwav(szWavfile))
|
||||
if(!audio.loadwav(szWavfile, &sampling_rate))
|
||||
return nullptr;
|
||||
//audio.split();
|
||||
|
||||
@ -117,10 +119,10 @@ extern "C" {
|
||||
int flag = 0;
|
||||
int nStep = 0;
|
||||
int nTotal = audio.get_queue_size();
|
||||
RPASR_RECOG_RESULT* pResult = new RPASR_RECOG_RESULT;
|
||||
FUNASR_RECOG_RESULT* pResult = new FUNASR_RECOG_RESULT;
|
||||
pResult->snippet_time = audio.get_time_len();
|
||||
while (audio.fetch(buff, len, flag) > 0) {
|
||||
pRecogObj->reset();
|
||||
//pRecogObj->reset();
|
||||
string msg = pRecogObj->forward(buff, len, flag);
|
||||
pResult->msg+= msg;
|
||||
nStep++;
|
||||
@ -131,7 +133,7 @@ extern "C" {
|
||||
return pResult;
|
||||
}
|
||||
|
||||
_RAPIDASRAPI const int RapidAsrGetRetNumber(RPASR_RESULT Result)
|
||||
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT Result)
|
||||
{
|
||||
if (!Result)
|
||||
return 0;
|
||||
@ -140,32 +142,32 @@ extern "C" {
|
||||
}
|
||||
|
||||
|
||||
_RAPIDASRAPI const float RapidAsrGetRetSnippetTime(RPASR_RESULT Result)
|
||||
_FUNASRAPI const float FunASRGetRetSnippetTime(FUNASR_RESULT Result)
|
||||
{
|
||||
if (!Result)
|
||||
return 0.0f;
|
||||
|
||||
return ((RPASR_RECOG_RESULT*)Result)->snippet_time;
|
||||
return ((FUNASR_RECOG_RESULT*)Result)->snippet_time;
|
||||
}
|
||||
|
||||
_RAPIDASRAPI const char* RapidAsrGetResult(RPASR_RESULT Result,int nIndex)
|
||||
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT Result,int nIndex)
|
||||
{
|
||||
RPASR_RECOG_RESULT * pResult = (RPASR_RECOG_RESULT*)Result;
|
||||
FUNASR_RECOG_RESULT * pResult = (FUNASR_RECOG_RESULT*)Result;
|
||||
if(!pResult)
|
||||
return nullptr;
|
||||
|
||||
return pResult->msg.c_str();
|
||||
}
|
||||
|
||||
_RAPIDASRAPI void RapidAsrFreeResult(RPASR_RESULT Result)
|
||||
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT Result)
|
||||
{
|
||||
if (Result)
|
||||
{
|
||||
delete (RPASR_RECOG_RESULT*)Result;
|
||||
delete (FUNASR_RECOG_RESULT*)Result;
|
||||
}
|
||||
}
|
||||
|
||||
_RAPIDASRAPI void RapidAsrUninit(RPASR_HANDLE handle)
|
||||
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle)
|
||||
{
|
||||
Model* pRecogObj = (Model*)handle;
|
||||
|
||||
@ -4,7 +4,7 @@ using namespace std;
|
||||
using namespace paraformer;
|
||||
|
||||
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
|
||||
{
|
||||
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
|
||||
string model_path;
|
||||
string cmvn_path;
|
||||
string config_path;
|
||||
@ -18,7 +18,10 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
|
||||
cmvn_path = pathAppend(path, "am.mvn");
|
||||
config_path = pathAppend(path, "config.yaml");
|
||||
|
||||
fe = new FeatureExtract(3);
|
||||
fft_input = (float *)fftwf_malloc(sizeof(float) * fft_size);
|
||||
fft_out = (fftwf_complex *)fftwf_malloc(sizeof(fftwf_complex) * fft_size);
|
||||
memset(fft_input, 0, sizeof(float) * fft_size);
|
||||
plan = fftwf_plan_dft_r2c_1d(fft_size, fft_input, fft_out, FFTW_ESTIMATE);
|
||||
|
||||
//sessionOptions.SetInterOpNumThreads(1);
|
||||
sessionOptions.SetIntraOpNumThreads(nNumThread);
|
||||
@ -26,20 +29,20 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
|
||||
|
||||
#ifdef _WIN32
|
||||
wstring wstrPath = strToWstr(model_path);
|
||||
m_session = new Ort::Session(env, wstrPath.c_str(), sessionOptions);
|
||||
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
|
||||
#else
|
||||
m_session = new Ort::Session(env, model_path.c_str(), sessionOptions);
|
||||
m_session = std::make_unique<Ort::Session>(env_, model_path.c_str(), sessionOptions);
|
||||
#endif
|
||||
|
||||
string strName;
|
||||
getInputName(m_session, strName);
|
||||
getInputName(m_session.get(), strName);
|
||||
m_strInputNames.push_back(strName.c_str());
|
||||
getInputName(m_session, strName,1);
|
||||
getInputName(m_session.get(), strName,1);
|
||||
m_strInputNames.push_back(strName);
|
||||
|
||||
getOutputName(m_session, strName);
|
||||
getOutputName(m_session.get(), strName);
|
||||
m_strOutputNames.push_back(strName);
|
||||
getOutputName(m_session, strName,1);
|
||||
getOutputName(m_session.get(), strName,1);
|
||||
m_strOutputNames.push_back(strName);
|
||||
|
||||
for (auto& item : m_strInputNames)
|
||||
@ -52,20 +55,16 @@ ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
|
||||
|
||||
ModelImp::~ModelImp()
|
||||
{
|
||||
if(fe)
|
||||
delete fe;
|
||||
if (m_session)
|
||||
{
|
||||
delete m_session;
|
||||
m_session = nullptr;
|
||||
}
|
||||
if(vocab)
|
||||
delete vocab;
|
||||
fftwf_free(fft_input);
|
||||
fftwf_free(fft_out);
|
||||
fftwf_destroy_plan(plan);
|
||||
fftwf_cleanup();
|
||||
}
|
||||
|
||||
void ModelImp::reset()
|
||||
{
|
||||
fe->reset();
|
||||
}
|
||||
|
||||
void ModelImp::apply_lfr(Tensor<float>*& din)
|
||||
@ -159,14 +158,21 @@ string ModelImp::greedy_search(float * in, int nLen )
|
||||
|
||||
string ModelImp::forward(float* din, int len, int flag)
|
||||
{
|
||||
|
||||
Tensor<float>* in;
|
||||
fe->insert(din, len, flag);
|
||||
FeatureExtract* fe = new FeatureExtract(3);
|
||||
fe->reset();
|
||||
fe->insert(plan, din, len, flag);
|
||||
fe->fetch(in);
|
||||
apply_lfr(in);
|
||||
apply_cmvn(in);
|
||||
Ort::RunOptions run_option;
|
||||
|
||||
#ifdef _WIN_X86
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
#else
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
#endif
|
||||
|
||||
std::array<int64_t, 3> input_shape_{ in->size[0],in->size[2],in->size[3] };
|
||||
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(m_memoryInfo,
|
||||
in->buff,
|
||||
@ -192,7 +198,6 @@ string ModelImp::forward(float* din, int len, int flag)
|
||||
auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
|
||||
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
|
||||
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
|
||||
float* floatData = outputTensor[0].GetTensorMutableData<float>();
|
||||
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
|
||||
@ -203,9 +208,14 @@ string ModelImp::forward(float* din, int len, int flag)
|
||||
result = "";
|
||||
}
|
||||
|
||||
|
||||
if(in)
|
||||
if(in){
|
||||
delete in;
|
||||
in = nullptr;
|
||||
}
|
||||
if(fe){
|
||||
delete fe;
|
||||
fe = nullptr;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -8,7 +8,10 @@ namespace paraformer {
|
||||
|
||||
class ModelImp : public Model {
|
||||
private:
|
||||
FeatureExtract* fe;
|
||||
int fft_size=512;
|
||||
float *fft_input;
|
||||
fftwf_complex *fft_out;
|
||||
fftwf_plan plan;
|
||||
|
||||
Vocab* vocab;
|
||||
vector<float> means_list;
|
||||
@ -21,21 +24,13 @@ namespace paraformer {
|
||||
|
||||
string greedy_search( float* in, int nLen);
|
||||
|
||||
#ifdef _WIN_X86
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
#else
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
#endif
|
||||
|
||||
Ort::Session* m_session = nullptr;
|
||||
Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "paraformer");
|
||||
Ort::SessionOptions sessionOptions = Ort::SessionOptions();
|
||||
std::unique_ptr<Ort::Session> m_session;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sessionOptions;
|
||||
|
||||
vector<string> m_strInputNames, m_strOutputNames;
|
||||
vector<const char*> m_szInputNames;
|
||||
vector<const char*> m_szOutputNames;
|
||||
//string m_strInputName, m_strInputNameLen;
|
||||
//string m_strOutputName, m_strOutputNameLen;
|
||||
|
||||
public:
|
||||
ModelImp(const char* path, int nNumThread=0, bool quantize=false);
|
||||
|
||||
@ -44,9 +44,10 @@ using namespace std;
|
||||
#include "FeatureQueue.h"
|
||||
#include "SpeechWrap.h"
|
||||
#include <Audio.h>
|
||||
#include "resample.h"
|
||||
#include "Model.h"
|
||||
#include "paraformer_onnx.h"
|
||||
#include "librapidasrapi.h"
|
||||
#include "libfunasrapi.h"
|
||||
|
||||
|
||||
using namespace paraformer;
|
||||
|
||||
305
funasr/runtime/onnxruntime/src/resample.cc
Normal file
305
funasr/runtime/onnxruntime/src/resample.cc
Normal file
@ -0,0 +1,305 @@
|
||||
/**
|
||||
* Copyright 2013 Pegah Ghahremani
|
||||
* 2014 IMSL, PKU-HKUST (author: Wei Shi)
|
||||
* 2014 Yanqing Sun, Junjie Wang
|
||||
* 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
* Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
// this file is copied and modified from
|
||||
// kaldi/src/feat/resample.cc
|
||||
|
||||
#include "resample.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef M_2PI
|
||||
#define M_2PI 6.283185307179586476925286766559005
|
||||
#endif
|
||||
|
||||
#ifndef M_PI
|
||||
#define M_PI 3.1415926535897932384626433832795
|
||||
#endif
|
||||
|
||||
template <class I>
|
||||
I Gcd(I m, I n) {
|
||||
// this function is copied from kaldi/src/base/kaldi-math.h
|
||||
if (m == 0 || n == 0) {
|
||||
if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
|
||||
fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
|
||||
exit(-1);
|
||||
}
|
||||
return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
|
||||
// return absolute value of whichever is nonzero
|
||||
}
|
||||
// could use compile-time assertion
|
||||
// but involves messing with complex template stuff.
|
||||
static_assert(std::is_integral<I>::value, "");
|
||||
while (1) {
|
||||
m %= n;
|
||||
if (m == 0) return (n > 0 ? n : -n);
|
||||
n %= m;
|
||||
if (n == 0) return (m > 0 ? m : -m);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the least common multiple of two integers. Will
|
||||
/// crash unless the inputs are positive.
|
||||
template <class I>
|
||||
I Lcm(I m, I n) {
|
||||
// This function is copied from kaldi/src/base/kaldi-math.h
|
||||
assert(m > 0 && n > 0);
|
||||
I gcd = Gcd(m, n);
|
||||
return gcd * (m / gcd) * (n / gcd);
|
||||
}
|
||||
|
||||
static float DotProduct(const float *a, const float *b, int32_t n) {
|
||||
float sum = 0;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
LinearResample::LinearResample(int32_t samp_rate_in_hz,
|
||||
int32_t samp_rate_out_hz, float filter_cutoff_hz,
|
||||
int32_t num_zeros)
|
||||
: samp_rate_in_(samp_rate_in_hz),
|
||||
samp_rate_out_(samp_rate_out_hz),
|
||||
filter_cutoff_(filter_cutoff_hz),
|
||||
num_zeros_(num_zeros) {
|
||||
assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 &&
|
||||
filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz &&
|
||||
filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0);
|
||||
|
||||
// base_freq is the frequency of the repeating unit, which is the gcd
|
||||
// of the input frequencies.
|
||||
int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_);
|
||||
input_samples_in_unit_ = samp_rate_in_ / base_freq;
|
||||
output_samples_in_unit_ = samp_rate_out_ / base_freq;
|
||||
|
||||
SetIndexesAndWeights();
|
||||
Reset();
|
||||
}
|
||||
|
||||
void LinearResample::SetIndexesAndWeights() {
|
||||
first_index_.resize(output_samples_in_unit_);
|
||||
weights_.resize(output_samples_in_unit_);
|
||||
|
||||
double window_width = num_zeros_ / (2.0 * filter_cutoff_);
|
||||
|
||||
for (int32_t i = 0; i < output_samples_in_unit_; i++) {
|
||||
double output_t = i / static_cast<double>(samp_rate_out_);
|
||||
double min_t = output_t - window_width, max_t = output_t + window_width;
|
||||
// we do ceil on the min and floor on the max, because if we did it
|
||||
// the other way around we would unnecessarily include indexes just
|
||||
// outside the window, with zero coefficients. It's possible
|
||||
// if the arguments to the ceil and floor expressions are integers
|
||||
// (e.g. if filter_cutoff_ has an exact ratio with the sample rates),
|
||||
// that we unnecessarily include something with a zero coefficient,
|
||||
// but this is only a slight efficiency issue.
|
||||
int32_t min_input_index = ceil(min_t * samp_rate_in_),
|
||||
max_input_index = floor(max_t * samp_rate_in_),
|
||||
num_indices = max_input_index - min_input_index + 1;
|
||||
first_index_[i] = min_input_index;
|
||||
weights_[i].resize(num_indices);
|
||||
for (int32_t j = 0; j < num_indices; j++) {
|
||||
int32_t input_index = min_input_index + j;
|
||||
double input_t = input_index / static_cast<double>(samp_rate_in_),
|
||||
delta_t = input_t - output_t;
|
||||
// sign of delta_t doesn't matter.
|
||||
weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Here, t is a time in seconds representing an offset from
|
||||
the center of the windowed filter function, and FilterFunction(t)
|
||||
returns the windowed filter function, described
|
||||
in the header as h(t) = f(t)g(t), evaluated at t.
|
||||
*/
|
||||
float LinearResample::FilterFunc(float t) const {
|
||||
float window, // raised-cosine (Hanning) window of width
|
||||
// num_zeros_/2*filter_cutoff_
|
||||
filter; // sinc filter function
|
||||
if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_))
|
||||
window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t));
|
||||
else
|
||||
window = 0.0; // outside support of window function
|
||||
if (t != 0)
|
||||
filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t);
|
||||
else
|
||||
filter = 2 * filter_cutoff_; // limit of the function at t = 0
|
||||
return filter * window;
|
||||
}
|
||||
|
||||
void LinearResample::Reset() {
|
||||
input_sample_offset_ = 0;
|
||||
output_sample_offset_ = 0;
|
||||
input_remainder_.resize(0);
|
||||
}
|
||||
|
||||
void LinearResample::Resample(const float *input, int32_t input_dim, bool flush,
|
||||
std::vector<float> *output) {
|
||||
int64_t tot_input_samp = input_sample_offset_ + input_dim,
|
||||
tot_output_samp = GetNumOutputSamples(tot_input_samp, flush);
|
||||
|
||||
assert(tot_output_samp >= output_sample_offset_);
|
||||
|
||||
output->resize(tot_output_samp - output_sample_offset_);
|
||||
|
||||
// samp_out is the index into the total output signal, not just the part
|
||||
// of it we are producing here.
|
||||
for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp;
|
||||
samp_out++) {
|
||||
int64_t first_samp_in;
|
||||
int32_t samp_out_wrapped;
|
||||
GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped);
|
||||
const std::vector<float> &weights = weights_[samp_out_wrapped];
|
||||
// first_input_index is the first index into "input" that we have a weight
|
||||
// for.
|
||||
int32_t first_input_index =
|
||||
static_cast<int32_t>(first_samp_in - input_sample_offset_);
|
||||
float this_output;
|
||||
if (first_input_index >= 0 &&
|
||||
first_input_index + static_cast<int32_t>(weights.size()) <= input_dim) {
|
||||
this_output =
|
||||
DotProduct(input + first_input_index, weights.data(), weights.size());
|
||||
} else { // Handle edge cases.
|
||||
this_output = 0.0;
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(weights.size()); i++) {
|
||||
float weight = weights[i];
|
||||
int32_t input_index = first_input_index + i;
|
||||
if (input_index < 0 &&
|
||||
static_cast<int32_t>(input_remainder_.size()) + input_index >= 0) {
|
||||
this_output +=
|
||||
weight * input_remainder_[input_remainder_.size() + input_index];
|
||||
} else if (input_index >= 0 && input_index < input_dim) {
|
||||
this_output += weight * input[input_index];
|
||||
} else if (input_index >= input_dim) {
|
||||
// We're past the end of the input and are adding zero; should only
|
||||
// happen if the user specified flush == true, or else we would not
|
||||
// be trying to output this sample.
|
||||
assert(flush);
|
||||
}
|
||||
}
|
||||
}
|
||||
int32_t output_index =
|
||||
static_cast<int32_t>(samp_out - output_sample_offset_);
|
||||
(*output)[output_index] = this_output;
|
||||
}
|
||||
|
||||
if (flush) {
|
||||
Reset(); // Reset the internal state.
|
||||
} else {
|
||||
SetRemainder(input, input_dim);
|
||||
input_sample_offset_ = tot_input_samp;
|
||||
output_sample_offset_ = tot_output_samp;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp,
|
||||
bool flush) const {
|
||||
// For exact computation, we measure time in "ticks" of 1.0 / tick_freq,
|
||||
// where tick_freq is the least common multiple of samp_rate_in_ and
|
||||
// samp_rate_out_.
|
||||
int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_);
|
||||
int32_t ticks_per_input_period = tick_freq / samp_rate_in_;
|
||||
|
||||
// work out the number of ticks in the time interval
|
||||
// [ 0, input_num_samp/samp_rate_in_ ).
|
||||
int64_t interval_length_in_ticks = input_num_samp * ticks_per_input_period;
|
||||
if (!flush) {
|
||||
float window_width = num_zeros_ / (2.0 * filter_cutoff_);
|
||||
// To count the window-width in ticks we take the floor. This
|
||||
// is because since we're looking for the largest integer num-out-samp
|
||||
// that fits in the interval, which is open on the right, a reduction
|
||||
// in interval length of less than a tick will never make a difference.
|
||||
// For example, the largest integer in the interval [ 0, 2 ) and the
|
||||
// largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one).
|
||||
// So when we're subtracting the window-width we can ignore the fractional
|
||||
// part.
|
||||
int32_t window_width_ticks = floor(window_width * tick_freq);
|
||||
// The time-period of the output that we can sample gets reduced
|
||||
// by the window-width (which is actually the distance from the
|
||||
// center to the edge of the windowing function) if we're not
|
||||
// "flushing the output".
|
||||
interval_length_in_ticks -= window_width_ticks;
|
||||
}
|
||||
if (interval_length_in_ticks <= 0) return 0;
|
||||
|
||||
int32_t ticks_per_output_period = tick_freq / samp_rate_out_;
|
||||
// Get the last output-sample in the closed interval, i.e. replacing [ ) with
|
||||
// [ ]. Note: integer division rounds down. See
|
||||
// http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of
|
||||
// the notation.
|
||||
int64_t last_output_samp = interval_length_in_ticks / ticks_per_output_period;
|
||||
// We need the last output-sample in the open interval, so if it takes us to
|
||||
// the end of the interval exactly, subtract one.
|
||||
if (last_output_samp * ticks_per_output_period == interval_length_in_ticks)
|
||||
last_output_samp--;
|
||||
|
||||
// First output-sample index is zero, so the number of output samples
|
||||
// is the last output-sample plus one.
|
||||
int64_t num_output_samp = last_output_samp + 1;
|
||||
return num_output_samp;
|
||||
}
|
||||
|
||||
// inline
|
||||
void LinearResample::GetIndexes(int64_t samp_out, int64_t *first_samp_in,
|
||||
int32_t *samp_out_wrapped) const {
|
||||
// A unit is the smallest nonzero amount of time that is an exact
|
||||
// multiple of the input and output sample periods. The unit index
|
||||
// is the answer to "which numbered unit we are in".
|
||||
int64_t unit_index = samp_out / output_samples_in_unit_;
|
||||
// samp_out_wrapped is equal to samp_out % output_samples_in_unit_
|
||||
*samp_out_wrapped =
|
||||
static_cast<int32_t>(samp_out - unit_index * output_samples_in_unit_);
|
||||
*first_samp_in =
|
||||
first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_;
|
||||
}
|
||||
|
||||
void LinearResample::SetRemainder(const float *input, int32_t input_dim) {
|
||||
std::vector<float> old_remainder(input_remainder_);
|
||||
// max_remainder_needed is the width of the filter from side to side,
|
||||
// measured in input samples. you might think it should be half that,
|
||||
// but you have to consider that you might be wanting to output samples
|
||||
// that are "in the past" relative to the beginning of the latest
|
||||
// input... anyway, storing more remainder than needed is not harmful.
|
||||
int32_t max_remainder_needed =
|
||||
ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_);
|
||||
input_remainder_.resize(max_remainder_needed);
|
||||
for (int32_t index = -static_cast<int32_t>(input_remainder_.size());
|
||||
index < 0; index++) {
|
||||
// we interpret "index" as an offset from the end of "input" and
|
||||
// from the end of input_remainder_.
|
||||
int32_t input_index = index + input_dim;
|
||||
if (input_index >= 0) {
|
||||
input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
|
||||
input[input_index];
|
||||
} else if (input_index + static_cast<int32_t>(old_remainder.size()) >= 0) {
|
||||
input_remainder_[index + static_cast<int32_t>(input_remainder_.size())] =
|
||||
old_remainder[input_index +
|
||||
static_cast<int32_t>(old_remainder.size())];
|
||||
// else leave it at zero.
|
||||
}
|
||||
}
|
||||
}
|
||||
137
funasr/runtime/onnxruntime/src/resample.h
Normal file
137
funasr/runtime/onnxruntime/src/resample.h
Normal file
@ -0,0 +1,137 @@
|
||||
/**
|
||||
* Copyright 2013 Pegah Ghahremani
|
||||
* 2014 IMSL, PKU-HKUST (author: Wei Shi)
|
||||
* 2014 Yanqing Sun, Junjie Wang
|
||||
* 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
* Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||
*
|
||||
* See LICENSE for clarification regarding multiple authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
// this file is copied and modified from
|
||||
// kaldi/src/feat/resample.h
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
|
||||
/*
|
||||
We require that the input and output sampling rate be specified as
|
||||
integers, as this is an easy way to specify that their ratio be rational.
|
||||
*/
|
||||
|
||||
class LinearResample {
|
||||
public:
|
||||
/// Constructor. We make the input and output sample rates integers, because
|
||||
/// we are going to need to find a common divisor. This should just remind
|
||||
/// you that they need to be integers. The filter cutoff needs to be less
|
||||
/// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros
|
||||
/// controls the sharpness of the filter, more == sharper but less efficient.
|
||||
/// We suggest around 4 to 10 for normal use.
|
||||
LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz,
|
||||
float filter_cutoff_hz, int32_t num_zeros);
|
||||
|
||||
/// Calling the function Reset() resets the state of the object prior to
|
||||
/// processing a new signal; it is only necessary if you have called
|
||||
/// Resample(x, x_size, false, y) for some signal, leading to a remainder of
|
||||
/// the signal being called, but then abandon processing the signal before
|
||||
/// calling Resample(x, x_size, true, y) for the last piece. Call it
|
||||
/// unnecessarily between signals will not do any harm.
|
||||
void Reset();
|
||||
|
||||
/// This function does the resampling. If you call it with flush == true and
|
||||
/// you have never called it with flush == false, it just resamples the input
|
||||
/// signal (it resizes the output to a suitable number of samples).
|
||||
///
|
||||
/// You can also use this function to process a signal a piece at a time.
|
||||
/// suppose you break it into piece1, piece2, ... pieceN. You can call
|
||||
/// \code{.cc}
|
||||
/// Resample(piece1, piece1_size, false, &output1);
|
||||
/// Resample(piece2, piece2_size, false, &output2);
|
||||
/// Resample(piece3, piece3_size, true, &output3);
|
||||
/// \endcode
|
||||
/// If you call it with flush == false, it won't output the last few samples
|
||||
/// but will remember them, so that if you later give it a second piece of
|
||||
/// the input signal it can process it correctly.
|
||||
/// If your most recent call to the object was with flush == false, it will
|
||||
/// have internal state; you can remove this by calling Reset().
|
||||
/// Empty input is acceptable.
|
||||
void Resample(const float *input, int32_t input_dim, bool flush,
|
||||
std::vector<float> *output);
|
||||
|
||||
//// Return the input and output sampling rates (for checks, for example)
|
||||
int32_t GetInputSamplingRate() const { return samp_rate_in_; }
|
||||
int32_t GetOutputSamplingRate() const { return samp_rate_out_; }
|
||||
|
||||
private:
|
||||
void SetIndexesAndWeights();
|
||||
|
||||
float FilterFunc(float) const;
|
||||
|
||||
/// This function outputs the number of output samples we will output
|
||||
/// for a signal with "input_num_samp" input samples. If flush == true,
|
||||
/// we return the largest n such that
|
||||
/// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ),
|
||||
/// and note that the interval is half-open. If flush == false,
|
||||
/// define window_width as num_zeros / (2.0 * filter_cutoff_);
|
||||
/// we return the largest n such that (n/samp_rate_out_) is in the interval
|
||||
/// [ 0, input_num_samp/samp_rate_in_ - window_width ).
|
||||
int64_t GetNumOutputSamples(int64_t input_num_samp, bool flush) const;
|
||||
|
||||
/// Given an output-sample index, this function outputs to *first_samp_in the
|
||||
/// first input-sample index that we have a weight on (may be negative),
|
||||
/// and to *samp_out_wrapped the index into weights_ where we can get the
|
||||
/// corresponding weights on the input.
|
||||
inline void GetIndexes(int64_t samp_out, int64_t *first_samp_in,
|
||||
int32_t *samp_out_wrapped) const;
|
||||
|
||||
void SetRemainder(const float *input, int32_t input_dim);
|
||||
|
||||
private:
|
||||
// The following variables are provided by the user.
|
||||
int32_t samp_rate_in_;
|
||||
int32_t samp_rate_out_;
|
||||
float filter_cutoff_;
|
||||
int32_t num_zeros_;
|
||||
|
||||
int32_t input_samples_in_unit_; ///< The number of input samples in the
|
||||
///< smallest repeating unit: num_samp_in_ =
|
||||
///< samp_rate_in_hz / Gcd(samp_rate_in_hz,
|
||||
///< samp_rate_out_hz)
|
||||
|
||||
int32_t output_samples_in_unit_; ///< The number of output samples in the
|
||||
///< smallest repeating unit: num_samp_out_
|
||||
///< = samp_rate_out_hz /
|
||||
///< Gcd(samp_rate_in_hz, samp_rate_out_hz)
|
||||
|
||||
/// The first input-sample index that we sum over, for this output-sample
|
||||
/// index. May be negative; any truncation at the beginning is handled
|
||||
/// separately. This is just for the first few output samples, but we can
|
||||
/// extrapolate the correct input-sample index for arbitrary output samples.
|
||||
std::vector<int32_t> first_index_;
|
||||
|
||||
/// Weights on the input samples, for this output-sample index.
|
||||
std::vector<std::vector<float>> weights_;
|
||||
|
||||
// the following variables keep track of where we are in a particular signal,
|
||||
// if it is being provided over multiple calls to Resample().
|
||||
|
||||
int64_t input_sample_offset_; ///< The number of input samples we have
|
||||
///< already received for this signal
|
||||
///< (including anything in remainder_)
|
||||
int64_t output_sample_offset_; ///< The number of samples we have already
|
||||
///< output for this signal.
|
||||
std::vector<float> input_remainder_; ///< A small trailing part of the
|
||||
///< previously seen input signal.
|
||||
};
|
||||
@ -8,7 +8,7 @@ if(WIN32)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(EXTRA_LIBS rapidasr)
|
||||
set(EXTRA_LIBS funasr)
|
||||
|
||||
|
||||
include_directories(${CMAKE_SOURCE_DIR}/include)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include "librapidasrapi.h"
|
||||
#include "libfunasrapi.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
@ -26,7 +26,7 @@ int main(int argc, char *argv[])
|
||||
// is quantize
|
||||
bool quantize = false;
|
||||
istringstream(argv[3]) >> boolalpha >> quantize;
|
||||
RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
|
||||
FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
|
||||
|
||||
if (!AsrHanlde)
|
||||
{
|
||||
@ -42,72 +42,32 @@ int main(int argc, char *argv[])
|
||||
gettimeofday(&start, NULL);
|
||||
float snippet_time = 0.0f;
|
||||
|
||||
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
|
||||
FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
|
||||
if (Result)
|
||||
{
|
||||
string msg = RapidAsrGetResult(Result, 0);
|
||||
string msg = FunASRGetResult(Result, 0);
|
||||
setbuf(stdout, NULL);
|
||||
cout << "Result: \"";
|
||||
cout << msg << "\"." << endl;
|
||||
snippet_time = RapidAsrGetRetSnippetTime(Result);
|
||||
RapidAsrFreeResult(Result);
|
||||
printf("Result: %s \n", msg.c_str());
|
||||
snippet_time = FunASRGetRetSnippetTime(Result);
|
||||
FunASRFreeResult(Result);
|
||||
}
|
||||
else
|
||||
{
|
||||
cout <<"no return data!";
|
||||
}
|
||||
|
||||
//char* buff = nullptr;
|
||||
//int len = 0;
|
||||
//ifstream ifs(argv[2], std::ios::binary | std::ios::in);
|
||||
//if (ifs.is_open())
|
||||
//{
|
||||
// ifs.seekg(0, std::ios::end);
|
||||
// len = ifs.tellg();
|
||||
// ifs.seekg(0, std::ios::beg);
|
||||
|
||||
// buff = new char[len];
|
||||
|
||||
// ifs.read(buff, len);
|
||||
|
||||
|
||||
// //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
|
||||
|
||||
// RPASR_RESULT Result=RapidAsrRecogPCMBuffer(AsrHanlde, buff,len, RASR_NONE, NULL);
|
||||
// //RPASR_RESULT Result = RapidAsrRecogPCMFile(AsrHanlde, argv[2], RASR_NONE, NULL);
|
||||
// gettimeofday(&end, NULL);
|
||||
//
|
||||
// if (Result)
|
||||
// {
|
||||
// string msg = RapidAsrGetResult(Result, 0);
|
||||
// setbuf(stdout, NULL);
|
||||
// cout << "Result: \"";
|
||||
// cout << msg << endl;
|
||||
// cout << "\"." << endl;
|
||||
// snippet_time = RapidAsrGetRetSnippetTime(Result);
|
||||
// RapidAsrFreeResult(Result);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// cout <<"no return data!";
|
||||
// }
|
||||
|
||||
//
|
||||
//delete[]buff;
|
||||
//}
|
||||
|
||||
printf("Audio length %lfs.\n", (double)snippet_time);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
printf("Model inference takes %lfs.\n", (double)taking_micros / 1000000);
|
||||
printf("Model inference RTF: %04lf.\n", (double)taking_micros/ (snippet_time*1000000));
|
||||
|
||||
RapidAsrUninit(AsrHanlde);
|
||||
FunASRUninit(AsrHanlde);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include "librapidasrapi.h"
|
||||
#include "libfunasrapi.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
@ -47,7 +47,7 @@ int main(int argc, char *argv[])
|
||||
bool quantize = false;
|
||||
istringstream(argv[3]) >> boolalpha >> quantize;
|
||||
|
||||
RPASR_HANDLE AsrHanlde=RapidAsrInit(argv[1], nThreadNum, quantize);
|
||||
FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
|
||||
if (!AsrHanlde)
|
||||
{
|
||||
printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
|
||||
@ -61,7 +61,7 @@ int main(int argc, char *argv[])
|
||||
// warm up
|
||||
for (size_t i = 0; i < 30; i++)
|
||||
{
|
||||
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
|
||||
FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
|
||||
}
|
||||
|
||||
// forward
|
||||
@ -72,19 +72,19 @@ int main(int argc, char *argv[])
|
||||
for (size_t i = 0; i < wav_list.size(); i++)
|
||||
{
|
||||
gettimeofday(&start, NULL);
|
||||
RPASR_RESULT Result=RapidAsrRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
|
||||
FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
total_time += taking_micros;
|
||||
|
||||
if(Result){
|
||||
string msg = RapidAsrGetResult(Result, 0);
|
||||
printf("Result: %s \n", msg);
|
||||
string msg = FunASRGetResult(Result, 0);
|
||||
printf("Result: %s \n", msg.c_str());
|
||||
|
||||
snippet_time = RapidAsrGetRetSnippetTime(Result);
|
||||
snippet_time = FunASRGetRetSnippetTime(Result);
|
||||
total_length += snippet_time;
|
||||
RapidAsrFreeResult(Result);
|
||||
FunASRFreeResult(Result);
|
||||
}else{
|
||||
cout <<"No return data!";
|
||||
}
|
||||
@ -94,6 +94,6 @@ int main(int argc, char *argv[])
|
||||
printf("total_time_comput %ld ms.\n", total_time / 1000);
|
||||
printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
|
||||
|
||||
RapidAsrUninit(AsrHanlde);
|
||||
FunASRUninit(AsrHanlde);
|
||||
return 0;
|
||||
}
|
||||
|
||||
62
funasr/runtime/python/grpc/grpc_main_client.py
Normal file
62
funasr/runtime/python/grpc/grpc_main_client.py
Normal file
@ -0,0 +1,62 @@
|
||||
import grpc
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import soundfile as sf
|
||||
import argparse
|
||||
|
||||
from grpc_client import transcribe_audio_bytes
|
||||
from paraformer_pb2_grpc import ASRStub
|
||||
|
||||
# send the audio data once
|
||||
async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
|
||||
with grpc.insecure_channel(grpc_uri) as channel:
|
||||
stub = ASRStub(channel)
|
||||
for line in wav_scp:
|
||||
wav_file = line.split()[1]
|
||||
wav, _ = sf.read(wav_file, dtype='int16')
|
||||
|
||||
b = time.time()
|
||||
response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
|
||||
resp = response.next()
|
||||
text = ''
|
||||
if 'decoding' == resp.action:
|
||||
resp = response.next()
|
||||
if 'finish' == resp.action:
|
||||
text = json.loads(resp.sentence)['text']
|
||||
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
|
||||
res= {'text': text, 'time': time.time() - b}
|
||||
print(res)
|
||||
|
||||
async def test(args):
|
||||
wav_scp = open(args.wav_scp, "r").readlines()
|
||||
uri = '{}:{}'.format(args.host, args.port)
|
||||
res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
required=False,
|
||||
help="grpc server host ip")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=10108,
|
||||
required=False,
|
||||
help="grpc server port")
|
||||
parser.add_argument("--user_allowed",
|
||||
type=str,
|
||||
default="project1_user1",
|
||||
help="allowed user for grpc client")
|
||||
parser.add_argument("--sample_rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="audio sample_rate from client")
|
||||
parser.add_argument("--wav_scp",
|
||||
type=str,
|
||||
required=True,
|
||||
help="audio wav scp")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test(args))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user