This commit is contained in:
游雁 2024-06-13 17:52:07 +08:00
parent d72df6cd2f
commit caf70826a6

View File

@ -146,7 +146,7 @@ def main(**kwargs):
dataloader = dataloader_class(**kwargs)
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
scaler = GradScaler(enabled=trainer.use_fp16) if trainer.use_fp16 else None
scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
trainer.resume_checkpoint(