mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge pull request #2 from alibaba-damo-academy/main
merge from official
This commit is contained in:
commit
d836e403ad
Binary file not shown.
|
Before Width: | Height: | Size: 8.1 KiB After Width: | Height: | Size: 53 KiB |
@ -5,7 +5,7 @@ inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助
|
|||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
inference_pipline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
task=Tasks.punctuation,
|
task=Tasks.punctuation,
|
||||||
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
|
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
|
||||||
model_revision="v1.0.0",
|
model_revision="v1.0.0",
|
||||||
@ -17,7 +17,7 @@ vads = inputs.split("|")
|
|||||||
cache_out = []
|
cache_out = []
|
||||||
rec_result_all="outputs:"
|
rec_result_all="outputs:"
|
||||||
for vad in vads:
|
for vad in vads:
|
||||||
rec_result = inference_pipline(text_in=vad, cache=cache_out)
|
rec_result = inference_pipeline(text_in=vad, cache=cache_out)
|
||||||
#print(rec_result)
|
#print(rec_result)
|
||||||
cache_out = rec_result['cache']
|
cache_out = rec_result['cache']
|
||||||
rec_result_all += rec_result['text']
|
rec_result_all += rec_result['text']
|
||||||
|
|||||||
@ -0,0 +1,26 @@
|
|||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
# 初始化推理 pipeline
|
||||||
|
# 当以原始音频作为输入时使用配置文件 sond.yaml,并设置 mode 为sond_demo
|
||||||
|
inference_diar_pipline = pipeline(
|
||||||
|
mode="sond_demo",
|
||||||
|
num_workers=0,
|
||||||
|
task=Tasks.speaker_diarization,
|
||||||
|
diar_model_config="sond.yaml",
|
||||||
|
model='damo/speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch',
|
||||||
|
sv_model="damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch",
|
||||||
|
sv_model_revision="master",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 以 audio_list 作为输入,其中第一个音频为待检测语音,后面的音频为不同说话人的声纹注册语音
|
||||||
|
audio_list = [[
|
||||||
|
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav",
|
||||||
|
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_A.wav",
|
||||||
|
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B.wav",
|
||||||
|
"https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/spk_B1.wav"
|
||||||
|
]]
|
||||||
|
|
||||||
|
results = inference_diar_pipline(audio_in=audio_list)
|
||||||
|
for rst in results:
|
||||||
|
print(rst["value"])
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
# ModelScope Model
|
||||||
|
|
||||||
|
## How to finetune and infer using a pretrained ModelScope Model
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
|
||||||
|
Or you can use the finetuned model for inference directly.
|
||||||
|
|
||||||
|
- Setting parameters in `infer.py`
|
||||||
|
- <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
|
||||||
|
- <strong>text_in:</strong> # support text, text url.
|
||||||
|
- <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
|
||||||
|
|
||||||
|
- Then you can run the pipeline to infer with:
|
||||||
|
```python
|
||||||
|
python infer.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Modify inference related parameters in vad.yaml.
|
||||||
|
|
||||||
|
- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
|
||||||
|
- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
|
||||||
|
- The value tends to -1, the greater probability of noise being judged as speech
|
||||||
|
- The value tends to 1, the greater probability of speech being judged as noise
|
||||||
@ -0,0 +1,12 @@
|
|||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
inference_pipline = pipeline(
|
||||||
|
task=Tasks.speech_timestamp,
|
||||||
|
model='damo/speech_timestamp_prediction-v1-16k-offline',
|
||||||
|
output_dir='./tmp')
|
||||||
|
|
||||||
|
rec_result = inference_pipline(
|
||||||
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
|
||||||
|
text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢')
|
||||||
|
print(rec_result)
|
||||||
@ -7,7 +7,7 @@ if __name__ == '__main__':
|
|||||||
inference_pipline = pipeline(
|
inference_pipline = pipeline(
|
||||||
task=Tasks.voice_activity_detection,
|
task=Tasks.voice_activity_detection,
|
||||||
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
model_revision=None,
|
model_revision='v1.2.0',
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -0,0 +1,33 @@
|
|||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
output_dir = None
|
||||||
|
inference_pipline = pipeline(
|
||||||
|
task=Tasks.voice_activity_detection,
|
||||||
|
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
|
model_revision='v1.2.0',
|
||||||
|
output_dir=output_dir,
|
||||||
|
batch_size=1,
|
||||||
|
mode='online',
|
||||||
|
)
|
||||||
|
speech, sample_rate = soundfile.read("./vad_example_16k.wav")
|
||||||
|
speech_length = speech.shape[0]
|
||||||
|
|
||||||
|
sample_offset = 0
|
||||||
|
|
||||||
|
step = 160 * 10
|
||||||
|
param_dict = {'in_cache': dict()}
|
||||||
|
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
||||||
|
if sample_offset + step >= speech_length - 1:
|
||||||
|
step = speech_length - sample_offset
|
||||||
|
is_final = True
|
||||||
|
else:
|
||||||
|
is_final = False
|
||||||
|
param_dict['is_final'] = is_final
|
||||||
|
segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
|
||||||
|
param_dict=param_dict)
|
||||||
|
print(segments_result)
|
||||||
|
|
||||||
@ -7,8 +7,8 @@ if __name__ == '__main__':
|
|||||||
inference_pipline = pipeline(
|
inference_pipline = pipeline(
|
||||||
task=Tasks.voice_activity_detection,
|
task=Tasks.voice_activity_detection,
|
||||||
model="damo/speech_fsmn_vad_zh-cn-8k-common",
|
model="damo/speech_fsmn_vad_zh-cn-8k-common",
|
||||||
model_revision=None,
|
model_revision='v1.2.0',
|
||||||
output_dir='./output_dir',
|
output_dir=output_dir,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
)
|
)
|
||||||
segments_result = inference_pipline(audio_in=audio_in)
|
segments_result = inference_pipline(audio_in=audio_in)
|
||||||
|
|||||||
@ -0,0 +1,33 @@
|
|||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
output_dir = None
|
||||||
|
inference_pipline = pipeline(
|
||||||
|
task=Tasks.voice_activity_detection,
|
||||||
|
model="damo/speech_fsmn_vad_zh-cn-8k-common",
|
||||||
|
model_revision='v1.2.0',
|
||||||
|
output_dir=output_dir,
|
||||||
|
batch_size=1,
|
||||||
|
mode='online',
|
||||||
|
)
|
||||||
|
speech, sample_rate = soundfile.read("./vad_example_8k.wav")
|
||||||
|
speech_length = speech.shape[0]
|
||||||
|
|
||||||
|
sample_offset = 0
|
||||||
|
|
||||||
|
step = 80 * 10
|
||||||
|
param_dict = {'in_cache': dict()}
|
||||||
|
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
||||||
|
if sample_offset + step >= speech_length - 1:
|
||||||
|
step = speech_length - sample_offset
|
||||||
|
is_final = True
|
||||||
|
else:
|
||||||
|
is_final = False
|
||||||
|
param_dict['is_final'] = is_final
|
||||||
|
segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
|
||||||
|
param_dict=param_dict)
|
||||||
|
print(segments_result)
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
|||||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||||
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||||
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence
|
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||||
|
|
||||||
|
|
||||||
class Speech2Text:
|
class Speech2Text:
|
||||||
@ -245,7 +245,7 @@ class Speech2Text:
|
|||||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||||
|
|
||||||
if isinstance(self.asr_model, BiCifParaformer):
|
if isinstance(self.asr_model, BiCifParaformer):
|
||||||
_, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
||||||
pre_token_length) # test no bias cif2
|
pre_token_length) # test no bias cif2
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@ -291,7 +291,10 @@ class Speech2Text:
|
|||||||
text = None
|
text = None
|
||||||
|
|
||||||
if isinstance(self.asr_model, BiCifParaformer):
|
if isinstance(self.asr_model, BiCifParaformer):
|
||||||
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
|
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
|
||||||
|
us_peaks[i],
|
||||||
|
copy.copy(token),
|
||||||
|
vad_offset=begin_time)
|
||||||
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
|
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
|
||||||
else:
|
else:
|
||||||
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
|
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
|
||||||
|
|||||||
@ -44,11 +44,10 @@ from funasr.utils import asr_utils, wav_utils, postprocess_utils
|
|||||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
from funasr.tasks.vad import VADTask
|
from funasr.tasks.vad import VADTask
|
||||||
from funasr.bin.vad_inference import Speech2VadSegment
|
from funasr.bin.vad_inference import Speech2VadSegment
|
||||||
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
|
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
|
||||||
from funasr.bin.punctuation_infer import Text2Punc
|
from funasr.bin.punctuation_infer import Text2Punc
|
||||||
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
|
||||||
|
|
||||||
from funasr.utils.timestamp_tools import time_stamp_sentence
|
|
||||||
|
|
||||||
header_colors = '\033[95m'
|
header_colors = '\033[95m'
|
||||||
end_colors = '\033[0m'
|
end_colors = '\033[0m'
|
||||||
@ -257,7 +256,7 @@ class Speech2Text:
|
|||||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||||
|
|
||||||
if isinstance(self.asr_model, BiCifParaformer):
|
if isinstance(self.asr_model, BiCifParaformer):
|
||||||
_, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
_, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
|
||||||
pre_token_length) # test no bias cif2
|
pre_token_length) # test no bias cif2
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@ -303,7 +302,10 @@ class Speech2Text:
|
|||||||
text = None
|
text = None
|
||||||
|
|
||||||
if isinstance(self.asr_model, BiCifParaformer):
|
if isinstance(self.asr_model, BiCifParaformer):
|
||||||
timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
|
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
|
||||||
|
us_peaks[i],
|
||||||
|
copy.copy(token),
|
||||||
|
vad_offset=begin_time)
|
||||||
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
|
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
|
||||||
else:
|
else:
|
||||||
results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
|
results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
|
||||||
|
|||||||
@ -28,7 +28,9 @@ def parse_args(mode):
|
|||||||
elif mode == "uniasr":
|
elif mode == "uniasr":
|
||||||
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
||||||
elif mode == "mfcca":
|
elif mode == "mfcca":
|
||||||
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
||||||
|
elif mode == "tp":
|
||||||
|
from funasr.tasks.asr import ASRTaskAligner as ASRTask
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mode: {}".format(mode))
|
raise ValueError("Unknown mode: {}".format(mode))
|
||||||
parser = ASRTask.get_parser()
|
parser = ASRTask.get_parser()
|
||||||
|
|||||||
413
funasr/bin/eend_ola_inference.py
Executable file
413
funasr/bin/eend_ola_inference.py
Executable file
@ -0,0 +1,413 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||||
|
from funasr.tasks.diar import EENDOLADiarTask
|
||||||
|
from funasr.torch_utils.device_funcs import to_device
|
||||||
|
from funasr.utils import config_argparse
|
||||||
|
from funasr.utils.cli_utils import get_commandline_args
|
||||||
|
from funasr.utils.types import str2bool
|
||||||
|
from funasr.utils.types import str2triple_str
|
||||||
|
from funasr.utils.types import str_or_none
|
||||||
|
|
||||||
|
|
||||||
|
class Speech2Diarization:
|
||||||
|
"""Speech2Diarlization class
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import soundfile
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
|
||||||
|
>>> profile = np.load("profiles.npy")
|
||||||
|
>>> audio, rate = soundfile.read("speech.wav")
|
||||||
|
>>> speech2diar(audio, profile)
|
||||||
|
{"spk1": [(int, int), ...], ...}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
diar_train_config: Union[Path, str] = None,
|
||||||
|
diar_model_file: Union[Path, str] = None,
|
||||||
|
device: str = "cpu",
|
||||||
|
dtype: str = "float32",
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
|
||||||
|
# 1. Build Diarization model
|
||||||
|
diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
|
||||||
|
config_file=diar_train_config,
|
||||||
|
model_file=diar_model_file,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
frontend = None
|
||||||
|
if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
|
||||||
|
frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
|
||||||
|
|
||||||
|
# set up seed for eda
|
||||||
|
np.random.seed(diar_train_args.seed)
|
||||||
|
torch.manual_seed(diar_train_args.seed)
|
||||||
|
torch.cuda.manual_seed(diar_train_args.seed)
|
||||||
|
os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
|
||||||
|
logging.info("diar_model: {}".format(diar_model))
|
||||||
|
logging.info("diar_train_args: {}".format(diar_train_args))
|
||||||
|
diar_model.to(dtype=getattr(torch, dtype)).eval()
|
||||||
|
|
||||||
|
self.diar_model = diar_model
|
||||||
|
self.diar_train_args = diar_train_args
|
||||||
|
self.device = device
|
||||||
|
self.dtype = dtype
|
||||||
|
self.frontend = frontend
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
speech: Union[torch.Tensor, np.ndarray],
|
||||||
|
speech_lengths: Union[torch.Tensor, np.ndarray] = None
|
||||||
|
):
|
||||||
|
"""Inference
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech: Input speech data
|
||||||
|
Returns:
|
||||||
|
diarization results
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
|
# Input as audio signal
|
||||||
|
if isinstance(speech, np.ndarray):
|
||||||
|
speech = torch.tensor(speech)
|
||||||
|
|
||||||
|
if self.frontend is not None:
|
||||||
|
feats, feats_len = self.frontend.forward(speech, speech_lengths)
|
||||||
|
feats = to_device(feats, device=self.device)
|
||||||
|
feats_len = feats_len.int()
|
||||||
|
self.diar_model.frontend = None
|
||||||
|
else:
|
||||||
|
feats = speech
|
||||||
|
feats_len = speech_lengths
|
||||||
|
batch = {"speech": feats, "speech_lengths": feats_len}
|
||||||
|
batch = to_device(batch, device=self.device)
|
||||||
|
results = self.diar_model.estimate_sequential(**batch)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
model_tag: Optional[str] = None,
|
||||||
|
**kwargs: Optional[Any],
|
||||||
|
):
|
||||||
|
"""Build Speech2Diarization instance from the pretrained model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_tag (Optional[str]): Model tag of the pretrained models.
|
||||||
|
Currently, the tags of espnet_model_zoo are supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Speech2Diarization: Speech2Diarization instance.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if model_tag is not None:
|
||||||
|
try:
|
||||||
|
from espnet_model_zoo.downloader import ModelDownloader
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logging.error(
|
||||||
|
"`espnet_model_zoo` is not installed. "
|
||||||
|
"Please install via `pip install -U espnet_model_zoo`."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
d = ModelDownloader()
|
||||||
|
kwargs.update(**d.download_and_unpack(model_tag))
|
||||||
|
|
||||||
|
return Speech2Diarization(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def inference_modelscope(
|
||||||
|
diar_train_config: str,
|
||||||
|
diar_model_file: str,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
batch_size: int = 1,
|
||||||
|
dtype: str = "float32",
|
||||||
|
ngpu: int = 0,
|
||||||
|
num_workers: int = 0,
|
||||||
|
log_level: Union[int, str] = "INFO",
|
||||||
|
key_file: Optional[str] = None,
|
||||||
|
model_tag: Optional[str] = None,
|
||||||
|
allow_variable_data_keys: bool = True,
|
||||||
|
streaming: bool = False,
|
||||||
|
param_dict: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
if batch_size > 1:
|
||||||
|
raise NotImplementedError("batch decoding is not implemented")
|
||||||
|
if ngpu > 1:
|
||||||
|
raise NotImplementedError("only single GPU decoding is supported")
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||||
|
)
|
||||||
|
logging.info("param_dict: {}".format(param_dict))
|
||||||
|
|
||||||
|
if ngpu >= 1 and torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
# 1. Build speech2diar
|
||||||
|
speech2diar_kwargs = dict(
|
||||||
|
diar_train_config=diar_train_config,
|
||||||
|
diar_model_file=diar_model_file,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
|
||||||
|
speech2diar = Speech2Diarization.from_pretrained(
|
||||||
|
model_tag=model_tag,
|
||||||
|
**speech2diar_kwargs,
|
||||||
|
)
|
||||||
|
speech2diar.diar_model.eval()
|
||||||
|
|
||||||
|
def output_results_str(results: dict, uttid: str):
|
||||||
|
rst = []
|
||||||
|
mid = uttid.rsplit("-", 1)[0]
|
||||||
|
for key in results:
|
||||||
|
results[key] = [(x[0] / 100, x[1] / 100) for x in results[key]]
|
||||||
|
template = "SPEAKER {} 0 {:.2f} {:.2f} <NA> <NA> {} <NA> <NA>"
|
||||||
|
for spk, segs in results.items():
|
||||||
|
rst.extend([template.format(mid, st, ed, spk) for st, ed in segs])
|
||||||
|
|
||||||
|
return "\n".join(rst)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
|
||||||
|
raw_inputs: List[List[Union[np.ndarray, torch.Tensor, str, bytes]]] = None,
|
||||||
|
output_dir_v2: Optional[str] = None,
|
||||||
|
param_dict: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
# 2. Build data-iterator
|
||||||
|
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||||
|
if isinstance(raw_inputs, torch.Tensor):
|
||||||
|
raw_inputs = raw_inputs.numpy()
|
||||||
|
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||||
|
loader = EENDOLADiarTask.build_streaming_iterator(
|
||||||
|
data_path_and_name_and_type,
|
||||||
|
dtype=dtype,
|
||||||
|
batch_size=batch_size,
|
||||||
|
key_file=key_file,
|
||||||
|
num_workers=num_workers,
|
||||||
|
preprocess_fn=EENDOLADiarTask.build_preprocess_fn(speech2diar.diar_train_args, False),
|
||||||
|
collate_fn=EENDOLADiarTask.build_collate_fn(speech2diar.diar_train_args, False),
|
||||||
|
allow_variable_data_keys=allow_variable_data_keys,
|
||||||
|
inference=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Start for-loop
|
||||||
|
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||||
|
if output_path is not None:
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
output_writer = open("{}/result.txt".format(output_path), "w")
|
||||||
|
result_list = []
|
||||||
|
for keys, batch in loader:
|
||||||
|
assert isinstance(batch, dict), type(batch)
|
||||||
|
assert all(isinstance(s, str) for s in keys), keys
|
||||||
|
_bs = len(next(iter(batch.values())))
|
||||||
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||||
|
# batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
|
||||||
|
|
||||||
|
results = speech2diar(**batch)
|
||||||
|
# Only supporting batch_size==1
|
||||||
|
key, value = keys[0], output_results_str(results, keys[0])
|
||||||
|
item = {"key": key, "value": value}
|
||||||
|
result_list.append(item)
|
||||||
|
if output_path is not None:
|
||||||
|
output_writer.write(value)
|
||||||
|
output_writer.flush()
|
||||||
|
|
||||||
|
if output_path is not None:
|
||||||
|
output_writer.close()
|
||||||
|
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
return _forward
|
||||||
|
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
|
||||||
|
diar_train_config: Optional[str],
|
||||||
|
diar_model_file: Optional[str],
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
batch_size: int = 1,
|
||||||
|
dtype: str = "float32",
|
||||||
|
ngpu: int = 0,
|
||||||
|
seed: int = 0,
|
||||||
|
num_workers: int = 1,
|
||||||
|
log_level: Union[int, str] = "INFO",
|
||||||
|
key_file: Optional[str] = None,
|
||||||
|
model_tag: Optional[str] = None,
|
||||||
|
allow_variable_data_keys: bool = True,
|
||||||
|
streaming: bool = False,
|
||||||
|
smooth_size: int = 83,
|
||||||
|
dur_threshold: int = 10,
|
||||||
|
out_format: str = "vad",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
inference_pipeline = inference_modelscope(
|
||||||
|
diar_train_config=diar_train_config,
|
||||||
|
diar_model_file=diar_model_file,
|
||||||
|
output_dir=output_dir,
|
||||||
|
batch_size=batch_size,
|
||||||
|
dtype=dtype,
|
||||||
|
ngpu=ngpu,
|
||||||
|
seed=seed,
|
||||||
|
num_workers=num_workers,
|
||||||
|
log_level=log_level,
|
||||||
|
key_file=key_file,
|
||||||
|
model_tag=model_tag,
|
||||||
|
allow_variable_data_keys=allow_variable_data_keys,
|
||||||
|
streaming=streaming,
|
||||||
|
smooth_size=smooth_size,
|
||||||
|
dur_threshold=dur_threshold,
|
||||||
|
out_format=out_format,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return inference_pipeline(data_path_and_name_and_type, raw_inputs=None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = config_argparse.ArgumentParser(
|
||||||
|
description="Speaker verification/x-vector extraction",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note(kamo): Use '_' instead of '-' as separator.
|
||||||
|
# '-' is confusing if written in yaml.
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_level",
|
||||||
|
type=lambda x: x.upper(),
|
||||||
|
default="INFO",
|
||||||
|
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||||
|
help="The verbose level of logging",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--output_dir", type=str, required=False)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The number of gpus. 0 indicates CPU mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpuid_list",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The visible gpus",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
default="float32",
|
||||||
|
choices=["float16", "float32", "float64"],
|
||||||
|
help="Data type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The number of workers used for DataLoader",
|
||||||
|
)
|
||||||
|
|
||||||
|
group = parser.add_argument_group("Input data related")
|
||||||
|
group.add_argument(
|
||||||
|
"--data_path_and_name_and_type",
|
||||||
|
type=str2triple_str,
|
||||||
|
required=False,
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
group.add_argument("--key_file", type=str_or_none)
|
||||||
|
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||||
|
|
||||||
|
group = parser.add_argument_group("The model configuration related")
|
||||||
|
group.add_argument(
|
||||||
|
"--diar_train_config",
|
||||||
|
type=str,
|
||||||
|
help="diarization training configuration",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--diar_model_file",
|
||||||
|
type=str,
|
||||||
|
help="diarization model parameter file",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--dur_threshold",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="The threshold for short segments in number frames"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--smooth_size",
|
||||||
|
type=int,
|
||||||
|
default=83,
|
||||||
|
help="The smoothing window length in number frames"
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--model_tag",
|
||||||
|
type=str,
|
||||||
|
help="Pretrained model tag. If specify this option, *_train_config and "
|
||||||
|
"*_file will be overwritten",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The batch size for inference",
|
||||||
|
)
|
||||||
|
parser.add_argument("--streaming", type=str2bool, default=False)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main(cmd=None):
|
||||||
|
print(get_commandline_args(), file=sys.stderr)
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args(cmd)
|
||||||
|
kwargs = vars(args)
|
||||||
|
kwargs.pop("config", None)
|
||||||
|
logging.info("args: {}".format(kwargs))
|
||||||
|
if args.output_dir is None:
|
||||||
|
jobid, n_gpu = 1, 1
|
||||||
|
gpuid = args.gpuid_list.split(",")[jobid - 1]
|
||||||
|
else:
|
||||||
|
jobid = int(args.output_dir.split(".")[-1])
|
||||||
|
n_gpu = len(args.gpuid_list.split(","))
|
||||||
|
gpuid = args.gpuid_list.split(",")[(jobid - 1) % n_gpu]
|
||||||
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
|
||||||
|
results_list = inference(**kwargs)
|
||||||
|
for results in results_list:
|
||||||
|
print("{} {}".format(results["key"], results["value"]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -28,6 +28,8 @@ from funasr.utils.types import str2triple_str
|
|||||||
from funasr.utils.types import str_or_none
|
from funasr.utils.types import str_or_none
|
||||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
from funasr.text.token_id_converter import TokenIDConverter
|
from funasr.text.token_id_converter import TokenIDConverter
|
||||||
|
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||||
|
|
||||||
|
|
||||||
header_colors = '\033[95m'
|
header_colors = '\033[95m'
|
||||||
end_colors = '\033[0m'
|
end_colors = '\033[0m'
|
||||||
@ -38,61 +40,6 @@ global_sample_rate: Union[int, Dict[Any, int]] = {
|
|||||||
'model_fs': 16000
|
'model_fs': 16000
|
||||||
}
|
}
|
||||||
|
|
||||||
def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list):
|
|
||||||
START_END_THRESHOLD = 5
|
|
||||||
MAX_TOKEN_DURATION = 12
|
|
||||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
|
||||||
if len(us_cif_peak.shape) == 2:
|
|
||||||
alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
|
|
||||||
else:
|
|
||||||
alphas, cif_peak = us_alphas, us_cif_peak
|
|
||||||
num_frames = cif_peak.shape[0]
|
|
||||||
if char_list[-1] == '</s>':
|
|
||||||
char_list = char_list[:-1]
|
|
||||||
# char_list = [i for i in text]
|
|
||||||
timestamp_list = []
|
|
||||||
new_char_list = []
|
|
||||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
|
||||||
# so treat the frames between two peaks as the duration of the former token
|
|
||||||
fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 3.2 # total offset
|
|
||||||
num_peak = len(fire_place)
|
|
||||||
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
|
||||||
# begin silence
|
|
||||||
if fire_place[0] > START_END_THRESHOLD:
|
|
||||||
# char_list.insert(0, '<sil>')
|
|
||||||
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
|
|
||||||
new_char_list.append('<sil>')
|
|
||||||
# tokens timestamp
|
|
||||||
for i in range(len(fire_place)-1):
|
|
||||||
new_char_list.append(char_list[i])
|
|
||||||
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] < MAX_TOKEN_DURATION:
|
|
||||||
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
|
||||||
else:
|
|
||||||
# cut the duration to token and sil of the 0-weight frames last long
|
|
||||||
_split = fire_place[i] + MAX_TOKEN_DURATION
|
|
||||||
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
|
|
||||||
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
|
||||||
new_char_list.append('<sil>')
|
|
||||||
# tail token and end silence
|
|
||||||
# new_char_list.append(char_list[-1])
|
|
||||||
if num_frames - fire_place[-1] > START_END_THRESHOLD:
|
|
||||||
_end = (num_frames + fire_place[-1]) * 0.5
|
|
||||||
# _end = fire_place[-1]
|
|
||||||
timestamp_list[-1][1] = _end*TIME_RATE
|
|
||||||
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
|
|
||||||
new_char_list.append("<sil>")
|
|
||||||
else:
|
|
||||||
timestamp_list[-1][1] = num_frames*TIME_RATE
|
|
||||||
assert len(new_char_list) == len(timestamp_list)
|
|
||||||
res_str = ""
|
|
||||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
|
||||||
res_str += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
|
||||||
res = []
|
|
||||||
for char, timestamp in zip(new_char_list, timestamp_list):
|
|
||||||
if char != '<sil>':
|
|
||||||
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
|
|
||||||
return res_str, res
|
|
||||||
|
|
||||||
|
|
||||||
class SpeechText2Timestamp:
|
class SpeechText2Timestamp:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -315,7 +262,7 @@ def inference_modelscope(
|
|||||||
for batch_id in range(_bs):
|
for batch_id in range(_bs):
|
||||||
key = keys[batch_id]
|
key = keys[batch_id]
|
||||||
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
|
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
|
||||||
ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token)
|
ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
|
||||||
logging.warning(ts_str)
|
logging.warning(ts_str)
|
||||||
item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
|
item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
|
||||||
tp_result_list.append(item)
|
tp_result_list.append(item)
|
||||||
|
|||||||
345
funasr/bin/vad_inference_online.py
Normal file
345
funasr/bin/vad_inference_online.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
from typeguard import check_return_type
|
||||||
|
|
||||||
|
from funasr.fileio.datadir_writer import DatadirWriter
|
||||||
|
from funasr.tasks.vad import VADTask
|
||||||
|
from funasr.torch_utils.device_funcs import to_device
|
||||||
|
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
|
||||||
|
from funasr.utils import config_argparse
|
||||||
|
from funasr.utils.cli_utils import get_commandline_args
|
||||||
|
from funasr.utils.types import str2bool
|
||||||
|
from funasr.utils.types import str2triple_str
|
||||||
|
from funasr.utils.types import str_or_none
|
||||||
|
from funasr.models.frontend.wav_frontend import WavFrontendOnline
|
||||||
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
|
from funasr.bin.vad_inference import Speech2VadSegment
|
||||||
|
|
||||||
|
header_colors = '\033[95m'
|
||||||
|
end_colors = '\033[0m'
|
||||||
|
|
||||||
|
global_asr_language: str = 'zh-cn'
|
||||||
|
global_sample_rate: Union[int, Dict[Any, int]] = {
|
||||||
|
'audio_fs': 16000,
|
||||||
|
'model_fs': 16000
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||||
|
"""Speech2VadSegmentOnline class
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> import soundfile
|
||||||
|
>>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
|
||||||
|
>>> audio, rate = soundfile.read("speech.wav")
|
||||||
|
>>> speech2segment(audio)
|
||||||
|
[[10, 230], [245, 450], ...]
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(Speech2VadSegmentOnline, self).__init__(**kwargs)
|
||||||
|
vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
|
||||||
|
self.frontend = None
|
||||||
|
if self.vad_infer_args.frontend is not None:
|
||||||
|
self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
||||||
|
in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
|
||||||
|
"""Inference
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech: Input speech data
|
||||||
|
Returns:
|
||||||
|
text, token, token_int, hyp
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
|
|
||||||
|
# Input as audio signal
|
||||||
|
if isinstance(speech, np.ndarray):
|
||||||
|
speech = torch.tensor(speech)
|
||||||
|
batch_size = speech.shape[0]
|
||||||
|
segments = [[]] * batch_size
|
||||||
|
if self.frontend is not None:
|
||||||
|
feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
|
||||||
|
fbanks, _ = self.frontend.get_fbank()
|
||||||
|
else:
|
||||||
|
raise Exception("Need to extract feats first, please configure frontend configuration")
|
||||||
|
if feats.shape[0]:
|
||||||
|
feats = to_device(feats, device=self.device)
|
||||||
|
feats_len = feats_len.int()
|
||||||
|
waveforms = self.frontend.get_waveforms()
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"feats": feats,
|
||||||
|
"waveform": waveforms,
|
||||||
|
"in_cache": in_cache,
|
||||||
|
"is_final": is_final
|
||||||
|
}
|
||||||
|
# a. To device
|
||||||
|
batch = to_device(batch, device=self.device)
|
||||||
|
segments, in_cache = self.vad_model.forward_online(**batch)
|
||||||
|
# in_cache.update(batch['in_cache'])
|
||||||
|
# in_cache = {key: value for key, value in batch['in_cache'].items()}
|
||||||
|
return fbanks, segments, in_cache
|
||||||
|
|
||||||
|
|
||||||
|
def inference(
|
||||||
|
batch_size: int,
|
||||||
|
ngpu: int,
|
||||||
|
log_level: Union[int, str],
|
||||||
|
data_path_and_name_and_type,
|
||||||
|
vad_infer_config: Optional[str],
|
||||||
|
vad_model_file: Optional[str],
|
||||||
|
vad_cmvn_file: Optional[str] = None,
|
||||||
|
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||||
|
key_file: Optional[str] = None,
|
||||||
|
allow_variable_data_keys: bool = False,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
dtype: str = "float32",
|
||||||
|
seed: int = 0,
|
||||||
|
num_workers: int = 1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
inference_pipeline = inference_modelscope(
|
||||||
|
batch_size=batch_size,
|
||||||
|
ngpu=ngpu,
|
||||||
|
log_level=log_level,
|
||||||
|
vad_infer_config=vad_infer_config,
|
||||||
|
vad_model_file=vad_model_file,
|
||||||
|
vad_cmvn_file=vad_cmvn_file,
|
||||||
|
key_file=key_file,
|
||||||
|
allow_variable_data_keys=allow_variable_data_keys,
|
||||||
|
output_dir=output_dir,
|
||||||
|
dtype=dtype,
|
||||||
|
seed=seed,
|
||||||
|
num_workers=num_workers,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def inference_modelscope(
|
||||||
|
batch_size: int,
|
||||||
|
ngpu: int,
|
||||||
|
log_level: Union[int, str],
|
||||||
|
# data_path_and_name_and_type,
|
||||||
|
vad_infer_config: Optional[str],
|
||||||
|
vad_model_file: Optional[str],
|
||||||
|
vad_cmvn_file: Optional[str] = None,
|
||||||
|
# raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||||
|
key_file: Optional[str] = None,
|
||||||
|
allow_variable_data_keys: bool = False,
|
||||||
|
output_dir: Optional[str] = None,
|
||||||
|
dtype: str = "float32",
|
||||||
|
seed: int = 0,
|
||||||
|
num_workers: int = 1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
if batch_size > 1:
|
||||||
|
raise NotImplementedError("batch decoding is not implemented")
|
||||||
|
if ngpu > 1:
|
||||||
|
raise NotImplementedError("only single GPU decoding is supported")
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=log_level,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
if ngpu >= 1 and torch.cuda.is_available():
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
# 1. Set random-seed
|
||||||
|
set_all_random_seed(seed)
|
||||||
|
|
||||||
|
# 2. Build speech2vadsegment
|
||||||
|
speech2vadsegment_kwargs = dict(
|
||||||
|
vad_infer_config=vad_infer_config,
|
||||||
|
vad_model_file=vad_model_file,
|
||||||
|
vad_cmvn_file=vad_cmvn_file,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
|
||||||
|
speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
data_path_and_name_and_type,
|
||||||
|
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
|
||||||
|
output_dir_v2: Optional[str] = None,
|
||||||
|
fs: dict = None,
|
||||||
|
param_dict: dict = None,
|
||||||
|
):
|
||||||
|
# 3. Build data-iterator
|
||||||
|
if data_path_and_name_and_type is None and raw_inputs is not None:
|
||||||
|
if isinstance(raw_inputs, torch.Tensor):
|
||||||
|
raw_inputs = raw_inputs.numpy()
|
||||||
|
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
|
||||||
|
loader = VADTask.build_streaming_iterator(
|
||||||
|
data_path_and_name_and_type,
|
||||||
|
dtype=dtype,
|
||||||
|
batch_size=batch_size,
|
||||||
|
key_file=key_file,
|
||||||
|
num_workers=num_workers,
|
||||||
|
preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
|
||||||
|
collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
|
||||||
|
allow_variable_data_keys=allow_variable_data_keys,
|
||||||
|
inference=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_count = 0
|
||||||
|
file_count = 1
|
||||||
|
# 7 .Start for-loop
|
||||||
|
# FIXME(kamo): The output format should be discussed about
|
||||||
|
output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
|
||||||
|
if output_path is not None:
|
||||||
|
writer = DatadirWriter(output_path)
|
||||||
|
ibest_writer = writer[f"1best_recog"]
|
||||||
|
else:
|
||||||
|
writer = None
|
||||||
|
ibest_writer = None
|
||||||
|
|
||||||
|
vad_results = []
|
||||||
|
batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
|
||||||
|
is_final = param_dict['is_final'] if param_dict is not None else False
|
||||||
|
for keys, batch in loader:
|
||||||
|
assert isinstance(batch, dict), type(batch)
|
||||||
|
assert all(isinstance(s, str) for s in keys), keys
|
||||||
|
_bs = len(next(iter(batch.values())))
|
||||||
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||||
|
batch['in_cache'] = batch_in_cache
|
||||||
|
batch['is_final'] = is_final
|
||||||
|
|
||||||
|
# do vad segment
|
||||||
|
_, results, param_dict['in_cache'] = speech2vadsegment(**batch)
|
||||||
|
# param_dict['in_cache'] = batch['in_cache']
|
||||||
|
if results:
|
||||||
|
for i, _ in enumerate(keys):
|
||||||
|
if results[i]:
|
||||||
|
results[i] = json.dumps(results[i])
|
||||||
|
item = {'key': keys[i], 'value': results[i]}
|
||||||
|
vad_results.append(item)
|
||||||
|
if writer is not None:
|
||||||
|
results[i] = json.loads(results[i])
|
||||||
|
ibest_writer["text"][keys[i]] = "{}".format(results[i])
|
||||||
|
|
||||||
|
return vad_results
|
||||||
|
|
||||||
|
return _forward
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = config_argparse.ArgumentParser(
|
||||||
|
description="VAD Decoding",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note(kamo): Use '_' instead of '-' as separator.
|
||||||
|
# '-' is confusing if written in yaml.
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_level",
|
||||||
|
type=lambda x: x.upper(),
|
||||||
|
default="INFO",
|
||||||
|
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
||||||
|
help="The verbose level of logging",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--output_dir", type=str, required=False)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The number of gpus. 0 indicates CPU mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpuid_list",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The visible gpus",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
default="float32",
|
||||||
|
choices=["float16", "float32", "float64"],
|
||||||
|
help="Data type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_workers",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The number of workers used for DataLoader",
|
||||||
|
)
|
||||||
|
|
||||||
|
group = parser.add_argument_group("Input data related")
|
||||||
|
group.add_argument(
|
||||||
|
"--data_path_and_name_and_type",
|
||||||
|
type=str2triple_str,
|
||||||
|
required=False,
|
||||||
|
action="append",
|
||||||
|
)
|
||||||
|
group.add_argument("--raw_inputs", type=list, default=None)
|
||||||
|
# example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
|
||||||
|
group.add_argument("--key_file", type=str_or_none)
|
||||||
|
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
|
||||||
|
|
||||||
|
group = parser.add_argument_group("The model configuration related")
|
||||||
|
group.add_argument(
|
||||||
|
"--vad_infer_config",
|
||||||
|
type=str,
|
||||||
|
help="VAD infer configuration",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--vad_model_file",
|
||||||
|
type=str,
|
||||||
|
help="VAD model parameter file",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--vad_cmvn_file",
|
||||||
|
type=str,
|
||||||
|
help="Global cmvn file",
|
||||||
|
)
|
||||||
|
|
||||||
|
group = parser.add_argument_group("infer related")
|
||||||
|
group.add_argument(
|
||||||
|
"--batch_size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The batch size for inference",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def main(cmd=None):
|
||||||
|
print(get_commandline_args(), file=sys.stderr)
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args(cmd)
|
||||||
|
kwargs = vars(args)
|
||||||
|
kwargs.pop("config", None)
|
||||||
|
inference(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -926,10 +926,10 @@ class BiCifParaformer(Paraformer):
|
|||||||
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
|
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
|
||||||
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||||
encoder_out.device)
|
encoder_out.device)
|
||||||
ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
|
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
|
||||||
encoder_out_mask,
|
encoder_out_mask,
|
||||||
token_num)
|
token_num)
|
||||||
return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
|
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
242
funasr/models/e2e_diar_eend_ola.py
Normal file
242
funasr/models/e2e_diar_eend_ola.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||||
|
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||||
|
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||||
|
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
|
||||||
|
from funasr.torch_utils.device_funcs import force_gatherable
|
||||||
|
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||||
|
|
||||||
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Nothing to do if torch<1.6.0
|
||||||
|
@contextmanager
|
||||||
|
def autocast(enabled=True):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def pad_attractor(att, max_n_speakers):
|
||||||
|
C, D = att.shape
|
||||||
|
if C < max_n_speakers:
|
||||||
|
att = torch.cat([att, torch.zeros(max_n_speakers - C, D).to(torch.float32).to(att.device)], dim=0)
|
||||||
|
return att
|
||||||
|
|
||||||
|
|
||||||
|
class DiarEENDOLAModel(AbsESPnetModel):
|
||||||
|
"""EEND-OLA diarization model"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frontend: WavFrontendMel23,
|
||||||
|
encoder: EENDOLATransformerEncoder,
|
||||||
|
encoder_decoder_attractor: EncoderDecoderAttractor,
|
||||||
|
n_units: int = 256,
|
||||||
|
max_n_speaker: int = 8,
|
||||||
|
attractor_loss_weight: float = 1.0,
|
||||||
|
mapping_dict=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.frontend = frontend
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder_decoder_attractor = encoder_decoder_attractor
|
||||||
|
self.attractor_loss_weight = attractor_loss_weight
|
||||||
|
self.max_n_speaker = max_n_speaker
|
||||||
|
if mapping_dict is None:
|
||||||
|
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
|
||||||
|
self.mapping_dict = mapping_dict
|
||||||
|
# PostNet
|
||||||
|
self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
||||||
|
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
|
||||||
|
|
||||||
|
def forward_encoder(self, xs, ilens):
|
||||||
|
xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1)
|
||||||
|
pad_shape = xs.shape
|
||||||
|
xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
|
||||||
|
xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
|
||||||
|
emb = self.encoder(xs, xs_mask)
|
||||||
|
emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
|
||||||
|
emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def forward_post_net(self, logits, ilens):
|
||||||
|
maxlen = torch.max(ilens).to(torch.int).item()
|
||||||
|
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
|
||||||
|
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
|
||||||
|
outputs, (_, _) = self.PostNet(logits)
|
||||||
|
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
|
||||||
|
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
|
||||||
|
outputs = [self.output_layer(output) for output in outputs]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||||
|
"""Frontend + Encoder + Decoder + Calc loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech: (Batch, Length, ...)
|
||||||
|
speech_lengths: (Batch, )
|
||||||
|
text: (Batch, Length)
|
||||||
|
text_lengths: (Batch,)
|
||||||
|
"""
|
||||||
|
assert text_lengths.dim() == 1, text_lengths.shape
|
||||||
|
# Check that batch_size is unified
|
||||||
|
assert (
|
||||||
|
speech.shape[0]
|
||||||
|
== speech_lengths.shape[0]
|
||||||
|
== text.shape[0]
|
||||||
|
== text_lengths.shape[0]
|
||||||
|
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||||
|
batch_size = speech.shape[0]
|
||||||
|
|
||||||
|
# for data-parallel
|
||||||
|
text = text[:, : text_lengths.max()]
|
||||||
|
|
||||||
|
# 1. Encoder
|
||||||
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||||
|
intermediate_outs = None
|
||||||
|
if isinstance(encoder_out, tuple):
|
||||||
|
intermediate_outs = encoder_out[1]
|
||||||
|
encoder_out = encoder_out[0]
|
||||||
|
|
||||||
|
loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
||||||
|
loss_ctc, cer_ctc = None, None
|
||||||
|
stats = dict()
|
||||||
|
|
||||||
|
# 1. CTC branch
|
||||||
|
if self.ctc_weight != 0.0:
|
||||||
|
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||||
|
encoder_out, encoder_out_lens, text, text_lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect CTC branch stats
|
||||||
|
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
||||||
|
stats["cer_ctc"] = cer_ctc
|
||||||
|
|
||||||
|
# Intermediate CTC (optional)
|
||||||
|
loss_interctc = 0.0
|
||||||
|
if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
||||||
|
for layer_idx, intermediate_out in intermediate_outs:
|
||||||
|
# we assume intermediate_out has the same length & padding
|
||||||
|
# as those of encoder_out
|
||||||
|
loss_ic, cer_ic = self._calc_ctc_loss(
|
||||||
|
intermediate_out, encoder_out_lens, text, text_lengths
|
||||||
|
)
|
||||||
|
loss_interctc = loss_interctc + loss_ic
|
||||||
|
|
||||||
|
# Collect Intermedaite CTC stats
|
||||||
|
stats["loss_interctc_layer{}".format(layer_idx)] = (
|
||||||
|
loss_ic.detach() if loss_ic is not None else None
|
||||||
|
)
|
||||||
|
stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
||||||
|
|
||||||
|
loss_interctc = loss_interctc / len(intermediate_outs)
|
||||||
|
|
||||||
|
# calculate whole encoder loss
|
||||||
|
loss_ctc = (
|
||||||
|
1 - self.interctc_weight
|
||||||
|
) * loss_ctc + self.interctc_weight * loss_interctc
|
||||||
|
|
||||||
|
# 2b. Attention decoder branch
|
||||||
|
if self.ctc_weight != 1.0:
|
||||||
|
loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
||||||
|
encoder_out, encoder_out_lens, text, text_lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. CTC-Att loss definition
|
||||||
|
if self.ctc_weight == 0.0:
|
||||||
|
loss = loss_att
|
||||||
|
elif self.ctc_weight == 1.0:
|
||||||
|
loss = loss_ctc
|
||||||
|
else:
|
||||||
|
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
||||||
|
|
||||||
|
# Collect Attn branch stats
|
||||||
|
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
||||||
|
stats["acc"] = acc_att
|
||||||
|
stats["cer"] = cer_att
|
||||||
|
stats["wer"] = wer_att
|
||||||
|
|
||||||
|
# Collect total loss stats
|
||||||
|
stats["loss"] = torch.clone(loss.detach())
|
||||||
|
|
||||||
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||||
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||||
|
return loss, stats, weight
|
||||||
|
|
||||||
|
def estimate_sequential(self,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
n_speakers: int = None,
|
||||||
|
shuffle: bool = True,
|
||||||
|
threshold: float = 0.5,
|
||||||
|
**kwargs):
|
||||||
|
if self.frontend is not None:
|
||||||
|
speech = self.frontend(speech)
|
||||||
|
speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
||||||
|
emb = self.forward_encoder(speech, speech_lengths)
|
||||||
|
if shuffle:
|
||||||
|
orders = [np.arange(e.shape[0]) for e in emb]
|
||||||
|
for order in orders:
|
||||||
|
np.random.shuffle(order)
|
||||||
|
attractors, probs = self.encoder_decoder_attractor.estimate(
|
||||||
|
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
|
||||||
|
else:
|
||||||
|
attractors, probs = self.encoder_decoder_attractor.estimate(emb)
|
||||||
|
attractors_active = []
|
||||||
|
for p, att, e in zip(probs, attractors, emb):
|
||||||
|
if n_speakers and n_speakers >= 0:
|
||||||
|
att = att[:n_speakers, ]
|
||||||
|
attractors_active.append(att)
|
||||||
|
elif threshold is not None:
|
||||||
|
silence = torch.nonzero(p < threshold)[0]
|
||||||
|
n_spk = silence[0] if silence.size else None
|
||||||
|
att = att[:n_spk, ]
|
||||||
|
attractors_active.append(att)
|
||||||
|
else:
|
||||||
|
NotImplementedError('n_speakers or threshold has to be given.')
|
||||||
|
raw_n_speakers = [att.shape[0] for att in attractors_active]
|
||||||
|
attractors = [
|
||||||
|
pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
|
||||||
|
for att in attractors_active]
|
||||||
|
ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
|
||||||
|
logits = self.forward_post_net(ys, speech_lengths)
|
||||||
|
ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
|
||||||
|
zip(logits, raw_n_speakers)]
|
||||||
|
|
||||||
|
return ys, emb, attractors, raw_n_speakers
|
||||||
|
|
||||||
|
def recover_y_from_powerlabel(self, logit, n_speaker):
|
||||||
|
pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1)
|
||||||
|
oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
|
||||||
|
for i in oov_index:
|
||||||
|
if i > 0:
|
||||||
|
pred[i] = pred[i - 1]
|
||||||
|
else:
|
||||||
|
pred[i] = 0
|
||||||
|
pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
|
||||||
|
decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
|
||||||
|
decisions = torch.from_numpy(
|
||||||
|
np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
|
||||||
|
torch.float32)
|
||||||
|
decisions = decisions[:, :n_speaker]
|
||||||
|
return decisions
|
||||||
175
funasr/models/e2e_tp.py
Normal file
175
funasr/models/e2e_tp.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||||
|
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||||
|
from funasr.models.predictor.cif import mae_loss
|
||||||
|
from funasr.modules.add_sos_eos import add_sos_eos
|
||||||
|
from funasr.modules.nets_utils import make_pad_mask, pad_list
|
||||||
|
from funasr.torch_utils.device_funcs import force_gatherable
|
||||||
|
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||||
|
from funasr.models.predictor.cif import CifPredictorV3
|
||||||
|
|
||||||
|
|
||||||
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
else:
|
||||||
|
# Nothing to do if torch<1.6.0
|
||||||
|
@contextmanager
|
||||||
|
def autocast(enabled=True):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class TimestampPredictor(AbsESPnetModel):
|
||||||
|
"""
|
||||||
|
Author: Speech Lab, Alibaba Group, China
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frontend: Optional[AbsFrontend],
|
||||||
|
encoder: AbsEncoder,
|
||||||
|
predictor: CifPredictorV3,
|
||||||
|
predictor_bias: int = 0,
|
||||||
|
token_list=None,
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
# note that eos is the same as sos (equivalent ID)
|
||||||
|
|
||||||
|
self.frontend = frontend
|
||||||
|
self.encoder = encoder
|
||||||
|
self.encoder.interctc_use_conditioning = False
|
||||||
|
|
||||||
|
self.predictor = predictor
|
||||||
|
self.predictor_bias = predictor_bias
|
||||||
|
self.criterion_pre = mae_loss()
|
||||||
|
self.token_list = token_list
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||||
|
"""Frontend + Encoder + Decoder + Calc loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech: (Batch, Length, ...)
|
||||||
|
speech_lengths: (Batch, )
|
||||||
|
text: (Batch, Length)
|
||||||
|
text_lengths: (Batch,)
|
||||||
|
"""
|
||||||
|
assert text_lengths.dim() == 1, text_lengths.shape
|
||||||
|
# Check that batch_size is unified
|
||||||
|
assert (
|
||||||
|
speech.shape[0]
|
||||||
|
== speech_lengths.shape[0]
|
||||||
|
== text.shape[0]
|
||||||
|
== text_lengths.shape[0]
|
||||||
|
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
||||||
|
batch_size = speech.shape[0]
|
||||||
|
# for data-parallel
|
||||||
|
text = text[:, : text_lengths.max()]
|
||||||
|
speech = speech[:, :speech_lengths.max()]
|
||||||
|
|
||||||
|
# 1. Encoder
|
||||||
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
||||||
|
|
||||||
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||||
|
encoder_out.device)
|
||||||
|
if self.predictor_bias == 1:
|
||||||
|
_, text = add_sos_eos(text, 1, 2, -1)
|
||||||
|
text_lengths = text_lengths + self.predictor_bias
|
||||||
|
_, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
|
||||||
|
|
||||||
|
# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
|
||||||
|
loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
|
||||||
|
|
||||||
|
loss = loss_pre
|
||||||
|
stats = dict()
|
||||||
|
|
||||||
|
# Collect Attn branch stats
|
||||||
|
stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
|
||||||
|
stats["loss"] = torch.clone(loss.detach())
|
||||||
|
|
||||||
|
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||||
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||||
|
return loss, stats, weight
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech: (Batch, Length, ...)
|
||||||
|
speech_lengths: (Batch, )
|
||||||
|
"""
|
||||||
|
with autocast(False):
|
||||||
|
# 1. Extract feats
|
||||||
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||||
|
|
||||||
|
# 4. Forward encoder
|
||||||
|
# feats: (Batch, Length, Dim)
|
||||||
|
# -> encoder_out: (Batch, Length2, Dim2)
|
||||||
|
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||||
|
|
||||||
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
def _extract_feats(
|
||||||
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||||
|
|
||||||
|
# for data-parallel
|
||||||
|
speech = speech[:, : speech_lengths.max()]
|
||||||
|
if self.frontend is not None:
|
||||||
|
# Frontend
|
||||||
|
# e.g. STFT and Feature extract
|
||||||
|
# data_loader may send time-domain signal in this case
|
||||||
|
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||||
|
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||||
|
else:
|
||||||
|
# No frontend and no feature extract
|
||||||
|
feats, feats_lengths = speech, speech_lengths
|
||||||
|
return feats, feats_lengths
|
||||||
|
|
||||||
|
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
|
||||||
|
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
|
||||||
|
encoder_out.device)
|
||||||
|
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
|
||||||
|
encoder_out_mask,
|
||||||
|
token_num)
|
||||||
|
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
|
||||||
|
|
||||||
|
def collect_feats(
|
||||||
|
self,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if self.extract_feats_in_collect_stats:
|
||||||
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||||
|
else:
|
||||||
|
# Generate dummy stats if extract_feats_in_collect_stats is False
|
||||||
|
logging.warning(
|
||||||
|
"Generating dummy stats for feats and feats_lengths, "
|
||||||
|
"because encoder_conf.extract_feats_in_collect_stats is "
|
||||||
|
f"{self.extract_feats_in_collect_stats}"
|
||||||
|
)
|
||||||
|
feats, feats_lengths = speech, speech_lengths
|
||||||
|
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||||
@ -215,6 +215,7 @@ class E2EVadModel(nn.Module):
|
|||||||
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
||||||
self.noise_average_decibel = -100.0
|
self.noise_average_decibel = -100.0
|
||||||
self.pre_end_silence_detected = False
|
self.pre_end_silence_detected = False
|
||||||
|
self.next_seg = True
|
||||||
|
|
||||||
self.output_data_buf = []
|
self.output_data_buf = []
|
||||||
self.output_data_buf_offset = 0
|
self.output_data_buf_offset = 0
|
||||||
@ -244,6 +245,7 @@ class E2EVadModel(nn.Module):
|
|||||||
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
||||||
self.noise_average_decibel = -100.0
|
self.noise_average_decibel = -100.0
|
||||||
self.pre_end_silence_detected = False
|
self.pre_end_silence_detected = False
|
||||||
|
self.next_seg = True
|
||||||
|
|
||||||
self.output_data_buf = []
|
self.output_data_buf = []
|
||||||
self.output_data_buf_offset = 0
|
self.output_data_buf_offset = 0
|
||||||
@ -441,7 +443,7 @@ class E2EVadModel(nn.Module):
|
|||||||
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
|
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
|
||||||
|
|
||||||
return frame_state
|
return frame_state
|
||||||
|
|
||||||
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||||
is_final: bool = False
|
is_final: bool = False
|
||||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||||
@ -470,6 +472,42 @@ class E2EVadModel(nn.Module):
|
|||||||
self.AllResetDetection()
|
self.AllResetDetection()
|
||||||
return segments, in_cache
|
return segments, in_cache
|
||||||
|
|
||||||
|
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||||
|
is_final: bool = False
|
||||||
|
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||||
|
self.waveform = waveform # compute decibel for each frame
|
||||||
|
self.ComputeDecibel()
|
||||||
|
self.ComputeScores(feats, in_cache)
|
||||||
|
if not is_final:
|
||||||
|
self.DetectCommonFrames()
|
||||||
|
else:
|
||||||
|
self.DetectLastFrames()
|
||||||
|
segments = []
|
||||||
|
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
|
||||||
|
segment_batch = []
|
||||||
|
if len(self.output_data_buf) > 0:
|
||||||
|
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
||||||
|
if not self.output_data_buf[i].contain_seg_start_point:
|
||||||
|
continue
|
||||||
|
if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
|
||||||
|
continue
|
||||||
|
start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
|
||||||
|
if self.output_data_buf[i].contain_seg_end_point:
|
||||||
|
end_ms = self.output_data_buf[i].end_ms
|
||||||
|
self.next_seg = True
|
||||||
|
self.output_data_buf_offset += 1
|
||||||
|
else:
|
||||||
|
end_ms = -1
|
||||||
|
self.next_seg = False
|
||||||
|
segment = [start_ms, end_ms]
|
||||||
|
segment_batch.append(segment)
|
||||||
|
if segment_batch:
|
||||||
|
segments.append(segment_batch)
|
||||||
|
if is_final:
|
||||||
|
# reset class variables and clear the dict for the next query
|
||||||
|
self.AllResetDetection()
|
||||||
|
return segments, in_cache
|
||||||
|
|
||||||
def DetectCommonFrames(self) -> int:
|
def DetectCommonFrames(self) -> int:
|
||||||
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
51
funasr/models/frontend/eend_ola_feature.py
Normal file
51
funasr/models/frontend/eend_ola_feature.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
#
|
||||||
|
# This module is for computing audio features
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def transform(Y, dtype=np.float32):
|
||||||
|
Y = np.abs(Y)
|
||||||
|
n_fft = 2 * (Y.shape[1] - 1)
|
||||||
|
sr = 8000
|
||||||
|
n_mels = 23
|
||||||
|
mel_basis = librosa.filters.mel(sr, n_fft, n_mels)
|
||||||
|
Y = np.dot(Y ** 2, mel_basis.T)
|
||||||
|
Y = np.log10(np.maximum(Y, 1e-10))
|
||||||
|
mean = np.mean(Y, axis=0)
|
||||||
|
Y = Y - mean
|
||||||
|
return Y.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def subsample(Y, T, subsampling=1):
|
||||||
|
Y_ss = Y[::subsampling]
|
||||||
|
T_ss = T[::subsampling]
|
||||||
|
return Y_ss, T_ss
|
||||||
|
|
||||||
|
|
||||||
|
def splice(Y, context_size=0):
|
||||||
|
Y_pad = np.pad(
|
||||||
|
Y,
|
||||||
|
[(context_size, context_size), (0, 0)],
|
||||||
|
'constant')
|
||||||
|
Y_spliced = np.lib.stride_tricks.as_strided(
|
||||||
|
np.ascontiguousarray(Y_pad),
|
||||||
|
(Y.shape[0], Y.shape[1] * (2 * context_size + 1)),
|
||||||
|
(Y.itemsize * Y.shape[1], Y.itemsize), writeable=False)
|
||||||
|
return Y_spliced
|
||||||
|
|
||||||
|
|
||||||
|
def stft(
|
||||||
|
data,
|
||||||
|
frame_size=1024,
|
||||||
|
frame_shift=256):
|
||||||
|
fft_size = 1 << (frame_size - 1).bit_length()
|
||||||
|
if len(data) % frame_shift == 0:
|
||||||
|
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||||
|
hop_length=frame_shift).T[:-1]
|
||||||
|
else:
|
||||||
|
return librosa.stft(data, n_fft=fft_size, win_length=frame_size,
|
||||||
|
hop_length=frame_shift).T
|
||||||
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
# Part of the implementation is borrowed from espnet/espnet.
|
# Part of the implementation is borrowed from espnet/espnet.
|
||||||
|
from abc import ABC
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -33,9 +33,9 @@ def load_cmvn(cmvn_file):
|
|||||||
means = np.array(means_list).astype(np.float)
|
means = np.array(means_list).astype(np.float)
|
||||||
vars = np.array(vars_list).astype(np.float)
|
vars = np.array(vars_list).astype(np.float)
|
||||||
cmvn = np.array([means, vars])
|
cmvn = np.array([means, vars])
|
||||||
cmvn = torch.as_tensor(cmvn)
|
cmvn = torch.as_tensor(cmvn)
|
||||||
return cmvn
|
return cmvn
|
||||||
|
|
||||||
|
|
||||||
def apply_cmvn(inputs, cmvn_file): # noqa
|
def apply_cmvn(inputs, cmvn_file): # noqa
|
||||||
"""
|
"""
|
||||||
@ -78,21 +78,22 @@ def apply_lfr(inputs, lfr_m, lfr_n):
|
|||||||
class WavFrontend(AbsFrontend):
|
class WavFrontend(AbsFrontend):
|
||||||
"""Conventional frontend structure for ASR.
|
"""Conventional frontend structure for ASR.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cmvn_file: str = None,
|
cmvn_file: str = None,
|
||||||
fs: int = 16000,
|
fs: int = 16000,
|
||||||
window: str = 'hamming',
|
window: str = 'hamming',
|
||||||
n_mels: int = 80,
|
n_mels: int = 80,
|
||||||
frame_length: int = 25,
|
frame_length: int = 25,
|
||||||
frame_shift: int = 10,
|
frame_shift: int = 10,
|
||||||
filter_length_min: int = -1,
|
filter_length_min: int = -1,
|
||||||
filter_length_max: int = -1,
|
filter_length_max: int = -1,
|
||||||
lfr_m: int = 1,
|
lfr_m: int = 1,
|
||||||
lfr_n: int = 1,
|
lfr_n: int = 1,
|
||||||
dither: float = 1.0,
|
dither: float = 1.0,
|
||||||
snip_edges: bool = True,
|
snip_edges: bool = True,
|
||||||
upsacle_samples: bool = True,
|
upsacle_samples: bool = True,
|
||||||
):
|
):
|
||||||
assert check_argument_types()
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -135,11 +136,11 @@ class WavFrontend(AbsFrontend):
|
|||||||
window_type=self.window,
|
window_type=self.window,
|
||||||
sample_frequency=self.fs,
|
sample_frequency=self.fs,
|
||||||
snip_edges=self.snip_edges)
|
snip_edges=self.snip_edges)
|
||||||
|
|
||||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||||
if self.cmvn_file is not None:
|
if self.cmvn_file is not None:
|
||||||
mat = apply_cmvn(mat, self.cmvn_file)
|
mat = apply_cmvn(mat, self.cmvn_file)
|
||||||
feat_length = mat.size(0)
|
feat_length = mat.size(0)
|
||||||
feats.append(mat)
|
feats.append(mat)
|
||||||
feats_lens.append(feat_length)
|
feats_lens.append(feat_length)
|
||||||
@ -171,7 +172,6 @@ class WavFrontend(AbsFrontend):
|
|||||||
window_type=self.window,
|
window_type=self.window,
|
||||||
sample_frequency=self.fs)
|
sample_frequency=self.fs)
|
||||||
|
|
||||||
|
|
||||||
feat_length = mat.size(0)
|
feat_length = mat.size(0)
|
||||||
feats.append(mat)
|
feats.append(mat)
|
||||||
feats_lens.append(feat_length)
|
feats_lens.append(feat_length)
|
||||||
@ -204,3 +204,243 @@ class WavFrontend(AbsFrontend):
|
|||||||
batch_first=True,
|
batch_first=True,
|
||||||
padding_value=0.0)
|
padding_value=0.0)
|
||||||
return feats_pad, feats_lens
|
return feats_pad, feats_lens
|
||||||
|
|
||||||
|
|
||||||
|
class WavFrontendOnline(AbsFrontend):
|
||||||
|
"""Conventional frontend structure for streaming ASR/VAD.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cmvn_file: str = None,
|
||||||
|
fs: int = 16000,
|
||||||
|
window: str = 'hamming',
|
||||||
|
n_mels: int = 80,
|
||||||
|
frame_length: int = 25,
|
||||||
|
frame_shift: int = 10,
|
||||||
|
filter_length_min: int = -1,
|
||||||
|
filter_length_max: int = -1,
|
||||||
|
lfr_m: int = 1,
|
||||||
|
lfr_n: int = 1,
|
||||||
|
dither: float = 1.0,
|
||||||
|
snip_edges: bool = True,
|
||||||
|
upsacle_samples: bool = True,
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
super().__init__()
|
||||||
|
self.fs = fs
|
||||||
|
self.window = window
|
||||||
|
self.n_mels = n_mels
|
||||||
|
self.frame_length = frame_length
|
||||||
|
self.frame_shift = frame_shift
|
||||||
|
self.frame_sample_length = int(self.frame_length * self.fs / 1000)
|
||||||
|
self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
|
||||||
|
self.filter_length_min = filter_length_min
|
||||||
|
self.filter_length_max = filter_length_max
|
||||||
|
self.lfr_m = lfr_m
|
||||||
|
self.lfr_n = lfr_n
|
||||||
|
self.cmvn_file = cmvn_file
|
||||||
|
self.dither = dither
|
||||||
|
self.snip_edges = snip_edges
|
||||||
|
self.upsacle_samples = upsacle_samples
|
||||||
|
self.waveforms = None
|
||||||
|
self.reserve_waveforms = None
|
||||||
|
self.fbanks = None
|
||||||
|
self.fbanks_lens = None
|
||||||
|
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
|
||||||
|
self.input_cache = None
|
||||||
|
self.lfr_splice_cache = []
|
||||||
|
|
||||||
|
def output_size(self) -> int:
|
||||||
|
return self.n_mels * self.lfr_m
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply CMVN with mvn data
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = inputs.device
|
||||||
|
dtype = inputs.dtype
|
||||||
|
frame, dim = inputs.shape
|
||||||
|
|
||||||
|
means = np.tile(cmvn[0:1, :dim], (frame, 1))
|
||||||
|
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
|
||||||
|
inputs += torch.from_numpy(means).type(dtype).to(device)
|
||||||
|
inputs *= torch.from_numpy(vars).type(dtype).to(device)
|
||||||
|
|
||||||
|
return inputs.type(torch.float32)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# inputs tensor has catted the cache tensor
|
||||||
|
# def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
|
||||||
|
# is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
Apply lfr with data
|
||||||
|
"""
|
||||||
|
|
||||||
|
LFR_inputs = []
|
||||||
|
# inputs = torch.vstack((inputs_lfr_cache, inputs))
|
||||||
|
T = inputs.shape[0] # include the right context
|
||||||
|
T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
|
||||||
|
splice_idx = T_lfr
|
||||||
|
for i in range(T_lfr):
|
||||||
|
if lfr_m <= T - i * lfr_n:
|
||||||
|
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
|
||||||
|
else: # process last LFR frame
|
||||||
|
if is_final:
|
||||||
|
num_padding = lfr_m - (T - i * lfr_n)
|
||||||
|
frame = (inputs[i * lfr_n:]).view(-1)
|
||||||
|
for _ in range(num_padding):
|
||||||
|
frame = torch.hstack((frame, inputs[-1]))
|
||||||
|
LFR_inputs.append(frame)
|
||||||
|
else:
|
||||||
|
# update splice_idx and break the circle
|
||||||
|
splice_idx = i
|
||||||
|
break
|
||||||
|
splice_idx = min(T - 1, splice_idx * lfr_n)
|
||||||
|
lfr_splice_cache = inputs[splice_idx:, :]
|
||||||
|
LFR_outputs = torch.vstack(LFR_inputs)
|
||||||
|
return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
|
||||||
|
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
||||||
|
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
||||||
|
|
||||||
|
def forward_fbank(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = input.size(0)
|
||||||
|
if self.input_cache is None:
|
||||||
|
self.input_cache = torch.empty(0)
|
||||||
|
input = torch.cat((self.input_cache, input), dim=1)
|
||||||
|
frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
|
||||||
|
# update self.in_cache
|
||||||
|
self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
|
||||||
|
waveforms = torch.empty(0)
|
||||||
|
feats_pad = torch.empty(0)
|
||||||
|
feats_lens = torch.empty(0)
|
||||||
|
if frame_num:
|
||||||
|
waveforms = []
|
||||||
|
feats = []
|
||||||
|
feats_lens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
waveform = input[i]
|
||||||
|
# we need accurate wave samples that used for fbank extracting
|
||||||
|
waveforms.append(
|
||||||
|
waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
|
||||||
|
waveform = waveform * (1 << 15)
|
||||||
|
waveform = waveform.unsqueeze(0)
|
||||||
|
mat = kaldi.fbank(waveform,
|
||||||
|
num_mel_bins=self.n_mels,
|
||||||
|
frame_length=self.frame_length,
|
||||||
|
frame_shift=self.frame_shift,
|
||||||
|
dither=self.dither,
|
||||||
|
energy_floor=0.0,
|
||||||
|
window_type=self.window,
|
||||||
|
sample_frequency=self.fs)
|
||||||
|
|
||||||
|
feat_length = mat.size(0)
|
||||||
|
feats.append(mat)
|
||||||
|
feats_lens.append(feat_length)
|
||||||
|
|
||||||
|
waveforms = torch.stack(waveforms)
|
||||||
|
feats_lens = torch.as_tensor(feats_lens)
|
||||||
|
feats_pad = pad_sequence(feats,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=0.0)
|
||||||
|
self.fbanks = feats_pad
|
||||||
|
import copy
|
||||||
|
self.fbanks_lens = copy.deepcopy(feats_lens)
|
||||||
|
return waveforms, feats_pad, feats_lens
|
||||||
|
|
||||||
|
def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return self.fbanks, self.fbanks_lens
|
||||||
|
|
||||||
|
def forward_lfr_cmvn(
|
||||||
|
self,
|
||||||
|
input: torch.Tensor,
|
||||||
|
input_lengths: torch.Tensor,
|
||||||
|
is_final: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = input.size(0)
|
||||||
|
feats = []
|
||||||
|
feats_lens = []
|
||||||
|
lfr_splice_frame_idxs = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
mat = input[i, :input_lengths[i], :]
|
||||||
|
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||||
|
# update self.lfr_splice_cache in self.apply_lfr
|
||||||
|
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
|
||||||
|
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
|
||||||
|
if self.cmvn_file is not None:
|
||||||
|
mat = self.apply_cmvn(mat, self.cmvn)
|
||||||
|
feat_length = mat.size(0)
|
||||||
|
feats.append(mat)
|
||||||
|
feats_lens.append(feat_length)
|
||||||
|
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
||||||
|
|
||||||
|
feats_lens = torch.as_tensor(feats_lens)
|
||||||
|
feats_pad = pad_sequence(feats,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=0.0)
|
||||||
|
lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
|
||||||
|
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = input.shape[0]
|
||||||
|
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
|
||||||
|
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
|
||||||
|
if feats.shape[0]:
|
||||||
|
#if self.reserve_waveforms is None and self.lfr_m > 1:
|
||||||
|
# self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
|
||||||
|
self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1)
|
||||||
|
if not self.lfr_splice_cache: # 初始化splice_cache
|
||||||
|
for i in range(batch_size):
|
||||||
|
self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
|
||||||
|
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
|
||||||
|
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
||||||
|
lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
|
||||||
|
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
|
||||||
|
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
|
||||||
|
frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
|
||||||
|
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
||||||
|
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
|
||||||
|
if self.lfr_m == 1:
|
||||||
|
self.reserve_waveforms = None
|
||||||
|
else:
|
||||||
|
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
||||||
|
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
||||||
|
# print('frame_frame: ' + str(frame_from_waveforms))
|
||||||
|
self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
|
||||||
|
sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
|
||||||
|
self.waveforms = self.waveforms[:, :sample_length]
|
||||||
|
else:
|
||||||
|
# update self.reserve_waveforms and self.lfr_splice_cache
|
||||||
|
self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
|
||||||
|
for i in range(batch_size):
|
||||||
|
self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
|
||||||
|
return torch.empty(0), feats_lengths
|
||||||
|
else:
|
||||||
|
if is_final:
|
||||||
|
self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
||||||
|
feats = torch.stack(self.lfr_splice_cache)
|
||||||
|
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
|
||||||
|
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
|
||||||
|
if is_final:
|
||||||
|
self.cache_reset()
|
||||||
|
return feats, feats_lengths
|
||||||
|
|
||||||
|
def get_waveforms(self):
|
||||||
|
return self.waveforms
|
||||||
|
|
||||||
|
def cache_reset(self):
|
||||||
|
self.reserve_waveforms = None
|
||||||
|
self.input_cache = None
|
||||||
|
self.lfr_splice_cache = []
|
||||||
|
|||||||
@ -82,7 +82,7 @@ def windowed_statistic_pooling(
|
|||||||
tt = xs_pad.shape[2]
|
tt = xs_pad.shape[2]
|
||||||
num_chunk = int(math.ceil(tt / pooling_stride))
|
num_chunk = int(math.ceil(tt / pooling_stride))
|
||||||
pad = pooling_size // 2
|
pad = pooling_size // 2
|
||||||
if xs_pad.shape == 4:
|
if len(xs_pad.shape) == 4:
|
||||||
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
|
features = F.pad(xs_pad, (0, 0, pad, pad), "reflect")
|
||||||
else:
|
else:
|
||||||
features = F.pad(xs_pad, (pad, pad), "reflect")
|
features = F.pad(xs_pad, (pad, pad), "reflect")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -81,10 +81,16 @@ class PositionalEncoding(torch.nn.Module):
|
|||||||
return self.dropout(x)
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
class EENDOLATransformerEncoder(nn.Module):
|
||||||
def __init__(self, idim, n_layers, n_units,
|
def __init__(self,
|
||||||
e_units=2048, h=8, dropout_rate=0.1, use_pos_emb=False):
|
idim: int,
|
||||||
super(TransformerEncoder, self).__init__()
|
n_layers: int,
|
||||||
|
n_units: int,
|
||||||
|
e_units: int = 2048,
|
||||||
|
h: int = 8,
|
||||||
|
dropout_rate: float = 0.1,
|
||||||
|
use_pos_emb: bool = False):
|
||||||
|
super(EENDOLATransformerEncoder, self).__init__()
|
||||||
self.lnorm_in = nn.LayerNorm(n_units)
|
self.lnorm_in = nn.LayerNorm(n_units)
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
self.dropout = nn.Dropout(dropout_rate)
|
||||||
|
|||||||
@ -29,6 +29,7 @@ tester /path/to/models/dir /path/to/wave/file
|
|||||||
|
|
||||||
## 依赖
|
## 依赖
|
||||||
- fftw3
|
- fftw3
|
||||||
|
- openblas
|
||||||
- onnxruntime
|
- onnxruntime
|
||||||
|
|
||||||
## 导出onnx格式模型文件
|
## 导出onnx格式模型文件
|
||||||
@ -47,18 +48,22 @@ python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn
|
|||||||
## Building Guidance for Linux/Unix
|
## Building Guidance for Linux/Unix
|
||||||
|
|
||||||
```
|
```
|
||||||
git clone https://github.com/RapidAI/RapidASR.git
|
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/onnxruntime
|
||||||
cd RapidASR/cpp_onnx/
|
|
||||||
mkdir build
|
mkdir build
|
||||||
cd build
|
cd build
|
||||||
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
|
# download an appropriate onnxruntime from https://github.com/microsoft/onnxruntime/releases/tag/v1.14.0
|
||||||
# here we get a copy of onnxruntime for linux 64
|
# here we get a copy of onnxruntime for linux 64
|
||||||
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
|
wget https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz
|
||||||
|
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
|
||||||
# ls
|
# ls
|
||||||
# onnxruntime-linux-x64-1.14.0 onnxruntime-linux-x64-1.14.0.tgz
|
# onnxruntime-linux-x64-1.14.0 onnxruntime-linux-x64-1.14.0.tgz
|
||||||
|
|
||||||
#install fftw3-dev
|
#install fftw3-dev
|
||||||
apt install libfftw3-dev
|
ubuntu: apt install libfftw3-dev
|
||||||
|
centos: yum install fftw fftw-devel
|
||||||
|
|
||||||
|
#install openblas
|
||||||
|
bash ./third_party/install_openblas.sh
|
||||||
|
|
||||||
# build
|
# build
|
||||||
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/mnt/c/Users/ma139/RapidASR/cpp_onnx/build/onnxruntime-linux-x64-1.14.0
|
cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/mnt/c/Users/ma139/RapidASR/cpp_onnx/build/onnxruntime-linux-x64-1.14.0
|
||||||
|
|||||||
39
funasr/runtime/onnxruntime/third_party/install_openblas.sh
vendored
Normal file
39
funasr/runtime/onnxruntime/third_party/install_openblas.sh
vendored
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
OPENBLAS_VERSION=0.3.13
|
||||||
|
|
||||||
|
WGET=${WGET:-wget}
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
if ! command -v gfortran 2>/dev/null; then
|
||||||
|
echo "$0: gfortran is not installed. Please install it, e.g. by:"
|
||||||
|
echo " apt-get install gfortran"
|
||||||
|
echo "(if on Debian or Ubuntu), or:"
|
||||||
|
echo " yum install gcc-gfortran"
|
||||||
|
echo "(if on RedHat/CentOS). On a Mac, if brew is installed, it's:"
|
||||||
|
echo " brew install gfortran"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
tarball=OpenBLAS-$OPENBLAS_VERSION.tar.gz
|
||||||
|
|
||||||
|
rm -rf xianyi-OpenBLAS-* OpenBLAS OpenBLAS-*.tar.gz
|
||||||
|
|
||||||
|
if [ -d "$DOWNLOAD_DIR" ]; then
|
||||||
|
cp -p "$DOWNLOAD_DIR/$tarball" .
|
||||||
|
else
|
||||||
|
url=$($WGET -qO- "https://api.github.com/repos/xianyi/OpenBLAS/releases/tags/v${OPENBLAS_VERSION}" | python -c 'import sys,json;print(json.load(sys.stdin)["tarball_url"])')
|
||||||
|
test -n "$url"
|
||||||
|
$WGET -t3 -nv -O $tarball "$url"
|
||||||
|
fi
|
||||||
|
|
||||||
|
tar xzf $tarball
|
||||||
|
mv xianyi-OpenBLAS-* OpenBLAS
|
||||||
|
|
||||||
|
make PREFIX=$(pwd)/OpenBLAS/install USE_LOCKING=1 USE_THREAD=0 -C OpenBLAS all install
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
echo "OpenBLAS is installed successfully."
|
||||||
|
rm $tarball
|
||||||
|
fi
|
||||||
@ -150,6 +150,7 @@ class OrtInferSession():
|
|||||||
def __init__(self, model_file, device_id=-1):
|
def __init__(self, model_file, device_id=-1):
|
||||||
device_id = str(device_id)
|
device_id = str(device_id)
|
||||||
sess_opt = SessionOptions()
|
sess_opt = SessionOptions()
|
||||||
|
sess_opt.intra_op_num_threads = 4
|
||||||
sess_opt.log_severity_level = 4
|
sess_opt.log_severity_level = 4
|
||||||
sess_opt.enable_cpu_mem_arena = False
|
sess_opt.enable_cpu_mem_arena = False
|
||||||
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
|||||||
7176
funasr/runtime/triton_gpu/client/aishell_test.txt
Normal file
7176
funasr/runtime/triton_gpu/client/aishell_test.txt
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,561 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
# 2023 Nvidia (authors: Yuekai Zhang)
|
||||||
|
# 2023 Recurrent.ai (authors: Songtao Shi)
|
||||||
|
# See LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads a manifest in nemo format and sends it to the server
|
||||||
|
for decoding, in parallel.
|
||||||
|
|
||||||
|
{'audio_filepath':'','text':'',duration:}\n
|
||||||
|
{'audio_filepath':'','text':'',duration:}\n
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# For aishell manifests:
|
||||||
|
apt-get install git-lfs
|
||||||
|
git-lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
|
||||||
|
sudo mkdir -p ./aishell-test-dev-manifests/aishell
|
||||||
|
tar xf ./aishell-test-dev-manifests/data_aishell.tar.gz -C ./aishell-test-dev-manifests/aishell # noqa
|
||||||
|
|
||||||
|
|
||||||
|
# cmd run
|
||||||
|
manifest_path='./client/aishell_test.txt'
|
||||||
|
serveraddr=localhost
|
||||||
|
num_task=60
|
||||||
|
python3 client/decode_manifest_triton_wo_cuts.py \
|
||||||
|
--server-addr $serveraddr \
|
||||||
|
--compute-cer \
|
||||||
|
--model-name infer_pipeline \
|
||||||
|
--num-tasks $num_task \
|
||||||
|
--manifest-filename $manifest_path \
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydub import AudioSegment
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import tritonclient
|
||||||
|
import tritonclient.grpc.aio as grpcclient
|
||||||
|
from tritonclient.utils import np_to_triton_dtype
|
||||||
|
|
||||||
|
from icefall.utils import store_transcripts, write_error_stats
|
||||||
|
|
||||||
|
DEFAULT_MANIFEST_FILENAME = "./aishell_test.txt" # noqa
|
||||||
|
DEFAULT_ROOT = './'
|
||||||
|
DEFAULT_ROOT = '/mfs/songtao/researchcode/FunASR/data/'
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-addr",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Address of the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--server-port",
|
||||||
|
type=int,
|
||||||
|
default=8001,
|
||||||
|
help="Port of the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-filename",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MANIFEST_FILENAME,
|
||||||
|
help="Path to the manifest for decoding",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model-name",
|
||||||
|
type=str,
|
||||||
|
default="transducer",
|
||||||
|
help="triton model_repo module name to request",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-tasks",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="Number of tasks to use for sending",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-interval",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Controls how frequently we print the log.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--compute-cer",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="""True to compute CER, e.g., for Chinese.
|
||||||
|
False to compute WER, e.g., for English words.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--streaming",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="""True for streaming ASR.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="""True for strictly simulate streaming ASR.
|
||||||
|
Threads will sleep to simulate the real speaking scene.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk_size",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=16,
|
||||||
|
help="chunk size default is 16",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--context",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=-1,
|
||||||
|
help="subsampling context for wenet",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder_right_context",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=2,
|
||||||
|
help="encoder right context",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--subsampling",
|
||||||
|
type=int,
|
||||||
|
required=False,
|
||||||
|
default=4,
|
||||||
|
help="subsampling rate",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--stats_file",
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
default="./stats.json",
|
||||||
|
help="output of stats anaylasis",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_manifest(fp):
|
||||||
|
data = []
|
||||||
|
with open(fp) as f:
|
||||||
|
for i, dp in enumerate(f.readlines()):
|
||||||
|
dp = eval(dp)
|
||||||
|
dp['id'] = i
|
||||||
|
data.append(dp)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def split_dps(dps, num_tasks):
|
||||||
|
dps_splited = []
|
||||||
|
# import pdb;pdb.set_trace()
|
||||||
|
assert len(dps) > num_tasks
|
||||||
|
|
||||||
|
one_task_num = len(dps)//num_tasks
|
||||||
|
for i in range(0, len(dps), one_task_num):
|
||||||
|
if i+one_task_num >= len(dps):
|
||||||
|
for k, j in enumerate(range(i, len(dps))):
|
||||||
|
dps_splited[k].append(dps[j])
|
||||||
|
else:
|
||||||
|
dps_splited.append(dps[i:i+one_task_num])
|
||||||
|
return dps_splited
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(path):
|
||||||
|
audio = AudioSegment.from_wav(path).set_frame_rate(16000).set_channels(1)
|
||||||
|
audiop_np = np.array(audio.get_array_of_samples())/32768.0
|
||||||
|
return audiop_np.astype(np.float32), audio.duration_seconds
|
||||||
|
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
dps: list,
|
||||||
|
name: str,
|
||||||
|
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
||||||
|
protocol_client: types.ModuleType,
|
||||||
|
log_interval: int,
|
||||||
|
compute_cer: bool,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
total_duration = 0.0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for i, dp in enumerate(dps):
|
||||||
|
if i % log_interval == 0:
|
||||||
|
print(f"{name}: {i}/{len(dps)}")
|
||||||
|
|
||||||
|
waveform, duration = load_audio(
|
||||||
|
os.path.join(DEFAULT_ROOT, dp['audio_filepath']))
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
# padding to nearset 10 seconds
|
||||||
|
samples = np.zeros(
|
||||||
|
(
|
||||||
|
1,
|
||||||
|
10 * sample_rate *
|
||||||
|
(int(len(waveform) / sample_rate // 10) + 1),
|
||||||
|
),
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
samples[0, : len(waveform)] = waveform
|
||||||
|
|
||||||
|
lengths = np.array([[len(waveform)]], dtype=np.int32)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
protocol_client.InferInput(
|
||||||
|
"WAV", samples.shape, np_to_triton_dtype(samples.dtype)
|
||||||
|
),
|
||||||
|
protocol_client.InferInput(
|
||||||
|
"WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
inputs[0].set_data_from_numpy(samples)
|
||||||
|
inputs[1].set_data_from_numpy(lengths)
|
||||||
|
outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
|
||||||
|
sequence_id = 10086 + i
|
||||||
|
|
||||||
|
response = await triton_client.infer(
|
||||||
|
model_name, inputs, request_id=str(sequence_id), outputs=outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
decoding_results = response.as_numpy("TRANSCRIPTS")[0]
|
||||||
|
if type(decoding_results) == np.ndarray:
|
||||||
|
decoding_results = b" ".join(decoding_results).decode("utf-8")
|
||||||
|
else:
|
||||||
|
# For wenet
|
||||||
|
decoding_results = decoding_results.decode("utf-8")
|
||||||
|
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
if compute_cer:
|
||||||
|
ref = dp['text'].split()
|
||||||
|
hyp = decoding_results.split()
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
results.append((dp['id'], ref, hyp))
|
||||||
|
else:
|
||||||
|
results.append(
|
||||||
|
(
|
||||||
|
dp['id'],
|
||||||
|
dp['text'].split(),
|
||||||
|
decoding_results.split(),
|
||||||
|
)
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
return total_duration, results
|
||||||
|
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
dps: list,
|
||||||
|
name: str,
|
||||||
|
triton_client: tritonclient.grpc.aio.InferenceServerClient,
|
||||||
|
protocol_client: types.ModuleType,
|
||||||
|
log_interval: int,
|
||||||
|
compute_cer: bool,
|
||||||
|
model_name: str,
|
||||||
|
first_chunk_in_secs: float,
|
||||||
|
other_chunk_in_secs: float,
|
||||||
|
task_index: int,
|
||||||
|
simulate_mode: bool = False,
|
||||||
|
):
|
||||||
|
total_duration = 0.0
|
||||||
|
results = []
|
||||||
|
latency_data = []
|
||||||
|
|
||||||
|
for i, dp in enumerate(dps):
|
||||||
|
if i % log_interval == 0:
|
||||||
|
print(f"{name}: {i}/{len(dps)}")
|
||||||
|
|
||||||
|
waveform, duration = load_audio(dp['audio_filepath'])
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
wav_segs = []
|
||||||
|
|
||||||
|
j = 0
|
||||||
|
while j < len(waveform):
|
||||||
|
if j == 0:
|
||||||
|
stride = int(first_chunk_in_secs * sample_rate)
|
||||||
|
wav_segs.append(waveform[j: j + stride])
|
||||||
|
else:
|
||||||
|
stride = int(other_chunk_in_secs * sample_rate)
|
||||||
|
wav_segs.append(waveform[j: j + stride])
|
||||||
|
j += len(wav_segs[-1])
|
||||||
|
|
||||||
|
sequence_id = task_index + 10086
|
||||||
|
|
||||||
|
for idx, seg in enumerate(wav_segs):
|
||||||
|
chunk_len = len(seg)
|
||||||
|
|
||||||
|
if simulate_mode:
|
||||||
|
await asyncio.sleep(chunk_len / sample_rate)
|
||||||
|
|
||||||
|
chunk_start = time.time()
|
||||||
|
if idx == 0:
|
||||||
|
chunk_samples = int(first_chunk_in_secs * sample_rate)
|
||||||
|
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
|
||||||
|
else:
|
||||||
|
chunk_samples = int(other_chunk_in_secs * sample_rate)
|
||||||
|
expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
|
||||||
|
|
||||||
|
expect_input[0][0:chunk_len] = seg
|
||||||
|
input0_data = expect_input
|
||||||
|
input1_data = np.array([[chunk_len]], dtype=np.int32)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
protocol_client.InferInput(
|
||||||
|
"WAV",
|
||||||
|
input0_data.shape,
|
||||||
|
np_to_triton_dtype(input0_data.dtype),
|
||||||
|
),
|
||||||
|
protocol_client.InferInput(
|
||||||
|
"WAV_LENS",
|
||||||
|
input1_data.shape,
|
||||||
|
np_to_triton_dtype(input1_data.dtype),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs[0].set_data_from_numpy(input0_data)
|
||||||
|
inputs[1].set_data_from_numpy(input1_data)
|
||||||
|
|
||||||
|
outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
|
||||||
|
end = False
|
||||||
|
if idx == len(wav_segs) - 1:
|
||||||
|
end = True
|
||||||
|
|
||||||
|
response = await triton_client.infer(
|
||||||
|
model_name,
|
||||||
|
inputs,
|
||||||
|
outputs=outputs,
|
||||||
|
sequence_id=sequence_id,
|
||||||
|
sequence_start=idx == 0,
|
||||||
|
sequence_end=end,
|
||||||
|
)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
decoding_results = response.as_numpy("TRANSCRIPTS")
|
||||||
|
if type(decoding_results) == np.ndarray:
|
||||||
|
decoding_results = b" ".join(decoding_results).decode("utf-8")
|
||||||
|
else:
|
||||||
|
# For wenet
|
||||||
|
decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
chunk_end = time.time() - chunk_start
|
||||||
|
latency_data.append((chunk_end, chunk_len / sample_rate))
|
||||||
|
|
||||||
|
total_duration += duration
|
||||||
|
|
||||||
|
if compute_cer:
|
||||||
|
ref = dp['text'].split()
|
||||||
|
hyp = decoding_results.split()
|
||||||
|
ref = list("".join(ref))
|
||||||
|
hyp = list("".join(hyp))
|
||||||
|
results.append((dp['id'], ref, hyp))
|
||||||
|
else:
|
||||||
|
results.append(
|
||||||
|
(
|
||||||
|
dp['id'],
|
||||||
|
dp['text'].split(),
|
||||||
|
decoding_results.split(),
|
||||||
|
)
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
return total_duration, results, latency_data
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
args = get_args()
|
||||||
|
filename = args.manifest_filename
|
||||||
|
server_addr = args.server_addr
|
||||||
|
server_port = args.server_port
|
||||||
|
url = f"{server_addr}:{server_port}"
|
||||||
|
num_tasks = args.num_tasks
|
||||||
|
log_interval = args.log_interval
|
||||||
|
compute_cer = args.compute_cer
|
||||||
|
|
||||||
|
dps = load_manifest(filename)
|
||||||
|
dps_list = split_dps(dps, num_tasks)
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
|
||||||
|
protocol_client = grpcclient
|
||||||
|
|
||||||
|
if args.streaming or args.simulate_streaming:
|
||||||
|
frame_shift_ms = 10
|
||||||
|
frame_length_ms = 25
|
||||||
|
add_frames = math.ceil(
|
||||||
|
(frame_length_ms - frame_shift_ms) / frame_shift_ms
|
||||||
|
)
|
||||||
|
# decode_window_length: input sequence length of streaming encoder
|
||||||
|
if args.context > 0:
|
||||||
|
# decode window length calculation for wenet
|
||||||
|
decode_window_length = (
|
||||||
|
args.chunk_size - 1
|
||||||
|
) * args.subsampling + args.context
|
||||||
|
else:
|
||||||
|
# decode window length calculation for icefall
|
||||||
|
decode_window_length = (
|
||||||
|
args.chunk_size + 2 + args.encoder_right_context
|
||||||
|
) * args.subsampling + 3
|
||||||
|
|
||||||
|
first_chunk_ms = (decode_window_length + add_frames) * frame_shift_ms
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
for i in range(num_tasks):
|
||||||
|
if args.streaming:
|
||||||
|
assert not args.simulate_streaming
|
||||||
|
task = asyncio.create_task(
|
||||||
|
send_streaming(
|
||||||
|
dps=dps_list[i],
|
||||||
|
name=f"task-{i}",
|
||||||
|
triton_client=triton_client,
|
||||||
|
protocol_client=protocol_client,
|
||||||
|
log_interval=log_interval,
|
||||||
|
compute_cer=compute_cer,
|
||||||
|
model_name=args.model_name,
|
||||||
|
first_chunk_in_secs=first_chunk_ms / 1000,
|
||||||
|
other_chunk_in_secs=args.chunk_size
|
||||||
|
* args.subsampling
|
||||||
|
* frame_shift_ms
|
||||||
|
/ 1000,
|
||||||
|
task_index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif args.simulate_streaming:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
send_streaming(
|
||||||
|
dps=dps_list[i],
|
||||||
|
name=f"task-{i}",
|
||||||
|
triton_client=triton_client,
|
||||||
|
protocol_client=protocol_client,
|
||||||
|
log_interval=log_interval,
|
||||||
|
compute_cer=compute_cer,
|
||||||
|
model_name=args.model_name,
|
||||||
|
first_chunk_in_secs=first_chunk_ms / 1000,
|
||||||
|
other_chunk_in_secs=args.chunk_size
|
||||||
|
* args.subsampling
|
||||||
|
* frame_shift_ms
|
||||||
|
/ 1000,
|
||||||
|
task_index=i,
|
||||||
|
simulate_mode=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
send(
|
||||||
|
dps=dps_list[i],
|
||||||
|
name=f"task-{i}",
|
||||||
|
triton_client=triton_client,
|
||||||
|
protocol_client=protocol_client,
|
||||||
|
log_interval=log_interval,
|
||||||
|
compute_cer=compute_cer,
|
||||||
|
model_name=args.model_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
ans_list = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed = end_time - start_time
|
||||||
|
|
||||||
|
results = []
|
||||||
|
total_duration = 0.0
|
||||||
|
latency_data = []
|
||||||
|
for ans in ans_list:
|
||||||
|
total_duration += ans[0]
|
||||||
|
results += ans[1]
|
||||||
|
if args.streaming or args.simulate_streaming:
|
||||||
|
latency_data += ans[2]
|
||||||
|
|
||||||
|
rtf = elapsed / total_duration
|
||||||
|
|
||||||
|
s = f"RTF: {rtf:.4f}\n"
|
||||||
|
s += f"total_duration: {total_duration:.3f} seconds\n"
|
||||||
|
s += f"({total_duration/3600:.2f} hours)\n"
|
||||||
|
s += (
|
||||||
|
f"processing time: {elapsed:.3f} seconds "
|
||||||
|
f"({elapsed/3600:.2f} hours)\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.streaming or args.simulate_streaming:
|
||||||
|
latency_list = [
|
||||||
|
chunk_end for (chunk_end, chunk_duration) in latency_data
|
||||||
|
]
|
||||||
|
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
|
||||||
|
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
|
||||||
|
s += f"latency_variance: {latency_variance:.2f}\n"
|
||||||
|
s += f"latency_50_percentile: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
|
||||||
|
s += f"latency_90_percentile: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
|
||||||
|
s += f"latency_99_percentile: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
|
||||||
|
s += f"average_latency_ms: {latency_ms:.2f}\n"
|
||||||
|
|
||||||
|
print(s)
|
||||||
|
|
||||||
|
with open("rtf.txt", "w") as f:
|
||||||
|
f.write(s)
|
||||||
|
|
||||||
|
name = Path(filename).stem.split(".")[0]
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=f"recogs-{name}.txt", texts=results)
|
||||||
|
|
||||||
|
with open(f"errs-{name}.txt", "w") as f:
|
||||||
|
write_error_stats(f, "test-set", results, enable_log=True)
|
||||||
|
|
||||||
|
with open(f"errs-{name}.txt", "r") as f:
|
||||||
|
print(f.readline()) # WER
|
||||||
|
print(f.readline()) # Detailed errors
|
||||||
|
|
||||||
|
if args.stats_file:
|
||||||
|
stats = await triton_client.get_inference_statistics(
|
||||||
|
model_name="", as_json=True
|
||||||
|
)
|
||||||
|
with open(args.stats_file, "w") as f:
|
||||||
|
json.dump(stats, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@ -40,7 +40,7 @@ output [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "token_num"
|
name: "token_num"
|
||||||
data_type: TYPE_INT64
|
data_type: TYPE_INT32
|
||||||
dims: [1]
|
dims: [1]
|
||||||
reshape: { shape: [ ] }
|
reshape: { shape: [ ] }
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,7 +43,7 @@ input [
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "token_num"
|
name: "token_num"
|
||||||
data_type: TYPE_INT64
|
data_type: TYPE_INT32
|
||||||
dims: [1]
|
dims: [1]
|
||||||
reshape: { shape: [ ] }
|
reshape: { shape: [ ] }
|
||||||
}
|
}
|
||||||
|
|||||||
@ -40,6 +40,7 @@ from funasr.models.decoder.transformer_decoder import TransformerDecoder
|
|||||||
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
|
||||||
from funasr.models.e2e_asr import ESPnetASRModel
|
from funasr.models.e2e_asr import ESPnetASRModel
|
||||||
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
|
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
|
||||||
|
from funasr.models.e2e_tp import TimestampPredictor
|
||||||
from funasr.models.e2e_asr_mfcca import MFCCA
|
from funasr.models.e2e_asr_mfcca import MFCCA
|
||||||
from funasr.models.e2e_uni_asr import UniASR
|
from funasr.models.e2e_uni_asr import UniASR
|
||||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||||
@ -124,6 +125,7 @@ model_choices = ClassChoices(
|
|||||||
bicif_paraformer=BiCifParaformer,
|
bicif_paraformer=BiCifParaformer,
|
||||||
contextual_paraformer=ContextualParaformer,
|
contextual_paraformer=ContextualParaformer,
|
||||||
mfcca=MFCCA,
|
mfcca=MFCCA,
|
||||||
|
timestamp_prediction=TimestampPredictor,
|
||||||
),
|
),
|
||||||
type_check=AbsESPnetModel,
|
type_check=AbsESPnetModel,
|
||||||
default="asr",
|
default="asr",
|
||||||
@ -1245,9 +1247,87 @@ class ASRTaskMFCCA(ASRTask):
|
|||||||
|
|
||||||
|
|
||||||
class ASRTaskAligner(ASRTaskParaformer):
|
class ASRTaskAligner(ASRTaskParaformer):
|
||||||
|
# If you need more than one optimizers, change this value
|
||||||
|
num_optimizers: int = 1
|
||||||
|
|
||||||
|
# Add variable objects configurations
|
||||||
|
class_choices_list = [
|
||||||
|
# --frontend and --frontend_conf
|
||||||
|
frontend_choices,
|
||||||
|
# --model and --model_conf
|
||||||
|
model_choices,
|
||||||
|
# --encoder and --encoder_conf
|
||||||
|
encoder_choices,
|
||||||
|
# --decoder and --decoder_conf
|
||||||
|
decoder_choices,
|
||||||
|
]
|
||||||
|
|
||||||
|
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||||
|
trainer = Trainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args: argparse.Namespace):
|
||||||
|
assert check_argument_types()
|
||||||
|
if isinstance(args.token_list, str):
|
||||||
|
with open(args.token_list, encoding="utf-8") as f:
|
||||||
|
token_list = [line.rstrip() for line in f]
|
||||||
|
|
||||||
|
# Overwriting token_list to keep it as "portable".
|
||||||
|
args.token_list = list(token_list)
|
||||||
|
elif isinstance(args.token_list, (tuple, list)):
|
||||||
|
token_list = list(args.token_list)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("token_list must be str or list")
|
||||||
|
|
||||||
|
# 1. frontend
|
||||||
|
if args.input_size is None:
|
||||||
|
# Extract features in the model
|
||||||
|
frontend_class = frontend_choices.get_class(args.frontend)
|
||||||
|
if args.frontend == 'wav_frontend':
|
||||||
|
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||||
|
else:
|
||||||
|
frontend = frontend_class(**args.frontend_conf)
|
||||||
|
input_size = frontend.output_size()
|
||||||
|
else:
|
||||||
|
# Give features from data-loader
|
||||||
|
args.frontend = None
|
||||||
|
args.frontend_conf = {}
|
||||||
|
frontend = None
|
||||||
|
input_size = args.input_size
|
||||||
|
|
||||||
|
# 2. Encoder
|
||||||
|
encoder_class = encoder_choices.get_class(args.encoder)
|
||||||
|
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||||
|
|
||||||
|
# 3. Predictor
|
||||||
|
predictor_class = predictor_choices.get_class(args.predictor)
|
||||||
|
predictor = predictor_class(**args.predictor_conf)
|
||||||
|
|
||||||
|
# 10. Build model
|
||||||
|
try:
|
||||||
|
model_class = model_choices.get_class(args.model)
|
||||||
|
except AttributeError:
|
||||||
|
model_class = model_choices.get_class("asr")
|
||||||
|
|
||||||
|
# 8. Build model
|
||||||
|
model = model_class(
|
||||||
|
frontend=frontend,
|
||||||
|
encoder=encoder,
|
||||||
|
predictor=predictor,
|
||||||
|
token_list=token_list,
|
||||||
|
**args.model_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 11. Initialize
|
||||||
|
if args.init is not None:
|
||||||
|
initialize(model, args.init)
|
||||||
|
|
||||||
|
assert check_return_type(model)
|
||||||
|
return model
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def required_data_names(
|
def required_data_names(
|
||||||
cls, train: bool = True, inference: bool = False
|
cls, train: bool = True, inference: bool = False
|
||||||
) -> Tuple[str, ...]:
|
) -> Tuple[str, ...]:
|
||||||
retval = ("speech", "text")
|
retval = ("speech", "text")
|
||||||
return retval
|
return retval
|
||||||
|
|||||||
@ -20,19 +20,19 @@ from funasr.datasets.collate_fn import CommonCollateFn
|
|||||||
from funasr.datasets.preprocessor import CommonPreprocessor
|
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||||
from funasr.layers.abs_normalize import AbsNormalize
|
from funasr.layers.abs_normalize import AbsNormalize
|
||||||
from funasr.layers.global_mvn import GlobalMVN
|
from funasr.layers.global_mvn import GlobalMVN
|
||||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
|
||||||
from funasr.layers.label_aggregation import LabelAggregate
|
from funasr.layers.label_aggregation import LabelAggregate
|
||||||
from funasr.models.ctc import CTC
|
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||||
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
|
|
||||||
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
|
|
||||||
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
|
|
||||||
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
|
|
||||||
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
|
|
||||||
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
|
|
||||||
from funasr.models.e2e_diar_sond import DiarSondModel
|
from funasr.models.e2e_diar_sond import DiarSondModel
|
||||||
|
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
|
||||||
from funasr.models.encoder.abs_encoder import AbsEncoder
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||||
|
from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
|
||||||
|
from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
|
||||||
|
from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
|
||||||
|
from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
|
||||||
|
from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
|
||||||
|
from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
|
||||||
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
from funasr.models.encoder.rnn_encoder import RNNEncoder
|
||||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||||
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
from funasr.models.encoder.transformer_encoder import TransformerEncoder
|
||||||
@ -41,17 +41,13 @@ from funasr.models.frontend.default import DefaultFrontend
|
|||||||
from funasr.models.frontend.fused import FusedFrontends
|
from funasr.models.frontend.fused import FusedFrontends
|
||||||
from funasr.models.frontend.s3prl import S3prlFrontend
|
from funasr.models.frontend.s3prl import S3prlFrontend
|
||||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
|
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
||||||
from funasr.models.frontend.windowing import SlidingWindow
|
from funasr.models.frontend.windowing import SlidingWindow
|
||||||
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
|
||||||
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
|
|
||||||
HuggingFaceTransformersPostEncoder, # noqa: H301
|
|
||||||
)
|
|
||||||
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
||||||
from funasr.models.preencoder.linear import LinearProjection
|
|
||||||
from funasr.models.preencoder.sinc import LightweightSincConvs
|
|
||||||
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||||
from funasr.models.specaug.specaug import SpecAug
|
from funasr.models.specaug.specaug import SpecAug
|
||||||
from funasr.models.specaug.specaug import SpecAugLFR
|
from funasr.models.specaug.specaug import SpecAugLFR
|
||||||
|
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
||||||
|
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
||||||
from funasr.tasks.abs_task import AbsTask
|
from funasr.tasks.abs_task import AbsTask
|
||||||
from funasr.torch_utils.initialize import initialize
|
from funasr.torch_utils.initialize import initialize
|
||||||
from funasr.train.abs_espnet_model import AbsESPnetModel
|
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||||
@ -70,6 +66,7 @@ frontend_choices = ClassChoices(
|
|||||||
s3prl=S3prlFrontend,
|
s3prl=S3prlFrontend,
|
||||||
fused=FusedFrontends,
|
fused=FusedFrontends,
|
||||||
wav_frontend=WavFrontend,
|
wav_frontend=WavFrontend,
|
||||||
|
wav_frontend_mel23=WavFrontendMel23,
|
||||||
),
|
),
|
||||||
type_check=AbsFrontend,
|
type_check=AbsFrontend,
|
||||||
default="default",
|
default="default",
|
||||||
@ -107,6 +104,7 @@ model_choices = ClassChoices(
|
|||||||
"model",
|
"model",
|
||||||
classes=dict(
|
classes=dict(
|
||||||
sond=DiarSondModel,
|
sond=DiarSondModel,
|
||||||
|
eend_ola=DiarEENDOLAModel,
|
||||||
),
|
),
|
||||||
type_check=AbsESPnetModel,
|
type_check=AbsESPnetModel,
|
||||||
default="sond",
|
default="sond",
|
||||||
@ -126,6 +124,7 @@ encoder_choices = ClassChoices(
|
|||||||
sanm_chunk_opt=SANMEncoderChunkOpt,
|
sanm_chunk_opt=SANMEncoderChunkOpt,
|
||||||
data2vec_encoder=Data2VecEncoder,
|
data2vec_encoder=Data2VecEncoder,
|
||||||
ecapa_tdnn=ECAPA_TDNN,
|
ecapa_tdnn=ECAPA_TDNN,
|
||||||
|
eend_ola_transformer=EENDOLATransformerEncoder,
|
||||||
),
|
),
|
||||||
type_check=torch.nn.Module,
|
type_check=torch.nn.Module,
|
||||||
default="resnet34",
|
default="resnet34",
|
||||||
@ -177,6 +176,15 @@ decoder_choices = ClassChoices(
|
|||||||
type_check=torch.nn.Module,
|
type_check=torch.nn.Module,
|
||||||
default="fsmn",
|
default="fsmn",
|
||||||
)
|
)
|
||||||
|
# encoder_decoder_attractor is used for EEND-OLA
|
||||||
|
encoder_decoder_attractor_choices = ClassChoices(
|
||||||
|
"encoder_decoder_attractor",
|
||||||
|
classes=dict(
|
||||||
|
eda=EncoderDecoderAttractor,
|
||||||
|
),
|
||||||
|
type_check=torch.nn.Module,
|
||||||
|
default="eda",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiarTask(AbsTask):
|
class DiarTask(AbsTask):
|
||||||
@ -594,3 +602,294 @@ class DiarTask(AbsTask):
|
|||||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||||
|
|
||||||
return var_dict_torch_update
|
return var_dict_torch_update
|
||||||
|
|
||||||
|
|
||||||
|
class EENDOLADiarTask(AbsTask):
|
||||||
|
# If you need more than 1 optimizer, change this value
|
||||||
|
num_optimizers: int = 1
|
||||||
|
|
||||||
|
# Add variable objects configurations
|
||||||
|
class_choices_list = [
|
||||||
|
# --frontend and --frontend_conf
|
||||||
|
frontend_choices,
|
||||||
|
# --specaug and --specaug_conf
|
||||||
|
model_choices,
|
||||||
|
# --encoder and --encoder_conf
|
||||||
|
encoder_choices,
|
||||||
|
# --speaker_encoder and --speaker_encoder_conf
|
||||||
|
encoder_decoder_attractor_choices,
|
||||||
|
]
|
||||||
|
|
||||||
|
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||||
|
trainer = Trainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(description="Task related")
|
||||||
|
|
||||||
|
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||||
|
# to provide --print_config mode. Instead of it, do as
|
||||||
|
# required = parser.get_default("required")
|
||||||
|
# required += ["token_list"]
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--token_list",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="A text mapping int-id to token",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--split_with_space",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="whether to split text using <space>",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--seg_dict_file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="seg_dict_file for text processing",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--init",
|
||||||
|
type=lambda x: str_or_none(x.lower()),
|
||||||
|
default=None,
|
||||||
|
help="The initialization method",
|
||||||
|
choices=[
|
||||||
|
"chainer",
|
||||||
|
"xavier_uniform",
|
||||||
|
"xavier_normal",
|
||||||
|
"kaiming_uniform",
|
||||||
|
"kaiming_normal",
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input_size",
|
||||||
|
type=int_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The number of input dimension of the feature",
|
||||||
|
)
|
||||||
|
|
||||||
|
group = parser.add_argument_group(description="Preprocess related")
|
||||||
|
group.add_argument(
|
||||||
|
"--use_preprocessor",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Apply preprocessing to data or not",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--token_type",
|
||||||
|
type=str,
|
||||||
|
default="char",
|
||||||
|
choices=["char"],
|
||||||
|
help="The text will be tokenized in the specified level token",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speech_volume_normalize",
|
||||||
|
type=float_or_none,
|
||||||
|
default=None,
|
||||||
|
help="Scale the maximum amplitude to the given value.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rir_scp",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The file path of rir scp file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rir_apply_prob",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="THe probability for applying RIR convolution.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cmvn_file",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The file path of noise scp file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_scp",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The file path of noise scp file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_apply_prob",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The probability applying Noise adding.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_db_range",
|
||||||
|
type=str,
|
||||||
|
default="13_15",
|
||||||
|
help="The range of noise decibel level.",
|
||||||
|
)
|
||||||
|
|
||||||
|
for class_choices in cls.class_choices_list:
|
||||||
|
# Append --<name> and --<name>_conf.
|
||||||
|
# e.g. --encoder and --encoder_conf
|
||||||
|
class_choices.add_arguments(group)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_collate_fn(
|
||||||
|
cls, args: argparse.Namespace, train: bool
|
||||||
|
) -> Callable[
|
||||||
|
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||||
|
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||||
|
]:
|
||||||
|
assert check_argument_types()
|
||||||
|
# NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
|
||||||
|
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_preprocess_fn(
|
||||||
|
cls, args: argparse.Namespace, train: bool
|
||||||
|
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||||
|
assert check_argument_types()
|
||||||
|
if args.use_preprocessor:
|
||||||
|
retval = CommonPreprocessor(
|
||||||
|
train=train,
|
||||||
|
token_type=args.token_type,
|
||||||
|
token_list=args.token_list,
|
||||||
|
bpemodel=None,
|
||||||
|
non_linguistic_symbols=None,
|
||||||
|
text_cleaner=None,
|
||||||
|
g2p_type=None,
|
||||||
|
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
|
||||||
|
seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
|
||||||
|
# NOTE(kamo): Check attribute existence for backward compatibility
|
||||||
|
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
|
||||||
|
rir_apply_prob=args.rir_apply_prob
|
||||||
|
if hasattr(args, "rir_apply_prob")
|
||||||
|
else 1.0,
|
||||||
|
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
|
||||||
|
noise_apply_prob=args.noise_apply_prob
|
||||||
|
if hasattr(args, "noise_apply_prob")
|
||||||
|
else 1.0,
|
||||||
|
noise_db_range=args.noise_db_range
|
||||||
|
if hasattr(args, "noise_db_range")
|
||||||
|
else "13_15",
|
||||||
|
speech_volume_normalize=args.speech_volume_normalize
|
||||||
|
if hasattr(args, "rir_scp")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
retval = None
|
||||||
|
assert check_return_type(retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def required_data_names(
|
||||||
|
cls, train: bool = True, inference: bool = False
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
|
if not inference:
|
||||||
|
retval = ("speech", "profile", "binary_labels")
|
||||||
|
else:
|
||||||
|
# Recognition mode
|
||||||
|
retval = ("speech")
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def optional_data_names(
|
||||||
|
cls, train: bool = True, inference: bool = False
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
|
retval = ()
|
||||||
|
assert check_return_type(retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args: argparse.Namespace):
|
||||||
|
assert check_argument_types()
|
||||||
|
|
||||||
|
# 1. frontend
|
||||||
|
if args.input_size is None or args.frontend == "wav_frontend_mel23":
|
||||||
|
# Extract features in the model
|
||||||
|
frontend_class = frontend_choices.get_class(args.frontend)
|
||||||
|
if args.frontend == 'wav_frontend':
|
||||||
|
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||||
|
else:
|
||||||
|
frontend = frontend_class(**args.frontend_conf)
|
||||||
|
input_size = frontend.output_size()
|
||||||
|
else:
|
||||||
|
# Give features from data-loader
|
||||||
|
args.frontend = None
|
||||||
|
args.frontend_conf = {}
|
||||||
|
frontend = None
|
||||||
|
input_size = args.input_size
|
||||||
|
|
||||||
|
# 2. Encoder
|
||||||
|
encoder_class = encoder_choices.get_class(args.encoder)
|
||||||
|
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
|
||||||
|
|
||||||
|
# 3. EncoderDecoderAttractor
|
||||||
|
encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
|
||||||
|
encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
|
||||||
|
|
||||||
|
# 9. Build model
|
||||||
|
model_class = model_choices.get_class(args.model)
|
||||||
|
model = model_class(
|
||||||
|
frontend=frontend,
|
||||||
|
encoder=encoder,
|
||||||
|
encoder_decoder_attractor=encoder_decoder_attractor,
|
||||||
|
**args.model_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 10. Initialize
|
||||||
|
if args.init is not None:
|
||||||
|
initialize(model, args.init)
|
||||||
|
|
||||||
|
assert check_return_type(model)
|
||||||
|
return model
|
||||||
|
|
||||||
|
# ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
|
||||||
|
@classmethod
|
||||||
|
def build_model_from_file(
|
||||||
|
cls,
|
||||||
|
config_file: Union[Path, str] = None,
|
||||||
|
model_file: Union[Path, str] = None,
|
||||||
|
cmvn_file: Union[Path, str] = None,
|
||||||
|
device: str = "cpu",
|
||||||
|
):
|
||||||
|
"""Build model from the files.
|
||||||
|
|
||||||
|
This method is used for inference or fine-tuning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file: The yaml file saved when training.
|
||||||
|
model_file: The model file saved when training.
|
||||||
|
cmvn_file: The cmvn file for front-end
|
||||||
|
device: Device type, "cpu", "cuda", or "cuda:N".
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
|
if config_file is None:
|
||||||
|
assert model_file is not None, (
|
||||||
|
"The argument 'model_file' must be provided "
|
||||||
|
"if the argument 'config_file' is not specified."
|
||||||
|
)
|
||||||
|
config_file = Path(model_file).parent / "config.yaml"
|
||||||
|
else:
|
||||||
|
config_file = Path(config_file)
|
||||||
|
|
||||||
|
with config_file.open("r", encoding="utf-8") as f:
|
||||||
|
args = yaml.safe_load(f)
|
||||||
|
args = argparse.Namespace(**args)
|
||||||
|
model = cls.build_model(args)
|
||||||
|
if not isinstance(model, AbsESPnetModel):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
|
||||||
|
)
|
||||||
|
if model_file is not None:
|
||||||
|
if device == "cuda":
|
||||||
|
device = f"cuda:{torch.cuda.current_device()}"
|
||||||
|
checkpoint = torch.load(model_file, map_location=device)
|
||||||
|
if "state_dict" in checkpoint.keys():
|
||||||
|
model.load_state_dict(checkpoint["state_dict"])
|
||||||
|
else:
|
||||||
|
model.load_state_dict(checkpoint)
|
||||||
|
model.to(device)
|
||||||
|
return model, args
|
||||||
|
|||||||
@ -5,55 +5,69 @@ import numpy as np
|
|||||||
from typing import Any, List, Tuple, Union
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
|
def ts_prediction_lfr6_standard(us_alphas,
|
||||||
|
us_peaks,
|
||||||
|
char_list,
|
||||||
|
vad_offset=0.0,
|
||||||
|
force_time_shift=-1.5
|
||||||
|
):
|
||||||
if not len(char_list):
|
if not len(char_list):
|
||||||
return []
|
return []
|
||||||
START_END_THRESHOLD = 5
|
START_END_THRESHOLD = 5
|
||||||
|
MAX_TOKEN_DURATION = 12
|
||||||
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
|
||||||
if len(us_alphas.shape) == 3:
|
if len(us_alphas.shape) == 2:
|
||||||
alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
|
_, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
|
||||||
else:
|
else:
|
||||||
alphas, cif_peak = us_alphas, us_cif_peak
|
_, peaks = us_alphas, us_peaks
|
||||||
num_frames = cif_peak.shape[0]
|
num_frames = peaks.shape[0]
|
||||||
if char_list[-1] == '</s>':
|
if char_list[-1] == '</s>':
|
||||||
char_list = char_list[:-1]
|
char_list = char_list[:-1]
|
||||||
# char_list = [i for i in text]
|
|
||||||
timestamp_list = []
|
timestamp_list = []
|
||||||
|
new_char_list = []
|
||||||
# for bicif model trained with large data, cif2 actually fires when a character starts
|
# for bicif model trained with large data, cif2 actually fires when a character starts
|
||||||
# so treat the frames between two peaks as the duration of the former token
|
# so treat the frames between two peaks as the duration of the former token
|
||||||
fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
|
fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
|
||||||
num_peak = len(fire_place)
|
num_peak = len(fire_place)
|
||||||
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
|
||||||
# begin silence
|
# begin silence
|
||||||
if fire_place[0] > START_END_THRESHOLD:
|
if fire_place[0] > START_END_THRESHOLD:
|
||||||
char_list.insert(0, '<sil>')
|
# char_list.insert(0, '<sil>')
|
||||||
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
|
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
|
||||||
|
new_char_list.append('<sil>')
|
||||||
# tokens timestamp
|
# tokens timestamp
|
||||||
for i in range(len(fire_place)-1):
|
for i in range(len(fire_place)-1):
|
||||||
# the peak is always a little ahead of the start time
|
new_char_list.append(char_list[i])
|
||||||
# timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
|
||||||
timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||||
# cut the duration to token and sil of the 0-weight frames last long
|
else:
|
||||||
|
# cut the duration to token and sil of the 0-weight frames last long
|
||||||
|
_split = fire_place[i] + MAX_TOKEN_DURATION
|
||||||
|
timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
|
||||||
|
timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
|
||||||
|
new_char_list.append('<sil>')
|
||||||
# tail token and end silence
|
# tail token and end silence
|
||||||
|
# new_char_list.append(char_list[-1])
|
||||||
if num_frames - fire_place[-1] > START_END_THRESHOLD:
|
if num_frames - fire_place[-1] > START_END_THRESHOLD:
|
||||||
_end = (num_frames + fire_place[-1]) / 2
|
_end = (num_frames + fire_place[-1]) * 0.5
|
||||||
|
# _end = fire_place[-1]
|
||||||
timestamp_list[-1][1] = _end*TIME_RATE
|
timestamp_list[-1][1] = _end*TIME_RATE
|
||||||
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
|
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
|
||||||
char_list.append("<sil>")
|
new_char_list.append("<sil>")
|
||||||
else:
|
else:
|
||||||
timestamp_list[-1][1] = num_frames*TIME_RATE
|
timestamp_list[-1][1] = num_frames*TIME_RATE
|
||||||
if begin_time: # add offset time in model with vad
|
if vad_offset: # add offset time in model with vad
|
||||||
for i in range(len(timestamp_list)):
|
for i in range(len(timestamp_list)):
|
||||||
timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
|
timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
|
||||||
timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
|
timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
|
||||||
res_txt = ""
|
res_txt = ""
|
||||||
for char, timestamp in zip(char_list, timestamp_list):
|
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||||
res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
|
res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
|
||||||
res = []
|
res = []
|
||||||
for char, timestamp in zip(char_list, timestamp_list):
|
for char, timestamp in zip(new_char_list, timestamp_list):
|
||||||
if char != '<sil>':
|
if char != '<sil>':
|
||||||
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
|
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
|
||||||
return res
|
return res_txt, res
|
||||||
|
|
||||||
|
|
||||||
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
|
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
|
||||||
|
|||||||
@ -451,8 +451,8 @@ class TestUniasrInferencePipelines(unittest.TestCase):
|
|||||||
|
|
||||||
def test_uniasr_2pass_zhcn_16k_common_vocab8358_offline(self):
|
def test_uniasr_2pass_zhcn_16k_common_vocab8358_offline(self):
|
||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
task=Tasks.auto_speech_recognition,
|
task=Tasks.,
|
||||||
model='damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
|
model='damo/speech_UniASauto_speech_recognitionR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline')
|
||||||
rec_result = inference_pipeline(
|
rec_result = inference_pipeline(
|
||||||
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav',
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav',
|
||||||
param_dict={"decoding_model": "offline"})
|
param_dict={"decoding_model": "offline"})
|
||||||
|
|||||||
32
tests/test_asr_vad_punc_inference_pipeline.py
Normal file
32
tests/test_asr_vad_punc_inference_pipeline.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
class TestParaformerInferencePipelines(unittest.TestCase):
|
||||||
|
def test_funasr_path(self):
|
||||||
|
import funasr
|
||||||
|
import os
|
||||||
|
logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
|
||||||
|
|
||||||
|
def test_inference_pipeline(self):
|
||||||
|
inference_pipeline = pipeline(
|
||||||
|
task=Tasks.auto_speech_recognition,
|
||||||
|
model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
|
||||||
|
model_revision="v1.2.1",
|
||||||
|
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
|
||||||
|
vad_model_revision="v1.1.8",
|
||||||
|
punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
|
||||||
|
punc_model_revision="v1.1.6",
|
||||||
|
ngpu=1,
|
||||||
|
)
|
||||||
|
rec_result = inference_pipeline(
|
||||||
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
|
||||||
|
logger.info("asr_vad_punc inference result: {0}".format(rec_result))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
25
tests/test_lm_pipeline.py
Normal file
25
tests/test_lm_pipeline.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
class TestTransformerInferencePipelines(unittest.TestCase):
|
||||||
|
def test_funasr_path(self):
|
||||||
|
import funasr
|
||||||
|
import os
|
||||||
|
logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
|
||||||
|
|
||||||
|
def test_inference_pipeline(self):
|
||||||
|
inference_pipeline = pipeline(
|
||||||
|
task=Tasks.language_score_prediction,
|
||||||
|
model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
|
||||||
|
)
|
||||||
|
rec_result = inference_pipeline(text_in="hello 大 家 好 呀")
|
||||||
|
logger.info("lm inference result: {0}".format(rec_result))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
43
tests/test_punctuation_pipeline.py
Normal file
43
tests/test_punctuation_pipeline.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
class TestTransformerInferencePipelines(unittest.TestCase):
|
||||||
|
def test_funasr_path(self):
|
||||||
|
import funasr
|
||||||
|
import os
|
||||||
|
logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
|
||||||
|
|
||||||
|
def test_inference_pipeline(self):
|
||||||
|
inference_pipeline = pipeline(
|
||||||
|
task=Tasks.punctuation,
|
||||||
|
model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
|
||||||
|
model_revision="v1.1.7",
|
||||||
|
)
|
||||||
|
inputs = "./egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt"
|
||||||
|
rec_result = inference_pipeline(text_in=inputs)
|
||||||
|
logger.info("punctuation inference result: {0}".format(rec_result))
|
||||||
|
|
||||||
|
def test_vadrealtime_inference_pipeline(self):
|
||||||
|
inference_pipeline = pipeline(
|
||||||
|
task=Tasks.punctuation,
|
||||||
|
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
|
||||||
|
model_revision="v1.0.0",
|
||||||
|
)
|
||||||
|
inputs = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
|
||||||
|
vads = inputs.split("|")
|
||||||
|
cache_out = []
|
||||||
|
rec_result_all = "outputs:"
|
||||||
|
for vad in vads:
|
||||||
|
rec_result = inference_pipeline(text_in=vad, cache=cache_out)
|
||||||
|
cache_out = rec_result['cache']
|
||||||
|
rec_result_all += rec_result['text']
|
||||||
|
logger.info("punctuation inference result: {0}".format(rec_result_all))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
47
tests/test_sv_inference_pipeline.py
Normal file
47
tests/test_sv_inference_pipeline.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class TestXVectorInferencePipelines(unittest.TestCase):
|
||||||
|
def test_funasr_path(self):
|
||||||
|
import funasr
|
||||||
|
import os
|
||||||
|
logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
|
||||||
|
|
||||||
|
def test_inference_pipeline(self):
|
||||||
|
inference_sv_pipline = pipeline(
|
||||||
|
task=Tasks.speaker_verification,
|
||||||
|
model='damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch'
|
||||||
|
)
|
||||||
|
# 提取不同句子的说话人嵌入码
|
||||||
|
rec_result = inference_sv_pipline(
|
||||||
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav')
|
||||||
|
enroll = rec_result["spk_embedding"]
|
||||||
|
|
||||||
|
rec_result = inference_sv_pipline(
|
||||||
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav')
|
||||||
|
same = rec_result["spk_embedding"]
|
||||||
|
|
||||||
|
rec_result = inference_sv_pipline(
|
||||||
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav')
|
||||||
|
different = rec_result["spk_embedding"]
|
||||||
|
|
||||||
|
# 对相同的说话人计算余弦相似度
|
||||||
|
sv_threshold = 0.9465
|
||||||
|
same_cos = np.sum(enroll * same) / (np.linalg.norm(enroll) * np.linalg.norm(same))
|
||||||
|
same_cos = max(same_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
|
||||||
|
logger.info("Similarity: {}".format(same_cos))
|
||||||
|
|
||||||
|
# 对不同的说话人计算余弦相似度
|
||||||
|
diff_cos = np.sum(enroll * different) / (np.linalg.norm(enroll) * np.linalg.norm(different))
|
||||||
|
diff_cos = max(diff_cos - sv_threshold, 0.0) / (1.0 - sv_threshold) * 100.0
|
||||||
|
logger.info("Similarity: {}".format(diff_cos))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
35
tests/test_vad_inference_pipeline.py
Normal file
35
tests/test_vad_inference_pipeline.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
class TestFSMNInferencePipelines(unittest.TestCase):
|
||||||
|
def test_funasr_path(self):
|
||||||
|
import funasr
|
||||||
|
import os
|
||||||
|
logger.info("run_dir:{0} ; funasr_path: {1}".format(os.getcwd(), funasr.__file__))
|
||||||
|
|
||||||
|
def test_8k(self):
|
||||||
|
inference_pipeline = pipeline(
|
||||||
|
task=Tasks.voice_activity_detection,
|
||||||
|
model="damo/speech_fsmn_vad_zh-cn-8k-common",
|
||||||
|
)
|
||||||
|
rec_result = inference_pipeline(
|
||||||
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example_8k.wav')
|
||||||
|
logger.info("vad inference result: {0}".format(rec_result))
|
||||||
|
|
||||||
|
def test_16k(self):
|
||||||
|
inference_pipeline = pipeline(
|
||||||
|
task=Tasks.voice_activity_detection,
|
||||||
|
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||||
|
)
|
||||||
|
rec_result = inference_pipeline(
|
||||||
|
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav')
|
||||||
|
logger.info("vad inference result: {0}".format(rec_result))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue
Block a user