dev_yufan egs_mfcca update

This commit is contained in:
yufan-aslp 2023-02-14 14:54:04 +08:00
parent b3bfea34ad
commit a319bf6e5f
13 changed files with 2382 additions and 0 deletions

View File

@ -0,0 +1,53 @@
# ModelScope Model
## How to finetune and infer using a pretrained Paraformer-large Model
### Finetune
- Modify finetune training related parameters in `finetune.py`
- <strong>output_dir:</strong> # result dir
- <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
- <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
- <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
- <strong>max_epoch:</strong> # number of training epoch
- <strong>lr:</strong> # learning rate
- Then you can run the pipeline to finetune with:
```python
python finetune.py
```
### Inference
Or you can use the finetuned model for inference directly.
- Setting parameters in `infer.py`
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- <strong>output_dir:</strong> # result dir
- <strong>ngpu:</strong> # the number of GPUs for decoding
- <strong>njob:</strong> # the number of jobs for each GPU
- Then you can run the pipeline to infer with:
```python
python infer.py
```
- Results
The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
### Inference using local finetuned model
- Modify inference related parameters in `infer_after_finetune.py`
- <strong>output_dir:</strong> # result dir
- <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
- <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
- Then you can run the pipeline to finetune with:
```python
python infer_after_finetune.py
```
- Results
The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.

View File

@ -0,0 +1,40 @@
# Paraformer-Large
- Model link: <https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
- Model size: 45M
# Environments
- date: `Tue Feb 13 20:13:22 CST 2023`
- python version: `3.7.12`
- FunASR version: `0.1.0`
- pytorch version: `pytorch 1.7.0`
- Git hash: ``
- Commit date: ``
# Beachmark Results
## result (paper)
beam=20CER toolhttps://github.com/yufan-aslp/AliMeeting
| model | Para (M) | Data (hrs) | Eval (CER%) | Test (CER%) |
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
| MFCCA | 45 | 917 | 16.1 | 17.5 |
## resultmodelscope
beam=10
with separating character (src)
| model | Para (M) | Data (hrs) | Eval_sp (CER%) | Test_sp (CER%) |
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
| MFCCA | 45 | 917 | 17.1 | 18.6 |
without separating character (src)
| model | Para (M) | Data (hrs) | Eval_nosp (CER%) | Test_nosp (CER%) |
|:-------------------:|:---------:|:---------:|:---------:| :---------:|
| MFCCA | 45 | 917 | 16.4 | 18.0 |
## 偏差
Considering the differences of the CER calculation tool and decoding beam size, the results of CER are biased (<0.5%).

View File

