mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
bugfix paraformer large long
This commit is contained in:
parent
5eb52f9c73
commit
3cde6fe8e0
@ -3,15 +3,14 @@ from modelscope.utils.constant import Tasks
|
||||
|
||||
if __name__ == '__main__':
|
||||
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav'
|
||||
output_dir = None
|
||||
output_dir = "./results"
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.auto_speech_recognition,
|
||||
model='damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
|
||||
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
|
||||
punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
|
||||
output_dir=output_dir,
|
||||
batch_size=64,
|
||||
)
|
||||
rec_result = inference_pipeline(audio_in=audio_in)
|
||||
rec_result = inference_pipeline(audio_in=audio_in, batch_size_token=5000)
|
||||
print(rec_result)
|
||||
|
||||
|
||||
@ -600,6 +600,9 @@ def inference_paraformer_vad_punc(
|
||||
if 'hotword' in kwargs:
|
||||
hotword_list_or_file = kwargs['hotword']
|
||||
|
||||
batch_size_token = kwargs.get("batch_size_token", 6000)
|
||||
print("batch_size_token: ", batch_size_token)
|
||||
|
||||
if speech2text.hotword_list is None:
|
||||
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
|
||||
|
||||
@ -642,8 +645,10 @@ def inference_paraformer_vad_punc(
|
||||
assert all(isinstance(s, str) for s in keys), keys
|
||||
_bs = len(next(iter(batch.values())))
|
||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||
|
||||
beg_vad = time.time()
|
||||
vad_results = speech2vadsegment(**batch)
|
||||
end_vad = time.time()
|
||||
print("time cost vad: ", end_vad-beg_vad)
|
||||
_, vadsegments = vad_results[0], vad_results[1][0]
|
||||
|
||||
speech, speech_lengths = batch["speech"], batch["speech_lengths"]
|
||||
@ -652,17 +657,29 @@ def inference_paraformer_vad_punc(
|
||||
data_with_index = [(vadsegments[i], i) for i in range(n)]
|
||||
sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
|
||||
results_sorted = []
|
||||
for j, beg_idx in enumerate(range(0, n, batch_size)):
|
||||
end_idx = min(n, beg_idx + batch_size)
|
||||
batch_size_token_ms = batch_size_token*60
|
||||
batch_size_token_ms_cum = 0
|
||||
beg_idx = 0
|
||||
for j, _ in enumerate(range(0, n)):
|
||||
batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
|
||||
if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms:
|
||||
continue
|
||||
batch_size_token_ms_cum = 0
|
||||
end_idx = j + 1
|
||||
speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
|
||||
|
||||
beg_idx = end_idx
|
||||
batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
|
||||
batch = to_device(batch, device=device)
|
||||
print("batch: ", speech_j.shape[0])
|
||||
beg_asr = time.time()
|
||||
results = speech2text(**batch)
|
||||
end_asr = time.time()
|
||||
print("time cost asr: ", end_asr - beg_asr)
|
||||
|
||||
if len(results) < 1:
|
||||
results = [["", [], [], [], [], [], []]]
|
||||
results_sorted.extend(results)
|
||||
|
||||
restored_data = [0] * n
|
||||
for j in range(n):
|
||||
index = sorted_data[j][1]
|
||||
@ -701,7 +718,10 @@ def inference_paraformer_vad_punc(
|
||||
text_postprocessed_punc = text_postprocessed
|
||||
punc_id_list = []
|
||||
if len(word_lists) > 0 and text2punc is not None:
|
||||
beg_punc = time.time()
|
||||
text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
|
||||
end_punc = time.time()
|
||||
print("time cost punc: ", end_punc-beg_punc)
|
||||
|
||||
item = {'key': key, 'value': text_postprocessed_punc}
|
||||
if text_postprocessed != "":
|
||||
|
||||
@ -1 +1 @@
|
||||
0.5.4
|
||||
0.5.5
|
||||
|
||||
Loading…
Reference in New Issue
Block a user