From b7cb19b01a1454f7a1388e24dcd4e10fc654bd7c Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Tue, 16 Jan 2024 11:30:25 +0800 Subject: [PATCH] update demo, readme --- README.md | 24 ++++++++++++------- README_zh.md | 18 +++++++++----- .../bicif_paraformer/demo.py | 16 ++++++------- .../campplus_sv/demo.py | 2 +- .../contextual_paraformer/demo.py | 2 +- .../ct_transformer/demo.py | 4 ++-- .../ct_transformer_streaming/demo.py | 2 +- .../emotion2vec/demo.py | 2 +- .../fsmn_vad_streaming/demo.py | 4 ++-- .../monotonic_aligner/demo.py | 2 +- .../paraformer-zh-spk/demo.py | 4 ++-- .../paraformer/demo.py | 4 ++-- .../paraformer_streaming/demo.py | 16 ++++++------- .../seaco_paraformer/demo.py | 4 ++-- funasr/auto/auto_model.py | 8 +++---- funasr/bin/inference.py | 21 +--------------- 16 files changed, 63 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index a53ce4d3e..2bd28e218 100644 --- a/README.md +++ b/README.md @@ -95,9 +95,9 @@ model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \ vad_model="fsmn-vad", vad_model_revision="v2.0.2", \ punc_model="ct-punc-c", punc_model_revision="v2.0.2", \ spk_model="cam++", spk_model_revision="v2.0.2") -res = model(input=f"{model.model_path}/example/asr_example.wav", - batch_size=64, - hotword='魔搭') +res = model.generate(input=f"{model.model_path}/example/asr_example.wav", + batch_size=64, + hotword='魔搭') print(res) ``` Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download. @@ -124,7 +124,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1) for i in range(total_chunk_num): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] is_final = i == total_chunk_num - 1 - res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back) + res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back) print(res) ``` Note: `chunk_size` is the configuration for streaming latency.` [0,10,5]` indicates that the real-time display granularity is `10*60=600ms`, and the lookahead information is `5*60=300ms`. Each inference input is `600ms` (sample points are `16000*0.6=960`), and the output is the corresponding text. For the last speech segment input, `is_final=True` needs to be set to force the output of the last word. @@ -135,7 +135,7 @@ from funasr import AutoModel model = AutoModel(model="fsmn-vad", model_revision="v2.0.2") wav_file = f"{model.model_path}/example/asr_example.wav" -res = model(input=wav_file) +res = model.generate(input=wav_file) print(res) ``` ### Voice Activity Detection (Non-streaming) @@ -156,7 +156,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1) for i in range(total_chunk_num): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] is_final = i == total_chunk_num - 1 - res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) + res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) if len(res[0]["value"]): print(res) ``` @@ -165,7 +165,7 @@ for i in range(total_chunk_num): from funasr import AutoModel model = AutoModel(model="ct-punc", model_revision="v2.0.2") -res = model(input="那今天的会就到这里吧 happy new year 明年见") +res = model.generate(input="那今天的会就到这里吧 happy new year 明年见") print(res) ``` ### Timestamp Prediction @@ -175,7 +175,7 @@ from funasr import AutoModel model = AutoModel(model="fa-zh", model_revision="v2.0.2") wav_file = f"{model.model_path}/example/asr_example.wav" text_file = f"{model.model_path}/example/text.txt" -res = model(input=(wav_file, text_file), data_type=("sound", "text")) +res = model.generate(input=(wav_file, text_file), data_type=("sound", "text")) print(res) ``` [//]: # (FunASR supports inference and fine-tuning of models trained on industrial datasets of tens of thousands of hours. For more details, please refer to ([modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)). It also supports training and fine-tuning of models on academic standard datasets. For more details, please refer to([egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)). The models include speech recognition (ASR), speech activity detection (VAD), punctuation recovery, language model, speaker verification, speaker separation, and multi-party conversation speech recognition. For a detailed list of models, please refer to the [Model Zoo](https://github.com/alibaba-damo-academy/FunASR/blob/main/docs/model_zoo/modelscope_models.md):) @@ -229,10 +229,16 @@ The use of pretraining model is subject to [model license](./MODEL_LICENSE) } @inproceedings{gao22b_interspeech, author={Zhifu Gao and ShiLiang Zhang and Ian McLoughlin and Zhijie Yan}, - title={{Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition}}, + title={Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition}, year=2022, booktitle={Proc. Interspeech 2022}, pages={2063--2067}, doi={10.21437/Interspeech.2022-9996} } +@inproceedings{shi2023seaco, + author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang}, + title={SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability}, + year={2023}, + booktitle={ICASSP2024} +} ``` diff --git a/README_zh.md b/README_zh.md index 861e61c2e..dc2030222 100644 --- a/README_zh.md +++ b/README_zh.md @@ -91,7 +91,7 @@ model = AutoModel(model="paraformer-zh", model_revision="v2.0.2", \ vad_model="fsmn-vad", vad_model_revision="v2.0.2", \ punc_model="ct-punc-c", punc_model_revision="v2.0.2", \ spk_model="cam++", spk_model_revision="v2.0.2") -res = model(input=f"{model.model_path}/example/asr_example.wav", +res = model.generate(input=f"{model.model_path}/example/asr_example.wav", batch_size=64, hotword='魔搭') print(res) @@ -121,7 +121,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1) for i in range(total_chunk_num): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] is_final = i == total_chunk_num - 1 - res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back) + res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back) print(res) ``` @@ -134,7 +134,7 @@ from funasr import AutoModel model = AutoModel(model="fsmn-vad", model_revision="v2.0.2") wav_file = f"{model.model_path}/example/asr_example.wav" -res = model(input=wav_file) +res = model.generate(input=wav_file) print(res) ``` @@ -156,7 +156,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1) for i in range(total_chunk_num): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] is_final = i == total_chunk_num - 1 - res = model(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) + res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size) if len(res[0]["value"]): print(res) ``` @@ -167,7 +167,7 @@ from funasr import AutoModel model = AutoModel(model="ct-punc", model_revision="v2.0.2") -res = model(input="那今天的会就到这里吧 happy new year 明年见") +res = model.generate(input="那今天的会就到这里吧 happy new year 明年见") print(res) ``` @@ -179,7 +179,7 @@ model = AutoModel(model="fa-zh", model_revision="v2.0.0") wav_file = f"{model.model_path}/example/asr_example.wav" text_file = f"{model.model_path}/example/text.txt" -res = model(input=(wav_file, text_file), data_type=("sound", "text")) +res = model.generate(input=(wav_file, text_file), data_type=("sound", "text")) print(res) ``` 更多详细用法([示例](examples/industrial_data_pretraining)) @@ -242,4 +242,10 @@ FunASR支持预训练或者进一步微调的模型进行服务部署。目前 pages={2063--2067}, doi={10.21437/Interspeech.2022-9996} } +@article{shi2023seaco, + author={Xian Shi and Yexin Yang and Zerui Li and Yanni Chen and Zhifu Gao and Shiliang Zhang}, + title={{SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability}}, + year=2023, + journal={arXiv preprint arXiv:2308.03266(accepted by ICASSP2024)}, +} ``` diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py index 60718de08..a06b308d1 100644 --- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py +++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py @@ -6,14 +6,14 @@ from funasr import AutoModel model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - model_revision="v2.0.2", - vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_model_revision="v2.0.2", - punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", - punc_model_revision="v2.0.2", - spk_model="damo/speech_campplus_sv_zh-cn_16k-common", - spk_model_revision="v2.0.2", + model_revision="v2.0.2", + vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + vad_model_revision="v2.0.2", + punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", + punc_model_revision="v2.0.2", + spk_model="damo/speech_campplus_sv_zh-cn_16k-common", + spk_model_revision="v2.0.2", ) -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60) +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60) print(res) diff --git a/examples/industrial_data_pretraining/campplus_sv/demo.py b/examples/industrial_data_pretraining/campplus_sv/demo.py index 6a7f10548..16d629b29 100644 --- a/examples/industrial_data_pretraining/campplus_sv/demo.py +++ b/examples/industrial_data_pretraining/campplus_sv/demo.py @@ -9,5 +9,5 @@ model = AutoModel(model="damo/speech_campplus_sv_zh-cn_16k-common", model_revision="v2.0.2", ) -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py index 78693c527..d1378ca3a 100644 --- a/examples/industrial_data_pretraining/contextual_paraformer/demo.py +++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.py @@ -7,6 +7,6 @@ from funasr import AutoModel model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.2") -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", hotword='达摩院 魔搭') print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/ct_transformer/demo.py b/examples/industrial_data_pretraining/ct_transformer/demo.py index d648f3d27..f547f0342 100644 --- a/examples/industrial_data_pretraining/ct_transformer/demo.py +++ b/examples/industrial_data_pretraining/ct_transformer/demo.py @@ -7,7 +7,7 @@ from funasr import AutoModel model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.2") -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt") +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt") print(res) @@ -15,5 +15,5 @@ from funasr import AutoModel model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.2") -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt") +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt") print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py index 5ef83813e..081fd1928 100644 --- a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py @@ -12,7 +12,7 @@ vads = inputs.split("|") rec_result_all = "outputs: " cache = {} for vad in vads: - rec_result = model(input=vad, cache=cache) + rec_result = model.generate(input=vad, cache=cache) print(rec_result) rec_result_all += rec_result[0]['text'] diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py index abaa9f40e..ea8da99dd 100644 --- a/examples/industrial_data_pretraining/emotion2vec/demo.py +++ b/examples/industrial_data_pretraining/emotion2vec/demo.py @@ -7,5 +7,5 @@ from funasr import AutoModel model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1") -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs") +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs") print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py index 459dfff41..8084dec26 100644 --- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py +++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py @@ -9,7 +9,7 @@ wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audi chunk_size = 60000 # ms model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.2") -res = model(input=wav_file, chunk_size=chunk_size, ) +res = model.generate(input=wav_file, chunk_size=chunk_size, ) print(res) @@ -28,7 +28,7 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1) for i in range(total_chunk_num): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] is_final = i == total_chunk_num - 1 - res = model(input=speech_chunk, + res = model.generate(input=speech_chunk, cache=cache, is_final=is_final, chunk_size=chunk_size, diff --git a/examples/industrial_data_pretraining/monotonic_aligner/demo.py b/examples/industrial_data_pretraining/monotonic_aligner/demo.py index def6b7de8..cad9aab91 100644 --- a/examples/industrial_data_pretraining/monotonic_aligner/demo.py +++ b/examples/industrial_data_pretraining/monotonic_aligner/demo.py @@ -7,7 +7,7 @@ from funasr import AutoModel model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.2") -res = model(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", +res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", "欢迎大家来到魔搭社区进行体验"), data_type=("sound", "text"), batch_size=2, diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py index aa895eb85..b4453e927 100644 --- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py +++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py @@ -15,6 +15,6 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co spk_model_revision="v2.0.2" ) -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", - hotword='达摩院 磨搭') +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", + hotword='达摩院 磨搭') print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py index 6dbe33d06..ef33bf40d 100644 --- a/examples/industrial_data_pretraining/paraformer/demo.py +++ b/examples/industrial_data_pretraining/paraformer/demo.py @@ -7,7 +7,7 @@ from funasr import AutoModel model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.2") -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") print(res) @@ -18,5 +18,5 @@ frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-co fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2) for batch_idx, fbank_dict in enumerate(fbanks): - res = model(**fbank_dict) + res = model.generate(**fbank_dict) print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py index 8f7eef350..07efde67c 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py @@ -11,7 +11,7 @@ decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cr model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2") cache = {} -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back, @@ -32,11 +32,11 @@ total_chunk_num = int(len((speech)-1)/chunk_stride+1) for i in range(total_chunk_num): speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride] is_final = i == total_chunk_num - 1 - res = model(input=speech_chunk, - cache=cache, - is_final=is_final, - chunk_size=chunk_size, - encoder_chunk_look_back=encoder_chunk_look_back, - decoder_chunk_look_back=decoder_chunk_look_back, - ) + res = model.generate(input=speech_chunk, + cache=cache, + is_final=is_final, + chunk_size=chunk_size, + encoder_chunk_look_back=encoder_chunk_look_back, + decoder_chunk_look_back=decoder_chunk_look_back, + ) print(res) diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index cf49e42c4..5f17252f9 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -15,6 +15,6 @@ model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-co spk_model_revision="v2.0.2", ) -res = model(input=f"{model.model_path}/example/asr_example.wav", - hotword='达摩院 魔搭') +res = model.generate(input=f"{model.model_path}/example/asr_example.wav", + hotword='达摩院 魔搭') print(res) \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 25edeb708..580cca8d4 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -264,7 +264,7 @@ class AutoModel: # step.1: compute the vad model self.vad_kwargs.update(cfg) beg_vad = time.time() - res = self.generate(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg) + res = self.inference(input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg) end_vad = time.time() print(f"time cost vad: {end_vad - beg_vad:0.3f}") @@ -316,7 +316,7 @@ class AutoModel: batch_size_ms_cum = 0 end_idx = j + 1 speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) - results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg) + results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg) if self.spk_model is not None: all_segments = [] # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] @@ -327,7 +327,7 @@ class AutoModel: segments = sv_chunk(vad_segments) all_segments.extend(segments) speech_b = [i[2] for i in segments] - spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg) + spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg) results[_b]['spk_embedding'] = spk_res[0]['spk_embedding'] beg_idx = end_idx if len(results) < 1: @@ -378,7 +378,7 @@ class AutoModel: # step.3 compute punc model if self.punc_model is not None: self.punc_kwargs.update(cfg) - punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg) + punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg) result["text_with_punc"] = punc_res[0]["text"] # speaker embedding cluster after resorted diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py index bc435c43c..d2f0c149d 100644 --- a/funasr/bin/inference.py +++ b/funasr/bin/inference.py @@ -1,25 +1,7 @@ -import json -import time -import torch import hydra -import random -import string import logging -import os.path -from tqdm import tqdm from omegaconf import DictConfig, OmegaConf, ListConfig -from funasr.register import tables -from funasr.utils.load_utils import load_bytes -from funasr.download.file import download_from_url -from funasr.download.download_from_hub import download_model -from funasr.utils.vad_utils import slice_padding_audio_samples -from funasr.train_utils.set_all_random_seed import set_all_random_seed -from funasr.train_utils.load_pretrained_model import load_pretrained_model -from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -from funasr.utils.timestamp_tools import timestamp_sentence -from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk -from funasr.models.campplus.cluster_backend import ClusterBackend from funasr.auto.auto_model import AutoModel @@ -41,10 +23,9 @@ def main_hydra(cfg: DictConfig): if kwargs.get("debug", False): import pdb; pdb.set_trace() model = AutoModel(**kwargs) - res = model(input=kwargs["input"]) + res = model.generate(input=kwargs["input"]) print(res) - if __name__ == '__main__': main_hydra() \ No newline at end of file