diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py index f3b4d560a..47226021f 100644 --- a/funasr/bin/asr_inference.py +++ b/funasr/bin/asr_inference.py @@ -346,6 +346,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if word_lm_train_config is not None: diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 2b6716ed8..e10ebf404 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -1,9 +1,4 @@ #!/usr/bin/env python3 -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import torch -torch.set_num_threads(1) import argparse import logging diff --git a/funasr/bin/asr_inference_mfcca.py b/funasr/bin/asr_inference_mfcca.py index 6f3dbb113..e83286958 100644 --- a/funasr/bin/asr_inference_mfcca.py +++ b/funasr/bin/asr_inference_mfcca.py @@ -472,6 +472,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if word_lm_train_config is not None: diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py index 8cbd41905..a8ac99d55 100644 --- a/funasr/bin/asr_inference_paraformer.py +++ b/funasr/bin/asr_inference_paraformer.py @@ -612,7 +612,9 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() - + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") if ngpu > 1: diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 944685f1d..821f69429 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -536,6 +536,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py index 1548f9ff1..977dc9bb3 100644 --- a/funasr/bin/asr_inference_paraformer_vad.py +++ b/funasr/bin/asr_inference_paraformer_vad.py @@ -157,6 +157,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index 9dc0b79ce..197930f47 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -484,6 +484,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py index 4aea72079..35ecdc24b 100644 --- a/funasr/bin/asr_inference_uniasr.py +++ b/funasr/bin/asr_inference_uniasr.py @@ -379,6 +379,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if word_lm_train_config is not None: diff --git a/funasr/bin/diar_inference_launch.py b/funasr/bin/diar_inference_launch.py index 83436e8a7..07974c072 100755 --- a/funasr/bin/diar_inference_launch.py +++ b/funasr/bin/diar_inference_launch.py @@ -2,8 +2,6 @@ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) -import torch -torch.set_num_threads(1) import argparse import logging diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py index 01d3f296a..87816dd22 100755 --- a/funasr/bin/eend_ola_inference.py +++ b/funasr/bin/eend_ola_inference.py @@ -158,6 +158,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py index 15c56caef..76de6df7a 100644 --- a/funasr/bin/lm_inference.py +++ b/funasr/bin/lm_inference.py @@ -89,10 +89,9 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() - logging.basicConfig( - level=log_level, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py index d229cc6c1..dc6414f6a 100644 --- a/funasr/bin/lm_inference_launch.py +++ b/funasr/bin/lm_inference_launch.py @@ -1,9 +1,6 @@ #!/usr/bin/env python3 -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import torch -torch.set_num_threads(1) + import argparse import logging diff --git a/funasr/bin/punc_inference_launch.py b/funasr/bin/punc_inference_launch.py index 2c5a2865f..b1d923553 100755 --- a/funasr/bin/punc_inference_launch.py +++ b/funasr/bin/punc_inference_launch.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import torch -torch.set_num_threads(1) import argparse import logging diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py index 5157eeb29..b2db1bf17 100644 --- a/funasr/bin/punctuation_infer_vadrealtime.py +++ b/funasr/bin/punctuation_infer_vadrealtime.py @@ -203,10 +203,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() - logging.basicConfig( - level=log_level, - format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", - ) + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if ngpu >= 1 and torch.cuda.is_available(): device = "cuda" diff --git a/funasr/bin/sond_inference.py b/funasr/bin/sond_inference.py index 5a0a8e28f..c55bc3544 100755 --- a/funasr/bin/sond_inference.py +++ b/funasr/bin/sond_inference.py @@ -252,6 +252,8 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: diff --git a/funasr/bin/sv_inference.py b/funasr/bin/sv_inference.py index 7e63bbd2d..76b1dfbb8 100755 --- a/funasr/bin/sv_inference.py +++ b/funasr/bin/sv_inference.py @@ -179,6 +179,9 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: diff --git a/funasr/bin/sv_inference_launch.py b/funasr/bin/sv_inference_launch.py index 64a3cff2e..880607013 100755 --- a/funasr/bin/sv_inference_launch.py +++ b/funasr/bin/sv_inference_launch.py @@ -2,8 +2,6 @@ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) -import torch -torch.set_num_threads(1) import argparse import logging diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py index 6360b17db..191bbf325 100644 --- a/funasr/bin/tp_inference.py +++ b/funasr/bin/tp_inference.py @@ -179,6 +179,9 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: diff --git a/funasr/bin/tp_inference_launch.py b/funasr/bin/tp_inference_launch.py index 55debac6d..6cdff057d 100644 --- a/funasr/bin/tp_inference_launch.py +++ b/funasr/bin/tp_inference_launch.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) -import torch -torch.set_num_threads(1) import argparse import logging diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py index 08d65a4e7..aff0a443b 100644 --- a/funasr/bin/vad_inference.py +++ b/funasr/bin/vad_inference.py @@ -192,6 +192,9 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py index 8fea8db16..4a1f334cf 100644 --- a/funasr/bin/vad_inference_launch.py +++ b/funasr/bin/vad_inference_launch.py @@ -1,9 +1,4 @@ #!/usr/bin/env python3 -# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved. -# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) - -import torch -torch.set_num_threads(1) import argparse import logging diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py index 9ed072199..4d026207d 100644 --- a/funasr/bin/vad_inference_online.py +++ b/funasr/bin/vad_inference_online.py @@ -151,6 +151,9 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() + ncpu = kwargs.get("ncpu", 1) + torch.set_num_threads(ncpu) + if batch_size > 1: raise NotImplementedError("batch decoding is not implemented") if ngpu > 1: