fix: resolve CPU runtime error introduced by previous commit (c1e365f) (#2375)

Fixed a bug that caused a runtime error when running the model on CPU, which was introduced in commit c1e365fea0. The error was related to incorrect handling of device placement.
This commit is contained in:
BienBoy 2025-02-05 17:47:20 +08:00 committed by GitHub
parent c1e365fea0
commit 6ebf6e48eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 17 additions and 8 deletions

View File

@ -366,8 +366,11 @@ class AutoModel:
if pbar:
# pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
with torch.cuda.device(next(model.parameters()).device):
torch.cuda.empty_cache()
device = next(model.parameters()).device
if device.type == 'cuda':
with torch.cuda.device():
torch.cuda.empty_cache()
return asr_result_list
def inference_with_vad(self, input, input_len=None, **cfg):

View File

@ -221,8 +221,10 @@ def main(**kwargs):
)
trainer.start_step = 0
with torch.cuda.device(kwargs["device"]):
torch.cuda.empty_cache()
device = next(model.parameters()).device
if device.type == 'cuda':
with torch.cuda.device():
torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
logging.info(

View File

@ -184,8 +184,10 @@ def main(**kwargs):
)
trainer.start_step = 0
with torch.cuda.device(kwargs["device"]):
torch.cuda.empty_cache()
device = next(model.parameters()).device
if device.type == 'cuda':
with torch.cuda.device():
torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
logging.info(

View File

@ -873,8 +873,10 @@ class Decoder(torch.nn.Module, ScorerInterface):
ctc_state[idx], accum_best_ids
)
with torch.cuda.device(vscores.device):
torch.cuda.empty_cache()
device = vscores.device
if device.type == 'cuda':
with torch.cuda.device():
torch.cuda.empty_cache()
dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}]
ended_hyps = [