update spk inference

This commit is contained in:
shixian.shi 2023-12-06 11:29:43 +08:00
parent b3fcd42bf6
commit e54535e5eb

View File

@ -952,10 +952,13 @@ def inference_paraformer_vad_speaker(
##### speaker_verification #####
##################################
# load sv model
sv_model_dict = torch.load(sv_model_file)
sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
if ngpu > 0:
sv_model_dict = torch.load(sv_model_file)
sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
sv_model.cuda()
else:
sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
sv_model.load_state_dict(sv_model_dict)
print(f'load sv model params: {sv_model_file}')
sv_model.eval()