@ -0,0 +1,35 @@
import os
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from funasr.datasets.ms_dataset import MsDataset
from funasr.utils.modelscope_param import modelscope_args
def modelscope_finetune(params):
if not os.path.exists(params.output_dir):
os.makedirs(params.output_dir, exist_ok=True)
# dataset split ["train", "validation"]
ds_dict = MsDataset.load(params.data_path)
kwargs = dict(
model=params.model,
model_revision=params.model_revision,
data_dir=ds_dict,
dataset_type=params.dataset_type,
work_dir=params.output_dir,
batch_bins=params.batch_bins,
max_epoch=params.max_epoch,
lr=params.lr)
trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
params = modelscope_args(model="yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
params.output_dir = "./checkpoint" # m模型保存路径
params.data_path = "./example_data/" # 数据路径
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 1000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 10 # 最大训练轮数
params.lr = 0.0001 # 设置学习率
params.model_revision = 'v2.0.0'
modelscope_finetune(params)

View File

@ -0,0 +1,103 @@
import os
import shutil
from multiprocessing import Pool
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from funasr.utils.compute_wer import compute_wer
import pdb;
def modelscope_infer_core(output_dir, split_dir, njob, idx):
output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
gpu_id = (int(idx) - 1) // njob
if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
else:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
inference_pipline = pipeline(
task=Tasks.auto_speech_recognition,
model='yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
model_revision='v2.0.0',
output_dir=output_dir_job,
batch_size=1,
)
audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
inference_pipline(audio_in=audio_in)
def modelscope_infer(params):
# prepare for multi-GPU decoding
ngpu = params["ngpu"]
njob = params["njob"]
output_dir = params["output_dir"]
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.mkdir(output_dir)
split_dir = os.path.join(output_dir, "split")
os.mkdir(split_dir)
nj = ngpu * njob
wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
with open(wav_scp_file) as f:
lines = f.readlines()
num_lines = len(lines)
num_job_lines = num_lines // nj
start = 0
for i in range(nj):
end = start + num_job_lines
file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
with open(file, "w") as f:
if i == nj - 1:
f.writelines(lines[start:])
else:
f.writelines(lines[start:end])
start = end
p = Pool(nj)
for i in range(nj):
p.apply_async(modelscope_infer_core,
args=(output_dir, split_dir, njob, str(i + 1)))
p.close()
p.join()
# combine decoding results
best_recog_path = os.path.join(output_dir, "1best_recog")
os.mkdir(best_recog_path)
files = ["text", "token", "score"]
for file in files:
with open(os.path.join(best_recog_path, file), "w") as f:
for i in range(nj):
job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
with open(job_file) as f_job:
lines = f_job.readlines()
f.writelines(lines)
# If text exists, compute CER
text_in = os.path.join(params["data_dir"], "text")
if os.path.exists(text_in):
text_proc_file = os.path.join(best_recog_path, "token")
text_proc_file2 = os.path.join(best_recog_path, "token_nosep")
with open(text_proc_file, 'r') as hyp_reader:
with open(text_proc_file2, 'w') as hyp_writer:
for line in hyp_reader:
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
hyp_writer.write(new_context+'\n')
text_in2 = os.path.join(best_recog_path, "ref_text_nosep")
with open(text_in, 'r') as ref_reader:
with open(text_in2, 'w') as ref_writer:
for line in ref_reader:
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
ref_writer.write(new_context+'\n')
compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.sp.cer"))
compute_wer(text_in2, text_proc_file2, os.path.join(best_recog_path, "text.nosp.cer"))
if __name__ == "__main__":
params = {}
params["data_dir"] = "./example_data/validation"
params["output_dir"] = "./output_dir"
params["ngpu"] = 1
params["njob"] = 1
modelscope_infer(params)

View File

@ -0,0 +1,67 @@
import json
import os
import shutil
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from funasr.utils.compute_wer import compute_wer
def modelscope_infer_after_finetune(params):
# prepare for decoding
pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
for file_name in params["required_files"]:
if file_name == "configuration.json":
with open(os.path.join(pretrained_model_path, file_name)) as f:
config_dict = json.load(f)
config_dict["model"]["am_model_name"] = params["decoding_model_name"]
with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
json.dump(config_dict, f, indent=4, separators=(',', ': '))
else:
shutil.copy(os.path.join(pretrained_model_path, file_name),
os.path.join(params["output_dir"], file_name))
decoding_path = os.path.join(params["output_dir"], "decode_results")
if os.path.exists(decoding_path):
shutil.rmtree(decoding_path)
os.mkdir(decoding_path)
# decoding
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model=params["output_dir"],
output_dir=decoding_path,
batch_size=1
)
audio_in = os.path.join(params["data_dir"], "wav.scp")
inference_pipeline(audio_in=audio_in)
# computer CER if GT text is set
text_in = os.path.join(params["data_dir"], "text")
if text_in is not None:
text_proc_file = os.path.join(decoding_path, "1best_recog/token")
text_proc_file2 = os.path.join(decoding_path, "1best_recog/token_nosep")
with open(text_proc_file, 'r') as hyp_reader:
with open(text_proc_file2, 'w') as hyp_writer:
for line in hyp_reader:
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
hyp_writer.write(new_context+'\n')
text_in2 = os.path.join(decoding_path, "1best_recog/ref_text_nosep")
with open(text_in, 'r') as ref_reader:
with open(text_in2, 'w') as ref_writer:
for line in ref_reader:
new_context = line.strip().replace("src","").replace(" "," ").replace(" "," ").strip()
ref_writer.write(new_context+'\n')
compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.sp.cer"))
compute_wer(text_in2, text_proc_file2, os.path.join(decoding_path, "text.nosp.cer"))
if __name__ == '__main__':
params = {}
params["modelscope_model_name"] = "yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
params["output_dir"] = "./checkpoint"
params["data_dir"] = "./example_data/validation"
params["decoding_model_name"] = "valid.acc.ave.pth"
modelscope_infer_after_finetune(params)

View File

@ -228,6 +228,9 @@ def inference_launch(**kwargs):
elif mode == "vad":
from funasr.bin.vad_inference import inference_modelscope
return inference_modelscope(**kwargs)
elif mode == "mfcca":
from funasr.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
@ -253,6 +256,9 @@ def inference_launch_funasr(**kwargs):
elif mode == "vad":
from funasr.bin.vad_inference import inference
return inference(**kwargs)
elif mode == "mfcca":
from funasr.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None

View File

