From 6ebf6e48eb0368518452312c803c58a65fe9bd26 Mon Sep 17 00:00:00 2001 From: BienBoy <92378515+BienBoy@users.noreply.github.com> Date: Wed, 5 Feb 2025 17:47:20 +0800 Subject: [PATCH] fix: resolve CPU runtime error introduced by previous commit (c1e365f) (#2375) Fixed a bug that caused a runtime error when running the model on CPU, which was introduced in commit c1e365fea09aafda387cac12fdff43d28c598979. The error was related to incorrect handling of device placement. --- funasr/auto/auto_model.py | 7 +++++-- funasr/bin/train.py | 6 ++++-- funasr/bin/train_ds.py | 6 ++++-- funasr/models/language_model/rnn/decoders.py | 6 ++++-- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index ec4e42013..60bfeff10 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -366,8 +366,11 @@ class AutoModel: if pbar: # pbar.update(1) pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") - with torch.cuda.device(next(model.parameters()).device): - torch.cuda.empty_cache() + + device = next(model.parameters()).device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() return asr_result_list def inference_with_vad(self, input, input_len=None, **cfg): diff --git a/funasr/bin/train.py b/funasr/bin/train.py index d0f154a0a..b2b1dcdb2 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -221,8 +221,10 @@ def main(**kwargs): ) trainer.start_step = 0 - with torch.cuda.device(kwargs["device"]): - torch.cuda.empty_cache() + device = next(model.parameters()).device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() time_escaped = (time.perf_counter() - time_slice_i) / 3600.0 logging.info( diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py index 24e81f617..bfa1dff94 100644 --- a/funasr/bin/train_ds.py +++ b/funasr/bin/train_ds.py @@ -184,8 +184,10 @@ def main(**kwargs): ) trainer.start_step = 0 - with torch.cuda.device(kwargs["device"]): - torch.cuda.empty_cache() + device = next(model.parameters()).device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() time_escaped = (time.perf_counter() - time_slice_i) / 3600.0 logging.info( diff --git a/funasr/models/language_model/rnn/decoders.py b/funasr/models/language_model/rnn/decoders.py index 314d49f3b..cdae9e376 100644 --- a/funasr/models/language_model/rnn/decoders.py +++ b/funasr/models/language_model/rnn/decoders.py @@ -873,8 +873,10 @@ class Decoder(torch.nn.Module, ScorerInterface): ctc_state[idx], accum_best_ids ) - with torch.cuda.device(vscores.device): - torch.cuda.empty_cache() + device = vscores.device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}] ended_hyps = [