mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_wjm' of https://github.com/alibaba-damo-academy/FunASR into dev_wjm
This commit is contained in:
commit
6d17715edf
11
.github/workflows/main.yml
vendored
11
.github/workflows/main.yml
vendored
@ -12,6 +12,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v1
|
||||
- uses: ammaraskar/sphinx-action@master
|
||||
with:
|
||||
docs-folder: "docs/"
|
||||
pre-build-command: "pip install sphinx-markdown-tables nbsphinx jinja2 recommonmark sphinx_rtd_theme"
|
||||
- uses: ammaraskar/sphinx-action@master
|
||||
with:
|
||||
docs-folder: "docs_cn/"
|
||||
@ -22,7 +26,12 @@ jobs:
|
||||
run: |
|
||||
mkdir public
|
||||
touch public/.nojekyll
|
||||
cp -r docs_cn/_build/html/* public/
|
||||
mkdir public/en
|
||||
touch public/en/.nojekyll
|
||||
cp -r docs/_build/html/* public/en/
|
||||
mkdir public/cn
|
||||
touch public/cn/.nojekyll
|
||||
cp -r docs_cn/_build/html/* public/cn/
|
||||
|
||||
- name: deploy github.io pages
|
||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev_wjm'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
<div align="left"><img src="docs/images/funasr_logo.jpg" width="400"/></div>
|
||||
[//]: # (<div align="left"><img src="docs/images/funasr_logo.jpg" width="400"/></div>)
|
||||
|
||||
# FunASR: A Fundamental End-to-End Speech Recognition Toolkit
|
||||
|
||||
@ -7,7 +7,8 @@
|
||||
[**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new)
|
||||
| [**Highlights**](#highlights)
|
||||
| [**Installation**](#installation)
|
||||
| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/index.html)
|
||||
| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
|
||||
| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
|
||||
| [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
|
||||
| [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
|
||||
| [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
|
||||
@ -42,7 +43,7 @@ pip install --editable ./
|
||||
For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
|
||||
|
||||
## Usage
|
||||
For users who are new to FunASR and ModelScope, please refer to [FunASR Docs](https://alibaba-damo-academy.github.io/FunASR/index.html).
|
||||
For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
|
||||
|
||||
## Contact
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 143 KiB After Width: | Height: | Size: 188 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 182 KiB After Width: | Height: | Size: 184 KiB |
@ -0,0 +1,53 @@
|
||||
# ModelScope Model
|
||||
|
||||
## How to finetune and infer using a pretrained Paraformer-large Model
|
||||
|
||||
### Finetune
|
||||
|
||||
- Modify finetune training related parameters in `finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
|
||||
- <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
|
||||
- <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
|
||||
- <strong>max_epoch:</strong> # number of training epoch
|
||||
- <strong>lr:</strong> # learning rate
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
python finetune.py
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Or you can use the finetuned model for inference directly.
|
||||
|
||||
- Setting parameters in `infer.py`
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>ngpu:</strong> # the number of GPUs for decoding
|
||||
- <strong>njob:</strong> # the number of jobs for each GPU
|
||||
|
||||
- Then you can run the pipeline to infer with:
|
||||
```python
|
||||
python infer.py
|
||||
```
|
||||
|
||||
- Results
|
||||
|
||||
The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
|
||||
|
||||
### Inference using local finetuned model
|
||||
|
||||
- Modify inference related parameters in `infer_after_finetune.py`
|
||||
- <strong>output_dir:</strong> # result dir
|
||||
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
|
||||
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
|
||||
|
||||
- Then you can run the pipeline to finetune with:
|
||||
```python
|
||||
python infer_after_finetune.py
|
||||
```
|
||||
|
||||
- Results
|
||||
|
||||
The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
|
||||
@ -0,0 +1,40 @@
|
||||
# Paraformer-Large
|
||||
- Model link: <https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
|
||||
- Model size: 45M
|
||||
|
||||
# Environments
|
||||
- date: `Tue Feb 13 20:13:22 CST 2023`
|
||||
- python version: `3.7.12`
|
||||
- FunASR version: `0.1.0`
|
||||
- pytorch version: `pytorch 1.7.0`
|
||||
- Git hash: ``
|
||||
- Commit date: ``
|
||||
|
||||
# Beachmark Results
|
||||
|
||||
## result (paper)
|
||||
beam=20,CER tool:https://github.com/yufan-aslp/AliMeeting
|
||||
|
||||
| model | Para (M) | Data (hrs) | Eval (CER%) | Test (CER%) |
|
||||
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
|
||||
| MFCCA | 45 | 917 | 16.1 | 17.5 |
|
||||
|
||||
## result(modelscope)
|
||||
|
||||
beam=10
|
||||
|
||||
with separating character (src)
|
||||
|
||||
| model | Para (M) | Data (hrs) | Eval_sp (CER%) | Test_sp (CER%) |
|
||||
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
|
||||
| MFCCA | 45 | 917 | 17.1 | 18.6 |
|
||||
|
||||
without separating character (src)
|
||||
|
||||
| model | Para (M) | Data (hrs) | Eval_nosp (CER%) | Test_nosp (CER%) |
|
||||
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
|
||||
| MFCCA | 45 | 917 | 16.4 | 18.0 |
|
||||
|
||||
## 偏差
|
||||
|
||||
Considering the differences of the CER calculation tool and decoding beam size, the results of CER are biased (<0.5%).
|
||||
@ -0,0 +1,35 @@
|
||||
import os
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.trainers import build_trainer
|
||||
from funasr.datasets.ms_dataset import MsDataset
|
||||
from funasr.utils.modelscope_param import modelscope_args
|
||||
|
||||
def modelscope_finetune(params):
|
||||
if not os.path.exists(params.output_dir):
|
||||
os.makedirs(params.output_dir, exist_ok=True)
|
||||
# dataset split ["train", "validation"]
|
||||
ds_dict = MsDataset.load(params.data_path)
|
||||
kwargs = dict(
|
||||
model=params.model,
|
||||
model_revision=params.model_revision,
|
||||
data_dir=ds_dict,
|
||||
dataset_type=params.dataset_type,
|
||||
work_dir=params.output_dir,
|
||||
batch_bins=params.batch_bins,
|
||||
max_epoch=params.max_epoch,
|
||||
lr=params.lr)
|
||||
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
params = modelscope_args(model="yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
|
||||
params.output_dir = "./checkpoint" # m模型保存路径
|
||||
params.data_path = "./example_data/" # 数据路径
|
||||
params.dataset_type = "small" # 小数据量设置small,若数据量大于1000小时,请使用large
|
||||
params.batch_bins = 1000 # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
|
||||
params.max_epoch = 10 # 最大训练轮数
|
||||
params.lr = 0.0001 # 设置学习率
|
||||
params.model_revision = 'v2.0.0'
|
||||
modelscope_finetune(params)
|
||||
@ -0,0 +1,103 @@
|
||||
import os
|
||||
import shutil
|
||||
from multiprocessing import Pool
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from funasr.utils.compute_wer import compute_wer
|
||||
|
||||
import pdb;
|
||||
def modelscope_infer_core(output_dir, split_dir, njob, idx):
|
||||
output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
|
||||
gpu_id = (int(idx) - 1) // njob
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
|
||||
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
|
||||
inference_pipline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model='yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
|
||||
model_revision='v2.0.0',
|
||||
output_dir=output_dir_job,
|
||||
batch_size=1,
|
||||
)
|
||||
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
|
||||
inference_pipline(audio_in=audio_in)
|
||||
|
||||
|
||||
def modelscope_infer(params):
|
||||
# prepare for multi-GPU decoding
|
||||
ngpu = params["ngpu"]
|
||||
njob = params["njob"]
|
||||
output_dir = params["output_dir"]
|
||||
if os.path.exists(output_dir):
|
||||
shutil.rmtree(output_dir)
|
||||
os.mkdir(output_dir)
|
||||
split_dir = os.path.join(output_dir, "split")
|
||||
os.mkdir(split_dir)
|
||||
nj = ngpu * njob
|
||||
wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
|
||||
with open(wav_scp_file) as f:
|
||||
lines = f.readlines()
|
||||
num_lines = len(lines)
|
||||
num_job_lines = num_lines // nj
|
||||
start = 0
|
||||
for i in range(nj):
|
||||
end = start + num_job_lines
|
||||
file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
|
||||
with open(file, "w") as f:
|
||||
if i == nj - 1:
|
||||
f.writelines(lines[start:])
|
||||
else:
|
||||
f.writelines(lines[start:end])
|
||||
start = end
|
||||
p = Pool(nj)
|
||||
for i in range(nj):
|
||||
p.apply_async(modelscope_infer_core,
|
||||
args=(output_dir, split_dir, njob, str(i + 1)))
|
||||
p.close()
|
||||
p.join()
|
||||
|
||||
# combine decoding results
|
||||
best_recog_path = os.path.join(output_dir, "1best_recog")
|
||||
os.mkdir(best_recog_path)
|
||||
files = ["text", "token", "score"]
|
||||
for file in files:
|
||||
with open(os.path.join(best_recog_path, file), "w") as f:
|
||||
for i in range(nj):
|
||||
job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
|
||||
with open(job_file) as f_job:
|
||||
lines = f_job.readlines()
|
||||
f.writelines(lines)
|
||||
|
||||
# If text exists, compute CER
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if os.path.exists(text_in):
|
||||
text_proc_file = os.path.join(best_recog_path, "token")
|
||||
text_proc_file2 = os.path.join(best_recog_path, "token_nosep")
|
||||
with open(text_proc_file, 'r') as hyp_reader:
|
||||
with open(text_proc_file2, 'w') as hyp_writer:
|
||||
for line in hyp_reader:
|
||||
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
|
||||
hyp_writer.write(new_context+'\n')
|
||||
text_in2 = os.path.join(best_recog_path, "ref_text_nosep")
|
||||
with open(text_in, 'r') as ref_reader:
|
||||
with open(text_in2, 'w') as ref_writer:
|
||||
for line in ref_reader:
|
||||
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
|
||||
ref_writer.write(new_context+'\n')
|
||||
|
||||
|
||||
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.sp.cer"))
|
||||
compute_wer(text_in2, text_proc_file2, os.path.join(best_recog_path, "text.nosp.cer"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
params = {}
|
||||
params["data_dir"] = "./example_data/validation"
|
||||
params["output_dir"] = "./output_dir"
|
||||
params["ngpu"] = 1
|
||||
params["njob"] = 1
|
||||
modelscope_infer(params)
|
||||
@ -0,0 +1,67 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
from funasr.utils.compute_wer import compute_wer
|
||||
|
||||
|
||||
def modelscope_infer_after_finetune(params):
|
||||
# prepare for decoding
|
||||
pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
|
||||
for file_name in params["required_files"]:
|
||||
if file_name == "configuration.json":
|
||||
with open(os.path.join(pretrained_model_path, file_name)) as f:
|
||||
config_dict = json.load(f)
|
||||
config_dict["model"]["am_model_name"] = params["decoding_model_name"]
|
||||
with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=4, separators=(',', ': '))
|
||||
else:
|
||||
shutil.copy(os.path.join(pretrained_model_path, file_name),
|
||||
os.path.join(params["output_dir"], file_name))
|
||||
decoding_path = os.path.join(params["output_dir"], "decode_results")
|
||||
if os.path.exists(decoding_path):
|
||||
shutil.rmtree(decoding_path)
|
||||
os.mkdir(decoding_path)
|
||||
|
||||
# decoding
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model=params["output_dir"],
|
||||
output_dir=decoding_path,
|
||||
batch_size=1
|
||||
)
|
||||
audio_in = os.path.join(params["data_dir"], "wav.scp")
|
||||
inference_pipeline(audio_in=audio_in)
|
||||
|
||||
# computer CER if GT text is set
|
||||
text_in = os.path.join(params["data_dir"], "text")
|
||||
if text_in is not None:
|
||||
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
|
||||
text_proc_file2 = os.path.join(decoding_path, "1best_recog/token_nosep")
|
||||
with open(text_proc_file, 'r') as hyp_reader:
|
||||
with open(text_proc_file2, 'w') as hyp_writer:
|
||||
for line in hyp_reader:
|
||||
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
|
||||
hyp_writer.write(new_context+'\n')
|
||||
text_in2 = os.path.join(decoding_path, "1best_recog/ref_text_nosep")
|
||||
with open(text_in, 'r') as ref_reader:
|
||||
with open(text_in2, 'w') as ref_writer:
|
||||
for line in ref_reader:
|
||||
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
|
||||
ref_writer.write(new_context+'\n')
|
||||
|
||||
|
||||
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.sp.cer"))
|
||||
compute_wer(text_in2, text_proc_file2, os.path.join(decoding_path, "text.nosp.cer"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {}
|
||||
params["modelscope_model_name"] = "yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
|
||||
params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
|
||||
params["output_dir"] = "./checkpoint"
|
||||
params["data_dir"] = "./example_data/validation"
|
||||
params["decoding_model_name"] = "valid.acc.ave.pth"
|
||||
modelscope_infer_after_finetune(params)
|
||||
@ -6,7 +6,7 @@ if __name__ == '__main__':
|
||||
param_dict = dict()
|
||||
param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
|
||||
|
||||
audio_in = "//isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav"
|
||||
audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav"
|
||||
output_dir = None
|
||||
batch_size = 1
|
||||
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
import os
|
||||
import tempfile
|
||||
import codecs
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.msdatasets import MsDataset
|
||||
|
||||
if __name__ == '__main__':
|
||||
param_dict = dict()
|
||||
param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
|
||||
|
||||
output_dir = "./output"
|
||||
batch_size = 1
|
||||
|
||||
# dataset split ['test']
|
||||
ds_dict = MsDataset.load(dataset_name='speech_asr_aishell1_hotwords_testsets', namespace='speech_asr')
|
||||
work_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(work_dir):
|
||||
os.makedirs(work_dir)
|
||||
wav_file_path = os.path.join(work_dir, "wav.scp")
|
||||
|
||||
with codecs.open(wav_file_path, 'w') as fin:
|
||||
for line in ds_dict:
|
||||
wav = line["Audio:FILE"]
|
||||
idx = wav.split("/")[-1].split(".")[0]
|
||||
fin.writelines(idx + " " + wav + "\n")
|
||||
audio_in = wav_file_path
|
||||
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
|
||||
output_dir=output_dir,
|
||||
batch_size=batch_size,
|
||||
param_dict=param_dict)
|
||||
|
||||
rec_result = inference_pipeline(audio_in=audio_in)
|
||||
@ -228,6 +228,9 @@ def inference_launch(**kwargs):
|
||||
elif mode == "vad":
|
||||
from funasr.bin.vad_inference import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
elif mode == "mfcca":
|
||||
from funasr.bin.asr_inference_mfcca import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
else:
|
||||
logging.info("Unknown decoding mode: {}".format(mode))
|
||||
return None
|
||||
@ -253,6 +256,9 @@ def inference_launch_funasr(**kwargs):
|
||||
elif mode == "vad":
|
||||
from funasr.bin.vad_inference import inference
|
||||
return inference(**kwargs)
|
||||
elif mode == "mfcca":
|
||||
from funasr.bin.asr_inference_mfcca import inference_modelscope
|
||||
return inference_modelscope(**kwargs)
|
||||
else:
|
||||
logging.info("Unknown decoding mode: {}".format(mode))
|
||||
return None
|
||||
|
||||
774
funasr/bin/asr_inference_mfcca.py
Normal file
774
funasr/bin/asr_inference_mfcca.py
Normal file
@ -0,0 +1,774 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
from typeguard import check_return_type
|
||||
|
||||
from funasr.fileio.datadir_writer import DatadirWriter
|
||||
from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
|
||||
from funasr.modules.beam_search.beam_search import BeamSearch
|
||||
from funasr.modules.beam_search.beam_search import Hypothesis
|
||||
from funasr.modules.scorers.ctc import CTCPrefixScorer
|
||||
from funasr.modules.scorers.length_bonus import LengthBonus
|
||||
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
||||
from funasr.tasks.lm import LMTask
|
||||
from funasr.text.build_tokenizer import build_tokenizer
|
||||
from funasr.text.token_id_converter import TokenIDConverter
|
||||
from funasr.torch_utils.device_funcs import to_device
|
||||
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||
from funasr.utils import config_argparse
|
||||
from funasr.utils.cli_utils import get_commandline_args
|
||||
from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
||||
import pdb
|
||||
|
||||
header_colors = '\033[95m'
|
||||
end_colors = '\033[0m'
|
||||
|
||||
global_asr_language: str = 'zh-cn'
|
||||
global_sample_rate: Union[int, Dict[Any, int]] = {
|
||||
'audio_fs': 16000,
|
||||
'model_fs': 16000
|
||||
}
|
||||
|
||||
class Speech2Text:
|
||||
"""Speech2Text class
|
||||
|
||||
Examples:
|
||||
>>> import soundfile
|
||||
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
|
||||
>>> audio, rate = soundfile.read("speech.wav")
|
||||
>>> speech2text(audio)
|
||||
[(text, token, token_int, hypothesis object), ...]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
asr_train_config: Union[Path, str] = None,
|
||||
asr_model_file: Union[Path, str] = None,
|
||||
cmvn_file: Union[Path, str] = None,
|
||||
lm_train_config: Union[Path, str] = None,
|
||||
lm_file: Union[Path, str] = None,
|
||||
token_type: str = None,
|
||||
bpemodel: str = None,
|
||||
device: str = "cpu",
|
||||
maxlenratio: float = 0.0,
|
||||
minlenratio: float = 0.0,
|
||||
batch_size: int = 1,
|
||||
dtype: str = "float32",
|
||||
beam_size: int = 20,
|
||||
ctc_weight: float = 0.5,
|
||||
lm_weight: float = 1.0,
|
||||
ngram_weight: float = 0.9,
|
||||
penalty: float = 0.0,
|
||||
nbest: int = 1,
|
||||
streaming: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
|
||||
# 1. Build ASR model
|
||||
scorers = {}
|
||||
asr_model, asr_train_args = ASRTask.build_model_from_file(
|
||||
asr_train_config, asr_model_file, cmvn_file, device
|
||||
)
|
||||
|
||||
logging.info("asr_model: {}".format(asr_model))
|
||||
logging.info("asr_train_args: {}".format(asr_train_args))
|
||||
asr_model.to(dtype=getattr(torch, dtype)).eval()
|
||||
|
||||
decoder = asr_model.decoder
|
||||
|
||||
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
|
||||
token_list = asr_model.token_list
|
||||
scorers.update(
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
length_bonus=LengthBonus(len(token_list)),
|
||||
)
|
||||
|
||||
# 2. Build Language model
|
||||
if lm_train_config is not None:
|
||||
lm, lm_train_args = LMTask.build_model_from_file(
|
||||
lm_train_config, lm_file, device
|
||||
)
|
||||
lm.to(device)
|
||||
scorers["lm"] = lm.lm
|
||||
# 3. Build ngram model
|
||||
# ngram is not supported now
|
||||
ngram = None
|
||||
scorers["ngram"] = ngram
|
||||
|
||||
# 4. Build BeamSearch object
|
||||
# transducer is not supported now
|
||||
beam_search_transducer = None
|
||||
|
||||
weights = dict(
|
||||
decoder=1.0 - ctc_weight,
|
||||
ctc=ctc_weight,
|
||||
lm=lm_weight,
|
||||
ngram=ngram_weight,
|
||||
length_bonus=penalty,
|
||||
)
|
||||
beam_search = BeamSearch(
|
||||
beam_size=beam_size,
|
||||
weights=weights,
|
||||
scorers=scorers,
|
||||
sos=asr_model.sos,
|
||||
eos=asr_model.eos,
|
||||
vocab_size=len(token_list),
|
||||
token_list=token_list,
|
||||
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
|
||||
)
|
||||
#beam_search.__class__ = BatchBeamSearch
|
||||
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
|
||||
if token_type is None:
|
||||
token_type = asr_train_args.token_type
|
||||
if bpemodel is None:
|
||||
bpemodel = asr_train_args.bpemodel
|
||||
|
||||
if token_type is None:
|
||||
tokenizer = None
|
||||
elif token_type == "bpe":
|
||||
if bpemodel is not None:
|
||||
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
|
||||
else:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = build_tokenizer(token_type=token_type)
|
||||
converter = TokenIDConverter(token_list=token_list)
|
||||
logging.info(f"Text tokenizer: {tokenizer}")
|
||||
|
||||
self.asr_model = asr_model
|
||||
self.asr_train_args = asr_train_args
|
||||
self.converter = converter
|
||||
self.tokenizer = tokenizer
|
||||
self.beam_search = beam_search
|
||||
self.beam_search_transducer = beam_search_transducer
|
||||
self.maxlenratio = maxlenratio
|
||||
self.minlenratio = minlenratio
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.nbest = nbest
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
|
||||
) -> List[
|
||||
Tuple[
|
||||
Optional[str],
|
||||
List[str],
|
||||
List[int],
|
||||
Union[Hypothesis],
|
||||
]
|
||||
]:
|
||||
"""Inference
|
||||
|
||||
Args:
|
||||
speech: Input speech data
|
||||
Returns:
|
||||
text, token, token_int, hyp
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
# Input as audio signal
|
||||
if isinstance(speech, np.ndarray):
|
||||
speech = torch.tensor(speech)
|
||||
|
||||
|
||||
#speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
|
||||
speech = speech.to(getattr(torch, self.dtype))
|
||||
# lenghts: (1,)
|
||||
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
|
||||
batch = {"speech": speech, "speech_lengths": lengths}
|
||||
|
||||
# a. To device
|
||||
batch = to_device(batch, device=self.device)
|
||||
|
||||
# b. Forward Encoder
|
||||
enc, _ = self.asr_model.encode(**batch)
|
||||
|
||||
assert len(enc) == 1, len(enc)
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
nbest_hyps = self.beam_search(
|
||||
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
|
||||
)
|
||||
|
||||
nbest_hyps = nbest_hyps[: self.nbest]
|
||||
|
||||
results = []
|
||||
for hyp in nbest_hyps:
|
||||
assert isinstance(hyp, (Hypothesis)), type(hyp)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
if isinstance(hyp.yseq, list):
|
||||
token_int = hyp.yseq[1:last_pos]
|
||||
else:
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != 0, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
|
||||
if self.tokenizer is not None:
|
||||
text = self.tokenizer.tokens2text(token)
|
||||
else:
|
||||
text = None
|
||||
results.append((text, token, token_int, hyp))
|
||||
|
||||
assert check_return_type(results)
|
||||
return results
|
||||
|
||||
|
||||
# def inference(
|
||||
# maxlenratio: float,
|
||||
# minlenratio: float,
|
||||
# batch_size: int,
|
||||
# beam_size: int,
|
||||
# ngpu: int,
|
||||
# ctc_weight: float,
|
||||
# lm_weight: float,
|
||||
# penalty: float,
|
||||
# log_level: Union[int, str],
|
||||
# data_path_and_name_and_type,
|
||||
# asr_train_config: Optional[str],
|
||||
# asr_model_file: Optional[str],
|
||||
# cmvn_file: Optional[str] = None,
|
||||
# lm_train_config: Optional[str] = None,
|
||||
# lm_file: Optional[str] = None,
|
||||
# token_type: Optional[str] = None,
|
||||
# key_file: Optional[str] = None,
|
||||
# word_lm_train_config: Optional[str] = None,
|
||||
# bpemodel: Optional[str] = None,
|
||||
# allow_variable_data_keys: bool = False,
|
||||
# streaming: bool = False,
|
||||
# output_dir: Optional[str] = None,
|
||||
# dtype: str = "float32",
|
||||
# seed: int = 0,
|
||||
# ngram_weight: float = 0.9,
|
||||
# nbest: int = 1,
|
||||
# num_workers: int = 1,
|
||||
# **kwargs,
|
||||
# ):
|
||||
# assert check_argument_types()
|
||||
# if batch_size > 1:
|
||||
# raise NotImplementedError("batch decoding is not implemented")
|
||||
# if word_lm_train_config is not None:
|
||||
# raise NotImplementedError("Word LM is not implemented")
|
||||
# if ngpu > 1:
|
||||
# raise NotImplementedError("only single GPU decoding is supported")
|
||||
#
|
||||
# logging.basicConfig(
|
||||
# level=log_level,
|
||||
# format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
# )
|
||||
#
|
||||
# if ngpu >= 1 and torch.cuda.is_available():
|
||||
# device = "cuda"
|
||||
# else:
|
||||
# device = "cpu"
|
||||
#
|
||||
# # 1. Set random-seed
|
||||
# set_all_random_seed(seed)
|
||||
#
|
||||
# # 2. Build speech2text
|
||||
# speech2text_kwargs = dict(
|
||||
# asr_train_config=asr_train_config,
|
||||
# asr_model_file=asr_model_file,
|
||||
# cmvn_file=cmvn_file,
|
||||
# lm_train_config=lm_train_config,
|
||||
# lm_file=lm_file,
|
||||
# token_type=token_type,
|
||||
# bpemodel=bpemodel,
|
||||
# device=device,
|
||||
# maxlenratio=maxlenratio,
|
||||
# minlenratio=minlenratio,
|
||||
# dtype=dtype,
|
||||
# beam_size=beam_size,
|
||||
# ctc_weight=ctc_weight,
|
||||
# lm_weight=lm_weight,
|
||||
# ngram_weight=ngram_weight,
|
||||
# penalty=penalty,
|
||||
# nbest=nbest,
|
||||
# streaming=streaming,
|
||||
# )
|
||||
# logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
|
||||
# speech2text = Speech2Text(**speech2text_kwargs)
|
||||
#
|
||||
# # 3. Build data-iterator
|
||||
# loader = ASRTask.build_streaming_iterator(
|
||||
# data_path_and_name_and_type,
|
||||
# dtype=dtype,
|
||||
# batch_size=batch_size,
|
||||
# key_file=key_file,
|
||||
# num_workers=num_workers,
|
||||
# preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
# collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
# allow_variable_data_keys=allow_variable_data_keys,
|
||||
# inference=True,
|
||||
# )
|
||||
#
|
||||
# finish_count = 0
|
||||
# file_count = 1
|
||||
# # 7 .Start for-loop
|
||||
# # FIXME(kamo): The output format should be discussed about
|
||||
# asr_result_list = []
|
||||
# if output_dir is not None:
|
||||
# writer = DatadirWriter(output_dir)
|
||||
# else:
|
||||
# writer = None
|
||||
#
|
||||
# for keys, batch in loader:
|
||||
# assert isinstance(batch, dict), type(batch)
|
||||
# assert all(isinstance(s, str) for s in keys), keys
|
||||
# _bs = len(next(iter(batch.values())))
|
||||
# assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
# #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
#
|
||||
# # N-best list of (text, token, token_int, hyp_object)
|
||||
# try:
|
||||
# results = speech2text(**batch)
|
||||
# except TooShortUttError as e:
|
||||
# logging.warning(f"Utterance {keys} {e}")
|
||||
# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
# results = [[" ", ["<space>"], [2], hyp]] * nbest
|
||||
#
|
||||
# # Only supporting batch_size==1
|
||||
# key = keys[0]
|
||||
# for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
|
||||
# # Create a directory: outdir/{n}best_recog
|
||||
# if writer is not None:
|
||||
# ibest_writer = writer[f"{n}best_recog"]
|
||||
#
|
||||
# # Write the result to each file
|
||||
# ibest_writer["token"][key] = " ".join(token)
|
||||
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
# ibest_writer["score"][key] = str(hyp.score)
|
||||
#
|
||||
# if text is not None:
|
||||
# text_postprocessed = postprocess_utils.sentence_postprocess(token)
|
||||
# item = {'key': key, 'value': text_postprocessed}
|
||||
# asr_result_list.append(item)
|
||||
# finish_count += 1
|
||||
# asr_utils.print_progress(finish_count / file_count)
|
||||
# if writer is not None:
|
||||
# ibest_writer["text"][key] = text
|
||||
# return asr_result_list
|
||||
|
||||
def inference(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
log_level: Union[int, str],
|
||||
data_path_and_name_and_type,
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str] = None,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
key_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
streaming: bool = False,
|
||||
output_dir: Optional[str] = None,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
inference_pipeline = inference_modelscope(
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
batch_size=batch_size,
|
||||
beam_size=beam_size,
|
||||
ngpu=ngpu,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
penalty=penalty,
|
||||
log_level=log_level,
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
raw_inputs=raw_inputs,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
key_file=key_file,
|
||||
word_lm_train_config=word_lm_train_config,
|
||||
bpemodel=bpemodel,
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
streaming=streaming,
|
||||
output_dir=output_dir,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
ngram_weight=ngram_weight,
|
||||
nbest=nbest,
|
||||
num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
|
||||
|
||||
def inference_modelscope(
|
||||
maxlenratio: float,
|
||||
minlenratio: float,
|
||||
batch_size: int,
|
||||
beam_size: int,
|
||||
ngpu: int,
|
||||
ctc_weight: float,
|
||||
lm_weight: float,
|
||||
penalty: float,
|
||||
log_level: Union[int, str],
|
||||
# data_path_and_name_and_type,
|
||||
asr_train_config: Optional[str],
|
||||
asr_model_file: Optional[str],
|
||||
cmvn_file: Optional[str] = None,
|
||||
lm_train_config: Optional[str] = None,
|
||||
lm_file: Optional[str] = None,
|
||||
token_type: Optional[str] = None,
|
||||
key_file: Optional[str] = None,
|
||||
word_lm_train_config: Optional[str] = None,
|
||||
bpemodel: Optional[str] = None,
|
||||
allow_variable_data_keys: bool = False,
|
||||
streaming: bool = False,
|
||||
output_dir: Optional[str] = None,
|
||||
dtype: str = "float32",
|
||||
seed: int = 0,
|
||||
ngram_weight: float = 0.9,
|
||||
nbest: int = 1,
|
||||
num_workers: int = 1,
|
||||
param_dict: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert check_argument_types()
|
||||
if batch_size > 1:
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
if word_lm_train_config is not None:
|
||||
raise NotImplementedError("Word LM is not implemented")
|
||||
if ngpu > 1:
|
||||
raise NotImplementedError("only single GPU decoding is supported")
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||
)
|
||||
|
||||
if ngpu >= 1 and torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# 1. Set random-seed
|
||||
set_all_random_seed(seed)
|
||||
|
||||
# 2. Build speech2text
|
||||
speech2text_kwargs = dict(
|
||||
asr_train_config=asr_train_config,
|
||||
asr_model_file=asr_model_file,
|
||||
cmvn_file=cmvn_file,
|
||||
lm_train_config=lm_train_config,
|
||||
lm_file=lm_file,
|
||||
token_type=token_type,
|
||||
bpemodel=bpemodel,
|
||||
device=device,
|
||||
maxlenratio=maxlenratio,
|
||||
minlenratio=minlenratio,
|
||||
dtype=dtype,
|
||||
beam_size=beam_size,
|
||||
ctc_weight=ctc_weight,
|
||||
lm_weight=lm_weight,
|
||||
ngram_weight=ngram_weight,
|
||||
penalty=penalty,
|
||||
nbest=nbest,
|
||||
streaming=streaming,
|
||||
)
|
||||
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
|
||||
speech2text = Speech2Text(**speech2text_kwargs)
|
||||
|
||||
def _forward(data_path_and_name_and_type,
|
||||
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||
output_dir_v2: Optional[str] = None,
|
||||
fs: dict = None,
|
||||
param_dict: dict = None,
|
||||
):
|
||||
# 3. Build data-iterator
|
||||
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||
if isinstance(raw_inputs, torch.Tensor):
|
||||
raw_inputs = raw_inputs.numpy()
|
||||
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||
loader = ASRTask.build_streaming_iterator(
|
||||
data_path_and_name_and_type,
|
||||
dtype=dtype,
|
||||
batch_size=batch_size,
|
||||
key_file=key_file,
|
||||
num_workers=num_workers,
|
||||
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
|
||||
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
|
||||
allow_variable_data_keys=allow_variable_data_keys,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
finish_count = 0
|
||||
file_count = 1
|
||||
# 7 .Start for-loop
|
||||
# FIXME(kamo): The output format should be discussed about
|
||||
asr_result_list = []
|
||||
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||
if output_path is not None:
|
||||
writer = DatadirWriter(output_path)
|
||||
else:
|
||||
writer = None
|
||||
|
||||
for keys, batch in loader:
|
||||
assert isinstance(batch, dict), type(batch)
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||
|
||||
# N-best list of (text, token, token_int, hyp_object)
|
||||
try:
|
||||
results = speech2text(**batch)
|
||||
except TooShortUttError as e:
|
||||
logging.warning(f"Utterance {keys} {e}")
|
||||
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
|
||||
results = [[" ", ["<space>"], [2], hyp]] * nbest
|
||||
|
||||
# Only supporting batch_size==1
|
||||
key = keys[0]
|
||||
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
|
||||
# Create a directory: outdir/{n}best_recog
|
||||
if writer is not None:
|
||||
ibest_writer = writer[f"{n}best_recog"]
|
||||
|
||||
# Write the result to each file
|
||||
ibest_writer["token"][key] = " ".join(token)
|
||||
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
|
||||
ibest_writer["score"][key] = str(hyp.score)
|
||||
|
||||
if text is not None:
|
||||
text_postprocessed = postprocess_utils.sentence_postprocess(token)
|
||||
item = {'key': key, 'value': text_postprocessed}
|
||||
asr_result_list.append(item)
|
||||
finish_count += 1
|
||||
asr_utils.print_progress(finish_count / file_count)
|
||||
if writer is not None:
|
||||
ibest_writer["text"][key] = text
|
||||
return asr_result_list
|
||||
|
||||
return _forward
|
||||
|
||||
def set_parameters(language: str = None,
|
||||
sample_rate: Union[int, Dict[Any, int]] = None):
|
||||
if language is not None:
|
||||
global global_asr_language
|
||||
global_asr_language = language
|
||||
if sample_rate is not None:
|
||||
global global_sample_rate
|
||||
global_sample_rate = sample_rate
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = config_argparse.ArgumentParser(
|
||||
description="ASR Decoding",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Note(kamo): Use '_' instead of '-' as separator.
|
||||
# '-' is confusing if written in yaml.
|
||||
parser.add_argument(
|
||||
"--log_level",
|
||||
type=lambda x: x.upper(),
|
||||
default="INFO",
|
||||
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||
help="The verbose level of logging",
|
||||
)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, required=True)
|
||||
parser.add_argument(
|
||||
"--ngpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of gpus. 0 indicates CPU mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpuid_list",
|
||||
type=str,
|
||||
default="",
|
||||
help="The visible gpus",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="float32",
|
||||
choices=["float16", "float32", "float64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of workers used for DataLoader",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Input data related")
|
||||
group.add_argument(
|
||||
"--data_path_and_name_and_type",
|
||||
type=str2triple_str,
|
||||
required=False,
|
||||
action="append",
|
||||
)
|
||||
group.add_argument("--raw_inputs", type=list, default=None)
|
||||
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
|
||||
group.add_argument("--key_file", type=str_or_none)
|
||||
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("The model configuration related")
|
||||
group.add_argument(
|
||||
"--asr_train_config",
|
||||
type=str,
|
||||
help="ASR training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--asr_model_file",
|
||||
type=str,
|
||||
help="ASR model parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--cmvn_file",
|
||||
type=str,
|
||||
help="Global cmvn file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_train_config",
|
||||
type=str,
|
||||
help="LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--lm_file",
|
||||
type=str,
|
||||
help="LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_train_config",
|
||||
type=str,
|
||||
help="Word LM training configuration",
|
||||
)
|
||||
group.add_argument(
|
||||
"--word_lm_file",
|
||||
type=str,
|
||||
help="Word LM parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ngram_file",
|
||||
type=str,
|
||||
help="N-gram parameter file",
|
||||
)
|
||||
group.add_argument(
|
||||
"--model_tag",
|
||||
type=str,
|
||||
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||
"*_file will be overwritten",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("Beam-search related")
|
||||
group.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The batch size for inference",
|
||||
)
|
||||
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
|
||||
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
|
||||
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
|
||||
group.add_argument(
|
||||
"--maxlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain max output length. "
|
||||
"If maxlenratio=0.0 (default), it uses a end-detect "
|
||||
"function "
|
||||
"to automatically find maximum hypothesis lengths."
|
||||
"If maxlenratio<0.0, its absolute value is interpreted"
|
||||
"as a constant max output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--minlenratio",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Input length ratio to obtain min output length",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ctc_weight",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="CTC weight in joint decoding",
|
||||
)
|
||||
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
|
||||
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
|
||||
group.add_argument("--streaming", type=str2bool, default=False)
|
||||
|
||||
group = parser.add_argument_group("Text converter related")
|
||||
group.add_argument(
|
||||
"--token_type",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
choices=["char", "bpe", None],
|
||||
help="The token type for ASR model. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bpemodel",
|
||||
type=str_or_none,
|
||||
default=None,
|
||||
help="The model path of sentencepiece. "
|
||||
"If not given, refers from the training args",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(cmd=None):
|
||||
print(get_commandline_args(), file=sys.stderr)
|
||||
parser = get_parser()
|
||||
args = parser.parse_args(cmd)
|
||||
kwargs = vars(args)
|
||||
kwargs.pop("config", None)
|
||||
inference(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -27,6 +27,8 @@ def parse_args(mode):
|
||||
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
||||
elif mode == "uniasr":
|
||||
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
||||
elif mode == "mfcca":
|
||||
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
||||
else:
|
||||
raise ValueError("Unknown mode: {}".format(mode))
|
||||
parser = ASRTask.get_parser()
|
||||
|
||||
@ -224,7 +224,7 @@ class IterableESPnetDataset(IterableDataset):
|
||||
name = self.path_name_type_list[i][1]
|
||||
_type = self.path_name_type_list[i][2]
|
||||
if _type == "sound":
|
||||
audio_type = os.path.basename(value).split(".")[1].lower()
|
||||
audio_type = os.path.basename(value).split(".")[-1].lower()
|
||||
if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
|
||||
raise NotImplementedError(
|
||||
f'Not supported audio type: {audio_type}')
|
||||
@ -326,7 +326,7 @@ class IterableESPnetDataset(IterableDataset):
|
||||
# 2.a. Load data streamingly
|
||||
for value, (path, name, _type) in zip(values, self.path_name_type_list):
|
||||
if _type == "sound":
|
||||
audio_type = os.path.basename(value).split(".")[1].lower()
|
||||
audio_type = os.path.basename(value).split(".")[-1].lower()
|
||||
if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
|
||||
raise NotImplementedError(
|
||||
f'Not supported audio type: {audio_type}')
|
||||
|
||||
@ -11,33 +11,23 @@ The installation is the same as [funasr](../../README.md)
|
||||
|
||||
## Export onnx format model
|
||||
Export model from modelscope
|
||||
```python
|
||||
from funasr.export.export_model import ASRModelExportParaformer
|
||||
|
||||
output_dir = "../export" # onnx/torchscripts model save path
|
||||
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
|
||||
export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
```shell
|
||||
python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
|
||||
```
|
||||
|
||||
|
||||
Export model from local path
|
||||
```python
|
||||
export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
```shell
|
||||
python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
|
||||
```
|
||||
|
||||
## Export torchscripts format model
|
||||
Export model from modelscope
|
||||
```python
|
||||
from funasr.export.export_model import ASRModelExportParaformer
|
||||
|
||||
output_dir = "../export" # onnx/torchscripts model save path
|
||||
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
|
||||
export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
```shell
|
||||
python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
|
||||
```
|
||||
|
||||
|
||||
Export model from local path
|
||||
```python
|
||||
|
||||
export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
```shell
|
||||
python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
|
||||
```
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ class ASRModelExportParaformer:
|
||||
feats_dim=560,
|
||||
onnx=False,
|
||||
)
|
||||
logging.info("output dir: {}".format(self.cache_dir))
|
||||
print("output dir: {}".format(self.cache_dir))
|
||||
self.onnx = onnx
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class ASRModelExportParaformer:
|
||||
else:
|
||||
self._export_torchscripts(model, verbose, export_dir)
|
||||
|
||||
logging.info("output dir: {}".format(export_dir))
|
||||
print("output dir: {}".format(export_dir))
|
||||
|
||||
|
||||
def _export_torchscripts(self, model, verbose, path, enc_size=None):
|
||||
@ -117,7 +117,15 @@ class ASRModelExportParaformer:
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
output_dir = "../export"
|
||||
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
|
||||
export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
import sys
|
||||
|
||||
model_path = sys.argv[1]
|
||||
output_dir = sys.argv[2]
|
||||
onnx = sys.argv[3]
|
||||
onnx = onnx.lower()
|
||||
onnx = onnx == 'true'
|
||||
# model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
|
||||
# output_dir = "../export"
|
||||
export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx)
|
||||
export_model.export(model_path)
|
||||
# export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
|
||||
@ -59,7 +59,7 @@ class Paraformer(nn.Module):
|
||||
enc, enc_len = self.encoder(**batch)
|
||||
mask = self.make_pad_mask(enc_len)[:, None, :]
|
||||
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
|
||||
pre_token_length = pre_token_length.round().long()
|
||||
pre_token_length = pre_token_length.round().type(torch.int32)
|
||||
|
||||
decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
|
||||
decoder_out = torch.log_softmax(decoder_out, dim=-1)
|
||||
|
||||
@ -116,53 +116,3 @@ def cif(hidden, alphas, threshold: float):
|
||||
pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
|
||||
list_ls.append(torch.cat([l, pad_l], 0))
|
||||
return torch.stack(list_ls, 0), fires
|
||||
|
||||
|
||||
def CifPredictorV2_test():
|
||||
x = torch.rand([2, 21, 2])
|
||||
x_len = torch.IntTensor([6, 21])
|
||||
|
||||
mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
|
||||
x = x * mask[:, :, None]
|
||||
|
||||
predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
|
||||
# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
|
||||
predictor_scripts.save('test.pt')
|
||||
loaded = torch.jit.load('test.pt')
|
||||
cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
|
||||
# print(cif_output)
|
||||
print(predictor_scripts.code)
|
||||
# predictor = CifPredictorV2(2, 1, 1)
|
||||
# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
|
||||
print(cif_output)
|
||||
|
||||
|
||||
def CifPredictorV2_export_test():
|
||||
x = torch.rand([2, 21, 2])
|
||||
x_len = torch.IntTensor([6, 21])
|
||||
|
||||
mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
|
||||
x = x * mask[:, :, None]
|
||||
|
||||
# predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
|
||||
# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
|
||||
predictor = CifPredictorV2(2, 1, 1)
|
||||
predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
|
||||
predictor_trace.save('test_trace.pt')
|
||||
loaded = torch.jit.load('test_trace.pt')
|
||||
|
||||
x = torch.rand([3, 30, 2])
|
||||
x_len = torch.IntTensor([6, 20, 30])
|
||||
mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
|
||||
x = x * mask[:, :, None]
|
||||
cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
|
||||
print(cif_output)
|
||||
# print(predictor_trace.code)
|
||||
# predictor = CifPredictorV2(2, 1, 1)
|
||||
# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
|
||||
# print(cif_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# CifPredictorV2_test()
|
||||
CifPredictorV2_export_test()
|
||||
322
funasr/models/e2e_asr_mfcca.py
Normal file
322
funasr/models/e2e_asr_mfcca.py
Normal file
@ -0,0 +1,322 @@
|
||||
from contextlib import contextmanager
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
import logging
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.modules.e2e_asr_common import ErrorCalculator
|
||||
from funasr.modules.nets_utils import th_accuracy
|
||||
from funasr.modules.add_sos_eos import add_sos_eos
|
||||
from funasr.losses.label_smoothing_loss import (
|
||||
LabelSmoothingLoss, # noqa: H301
|
||||
)
|
||||
from funasr.models.ctc import CTC
|
||||
from funasr.models.decoder.abs_decoder import AbsDecoder
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||
from funasr.layers.abs_normalize import AbsNormalize
|
||||
from funasr.torch_utils.device_funcs import force_gatherable
|
||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
# Nothing to do if torch<1.6.0
|
||||
@contextmanager
|
||||
def autocast(enabled=True):
|
||||
yield
|
||||
import pdb
|
||||
import random
|
||||
import math
|
||||
class MFCCA(AbsESPnetModel):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
token_list: Union[Tuple[str, ...], List[str]],
|
||||
frontend: Optional[AbsFrontend],
|
||||
specaug: Optional[AbsSpecAug],
|
||||
normalize: Optional[AbsNormalize],
|
||||
preencoder: Optional[AbsPreEncoder],
|
||||
encoder: AbsEncoder,
|
||||
decoder: AbsDecoder,
|
||||
ctc: CTC,
|
||||
rnnt_decoder: None,
|
||||
ctc_weight: float = 0.5,
|
||||
ignore_id: int = -1,
|
||||
lsm_weight: float = 0.0,
|
||||
mask_ratio: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
):
|
||||
assert check_argument_types()
|
||||
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
|
||||
assert rnnt_decoder is None, "Not implemented"
|
||||
|
||||
super().__init__()
|
||||
# note that eos is the same as sos (equivalent ID)
|
||||
self.sos = vocab_size - 1
|
||||
self.eos = vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.ctc_weight = ctc_weight
|
||||
self.token_list = token_list.copy()
|
||||
|
||||
self.mask_ratio = mask_ratio
|
||||
|
||||
|
||||
self.frontend = frontend
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.preencoder = preencoder
|
||||
self.encoder = encoder
|
||||
# we set self.decoder = None in the CTC mode since
|
||||
# self.decoder parameters were never used and PyTorch complained
|
||||
# and threw an Exception in the multi-GPU experiment.
|
||||
# thanks Jeff Farris for pointing out the issue.
|
||||
if ctc_weight == 1.0:
|
||||
self.decoder = None
|
||||
else:
|
||||
self.decoder = decoder
|
||||
if ctc_weight == 0.0:
|
||||
self.ctc = None
|
||||
else:
|
||||
self.ctc = ctc
|
||||
self.rnnt_decoder = rnnt_decoder
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=vocab_size,
|
||||
padding_idx=ignore_id,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
if report_cer or report_wer:
|
||||
self.error_calculator = ErrorCalculator(
|
||||
token_list, sym_space, sym_blank, report_cer, report_wer
|
||||
)
|
||||
else:
|
||||
self.error_calculator = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Frontend + Encoder + Decoder + Calc loss
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
assert text_lengths.dim() == 1, text_lengths.shape
|
||||
# Check that batch_size is unified
|
||||
assert (
|
||||
speech.shape[0]
|
||||
== speech_lengths.shape[0]
|
||||
== text.shape[0]
|
||||
== text_lengths.shape[0]
|
||||
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||
#pdb.set_trace()
|
||||
if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
|
||||
rate_num = random.random()
|
||||
#rate_num = 0.1
|
||||
if(rate_num<=self.mask_ratio):
|
||||
retain_channel = math.ceil(random.random() *8)
|
||||
if(retain_channel>1):
|
||||
speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
|
||||
else:
|
||||
speech = speech[:,:,torch.randperm(8)[0]]
|
||||
#pdb.set_trace()
|
||||
batch_size = speech.shape[0]
|
||||
# for data-parallel
|
||||
text = text[:, : text_lengths.max()]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||
|
||||
# 2a. Attention-decoder branch
|
||||
if self.ctc_weight == 1.0:
|
||||
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||
else:
|
||||
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 2b. CTC branch
|
||||
if self.ctc_weight == 0.0:
|
||||
loss_ctc, cer_ctc = None, None
|
||||
else:
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out, encoder_out_lens, text, text_lengths
|
||||
)
|
||||
|
||||
# 2c. RNN-T branch
|
||||
if self.rnnt_decoder is not None:
|
||||
_ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
|
||||
|
||||
if self.ctc_weight == 0.0:
|
||||
loss = loss_att
|
||||
elif self.ctc_weight == 1.0:
|
||||
loss = loss_ctc
|
||||
else:
|
||||
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||
|
||||
stats = dict(
|
||||
loss=loss.detach(),
|
||||
loss_att=loss_att.detach() if loss_att is not None else None,
|
||||
loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
|
||||
acc=acc_att,
|
||||
cer=cer_att,
|
||||
wer=wer_att,
|
||||
cer_ctc=cer_ctc,
|
||||
)
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def collect_feats(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
|
||||
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||
|
||||
def encode(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
"""
|
||||
with autocast(False):
|
||||
# 1. Extract feats
|
||||
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
|
||||
# 2. Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||
|
||||
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
|
||||
# Pre-encoder, e.g. used for raw input data
|
||||
if self.preencoder is not None:
|
||||
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||
#pdb.set_trace()
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
|
||||
|
||||
assert encoder_out.size(0) == speech.size(0), (
|
||||
encoder_out.size(),
|
||||
speech.size(0),
|
||||
)
|
||||
if(encoder_out.dim()==4):
|
||||
assert encoder_out.size(2) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
else:
|
||||
assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
||||
encoder_out.size(),
|
||||
encoder_out_lens.max(),
|
||||
)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
channel_size = 1
|
||||
return feats, feats_lengths, channel_size
|
||||
|
||||
def _calc_att_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
||||
ys_in_lens = ys_pad_lens + 1
|
||||
|
||||
# 1. Forward decoder
|
||||
decoder_out, _ = self.decoder(
|
||||
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
||||
)
|
||||
|
||||
# 2. Compute attention loss
|
||||
loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
||||
acc_att = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_out_pad,
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
# Compute cer/wer using attention-decoder
|
||||
if self.training or self.error_calculator is None:
|
||||
cer_att, wer_att = None, None
|
||||
else:
|
||||
ys_hat = decoder_out.argmax(dim=-1)
|
||||
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
||||
|
||||
return loss_att, acc_att, cer_att, wer_att
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
if(encoder_out.dim()==4):
|
||||
encoder_out = encoder_out.mean(1)
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
def _calc_rnnt_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
raise NotImplementedError
|
||||
270
funasr/models/encoder/encoder_layer_mfcca.py
Normal file
270
funasr/models/encoder/encoder_layer_mfcca.py
Normal file
@ -0,0 +1,270 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
|
||||
class Encoder_Conformer_Layer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
conv_module (torch.nn.Module): Convolution module instance.
|
||||
`ConvlutionModule` instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
feed_forward_macaron,
|
||||
conv_module,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
cca_pos=0,
|
||||
):
|
||||
"""Construct an Encoder_Conformer_Layer object."""
|
||||
super(Encoder_Conformer_Layer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = LayerNorm(size)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
self.cca_pos = cca_pos
|
||||
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, x_input, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
||||
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
||||
- w/o pos emb: Tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if self.cca_pos<2:
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x = residual + self.dropout(self.conv_module(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
|
||||
can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
|
||||
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
|
||||
can be used as the argument.
|
||||
conv_module (torch.nn.Module): Convolution module instance.
|
||||
`ConvlutionModule` instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn_cros_channel,
|
||||
self_attn_conformer,
|
||||
feed_forward_csa,
|
||||
feed_forward_macaron_csa,
|
||||
conv_module_csa,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
|
||||
self.encoder_cros_channel_atten = self_attn_cros_channel
|
||||
self.encoder_csa = Encoder_Conformer_Layer(
|
||||
size,
|
||||
self_attn_conformer,
|
||||
feed_forward_csa,
|
||||
feed_forward_macaron_csa,
|
||||
conv_module_csa,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
cca_pos=0)
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
|
||||
def forward(self, x_input, mask, channel_size, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
|
||||
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
|
||||
- w/o pos emb: Tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
residual = x
|
||||
x = self.norm_mha(x)
|
||||
t_leng = x.size(1)
|
||||
d_dim = x.size(2)
|
||||
x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
|
||||
x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
|
||||
pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
|
||||
pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
|
||||
x_pad = torch.cat([pad_before,x_new, pad_after], 1)
|
||||
x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
|
||||
x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
|
||||
x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
|
||||
x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
|
||||
x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
|
||||
x_new = x_new.reshape(-1,channel_size,d_dim)
|
||||
x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
|
||||
x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
|
||||
x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
|
||||
x = residual + self.dropout(x_att)
|
||||
if pos_emb is not None:
|
||||
x_input = (x, pos_emb)
|
||||
else:
|
||||
x_input = x
|
||||
x_input, mask = self.encoder_csa(x_input, mask)
|
||||
|
||||
|
||||
return x_input, mask , channel_size
|
||||
450
funasr/models/encoder/mfcca_encoder.py
Normal file
450
funasr/models/encoder/mfcca_encoder.py
Normal file
@ -0,0 +1,450 @@
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
|
||||
from funasr.modules.nets_utils import get_activation
|
||||
from funasr.modules.nets_utils import make_pad_mask
|
||||
from funasr.modules.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
LegacyRelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from funasr.modules.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
LegacyRelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from funasr.modules.layer_norm import LayerNorm
|
||||
from funasr.modules.multi_layer_conv import Conv1dLinear
|
||||
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
|
||||
from funasr.modules.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from funasr.modules.repeat import repeat
|
||||
from funasr.modules.subsampling import Conv2dSubsampling
|
||||
from funasr.modules.subsampling import Conv2dSubsampling2
|
||||
from funasr.modules.subsampling import Conv2dSubsampling6
|
||||
from funasr.modules.subsampling import Conv2dSubsampling8
|
||||
from funasr.modules.subsampling import TooShortUttError
|
||||
from funasr.modules.subsampling import check_short_utt
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
import pdb
|
||||
import math
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, channels).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
return x.transpose(1, 2)
|
||||
|
||||
|
||||
|
||||
class MFCCAEncoder(AbsEncoder):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
input_size (int): Input dimension.
|
||||
output_size (int): Dimention of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
input_layer (Union[str, torch.nn.Module]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
If True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
If False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
rel_pos_type (str): Whether to use the latest relative positional encoding or
|
||||
the legacy one. The legacy relative positional encoding will be deprecated
|
||||
in the future. More Details can be found in
|
||||
https://github.com/espnet/espnet/pull/2816.
|
||||
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
encoder_attn_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 3,
|
||||
macaron_style: bool = False,
|
||||
rel_pos_type: str = "legacy",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
zero_triu: bool = False,
|
||||
cnn_module_kernel: int = 31,
|
||||
padding_idx: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
if rel_pos_type == "legacy":
|
||||
if pos_enc_layer_type == "rel_pos":
|
||||
pos_enc_layer_type = "legacy_rel_pos"
|
||||
if selfattention_layer_type == "rel_selfattn":
|
||||
selfattention_layer_type = "legacy_rel_selfattn"
|
||||
elif rel_pos_type == "latest":
|
||||
assert selfattention_layer_type != "legacy_rel_selfattn"
|
||||
assert pos_enc_layer_type != "legacy_rel_pos"
|
||||
else:
|
||||
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == "legacy_rel_pos":
|
||||
assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
logging.warning(
|
||||
"Using legacy_rel_pos and it will be deprecated in the future."
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size, output_size),
|
||||
torch.nn.LayerNorm(output_size),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d6":
|
||||
self.embed = Conv2dSubsampling6(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d8":
|
||||
self.embed = Conv2dSubsampling8(
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(output_size, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(output_size, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "legacy_rel_selfattn":
|
||||
assert pos_enc_layer_type == "legacy_rel_pos"
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
logging.warning(
|
||||
"Using legacy_rel_selfattn and it will be deprecated in the future."
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
zero_triu,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation)
|
||||
encoder_selfattn_layer_raw = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args_raw = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
output_size,
|
||||
encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
channel_size: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
if (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
|
||||
t_leng = xs_pad.size(1)
|
||||
d_dim = xs_pad.size(2)
|
||||
xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
|
||||
#pdb.set_trace()
|
||||
if(channel_size<8):
|
||||
repeat_num = math.ceil(8/channel_size)
|
||||
xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
|
||||
xs_pad = self.conv1(xs_pad)
|
||||
xs_pad = self.conv2(xs_pad)
|
||||
xs_pad = self.conv3(xs_pad)
|
||||
xs_pad = self.conv4(xs_pad)
|
||||
xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
|
||||
mask_tmp = masks.size(1)
|
||||
masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
def forward_hidden(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
|
||||
ilens (torch.Tensor): Input length (#batch).
|
||||
prev_states (torch.Tensor): Not to be used now.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, L, output_size).
|
||||
torch.Tensor: Output length (#batch).
|
||||
torch.Tensor: Not to be used now.
|
||||
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
if (
|
||||
isinstance(self.embed, Conv2dSubsampling)
|
||||
or isinstance(self.embed, Conv2dSubsampling6)
|
||||
or isinstance(self.embed, Conv2dSubsampling8)
|
||||
):
|
||||
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
|
||||
if short_status:
|
||||
raise TooShortUttError(
|
||||
f"has {xs_pad.size(1)} frames and is too short for subsampling "
|
||||
+ f"(it needs more than {limit_size} frames), return empty results",
|
||||
xs_pad.size(1),
|
||||
limit_size,
|
||||
)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
num_layer = len(self.encoders)
|
||||
for idx, encoder in enumerate(self.encoders):
|
||||
xs_pad, masks = encoder(xs_pad, masks)
|
||||
if idx == num_layer // 2 - 1:
|
||||
hidden_feature = xs_pad
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
hidden_feature = hidden_feature[0]
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
self.hidden_feature = self.after_norm(hidden_feature)
|
||||
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
@ -131,3 +131,128 @@ class DefaultFrontend(AbsFrontend):
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
return input_stft, feats_lens
|
||||
|
||||
|
||||
|
||||
|
||||
class MultiChannelFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: Union[int, str] = 16000,
|
||||
n_fft: int = 512,
|
||||
win_length: int = None,
|
||||
hop_length: int = 128,
|
||||
window: Optional[str] = "hann",
|
||||
center: bool = True,
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
|
||||
apply_stft: bool = True,
|
||||
frame_length: int = None,
|
||||
frame_shift: int = None,
|
||||
lfr_m: int = None,
|
||||
lfr_n: int = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
if isinstance(fs, str):
|
||||
fs = humanfriendly.parse_size(fs)
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
self.hop_length = hop_length
|
||||
|
||||
if apply_stft:
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
window=window,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
)
|
||||
else:
|
||||
self.stft = None
|
||||
self.apply_stft = apply_stft
|
||||
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs,
|
||||
n_fft=n_fft,
|
||||
n_mels=n_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
htk=htk,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
self.frontend_type = "multichannelfrontend"
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
#import pdb;pdb.set_trace()
|
||||
if self.stft is not None:
|
||||
input_stft, feats_lens = self._compute_stft(input, input_lengths)
|
||||
else:
|
||||
if isinstance(input, ComplexTensor):
|
||||
input_stft = input
|
||||
else:
|
||||
input_stft = ComplexTensor(input[..., 0], input[..., 1])
|
||||
feats_lens = input_lengths
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
bt = input_feats.size(0)
|
||||
if input_feats.dim() ==4:
|
||||
channel_size = input_feats.size(2)
|
||||
# batch * channel * T * D
|
||||
#pdb.set_trace()
|
||||
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
|
||||
# input_feats = input_feats.transpose(1,2)
|
||||
# batch * channel
|
||||
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
|
||||
else:
|
||||
channel_size = 1
|
||||
return input_feats, feats_lens, channel_size
|
||||
|
||||
def _compute_stft(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
return input_stft, feats_lens
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
**/__pycache__
|
||||
*.onnx
|
||||
*.pyc
|
||||
@ -20,9 +20,19 @@ cd funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
3. Export the model.
|
||||
- Export your model([docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
|
||||
|
||||
- Export model from modelscope
|
||||
```shell
|
||||
python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
|
||||
```
|
||||
- Export model from local path
|
||||
```shell
|
||||
python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
|
||||
```
|
||||
- More details ref to ([docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
|
||||
|
||||
4. Run the demo.
|
||||
|
||||
5. Run the demo.
|
||||
- Model_dir: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`.
|
||||
- Input: wav formt file, support formats: `str, np.ndarray, List[str]`
|
||||
- Output: `List[str]`: recognition result.
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
from paraformer_onnx import Paraformer
|
||||
|
||||
model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model = Paraformer(model_dir, batch_size=1)
|
||||
|
||||
wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
|
||||
|
||||
result = model(wav_path)
|
||||
print(result)
|
||||
0
funasr/runtime/python/torchscripts/__init__.py
Normal file
0
funasr/runtime/python/torchscripts/__init__.py
Normal file
@ -71,7 +71,7 @@ from funasr.utils.types import str2bool
|
||||
from funasr.utils.types import str2triple_str
|
||||
from funasr.utils.types import str_or_int
|
||||
from funasr.utils.types import str_or_none
|
||||
from funasr.utils.wav_utils import calc_shape, generate_data_list
|
||||
from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
|
||||
from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
|
||||
|
||||
try:
|
||||
@ -1153,6 +1153,14 @@ class AbsTask(ABC):
|
||||
if args.batch_bins is not None:
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
# filter samples if wav.scp and text are mismatch
|
||||
if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
|
||||
if not args.simple_ddp or distributed_option.dist_rank == 0:
|
||||
filter_wav_text(args.data_dir, args.train_set)
|
||||
filter_wav_text(args.data_dir, args.dev_set)
|
||||
if args.simple_ddp:
|
||||
dist.barrier()
|
||||
|
||||
if args.train_shape_file is None and args.dataset_type == "small":
|
||||
if not args.simple_ddp or distributed_option.dist_rank == 0:
|
||||
calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
|
||||
|
||||
@ -40,6 +40,7 @@ from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
||||
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
||||
from funasr.models.e2e_asr import ESPnetASRModel
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
|
||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||
from funasr.models.e2e_uni_asr import UniASR
|
||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||
@ -47,8 +48,10 @@ from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.default import MultiChannelFrontend
|
||||
from funasr.models.frontend.fused import FusedFrontends
|
||||
from funasr.models.frontend.s3prl import S3prlFrontend
|
||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||
@ -86,6 +89,7 @@ frontend_choices = ClassChoices(
|
||||
s3prl=S3prlFrontend,
|
||||
fused=FusedFrontends,
|
||||
wav_frontend=WavFrontend,
|
||||
multichannelfrontend=MultiChannelFrontend,
|
||||
),
|
||||
type_check=AbsFrontend,
|
||||
default="default",
|
||||
@ -119,6 +123,7 @@ model_choices = ClassChoices(
|
||||
paraformer_bert=ParaformerBert,
|
||||
bicif_paraformer=BiCifParaformer,
|
||||
contextual_paraformer=ContextualParaformer,
|
||||
mfcca=MFCCA,
|
||||
),
|
||||
type_check=AbsESPnetModel,
|
||||
default="asr",
|
||||
@ -142,6 +147,7 @@ encoder_choices = ClassChoices(
|
||||
sanm=SANMEncoder,
|
||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
mfcca_enc=MFCCAEncoder,
|
||||
),
|
||||
type_check=AbsEncoder,
|
||||
default="rnn",
|
||||
@ -1106,3 +1112,135 @@ class ASRTaskParaformer(ASRTask):
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
|
||||
|
||||
class ASRTaskMFCCA(ASRTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --preencoder and --preencoder_conf
|
||||
preencoder_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args: argparse.Namespace):
|
||||
assert check_argument_types()
|
||||
if isinstance(args.token_list, str):
|
||||
with open(args.token_list, encoding="utf-8") as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
|
||||
# Overwriting token_list to keep it as "portable".
|
||||
args.token_list = list(token_list)
|
||||
elif isinstance(args.token_list, (tuple, list)):
|
||||
token_list = list(args.token_list)
|
||||
else:
|
||||
raise RuntimeError("token_list must be str or list")
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
|
||||
# 1. frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend':
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
else:
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# 2. Data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# 3. Normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# 4. Pre-encoder input block
|
||||
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||
if getattr(args, "preencoder", None) is not None:
|
||||
preencoder_class = preencoder_choices.get_class(args.preencoder)
|
||||
preencoder = preencoder_class(**args.preencoder_conf)
|
||||
input_size = preencoder.output_size()
|
||||
else:
|
||||
preencoder = None
|
||||
|
||||
# 5. Encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||
|
||||
# 7. Decoder
|
||||
decoder_class = decoder_choices.get_class(args.decoder)
|
||||
decoder = decoder_class(
|
||||
vocab_size=vocab_size,
|
||||
encoder_output_size=encoder.output_size(),
|
||||
**args.decoder_conf,
|
||||
)
|
||||
|
||||
# 8. CTC
|
||||
ctc = CTC(
|
||||
odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
|
||||
)
|
||||
|
||||
|
||||
# 10. Build model
|
||||
try:
|
||||
model_class = model_choices.get_class(args.model)
|
||||
except AttributeError:
|
||||
model_class = model_choices.get_class("asr")
|
||||
|
||||
rnnt_decoder = None
|
||||
|
||||
# 8. Build model
|
||||
model = model_class(
|
||||
vocab_size=vocab_size,
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
preencoder=preencoder,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
ctc=ctc,
|
||||
rnnt_decoder=rnnt_decoder,
|
||||
token_list=token_list,
|
||||
**args.model_conf,
|
||||
)
|
||||
|
||||
# 11. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
assert check_return_type(model)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@ -287,3 +287,35 @@ def generate_data_list(data_dir, dataset, nj=100):
|
||||
wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
|
||||
text_path = os.path.join(split_dir, str(i + 1), "text")
|
||||
f_data.write(wav_path + " " + text_path + "\n")
|
||||
|
||||
def filter_wav_text(data_dir, dataset):
|
||||
wav_file = os.path.join(data_dir,dataset,"wav.scp")
|
||||
text_file = os.path.join(data_dir, dataset, "text")
|
||||
with open(wav_file) as f_wav, open(text_file) as f_text:
|
||||
wav_lines = f_wav.readlines()
|
||||
text_lines = f_text.readlines()
|
||||
os.rename(wav_file, "{}.bak".format(wav_file))
|
||||
os.rename(text_file, "{}.bak".format(text_file))
|
||||
wav_dict = {}
|
||||
for line in wav_lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
sample_name, wav_path = parts
|
||||
wav_dict[sample_name] = wav_path
|
||||
text_dict = {}
|
||||
for line in text_lines:
|
||||
parts = line.strip().split(" ", 1)
|
||||
if len(parts) < 2:
|
||||
continue
|
||||
sample_name, txt = parts
|
||||
text_dict[sample_name] = txt
|
||||
filter_count = 0
|
||||
with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
|
||||
for sample_name, wav_path in wav_dict.items():
|
||||
if sample_name in text_dict.keys():
|
||||
f_wav.write(sample_name + " " + wav_path + "\n")
|
||||
f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
|
||||
else:
|
||||
filter_count += 1
|
||||
print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))
|
||||
Loading…
Reference in New Issue
Block a user