tp_inference device bug

This commit is contained in:
shixian.shi 2023-03-09 15:26:03 +08:00
parent 3610b3f48a
commit c441eb08c4

View File

@ -112,6 +112,9 @@ class SpeechText2Timestamp:
tp_model, tp_train_args = ASRTask.build_model_from_file(
timestamp_infer_config, timestamp_model_file, device
)
if 'cuda' in device:
tp_model = tp_model.cuda()
frontend = None
if tp_train_args.frontend is not None:
frontend = WavFrontend(cmvn_file=timestamp_cmvn_file, **tp_train_args.frontend_conf)
@ -240,7 +243,6 @@ def inference_modelscope(
device = "cuda"
else:
device = "cpu"
# 1. Set random-seed
set_all_random_seed(seed)