Merge branch 'main' into dev_zly

This commit is contained in:
zhifu gao 2023-02-17 10:31:05 +08:00 committed by GitHub
commit 6baf10d5d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 193 additions and 119 deletions

View File

@ -17,21 +17,23 @@
## What's new:
### 2023.2.16, funasr-0.2.0
- We support a new feature, export paraformer models into [onnx and torchscripts](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) from modelscopes. The local finetuned models are also supported.
- We support a new feature, [onnxruntime](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer), you could deploy the runtime without modelscope or funasr, for the [paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) model, the rtf of onnxruntime is 3x speedup(0.110->0.038) on cpu.
- We support e new feature, [grpc](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc), you could build the ASR service with grpc, by deploying the modelscope pipeline or onnxruntime.
### 2023.2.17, funasr-0.2.0, modelscope-1.3.0
- We support a new feature, export paraformer models into [onnx and torchscripts](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export) from modelscope. The local finetuned models are also supported.
- We support a new feature, [onnxruntime](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer), you could deploy the runtime without modelscope or funasr, for the [paraformer-large](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) model, the rtf of onnxruntime is 3x speedup(0.110->0.038) on cpu, [details](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer#speed).
- We support a new feature, [grpc](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/grpc), you could build the ASR service with grpc, by deploying the modelscope pipeline or onnxruntime.
- We release a new model [paraformer-large-contextual](https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary), which supports the hotword customization based on the incentive enhancement, and improves the recall and precision of hotwords.
- We optimize the timestamp alignment of [Paraformer-large-long](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary), the prediction accuracy of timestamp is much improved, and achieving accumulated average shift (aas) of 74.7ms, [details](https://arxiv.org/abs/2301.12343).
- We release a new model, [8k VAD model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary), which could predict the duration of none-silence speech. It could be freely integrated with any ASR models in [modelscope](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
- We release a new model, [MFCCA](https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary), a multi-channel multi-speaker model which is independent of the number and geometry of microphones and supports Mandarin meeting transcription.
- We release a new model, [MFCCA](https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary), a multi-channel multi-speaker model which is independent of the number and geometry of microphones and supports Mandarin meeting transcription.
- We release several new UniASR model:
[Southern Fujian Dialect model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/summary),
[Southern Fujian Dialect model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-minnan-16k-common-vocab3825/summary),
[French model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fr-16k-common-vocab3472-tensorflow1-online/summary),
[German model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-de-16k-common-vocab3690-tensorflow1-online/summary),
[Vietnamese model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-vi-16k-common-vocab1001-pytorch-online/summary),
[Persian model](https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-fa-16k-common-vocab1257-pytorch-online/summary).
- We release a new model, [paraformer-data2vec model](https://www.modelscope.cn/models/damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/summary), an unsupervised pretraining model on AISHELL-2, which is inited for paraformer model and then finetune on AISHEL-1.
### 2023.1.16, funasr-0.1.6
- Various new types of audio input types are now supported by modelscope inference pipeline, including: mp3、flac、ogg、opus...
### 2023.1.16, funasr-0.1.6 modelscope-1.2.0
- We release a new version model [Paraformer-large-long](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary), which integrate the [VAD](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary) model, [ASR](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary),
[Punctuation](https://www.modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary) model and timestamp together. The model could take in several hours long inputs.
- We release a new model, [16k VAD model](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary), which could predict the duration of none-silence speech. It could be freely integrated with any ASR models in [modelscope](https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary).
@ -101,4 +103,10 @@ This project is licensed under the [The MIT License](https://opensource.org/lice
booktitle={INTERSPEECH},
year={2022}
}
@inproceedings{Shi2023AchievingTP,
title={Achieving Timestamp Prediction While Recognizing with Non-Autoregressive End-to-End ASR Model},
author={Xian Shi and Yanni Chen and Shiliang Zhang and Zhijie Yan},
booktitle={arXiv preprint arXiv:2301.12343}
year={2023}
}
```

View File

@ -1,5 +1,5 @@
# Paraformer-Large
- Model link: <https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
- Model link: <https://www.modelscope.cn/models/NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
- Model size: 45M
# Environments

View File

@ -24,12 +24,12 @@ def modelscope_finetune(params):
if __name__ == '__main__':
params = modelscope_args(model="yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
params = modelscope_args(model="NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
params.output_dir = "./checkpoint" # m模型保存路径
params.data_path = "./example_data/" # 数据路径
params.dataset_type = "small" # 小数据量设置small若数据量大于1000小时请使用large
params.batch_bins = 1000 # batch size如果dataset_type="small"batch_bins单位为fbank特征帧数如果dataset_type="large"batch_bins单位为毫秒
params.max_epoch = 10 # 最大训练轮数
params.lr = 0.0001 # 设置学习率
params.model_revision = 'v2.0.0'
params.model_revision = 'v1.0.0'
modelscope_finetune(params)

View File

@ -18,8 +18,8 @@ def modelscope_infer_core(output_dir, split_dir, njob, idx):
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
inference_pipline = pipeline(
task=Tasks.auto_speech_recognition,
model='yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
model_revision='v2.0.0',
model='NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
model_revision='v1.0.0',
output_dir=output_dir_job,
batch_size=1,
)

View File

@ -59,7 +59,7 @@ def modelscope_infer_after_finetune(params):
if __name__ == '__main__':
params = {}
params["modelscope_model_name"] = "yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
params["modelscope_model_name"] = "NPU-ASLP/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
params["output_dir"] = "./checkpoint"
params["data_dir"] = "./example_data/validation"

View File

@ -41,16 +41,7 @@ from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
header_colors = '\033[95m'
end_colors = '\033[0m'
global_asr_language: str = 'zh-cn'
global_sample_rate: Union[int, Dict[Any, int]] = {
'audio_fs': 16000,
'model_fs': 16000
}
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
class Speech2Text:
@ -346,6 +337,160 @@ class Speech2Text:
# assert check_return_type(results)
return results
class Speech2TextExport:
"""Speech2TextExport class
"""
def __init__(
self,
asr_train_config: Union[Path, str] = None,
asr_model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
lm_train_config: Union[Path, str] = None,
lm_file: Union[Path, str] = None,
token_type: str = None,
bpemodel: str = None,
device: str = "cpu",
maxlenratio: float = 0.0,
minlenratio: float = 0.0,
dtype: str = "float32",
beam_size: int = 20,
ctc_weight: float = 0.5,
lm_weight: float = 1.0,
ngram_weight: float = 0.9,
penalty: float = 0.0,
nbest: int = 1,
frontend_conf: dict = None,
hotword_list_or_file: str = None,
**kwargs,
):
# 1. Build ASR model
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
asr_model.to(dtype=getattr(torch, dtype)).eval()
token_list = asr_model.token_list
logging.info(f"Decoding device={device}, dtype={dtype}")
# 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
bpemodel = asr_train_args.bpemodel
if token_type is None:
tokenizer = None
elif token_type == "bpe":
if bpemodel is not None:
tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
else:
tokenizer = None
else:
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
# self.asr_model = asr_model
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
self.device = device
self.dtype = dtype
self.nbest = nbest
self.frontend = frontend
model = Paraformer_export(asr_model, onnx=False)
self.asr_model = model
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
):
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
self.asr_model.frontend = None
else:
feats = speech
feats_len = speech_lengths
enc_len_batch_total = feats_len.sum()
lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
decoder_outs = self.asr_model(**batch)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
results = []
b, n, d = decoder_out.size()
for i in range(b):
am_scores = decoder_out[i, :ys_pad_lens[i], :]
yseq = am_scores.argmax(dim=-1)
score = am_scores.max(dim=-1)[0]
score = torch.sum(score, dim=-1)
# pad with mask tokens to ensure compatibility with sos/eos tokens
yseq = torch.tensor(
yseq.tolist(), device=yseq.device
)
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
for hyp in nbest_hyps:
assert isinstance(hyp, (Hypothesis)), type(hyp)
# remove sos/eos and get results
last_pos = -1
if isinstance(hyp.yseq, list):
token_int = hyp.yseq[1:last_pos]
else:
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
if self.tokenizer is not None:
text = self.tokenizer.tokens2text(token)
else:
text = None
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
return results
def inference(
maxlenratio: float,
@ -454,9 +599,11 @@ def inference_modelscope(
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
export_mode = False
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
export_mode = param_dict.get("export_mode", False)
else:
hotword_list_or_file = None
@ -490,7 +637,10 @@ def inference_modelscope(
nbest=nbest,
hotword_list_or_file=hotword_list_or_file,
)
speech2text = Speech2Text(**speech2text_kwargs)
if export_mode:
speech2text = Speech2TextExport(**speech2text_kwargs)
else:
speech2text = Speech2Text(**speech2text_kwargs)
def _forward(
data_path_and_name_and_type,

View File

@ -38,7 +38,6 @@ from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6
from funasr.bin.punctuation_infer import Text2Punc
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2Text
from funasr.bin.asr_inference_paraformer_vad_punc import Speech2VadSegment

View File

@ -39,7 +39,7 @@ from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer
@ -282,13 +282,10 @@ class Speech2Text:
else:
text = None
if isinstance(self.asr_model, BiCifParaformer):
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
else:
time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time,
end_time)
results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
@ -636,7 +633,8 @@ def inference_modelscope(
text, token, token_int = result[0], result[1], result[2]
time_stamp = None if len(result) < 4 else result[3]
if use_timestamp and time_stamp is not None:
if use_timestamp and time_stamp is not None:
postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
else:
postprocessed_result = postprocess_utils.sentence_postprocess(token)

View File

@ -44,6 +44,7 @@ class ASRModelExportParaformer:
model,
self.export_config,
)
model.eval()
# self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)

View File

@ -41,8 +41,8 @@ class Paraformer():
self.ort_infer = OrtInferSession(model_file, device_id)
self.batch_size = batch_size
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
waveform_list = self.load_data(wav_content, fs)
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []

View File

@ -4,88 +4,6 @@ import logging
import numpy as np
from typing import Any, List, Tuple, Union
def cut_interval(alphas: torch.Tensor, start: int, end: int, tail: bool):
if not tail:
if end == start + 1:
cut = (end + start) / 2.0
else:
alpha = alphas[start+1: end].tolist()
reverse_steps = 1
for reverse_alpha in alpha[::-1]:
if reverse_alpha > 0.35:
reverse_steps += 1
else:
break
cut = end - reverse_steps
else:
if end != len(alphas) - 1:
cut = end + 1
else:
cut = start + 1
return float(cut)
def time_stamp_lfr6(alphas: torch.Tensor, speech_lengths: torch.Tensor, raw_text: List[str], begin: int = 0, end: int = None):
time_stamp_list = []
alphas = alphas[0]
text = copy.deepcopy(raw_text)
if end is None:
time = speech_lengths * 60 / 1000
sacle_rate = (time / speech_lengths[0]).tolist()
else:
time = (end - begin) / 1000
sacle_rate = (time / speech_lengths[0]).tolist()
predictor = (alphas > 0.5).int()
fire_places = torch.nonzero(predictor == 1).squeeze(1).tolist()
cuts = []
npeak = int(predictor.sum())
nchar = len(raw_text)
if npeak - 1 == nchar:
fire_places = torch.where((alphas > 0.5) == 1)[0].tolist()
for i in range(len(fire_places)):
if fire_places[i] < len(alphas) - 1:
if 0.05 < alphas[fire_places[i]+1] < 0.5:
fire_places[i] += 1
elif npeak < nchar:
lost_num = nchar - npeak
lost_fire = speech_lengths[0].tolist() - fire_places[-1]
interval_distance = lost_fire // (lost_num + 1)
for i in range(1, lost_num + 1):
fire_places.append(fire_places[-1] + interval_distance)
elif npeak - 1 > nchar:
redundance_num = npeak - 1 - nchar
for i in range(redundance_num):
fire_places.pop()
cuts.append(0)
start_sil = True
if start_sil:
text.insert(0, '<sil>')
for i in range(len(fire_places)-1):
cuts.append(cut_interval(alphas, fire_places[i], fire_places[i+1], tail=(i==len(fire_places)-2)))
for i in range(2, len(fire_places)-2):
if fire_places[i-2] == fire_places[i-1] - 1 and fire_places[i-1] != fire_places[i] - 1:
cuts[i-1] += 1
if cuts[-1] != len(alphas) - 1:
text.append('<sil>')
cuts.append(speech_lengths[0].tolist())
cuts.insert(-1, (cuts[-1] + cuts[-2]) * 0.5)
sec_fire_places = np.array(cuts) * sacle_rate
for i in range(1, len(sec_fire_places) - 1):
start, end = sec_fire_places[i], sec_fire_places[i+1]
if i == len(sec_fire_places) - 2:
end = time
time_stamp_list.append([int(round(start, 2) * 1000) + begin, int(round(end, 2) * 1000) + begin])
text = text[1:]
if npeak - 1 == nchar or npeak > nchar:
return time_stamp_list[:-1]
else:
return time_stamp_list
def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
START_END_THRESHOLD = 5
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled