diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py index 59e61ee64..f34bfb2bb 100644 --- a/funasr/bin/asr_inference_launch.py +++ b/funasr/bin/asr_inference_launch.py @@ -952,10 +952,13 @@ def inference_paraformer_vad_speaker( ##### speaker_verification ##### ################################## # load sv model - sv_model_dict = torch.load(sv_model_file) - sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config']) if ngpu > 0: + sv_model_dict = torch.load(sv_model_file) + sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config']) sv_model.cuda() + else: + sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu')) + sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config']) sv_model.load_state_dict(sv_model_dict) print(f'load sv model params: {sv_model_file}') sv_model.eval()