diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py index b610d9ede..f8a302d14 100644 --- a/funasr/bin/train_ds.py +++ b/funasr/bin/train_ds.py @@ -151,8 +151,7 @@ def main(**kwargs): dataloader = dataloader_class(**kwargs) # dataloader_tr, dataloader_val = dataloader_class(**kwargs) - scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None - scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler + scaler = GradScaler(enabled=True) if trainer.use_fp16 else None if kwargs.get("train_conf", {}).get("save_init_model", True): diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 174065f3b..00db9b216 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -703,11 +703,11 @@ class Trainer: if self.use_deepspeed: scaled_loss = model.backward(loss) else: - loss = loss / self.accum_grad - if self.use_fp16 or self.use_bf16: - scaler.scale(loss).backward() + scaled_loss = loss / self.accum_grad + if scaler is not None: + scaler.scale(scaled_loss).backward() else: - loss.backward() + scaled_loss.backward() def update_step(self, model, optim, scheduler, scaler, loss_dict=None): batch_idx = loss_dict["batch_idx"] @@ -732,7 +732,7 @@ class Trainer: # Execute an optimization step (update model parameters) if self.use_ddp or self.use_fsdp: dist.barrier() - if self.use_fp16 or self.use_bf16: + if scaler is not None: scaler.step(optim) scaler.update() else: