bugfix paraformer large long

This commit is contained in:
游雁 2023-05-22 16:25:48 +08:00
parent 5eb52f9c73
commit 3cde6fe8e0
3 changed files with 27 additions and 8 deletions

View File

@ -3,15 +3,14 @@ from modelscope.utils.constant import Tasks
if __name__ == '__main__':
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav'
output_dir = None
output_dir = "./results"
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
output_dir=output_dir,
batch_size=64,
)
rec_result = inference_pipeline(audio_in=audio_in)
rec_result = inference_pipeline(audio_in=audio_in, batch_size_token=5000)
print(rec_result)

View File

@ -600,6 +600,9 @@ def inference_paraformer_vad_punc(
if 'hotword' in kwargs:
hotword_list_or_file = kwargs['hotword']
batch_size_token = kwargs.get("batch_size_token", 6000)
print("batch_size_token: ", batch_size_token)
if speech2text.hotword_list is None:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
@ -642,8 +645,10 @@ def inference_paraformer_vad_punc(
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
beg_vad = time.time()
vad_results = speech2vadsegment(**batch)
end_vad = time.time()
print("time cost vad: ", end_vad-beg_vad)
_, vadsegments = vad_results[0], vad_results[1][0]
speech, speech_lengths = batch["speech"], batch["speech_lengths"]
@ -652,17 +657,29 @@ def inference_paraformer_vad_punc(
data_with_index = [(vadsegments[i], i) for i in range(n)]
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
results_sorted = []
for j, beg_idx in enumerate(range(0, n, batch_size)):
end_idx = min(n, beg_idx + batch_size)
batch_size_token_ms = batch_size_token*60
batch_size_token_ms_cum = 0
beg_idx = 0
for j, _ in enumerate(range(0, n)):
batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
continue
batch_size_token_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
beg_idx = end_idx
batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
batch = to_device(batch, device=device)
print("batch: ", speech_j.shape[0])
beg_asr = time.time()
results = speech2text(**batch)
end_asr = time.time()
print("time cost asr: ", end_asr - beg_asr)
if len(results) < 1:
results = [["", [], [], [], [], [], []]]
results_sorted.extend(results)
restored_data = [0] * n
for j in range(n):
index = sorted_data[j][1]
@ -701,7 +718,10 @@ def inference_paraformer_vad_punc(
text_postprocessed_punc = text_postprocessed
punc_id_list = []
if len(word_lists) > 0 and text2punc is not None:
beg_punc = time.time()
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
end_punc = time.time()
print("time cost punc: ", end_punc-beg_punc)
item = {'key': key, 'value': text_postprocessed_punc}
if text_postprocessed != "":

View File

@ -1 +1 @@
0.5.4
0.5.5