paraformer vad punc

This commit is contained in:
游雁 2023-05-11 19:35:10 +08:00
parent 9d6aad2a44
commit 475064f914
4 changed files with 13 additions and 21 deletions

View File

@ -4,8 +4,8 @@ from modelscope.utils.constant import Tasks
inference_pipeline = pipeline(
task=Tasks.auto_speech_recognition,
model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
batch_size=64,
)
audio_in = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav'
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)

View File

@ -10,7 +10,7 @@ if __name__ == '__main__':
vad_model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
punc_model='damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',
output_dir=output_dir,
batch_size=8,
batch_size=64,
)
rec_result = inference_pipeline(audio_in=audio_in)
print(rec_result)

View File

@ -291,11 +291,11 @@ def inference_launch_funasr(**kwargs):
elif mode == "paraformer":
from funasr.bin.asr_inference_paraformer import inference_modelscope
inference_pipeline = inference_modelscope(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
elif mode.startswith("paraformer_vad"):
from funasr.bin.asr_inference_paraformer import inference_modelscope_vad_punc
inference_pipeline = inference_modelscope_vad_punc(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"])
return inference_pipeline(kwargs["data_path_and_name_and_type"], hotword=kwargs.get("hotword", None))
elif mode == "mfcca":
from funasr.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)

View File

@ -48,6 +48,8 @@ from funasr.bin.tp_inference import SpeechText2Timestamp
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.bin.punctuation_infer import Text2Punc
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
class Speech2Text:
"""Speech2Text class
@ -293,15 +295,14 @@ class Speech2Text:
text = self.tokenizer.tokens2text(token)
else:
text = None
timestamp = []
if isinstance(self.asr_model, BiCifParaformer):
_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:enc_len[i]*3],
us_peaks[i][:enc_len[i]*3],
copy.copy(token),
vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
else:
results.append((text, token, token_int, hyp, [], enc_len_batch_total, lfr_factor))
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
# assert check_return_type(results)
return results
@ -471,7 +472,7 @@ def inference_modelscope(
hotword_list_or_file = None
if param_dict is not None:
hotword_list_or_file = param_dict.get('hotword')
if 'hotword' in kwargs:
if 'hotword' in kwargs and kwargs['hotword'] is not None:
hotword_list_or_file = kwargs['hotword']
if hotword_list_or_file is not None or 'hotword' in kwargs:
speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
@ -1018,18 +1019,9 @@ def main(cmd=None):
kwargs = vars(args)
kwargs.pop("config", None)
kwargs['param_dict'] = param_dict
inference(**kwargs)
inference_pipeline = inference_modelscope(**kwargs)
return inference_pipeline(kwargs["data_path_and_name_and_type"], param_dict=param_dict)
if __name__ == "__main__":
main()
# from modelscope.pipelines import pipeline
# from modelscope.utils.constant import Tasks
#
# inference_16k_pipline = pipeline(
# task=Tasks.auto_speech_recognition,
# model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
#
# rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
# print(rec_result)