mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
vad + asr
This commit is contained in:
parent
a1b0cd33d5
commit
5a8f379084
@ -0,0 +1,31 @@
|
||||
|
||||
cmd="funasr/bin/inference.py"
|
||||
|
||||
python $cmd \
|
||||
+model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||
+vad_model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
|
||||
+input="/Users/zhifu/funasr_github/test_local/vad_example.wav" \
|
||||
+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
|
||||
+device="cpu" \
|
||||
+batch_size_s=300 \
|
||||
+batch_size_threshold_s=60 \
|
||||
+debug="true"
|
||||
|
||||
#python $cmd \
|
||||
#+model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
|
||||
#+input="/Users/zhifu/Downloads/asr_example.wav" \
|
||||
#+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
|
||||
#+device="cpu" \
|
||||
#+"hotword='达魔院 魔搭'"
|
||||
|
||||
#+input="/Users/zhifu/funasr_github/test_local/wav.scp"
|
||||
#+input="/Users/zhifu/funasr_github/test_local/asr_example.wav" \
|
||||
#+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
|
||||
#+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
|
||||
#+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||
|
||||
#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||
#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
|
||||
#+"hotword='达魔院 魔搭'"
|
||||
|
||||
#+vad_model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
|
||||
@ -16,7 +16,8 @@ import time
|
||||
import random
|
||||
import string
|
||||
from funasr.register import tables
|
||||
|
||||
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio
|
||||
from funasr.utils.vad_utils import slice_padding_audio_samples
|
||||
|
||||
def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
|
||||
"""
|
||||
@ -73,15 +74,44 @@ def main_hydra(kwargs: DictConfig):
|
||||
|
||||
logging.basicConfig(level=log_level)
|
||||
|
||||
import pdb;
|
||||
pdb.set_trace()
|
||||
if kwargs.get("debug", False):
|
||||
import pdb; pdb.set_trace()
|
||||
model = AutoModel(**kwargs)
|
||||
res = model.generate(input=kwargs["input"])
|
||||
res = model(input=kwargs["input"])
|
||||
print(res)
|
||||
|
||||
class AutoModel:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
tables.print()
|
||||
|
||||
model, kwargs = self.build_model(**kwargs)
|
||||
|
||||
# if vad_model is not None, build vad model else None
|
||||
vad_model = kwargs.get("vad_model", None)
|
||||
vad_kwargs = kwargs.get("vad_model_revision", None)
|
||||
if vad_model is not None:
|
||||
print("build vad model")
|
||||
vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
|
||||
vad_model, vad_kwargs = self.build_model(**vad_kwargs)
|
||||
|
||||
# if punc_model is not None, build punc model else None
|
||||
punc_model = kwargs.get("punc_model", None)
|
||||
punc_kwargs = kwargs.get("punc_model_revision", None)
|
||||
if punc_model is not None:
|
||||
punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
|
||||
punc_model, punc_kwargs = self.build_model(**punc_kwargs)
|
||||
|
||||
self.kwargs = kwargs
|
||||
self.model = model
|
||||
self.vad_model = vad_model
|
||||
self.vad_kwargs = vad_kwargs
|
||||
self.punc_model = punc_model
|
||||
self.punc_kwargs = punc_kwargs
|
||||
|
||||
|
||||
|
||||
def build_model(self, **kwargs):
|
||||
assert "model" in kwargs
|
||||
if "model_conf" not in kwargs:
|
||||
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
|
||||
@ -94,7 +124,7 @@ class AutoModel:
|
||||
device = "cpu"
|
||||
kwargs["batch_size"] = 1
|
||||
kwargs["device"] = device
|
||||
|
||||
|
||||
# build tokenizer
|
||||
tokenizer = kwargs.get("tokenizer", None)
|
||||
if tokenizer is not None:
|
||||
@ -113,7 +143,8 @@ class AutoModel:
|
||||
|
||||
# build model
|
||||
model_class = tables.model_classes.get(kwargs["model"].lower())
|
||||
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
|
||||
model = model_class(**kwargs, **kwargs["model_conf"],
|
||||
vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
@ -127,23 +158,34 @@ class AutoModel:
|
||||
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
|
||||
oss_bucket=kwargs.get("oss_bucket", None),
|
||||
)
|
||||
self.kwargs = kwargs
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
return model, kwargs
|
||||
|
||||
def generate(self, input, input_len=None, **cfg):
|
||||
self.kwargs.update(cfg)
|
||||
data_type = self.kwargs.get("data_type", "sound")
|
||||
batch_size = self.kwargs.get("batch_size", 1)
|
||||
if self.kwargs.get("device", "cpu") == "cpu":
|
||||
batch_size = 1
|
||||
def __call__(self, input, input_len=None, **cfg):
|
||||
if self.vad_model is None:
|
||||
return self.generate(input, input_len=input_len, **cfg)
|
||||
|
||||
else:
|
||||
return self.generate_with_vad(input, input_len=input_len, **cfg)
|
||||
|
||||
def generate(self, input, input_len=None, model=None, kwargs=None, **cfg):
|
||||
kwargs = self.kwargs if kwargs is None else kwargs
|
||||
kwargs.update(cfg)
|
||||
model = self.model if model is None else model
|
||||
|
||||
data_type = kwargs.get("data_type", "sound")
|
||||
batch_size = kwargs.get("batch_size", 1)
|
||||
# if kwargs.get("device", "cpu") == "cpu":
|
||||
# batch_size = 1
|
||||
|
||||
key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
|
||||
|
||||
speed_stats = {}
|
||||
asr_result_list = []
|
||||
num_samples = len(data_list)
|
||||
pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
|
||||
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
|
||||
time_speech_total = 0.0
|
||||
time_escape_total = 0.0
|
||||
for beg_idx in range(0, num_samples, batch_size):
|
||||
end_idx = min(num_samples, beg_idx + batch_size)
|
||||
data_batch = data_list[beg_idx:end_idx]
|
||||
@ -154,25 +196,139 @@ class AutoModel:
|
||||
batch["data_lengths"] = input_len
|
||||
|
||||
time1 = time.perf_counter()
|
||||
results, meta_data = self.model.generate(**batch, **self.kwargs)
|
||||
results, meta_data = model.generate(**batch, **kwargs)
|
||||
time2 = time.perf_counter()
|
||||
|
||||
asr_result_list.append(results)
|
||||
asr_result_list.extend(results)
|
||||
pbar.update(1)
|
||||
|
||||
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
|
||||
batch_data_time = meta_data.get("batch_data_time", -1)
|
||||
time_escape = time2 - time1
|
||||
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
|
||||
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
|
||||
speed_stats["forward"] = f"{time2 - time1:0.3f}"
|
||||
speed_stats["rtf"] = f"{(time2 - time1) / batch_data_time:0.3f}"
|
||||
speed_stats["forward"] = f"{time_escape:0.3f}"
|
||||
speed_stats["batch_size"] = f"{len(results)}"
|
||||
speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
|
||||
description = (
|
||||
f"{speed_stats}, "
|
||||
)
|
||||
pbar.set_description(description)
|
||||
|
||||
time_speech_total += batch_data_time
|
||||
time_escape_total += time_escape
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
|
||||
torch.cuda.empty_cache()
|
||||
return asr_result_list
|
||||
|
||||
def generate_with_vad(self, input, input_len=None, **cfg):
|
||||
|
||||
# step.1: compute the vad model
|
||||
model = self.vad_model
|
||||
kwargs = self.vad_kwargs
|
||||
beg_vad = time.time()
|
||||
res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
|
||||
end_vad = time.time()
|
||||
print(f"time cost vad: {end_vad - beg_vad:0.3f}")
|
||||
|
||||
|
||||
# step.2 compute asr model
|
||||
model = self.model
|
||||
kwargs = self.kwargs
|
||||
kwargs.update(cfg)
|
||||
batch_size = int(kwargs.get("batch_size_s", 300))*1000
|
||||
batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
|
||||
kwargs["batch_size"] = batch_size
|
||||
data_type = kwargs.get("data_type", "sound")
|
||||
key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
|
||||
results_ret_list = []
|
||||
time_speech_total_all_samples = 0.0
|
||||
|
||||
beg_total = time.time()
|
||||
pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
|
||||
for i in range(len(res)):
|
||||
key = res[i]["key"]
|
||||
vadsegments = res[i]["value"]
|
||||
input_i = data_list[i]
|
||||
speech = load_audio(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
|
||||
speech_lengths = len(speech)
|
||||
n = len(vadsegments)
|
||||
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
||||
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
||||
results_sorted = []
|
||||
|
||||
if not len(sorted_data):
|
||||
logging.info("decoding, utt: {}, empty speech".format(key))
|
||||
continue
|
||||
|
||||
|
||||
# if kwargs["device"] == "cpu":
|
||||
# batch_size = 0
|
||||
if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
|
||||
batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
|
||||
|
||||
batch_size_ms_cum = 0
|
||||
beg_idx = 0
|
||||
beg_asr_total = time.time()
|
||||
time_speech_total_per_sample = speech_lengths/16000
|
||||
time_speech_total_all_samples += time_speech_total_per_sample
|
||||
|
||||
for j, _ in enumerate(range(0, n)):
|
||||
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
|
||||
if j < n - 1 and (
|
||||
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
|
||||
sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms:
|
||||
continue
|
||||
batch_size_ms_cum = 0
|
||||
end_idx = j + 1
|
||||
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
|
||||
beg_idx = end_idx
|
||||
|
||||
results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
|
||||
|
||||
if len(results) < 1:
|
||||
continue
|
||||
results_sorted.extend(results)
|
||||
|
||||
|
||||
pbar_total.update(1)
|
||||
end_asr_total = time.time()
|
||||
time_escape_total_per_sample = end_asr_total - beg_asr_total
|
||||
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
|
||||
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
|
||||
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
|
||||
|
||||
restored_data = [0] * n
|
||||
for j in range(n):
|
||||
index = sorted_data[j][1]
|
||||
restored_data[index] = results_sorted[j]
|
||||
result = {}
|
||||
|
||||
for j in range(n):
|
||||
for k, v in restored_data[j].items():
|
||||
if not k.startswith("timestamp"):
|
||||
if k not in result:
|
||||
result[k] = restored_data[j][k]
|
||||
else:
|
||||
result[k] += restored_data[j][k]
|
||||
else:
|
||||
result[k] = []
|
||||
for t in restored_data[j][k]:
|
||||
t[0] += vadsegments[j][0]
|
||||
t[1] += vadsegments[j][0]
|
||||
result[k] += restored_data[j][k]
|
||||
|
||||
result["key"] = key
|
||||
results_ret_list.append(result)
|
||||
pbar_total.update(1)
|
||||
pbar_total.update(1)
|
||||
end_total = time.time()
|
||||
time_escape_total_all_samples = end_total - beg_total
|
||||
pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
|
||||
f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
|
||||
f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
|
||||
return results_ret_list
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_hydra()
|
||||
@ -25,7 +25,9 @@ from funasr.register import tables
|
||||
|
||||
@hydra.main(config_name=None, version_base=None)
|
||||
def main_hydra(kwargs: DictConfig):
|
||||
import pdb; pdb.set_trace()
|
||||
if kwargs.get("debug", False):
|
||||
import pdb; pdb.set_trace()
|
||||
|
||||
assert "model" in kwargs
|
||||
if "model_conf" not in kwargs:
|
||||
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
|
||||
|
||||
@ -24,11 +24,10 @@ def download_fr_ms(**kwargs):
|
||||
kwargs["init_param"] = init_param
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
|
||||
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
|
||||
if os.path.exists(os.path.join(model_or_path, "seg_dict")):
|
||||
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
|
||||
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
|
||||
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
|
||||
|
||||
kwargs["model"] = cfg["model"]
|
||||
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
|
||||
|
||||
|
||||
@ -29,6 +29,7 @@ from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
|
||||
from funasr.register import tables
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
from funasr.utils.timestamp_tools import time_stamp_sentence
|
||||
|
||||
from funasr.models.paraformer.model import Paraformer
|
||||
|
||||
@ -211,10 +212,11 @@ class BiCifParaformer(Paraformer):
|
||||
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
|
||||
|
||||
def generate(self,
|
||||
data_in: list,
|
||||
data_lengths: list = None,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
@ -230,17 +232,23 @@ class BiCifParaformer(Paraformer):
|
||||
self.nbest = kwargs.get("nbest", 1)
|
||||
|
||||
meta_data = {}
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=self.frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data[
|
||||
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
if isinstance(data_in, torch.Tensor): # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
|
||||
frontend=frontend)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
|
||||
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
@ -261,9 +269,8 @@ class BiCifParaformer(Paraformer):
|
||||
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
|
||||
|
||||
# BiCifParaformer, test no bias cif2
|
||||
|
||||
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
|
||||
pre_token_length)
|
||||
pre_token_length)
|
||||
|
||||
results = []
|
||||
b, n, d = decoder_out.size()
|
||||
@ -302,27 +309,32 @@ class BiCifParaformer(Paraformer):
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
|
||||
us_peaks[i][:encoder_out_lens[i] * 3],
|
||||
copy.copy(token),
|
||||
vad_offset=kwargs.get("begin_time", 0))
|
||||
|
||||
text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
|
||||
|
||||
result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
|
||||
"time_stamp_postprocessed": time_stamp_postprocessed,
|
||||
"word_lists": word_lists
|
||||
}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["text"][key[i]] = text
|
||||
ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
|
||||
if tokenizer is not None:
|
||||
# Change integer-ids to tokens
|
||||
token = tokenizer.ids2tokens(token_int)
|
||||
text = tokenizer.tokens2text(token)
|
||||
|
||||
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
|
||||
us_peaks[i][:encoder_out_lens[i] * 3],
|
||||
copy.copy(token),
|
||||
vad_offset=kwargs.get("begin_time", 0))
|
||||
|
||||
text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
|
||||
token, timestamp)
|
||||
sentences = time_stamp_sentence(None, time_stamp_postprocessed, text_postprocessed)
|
||||
result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
|
||||
"timestamp": time_stamp_postprocessed,
|
||||
"word_lists": word_lists,
|
||||
"sentences": sentences
|
||||
}
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["token"][key[i]] = " ".join(token)
|
||||
ibest_writer["text"][key[i]] = text
|
||||
ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
|
||||
ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
|
||||
else:
|
||||
result_i = {"key": key[i], "token_int": token_int}
|
||||
results.append(result_i)
|
||||
|
||||
return results, meta_data
|
||||
return results, meta_data
|
||||
134
funasr/models/bici_paraformer/template.yaml
Normal file
134
funasr/models/bici_paraformer/template.yaml
Normal file
@ -0,0 +1,134 @@
|
||||
# This is an example that demonstrates how to configure a model file.
|
||||
# You can modify the configuration according to your own requirements.
|
||||
|
||||
# to print the register_table:
|
||||
# from funasr.register import tables
|
||||
# tables.print()
|
||||
|
||||
# network architecture
|
||||
#model: funasr.models.paraformer.model:Paraformer
|
||||
model: BiCifParaformer
|
||||
model_conf:
|
||||
ctc_weight: 0.0
|
||||
lsm_weight: 0.1
|
||||
length_normalized_loss: true
|
||||
predictor_weight: 1.0
|
||||
predictor_bias: 1
|
||||
sampling_ratio: 0.75
|
||||
|
||||
# encoder
|
||||
encoder: SANMEncoder
|
||||
encoder_conf:
|
||||
output_size: 512
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 50
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.1
|
||||
input_layer: pe
|
||||
pos_enc_class: SinusoidalPositionEncoder
|
||||
normalize_before: true
|
||||
kernel_size: 11
|
||||
sanm_shfit: 0
|
||||
selfattention_layer_type: sanm
|
||||
|
||||
# decoder
|
||||
decoder: ParaformerSANMDecoder
|
||||
decoder_conf:
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 16
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
self_attention_dropout_rate: 0.1
|
||||
src_attention_dropout_rate: 0.1
|
||||
att_layer_num: 16
|
||||
kernel_size: 11
|
||||
sanm_shfit: 0
|
||||
|
||||
predictor: CifPredictorV3
|
||||
predictor_conf:
|
||||
idim: 512
|
||||
threshold: 1.0
|
||||
l_order: 1
|
||||
r_order: 1
|
||||
tail_threshold: 0.45
|
||||
smooth_factor2: 0.25
|
||||
noise_threshold2: 0.01
|
||||
upsample_times: 3
|
||||
use_cif1_cnn: false
|
||||
upsample_type: cnn_blstm
|
||||
|
||||
# frontend related
|
||||
frontend: WavFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hamming
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
lfr_m: 7
|
||||
lfr_n: 6
|
||||
|
||||
specaug: SpecAugLFR
|
||||
specaug_conf:
|
||||
apply_time_warp: false
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
lfr_rate: 6
|
||||
num_freq_mask: 1
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 12
|
||||
num_time_mask: 1
|
||||
|
||||
train_conf:
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 150
|
||||
val_scheduler_criterion:
|
||||
- valid
|
||||
- acc
|
||||
best_model_criterion:
|
||||
- - valid
|
||||
- acc
|
||||
- max
|
||||
keep_nbest_models: 10
|
||||
log_interval: 50
|
||||
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0005
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 30000
|
||||
|
||||
dataset: AudioDataset
|
||||
dataset_conf:
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: DynamicBatchLocalShuffleSampler
|
||||
batch_type: example # example or length
|
||||
batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
|
||||
max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
|
||||
buffer_size: 500
|
||||
shuffle: True
|
||||
num_workers: 0
|
||||
|
||||
tokenizer: CharTokenizer
|
||||
tokenizer_conf:
|
||||
unk_symbol: <unk>
|
||||
split_with_space: true
|
||||
|
||||
|
||||
ctc_conf:
|
||||
dropout_rate: 0.0
|
||||
ctc_type: builtin
|
||||
reduce: true
|
||||
ignore_nan_grad: true
|
||||
normalize: null
|
||||
@ -857,7 +857,7 @@ class BiCifParaformer(Paraformer):
|
||||
return results, meta_data
|
||||
|
||||
|
||||
class ParaformerOnline(Paraformer):
|
||||
class ParaformerStreaming(Paraformer):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
@ -15,4 +15,17 @@ def slice_padding_fbank(speech, speech_lengths, vad_segments):
|
||||
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
|
||||
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
|
||||
return feats_pad, speech_lengths_pad
|
||||
|
||||
|
||||
|
||||
def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
|
||||
speech_list = []
|
||||
speech_lengths_list = []
|
||||
for i, segment in enumerate(vad_segments):
|
||||
bed_idx = int(segment[0][0] * 16)
|
||||
end_idx = min(int(segment[0][1] * 16), speech_lengths)
|
||||
speech_i = speech[bed_idx: end_idx]
|
||||
speech_lengths_i = end_idx - bed_idx
|
||||
speech_list.append(speech_i)
|
||||
speech_lengths_list.append(speech_lengths_i)
|
||||
|
||||
return speech_list, speech_lengths_list
|
||||
Loading…
Reference in New Issue
Block a user