@ -0,0 +1,771 @@
#!/usr/bin/env python3
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import logging
import sys
from pathlib import Path
from typing import Any
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
from typing import Dict
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.fileio.datadir_writer import DatadirWriter
from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
from funasr.modules.beam_search.beam_search import BeamSearch
from funasr.modules.beam_search.beam_search import Hypothesis
from funasr.modules.scorers.ctc import CTCPrefixScorer
from funasr.modules.scorers.length_bonus import LengthBonus
from funasr.modules.scorers.scorer_interface import BatchScorerInterface
from funasr.modules.subsampling import TooShortUttError
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
import pdb
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
class Speech2Text:
"""Speech2Text class
Examples:
>>> import soundfile
>>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2text(audio)
[(text, token, token_int, hypothesis object), ...]
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
batch_size: int = 1,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
streaming: bool = False,
**kwargs,
):
assert check_argument_types()
# 1. Build ASR model
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
decoder = asr_model.decoder
ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
token_list = asr_model.token_list
scorers.update(
decoder=decoder,
ctc=ctc,
length_bonus=LengthBonus(len(token_list)),
)
# 2. Build Language model
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
lm.to(device)
scorers["lm"] = lm.lm
# 3. Build ngram model
# ngram is not supported now
ngram = None
scorers["ngram"] = ngram
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
weights = dict(
decoder=1.0 - ctc_weight,
ctc=ctc_weight,
lm=lm_weight,
ngram=ngram_weight,
length_bonus=penalty,
)
beam_search = BeamSearch(
beam_size=beam_size,
weights=weights,
scorers=scorers,
sos=asr_model.sos,
eos=asr_model.eos,
vocab_size=len(token_list),
token_list=token_list,
pre_beam_score_key=None if ctc_weight == 1.0 else "full",
)
#beam_search.__class__ = BatchBeamSearch
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.beam_search = beam_search
self.beam_search_transducer = beam_search_transducer
self.maxlenratio = maxlenratio
self.minlenratio = minlenratio
self.device = device
self.dtype = dtype
self.nbest = nbest
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[
Tuple[
Optional[str],
List[str],
List[int],
Union[Hypothesis],
]
]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
#speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
speech = speech.to(getattr(torch, self.dtype))
# lenghts: (1,)
lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
batch = {"speech": speech, "speech_lengths": lengths}
# a. To device
batch = to_device(batch, device=self.device)
# b. Forward Encoder
enc, _ = self.asr_model.encode(**batch)
assert len(enc) == 1, len(enc)
# c. Passed the encoder result and the beam search
nbest_hyps = self.beam_search(
x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
)
nbest_hyps = nbest_hyps[: self.nbest]
results = []
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp))
assert check_return_type(results)
return results
# def inference(
# maxlenratio: float,
# minlenratio: float,
# batch_size: int,
# beam_size: int,
# ngpu: int,
# ctc_weight: float,
# lm_weight: float,
# penalty: float,
# log_level: Union[int, str],
# data_path_and_name_and_type,
# asr_train_config: Optional[str],
# asr_model_file: Optional[str],
# cmvn_file: Optional[str] = None,
# lm_train_config: Optional[str] = None,
# lm_file: Optional[str] = None,
# token_type: Optional[str] = None,
# key_file: Optional[str] = None,
# word_lm_train_config: Optional[str] = None,
# bpemodel: Optional[str] = None,
# allow_variable_data_keys: bool = False,
# streaming: bool = False,
# output_dir: Optional[str] = None,
# dtype: str = "float32",
# seed: int = 0,
# ngram_weight: float = 0.9,
# nbest: int = 1,
# num_workers: int = 1,
# **kwargs,
# ):
# assert check_argument_types()
# if batch_size > 1:
# raise NotImplementedError("batch decoding is not implemented")
# if word_lm_train_config is not None:
# raise NotImplementedError("Word LM is not implemented")
# if ngpu > 1:
# raise NotImplementedError("only single GPU decoding is supported")
#
# logging.basicConfig(
# level=log_level,
# format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
# )
#
# if ngpu >= 1 and torch.cuda.is_available():
# device = "cuda"
# else:
# device = "cpu"
#
# # 1. Set random-seed
# set_all_random_seed(seed)
#
# # 2. Build speech2text
# speech2text_kwargs = dict(
# asr_train_config=asr_train_config,
# asr_model_file=asr_model_file,
# cmvn_file=cmvn_file,
# lm_train_config=lm_train_config,
# lm_file=lm_file,
# token_type=token_type,
# bpemodel=bpemodel,
# device=device,
# maxlenratio=maxlenratio,
# minlenratio=minlenratio,
# dtype=dtype,
# beam_size=beam_size,
# ctc_weight=ctc_weight,
# lm_weight=lm_weight,
# ngram_weight=ngram_weight,
# penalty=penalty,
# nbest=nbest,
# streaming=streaming,
# )
# logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
# speech2text = Speech2Text(**speech2text_kwargs)
#
# # 3. Build data-iterator
# loader = ASRTask.build_streaming_iterator(
# data_path_and_name_and_type,
# dtype=dtype,
# batch_size=batch_size,
# key_file=key_file,
# num_workers=num_workers,
# preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
# collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
# allow_variable_data_keys=allow_variable_data_keys,
# inference=True,
# )
#
# finish_count = 0
# file_count = 1
# # 7 .Start for-loop
# # FIXME(kamo): The output format should be discussed about
# asr_result_list = []
# if output_dir is not None:
# writer = DatadirWriter(output_dir)
# else:
# writer = None
#
# for keys, batch in loader:
# assert isinstance(batch, dict), type(batch)
# assert all(isinstance(s, str) for s in keys), keys
# _bs = len(next(iter(batch.values())))
# assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
#
# # N-best list of (text, token, token_int, hyp_object)
# try:
# results = speech2text(**batch)
# except TooShortUttError as e:
# logging.warning(f"Utterance {keys} {e}")
# hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
# results = [[" ", ["<space>"], [2], hyp]] * nbest
#
# # Only supporting batch_size==1
# key = keys[0]
# for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# # Create a directory: outdir/{n}best_recog
# if writer is not None:
# ibest_writer = writer[f"{n}best_recog"]
#
# # Write the result to each file
# ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
# ibest_writer["score"][key] = str(hyp.score)
#
# if text is not None:
# text_postprocessed = postprocess_utils.sentence_postprocess(token)
# item = {'key': key, 'value': text_postprocessed}
# asr_result_list.append(item)
# finish_count += 1
# asr_utils.print_progress(finish_count / file_count)
# if writer is not None:
# ibest_writer["text"][key] = text
# return asr_result_list
def inference(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
**kwargs,
):
inference_pipeline = inference_modelscope(
maxlenratio=maxlenratio,
minlenratio=minlenratio,
batch_size=batch_size,
beam_size=beam_size,
ngpu=ngpu,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
penalty=penalty,
log_level=log_level,
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
raw_inputs=raw_inputs,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
key_file=key_file,
word_lm_train_config=word_lm_train_config,
bpemodel=bpemodel,
allow_variable_data_keys=allow_variable_data_keys,
streaming=streaming,
output_dir=output_dir,
dtype=dtype,
seed=seed,
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
def inference_modelscope(
maxlenratio: float,
minlenratio: float,
batch_size: int,
beam_size: int,
ngpu: int,
ctc_weight: float,
lm_weight: float,
penalty: float,
log_level: Union[int, str],
# data_path_and_name_and_type,
asr_train_config: Optional[str],
asr_model_file: Optional[str],
cmvn_file: Optional[str] = None,
lm_train_config: Optional[str] = None,
lm_file: Optional[str] = None,
token_type: Optional[str] = None,
key_file: Optional[str] = None,
word_lm_train_config: Optional[str] = None,
bpemodel: Optional[str] = None,
allow_variable_data_keys: bool = False,
streaming: bool = False,
output_dir: Optional[str] = None,
dtype: str = "float32",
seed: int = 0,
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
**kwargs,
):
assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
if word_lm_train_config is not None:
raise NotImplementedError("Word LM is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
logging.basicConfig(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)
# 2. Build speech2text
speech2text_kwargs = dict(
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
maxlenratio=maxlenratio,
minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
ctc_weight=ctc_weight,
lm_weight=lm_weight,
ngram_weight=ngram_weight,
penalty=penalty,
nbest=nbest,
streaming=streaming,
)
logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = ASRTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
finish_count = 0
file_count = 1
# 7 .Start for-loop
# FIXME(kamo): The output format should be discussed about
asr_result_list = []
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
if output_path is not None:
writer = DatadirWriter(output_path)
else:
writer = None
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# N-best list of (text, token, token_int, hyp_object)
try:
results = speech2text(**batch)
except TooShortUttError as e:
logging.warning(f"Utterance {keys} {e}")
hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
results = [[" ", ["<space>"], [2], hyp]] * nbest
# Only supporting batch_size==1
key = keys[0]
for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
# Create a directory: outdir/{n}best_recog
if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
# Write the result to each file
ibest_writer["token"][key] = " ".join(token)
# ibest_writer["token_int"][key] = " ".join(map(str, token_int))
ibest_writer["score"][key] = str(hyp.score)
if text is not None:
text_postprocessed = postprocess_utils.sentence_postprocess(token)
item = {'key': key, 'value': text_postprocessed}
asr_result_list.append(item)
finish_count += 1
asr_utils.print_progress(finish_count / file_count)
if writer is not None:
ibest_writer["text"][key] = text
return asr_result_list
return _forward
def set_parameters(language: str = None,
sample_rate: Union[int, Dict[Any, int]] = None):
if language is not None:
global global_asr_language
global_asr_language = language
if sample_rate is not None:
global global_sample_rate
global_sample_rate = sample_rate
def get_parser():
parser = config_argparse.ArgumentParser(
description="ASR Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Note(kamo): Use '_' instead of '-' as separator.
# '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
default="INFO",
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
help="The verbose level of logging",
)
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument(
"--ngpu",
type=int,
default=0,
help="The number of gpus. 0 indicates CPU mode",
)
parser.add_argument(
"--gpuid_list",
type=str,
default="",
help="The visible gpus",
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--dtype",
default="float32",
choices=["float16", "float32", "float64"],
help="Data type",
)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="The number of workers used for DataLoader",
)
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
required=False,
action="append",
)
group.add_argument("--raw_inputs", type=list, default=None)
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--asr_train_config",
type=str,
help="ASR training configuration",
)
group.add_argument(
"--asr_model_file",
type=str,
help="ASR model parameter file",
)
group.add_argument(
"--cmvn_file",
type=str,
help="Global cmvn file",
)
group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
)
group.add_argument(
"--lm_file",
type=str,
help="LM parameter file",
)
group.add_argument(
"--word_lm_train_config",
type=str,
help="Word LM training configuration",
)
group.add_argument(
"--word_lm_file",
type=str,
help="Word LM parameter file",
)
group.add_argument(
"--ngram_file",
type=str,
help="N-gram parameter file",
)
group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
group.add_argument(
"--batch_size",
type=int,
default=1,
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
group.add_argument("--beam_size", type=int, default=20, help="Beam size")
group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
group.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain max output length. "
"If maxlenratio=0.0 (default), it uses a end-detect "
"function "
"to automatically find maximum hypothesis lengths."
"If maxlenratio<0.0, its absolute value is interpreted"
"as a constant max output length",
)
group.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
group.add_argument(
"--ctc_weight",
type=float,
default=0.5,
help="CTC weight in joint decoding",
)
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
group = parser.add_argument_group("Text converter related")
group.add_argument(
"--token_type",
type=str_or_none,
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
"If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
"If not given, refers from the training args",
)
return parser
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
parser = get_parser()
args = parser.parse_args(cmd)
kwargs = vars(args)
kwargs.pop("config", None)
inference(**kwargs)
if __name__ == "__main__":
main()

View File

@ -27,6 +27,8 @@ def parse_args(mode):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
elif mode == "uniasr":
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
elif mode == "mfcca":
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
else:
raise ValueError("Unknown mode: {}".format(mode))
parser = ASRTask.get_parser()

View File

@ -0,0 +1,322 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import logging
import torch
from typeguard import check_argument_types
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.losses.label_smoothing_loss import (
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
# Nothing to do if torch<1.6.0
@contextmanager
def autocast(enabled=True):
yield
import pdb
import random
import math
class MFCCA(AbsESPnetModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
decoder: AbsDecoder,
ctc: CTC,
rnnt_decoder: None,
ctc_weight: float = 0.5,
ignore_id: int = -1,
lsm_weight: float = 0.0,
mask_ratio: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert rnnt_decoder is None, "Not implemented"
super().__init__()
# note that eos is the same as sos (equivalent ID)
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.vocab_size = vocab_size
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.token_list = token_list.copy()
self.mask_ratio = mask_ratio
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
self.preencoder = preencoder
self.encoder = encoder
# we set self.decoder = None in the CTC mode since
# self.decoder parameters were never used and PyTorch complained
# and threw an Exception in the multi-GPU experiment.
# thanks Jeff Farris for pointing out the issue.
if ctc_weight == 1.0:
self.decoder = None
else:
self.decoder = decoder
if ctc_weight == 0.0:
self.ctc = None
else:
self.ctc = ctc
self.rnnt_decoder = rnnt_decoder
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
padding_idx=ignore_id,
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
if report_cer or report_wer:
self.error_calculator = ErrorCalculator(
token_list, sym_space, sym_blank, report_cer, report_wer
)
else:
self.error_calculator = None
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
speech.shape[0]
== speech_lengths.shape[0]
== text.shape[0]
== text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
#pdb.set_trace()
if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
rate_num = random.random()
#rate_num = 0.1
if(rate_num<=self.mask_ratio):
retain_channel = math.ceil(random.random() *8)
if(retain_channel>1):
speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
else:
speech = speech[:,:,torch.randperm(8)[0]]
#pdb.set_trace()
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# 2a. Attention-decoder branch
if self.ctc_weight == 1.0:
loss_att, acc_att, cer_att, wer_att = None, None, None, None
else:
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 2b. CTC branch
if self.ctc_weight == 0.0:
loss_ctc, cer_ctc = None, None
else:
loss_ctc, cer_ctc = self._calc_ctc_loss(
encoder_out, encoder_out_lens, text, text_lengths
)
# 2c. RNN-T branch
if self.rnnt_decoder is not None:
_ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
if self.ctc_weight == 0.0:
loss = loss_att
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
stats = dict(
loss=loss.detach(),
loss_att=loss_att.detach() if loss_att is not None else None,
loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
acc=acc_att,
cer=cer_att,
wer=wer_att,
cer_ctc=cer_ctc,
)
# force_gatherable: to-device and to-tensor if scalar for DataParallel
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
def collect_feats(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
"""
with autocast(False):
# 1. Extract feats
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
# 2. Data augmentation
if self.specaug is not None and self.training:
feats, feats_lengths = self.specaug(feats, feats_lengths)
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
#pdb.set_trace()
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
if(encoder_out.dim()==4):
assert encoder_out.size(2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
else:
assert encoder_out.size(1) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
)
return encoder_out, encoder_out_lens
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
speech = speech[:, : speech_lengths.max()]
if self.frontend is not None:
# Frontend
# e.g. STFT and Feature extract
# data_loader may send time-domain signal in this case
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
channel_size = 1
return feats, feats_lengths, channel_size
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
# 1. Forward decoder
decoder_out, _ = self.decoder(
encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
# Compute cer/wer using attention-decoder
if self.training or self.error_calculator is None:
cer_att, wer_att = None, None
else:
ys_hat = decoder_out.argmax(dim=-1)
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
if(encoder_out.dim()==4):
encoder_out = encoder_out.mean(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
# Calc CER using CTC
cer_ctc = None
if not self.training and self.error_calculator is not None:
ys_hat = self.ctc.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
def _calc_rnnt_loss(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
):
raise NotImplementedError

View File

@ -0,0 +1,270 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Northwestern Polytechnical University (Pengcheng Guo)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Encoder self-attention layer definition."""
import torch
from torch import nn
from funasr.modules.layer_norm import LayerNorm
from torch.autograd import Variable
class Encoder_Conformer_Layer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn,
feed_forward,
feed_forward_macaron,
conv_module,
dropout_rate,
normalize_before=True,
concat_after=False,
cca_pos=0,
):
"""Construct an Encoder_Conformer_Layer object."""
super(Encoder_Conformer_Layer, self).__init__()
self.self_attn = self_attn
self.feed_forward = feed_forward
self.feed_forward_macaron = feed_forward_macaron
self.conv_module = conv_module
self.norm_ff = LayerNorm(size) # for the FNN module
self.norm_mha = LayerNorm(size) # for the MHA module
if feed_forward_macaron is not None:
self.norm_ff_macaron = LayerNorm(size)
self.ff_scale = 0.5
else:
self.ff_scale = 1.0
if self.conv_module is not None:
self.norm_conv = LayerNorm(size) # for the CNN module
self.norm_final = LayerNorm(size) # for the final output of the block
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.cca_pos = cca_pos
if self.concat_after:
self.concat_linear = nn.Linear(size + size, size)
def forward(self, x_input, mask, cache=None):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
# whether to use macaron style
if self.feed_forward_macaron is not None:
residual = x
if self.normalize_before:
x = self.norm_ff_macaron(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
if not self.normalize_before:
x = self.norm_ff_macaron(x)
# multi-headed self-attention module
residual = x
if self.normalize_before:
x = self.norm_mha(x)
if cache is None:
x_q = x
else:
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
x_q = x[:, -1:, :]
residual = residual[:, -1:, :]
mask = None if mask is None else mask[:, -1:, :]
if self.cca_pos<2:
if pos_emb is not None:
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
else:
x_att = self.self_attn(x_q, x, x, mask)
if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)
# convolution module
if self.conv_module is not None:
residual = x
if self.normalize_before:
x = self.norm_conv(x)
x = residual + self.dropout(self.conv_module(x))
if not self.normalize_before:
x = self.norm_conv(x)
# feed forward module
residual = x
if self.normalize_before:
x = self.norm_ff(x)
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
if not self.normalize_before:
x = self.norm_ff(x)
if self.conv_module is not None:
x = self.norm_final(x)
if cache is not None:
x = torch.cat([cache, x], dim=1)
if pos_emb is not None:
return (x, pos_emb), mask
return x, mask
class EncoderLayer(nn.Module):
"""Encoder layer module.
Args:
size (int): Input dimension.
self_attn (torch.nn.Module): Self-attention module instance.
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
can be used as the argument.
feed_forward (torch.nn.Module): Feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
`PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
can be used as the argument.
conv_module (torch.nn.Module): Convolution module instance.
`ConvlutionModule` instance can be used as the argument.
dropout_rate (float): Dropout rate.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
if True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
if False, no additional linear will be applied. i.e. x -> x + att(x)
"""
def __init__(
self,
size,
self_attn_cros_channel,
self_attn_conformer,
feed_forward_csa,
feed_forward_macaron_csa,
conv_module_csa,
dropout_rate,
normalize_before=True,
concat_after=False,
):
"""Construct an EncoderLayer object."""
super(EncoderLayer, self).__init__()
self.encoder_cros_channel_atten = self_attn_cros_channel
self.encoder_csa = Encoder_Conformer_Layer(
size,
self_attn_conformer,
feed_forward_csa,
feed_forward_macaron_csa,
conv_module_csa,
dropout_rate,
normalize_before,
concat_after,
cca_pos=0)
self.norm_mha = LayerNorm(size) # for the MHA module
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x_input, mask, channel_size, cache=None):
"""Compute encoded features.
Args:
x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
- w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
- w/o pos emb: Tensor (#batch, time, size).
mask (torch.Tensor): Mask tensor for the input (#batch, time).
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
Returns:
torch.Tensor: Output tensor (#batch, time, size).
torch.Tensor: Mask tensor (#batch, time).
"""
if isinstance(x_input, tuple):
x, pos_emb = x_input[0], x_input[1]
else:
x, pos_emb = x_input, None
residual = x
x = self.norm_mha(x)
t_leng = x.size(1)
d_dim = x.size(2)
x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
x_pad = torch.cat([pad_before,x_new, pad_after], 1)
x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
x_new = x_new.reshape(-1,channel_size,d_dim)
x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
x = residual + self.dropout(x_att)
if pos_emb is not None:
x_input = (x, pos_emb)
else:
x_input = x
x_input, mask = self.encoder_csa(x_input, mask)
return x_input, mask , channel_size

View File

@ -0,0 +1,450 @@
from typing import Optional
from typing import Tuple
import logging
import torch
from torch import nn
from typeguard import check_argument_types
from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.modules.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
from funasr.modules.repeat import repeat
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
import pdb
import math
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.pointwise_conv1 = nn.Conv1d(
channels,
2 * channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
)
self.activation = activation
def forward(self, x):
"""Compute convolution module.
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
# GLU mechanism
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
x = self.pointwise_conv2(x)
return x.transpose(1, 2)
class MFCCAEncoder(AbsEncoder):
"""Conformer encoder module.
Args:
input_size (int): Input dimension.
output_size (int): Dimention of attention.
attention_heads (int): The number of heads of multi head attention.
linear_units (int): The number of units of position-wise feed forward.
num_blocks (int): The number of decoder blocks.
dropout_rate (float): Dropout rate.
attention_dropout_rate (float): Dropout rate in attention.
positional_dropout_rate (float): Dropout rate after adding positional encoding.
input_layer (Union[str, torch.nn.Module]): Input layer type.
normalize_before (bool): Whether to use layer_norm before the first block.
concat_after (bool): Whether to concat attention layer's input and output.
If True, additional linear will be applied.
i.e. x -> x + linear(concat(x, att(x)))
If False, no additional linear will be applied. i.e. x -> x + att(x)
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
rel_pos_type (str): Whether to use the latest relative positional encoding or
the legacy one. The legacy relative positional encoding will be deprecated
in the future. More Details can be found in
https://github.com/espnet/espnet/pull/2816.
encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
encoder_attn_layer_type (str): Encoder attention layer type.
activation_type (str): Encoder activation function type.
macaron_style (bool): Whether to use macaron style for positionwise layer.
use_cnn_module (bool): Whether to use convolution module.
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
"""
def __init__(
self,
input_size: int,
output_size: int = 256,
attention_heads: int = 4,
linear_units: int = 2048,
num_blocks: int = 6,
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
input_layer: str = "conv2d",
normalize_before: bool = True,
concat_after: bool = False,
positionwise_layer_type: str = "linear",
positionwise_conv_kernel_size: int = 3,
macaron_style: bool = False,
rel_pos_type: str = "legacy",
pos_enc_layer_type: str = "rel_pos",
selfattention_layer_type: str = "rel_selfattn",
activation_type: str = "swish",
use_cnn_module: bool = True,
zero_triu: bool = False,
cnn_module_kernel: int = 31,
padding_idx: int = -1,
):
assert check_argument_types()
super().__init__()
self._output_size = output_size
if rel_pos_type == "legacy":
if pos_enc_layer_type == "rel_pos":
pos_enc_layer_type = "legacy_rel_pos"
if selfattention_layer_type == "rel_selfattn":
selfattention_layer_type = "legacy_rel_selfattn"
elif rel_pos_type == "latest":
assert selfattention_layer_type != "legacy_rel_selfattn"
assert pos_enc_layer_type != "legacy_rel_pos"
else:
raise ValueError("unknown rel_pos_type: " + rel_pos_type)
activation = get_activation(activation_type)
if pos_enc_layer_type == "abs_pos":
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "scaled_abs_pos":
pos_enc_class = ScaledPositionalEncoding
elif pos_enc_layer_type == "rel_pos":
assert selfattention_layer_type == "rel_selfattn"
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type == "legacy_rel_pos":
assert selfattention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
logging.warning(
"Using legacy_rel_pos and it will be deprecated in the future."
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
torch.nn.LayerNorm(output_size),
torch.nn.Dropout(dropout_rate),
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d6":
self.embed = Conv2dSubsampling6(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d8":
self.embed = Conv2dSubsampling8(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "embed":
self.embed = torch.nn.Sequential(
torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
pos_enc_class(output_size, positional_dropout_rate),
)
elif isinstance(input_layer, torch.nn.Module):
self.embed = torch.nn.Sequential(
input_layer,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer is None:
self.embed = torch.nn.Sequential(
pos_enc_class(output_size, positional_dropout_rate)
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
if positionwise_layer_type == "linear":
positionwise_layer = PositionwiseFeedForward
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
elif positionwise_layer_type == "conv1d":
positionwise_layer = MultiLayeredConv1d
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
elif positionwise_layer_type == "conv1d-linear":
positionwise_layer = Conv1dLinear
positionwise_layer_args = (
output_size,
linear_units,
positionwise_conv_kernel_size,
dropout_rate,
)
else:
raise NotImplementedError("Support only linear or conv1d.")
if selfattention_layer_type == "selfattn":
encoder_selfattn_layer = MultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
elif selfattention_layer_type == "legacy_rel_selfattn":
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
)
logging.warning(
"Using legacy_rel_selfattn and it will be deprecated in the future."
)
elif selfattention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
zero_triu,
)
else:
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
convolution_layer = ConvolutionModule
convolution_layer_args = (output_size, cnn_module_kernel, activation)
encoder_selfattn_layer_raw = MultiHeadedAttention
encoder_selfattn_layer_args_raw = (
attention_heads,
output_size,
attention_dropout_rate,
)
self.encoders = repeat(
num_blocks,
lambda lnum: EncoderLayer(
output_size,
encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
encoder_selfattn_layer(*encoder_selfattn_layer_args),
positionwise_layer(*positionwise_layer_args),
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
),
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
def output_size(self) -> int:
return self._output_size
def forward(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
channel_size: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
t_leng = xs_pad.size(1)
d_dim = xs_pad.size(2)
xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
#pdb.set_trace()
if(channel_size<8):
repeat_num = math.ceil(8/channel_size)
xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
xs_pad = self.conv1(xs_pad)
xs_pad = self.conv2(xs_pad)
xs_pad = self.conv3(xs_pad)
xs_pad = self.conv4(xs_pad)
xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
mask_tmp = masks.size(1)
masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
def forward_hidden(
self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
raise TooShortUttError(
f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ f"(it needs more than {limit_size} frames), return empty results",
xs_pad.size(1),
limit_size,
)
xs_pad, masks = self.embed(xs_pad, masks)
else:
xs_pad = self.embed(xs_pad)
num_layer = len(self.encoders)
for idx, encoder in enumerate(self.encoders):
xs_pad, masks = encoder(xs_pad, masks)
if idx == num_layer // 2 - 1:
hidden_feature = xs_pad
if isinstance(xs_pad, tuple):
xs_pad = xs_pad[0]
hidden_feature = hidden_feature[0]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
self.hidden_feature = self.after_norm(hidden_feature)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None

View File

@ -131,3 +131,128 @@ class DefaultFrontend(AbsFrontend):
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
def __init__(
self,
fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
window: Optional[str] = "hann",
center: bool = True,
normalized: bool = False,
onesided: bool = True,
n_mels: int = 80,
fmin: int = None,
fmax: int = None,
htk: bool = False,
frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
apply_stft: bool = True,
frame_length: int = None,
frame_shift: int = None,
lfr_m: int = None,
lfr_n: int = None,
):
assert check_argument_types()
super().__init__()
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
if apply_stft:
self.stft = Stft(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=center,
window=window,
normalized=normalized,
onesided=onesided,
)
else:
self.stft = None
self.apply_stft = apply_stft
if frontend_conf is not None:
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
else:
self.frontend = None
self.logmel = LogMel(
fs=fs,
n_fft=n_fft,
n_mels=n_mels,
fmin=fmin,
fmax=fmax,
htk=htk,
)
self.n_mels = n_mels
self.frontend_type = "multichannelfrontend"
def output_size(self) -> int:
return self.n_mels
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Domain-conversion: e.g. Stft: time -> time-freq
#import pdb;pdb.set_trace()
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
else:
if isinstance(input, ComplexTensor):
input_stft = input
else:
input_stft = ComplexTensor(input[..., 0], input[..., 1])
feats_lens = input_lengths
# 2. [Option] Speech enhancement
if self.frontend is not None:
assert isinstance(input_stft, ComplexTensor), type(input_stft)
# input_stft: (Batch, Length, [Channel], Freq)
input_stft, _, mask = self.frontend(input_stft, feats_lens)
# 4. STFT -> Power spectrum
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
input_power = input_stft.real ** 2 + input_stft.imag ** 2
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
# input_power: (Batch, [Channel,] Length, Freq)
# -> input_feats: (Batch, Length, Dim)
input_feats, _ = self.logmel(input_power, feats_lens)
bt = input_feats.size(0)
if input_feats.dim() ==4:
channel_size = input_feats.size(2)
# batch * channel * T * D
#pdb.set_trace()
input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
# input_feats = input_feats.transpose(1,2)
# batch * channel
feats_lens = feats_lens.repeat(1,channel_size).squeeze()
else:
channel_size = 1
return input_feats, feats_lens, channel_size
def _compute_stft(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> torch.Tensor:
input_stft, feats_lens = self.stft(input, input_lengths)
assert input_stft.dim() >= 4, input_stft.shape
# "2" refers to the real/imag parts of Complex
assert input_stft.shape[-1] == 2, input_stft.shape
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
return input_stft, feats_lens

View File

@ -40,6 +40,7 @@ from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
@ -47,8 +48,10 @@ from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend
@ -86,6 +89,7 @@ frontend_choices = ClassChoices(
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
multichannelfrontend=MultiChannelFrontend,
),
type_check=AbsFrontend,
default="default",
@ -119,6 +123,7 @@ model_choices = ClassChoices(
paraformer_bert=ParaformerBert,
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
mfcca=MFCCA,
),
type_check=AbsESPnetModel,
default="asr",
@ -142,6 +147,7 @@ encoder_choices = ClassChoices(
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
),
type_check=AbsEncoder,
default="rnn",
@ -1106,3 +1112,135 @@ class ASRTaskParaformer(ASRTask):
var_dict_torch_update.update(var_dict_torch_update_local)
return var_dict_torch_update
class ASRTaskMFCCA(ASRTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --decoder and --decoder_conf
decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer
@classmethod
def build_model(cls, args: argparse.Namespace):
assert check_argument_types()
if isinstance(args.token_list, str):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
# Overwriting token_list to keep it as "portable".
args.token_list = list(token_list)
elif isinstance(args.token_list, (tuple, list)):
token_list = list(args.token_list)
else:
raise RuntimeError("token_list must be str or list")
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# 2. Data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# 3. Normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
else:
normalize = None
# 4. Pre-encoder input block
# NOTE(kan-bayashi): Use getattr to keep the compatibility
if getattr(args, "preencoder", None) is not None:
preencoder_class = preencoder_choices.get_class(args.preencoder)
preencoder = preencoder_class(**args.preencoder_conf)
input_size = preencoder.output_size()
else:
preencoder = None
# 5. Encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
# 7. Decoder
decoder_class = decoder_choices.get_class(args.decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder.output_size(),
**args.decoder_conf,
)
# 8. CTC
ctc = CTC(
odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
)
# 10. Build model
try:
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("asr")
rnnt_decoder = None
# 8. Build model
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
decoder=decoder,
ctc=ctc,
rnnt_decoder=rnnt_decoder,
token_list=token_list,
**args.model_conf,
)
# 11. Initialize
if args.init is not None:
initialize(model, args.init)
assert check_return_type(model)
return model