diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index ec4e42013..60bfeff10 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -366,8 +366,11 @@ class AutoModel: if pbar: # pbar.update(1) pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") - with torch.cuda.device(next(model.parameters()).device): - torch.cuda.empty_cache() + + device = next(model.parameters()).device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() return asr_result_list def inference_with_vad(self, input, input_len=None, **cfg): diff --git a/funasr/bin/train.py b/funasr/bin/train.py index d0f154a0a..b2b1dcdb2 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -221,8 +221,10 @@ def main(**kwargs): ) trainer.start_step = 0 - with torch.cuda.device(kwargs["device"]): - torch.cuda.empty_cache() + device = next(model.parameters()).device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() time_escaped = (time.perf_counter() - time_slice_i) / 3600.0 logging.info( diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py index 24e81f617..bfa1dff94 100644 --- a/funasr/bin/train_ds.py +++ b/funasr/bin/train_ds.py @@ -184,8 +184,10 @@ def main(**kwargs): ) trainer.start_step = 0 - with torch.cuda.device(kwargs["device"]): - torch.cuda.empty_cache() + device = next(model.parameters()).device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() time_escaped = (time.perf_counter() - time_slice_i) / 3600.0 logging.info( diff --git a/funasr/models/language_model/rnn/decoders.py b/funasr/models/language_model/rnn/decoders.py index 314d49f3b..cdae9e376 100644 --- a/funasr/models/language_model/rnn/decoders.py +++ b/funasr/models/language_model/rnn/decoders.py @@ -873,8 +873,10 @@ class Decoder(torch.nn.Module, ScorerInterface): ctc_state[idx], accum_best_ids ) - with torch.cuda.device(vscores.device): - torch.cuda.empty_cache() + device = vscores.device + if device.type == 'cuda': + with torch.cuda.device(): + torch.cuda.empty_cache() dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}] ended_hyps = [