This commit is contained in:
游雁 2024-08-12 10:36:41 +08:00
parent cf85b11448
commit dccbae4f7b
2 changed files with 6 additions and 7 deletions

View File

@ -151,8 +151,7 @@ def main(**kwargs):
dataloader = dataloader_class(**kwargs)
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
scaler = GradScaler(enabled=True) if trainer.use_fp16 else None
if kwargs.get("train_conf", {}).get("save_init_model", True):

View File

@ -703,11 +703,11 @@ class Trainer:
if self.use_deepspeed:
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
if self.use_fp16 or self.use_bf16:
scaler.scale(loss).backward()
scaled_loss = loss / self.accum_grad
if scaler is not None:
scaler.scale(scaled_loss).backward()
else:
loss.backward()
scaled_loss.backward()
def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
batch_idx = loss_dict["batch_idx"]
@ -732,7 +732,7 @@ class Trainer:
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
if self.use_fp16 or self.use_bf16:
if scaler is not None:
scaler.step(optim)
scaler.update()
else: