This commit is contained in:
游雁 2025-02-11 10:08:19 +08:00
parent 6ebf6e48eb
commit 001a66bbfe
2 changed files with 8 additions and 8 deletions

View File

@ -221,10 +221,10 @@ def main(**kwargs):
)
trainer.start_step = 0
device = next(model.parameters()).device
if device.type == 'cuda':
with torch.cuda.device():
torch.cuda.empty_cache()
# device = next(model.parameters()).device
# if device.type == 'cuda':
# with torch.cuda.device():
# torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
logging.info(

View File

@ -184,10 +184,10 @@ def main(**kwargs):
)
trainer.start_step = 0
device = next(model.parameters()).device
if device.type == 'cuda':
with torch.cuda.device():
torch.cuda.empty_cache()
# device = next(model.parameters()).device
# if device.type == 'cuda':
# with torch.cuda.device():
# torch.cuda.empty_cache()
time_escaped = (time.perf_counter() - time_slice_i) / 3600.0
logging.info(