mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
trainer
This commit is contained in:
parent
cf85b11448
commit
dccbae4f7b
@ -151,8 +151,7 @@ def main(**kwargs):
|
||||
dataloader = dataloader_class(**kwargs)
|
||||
# dataloader_tr, dataloader_val = dataloader_class(**kwargs)
|
||||
|
||||
scaler = GradScaler(enabled=True) if trainer.use_fp16 or trainer.use_bf16 else None
|
||||
scaler = ShardedGradScaler(enabled=trainer.use_fp16) if trainer.use_fsdp else scaler
|
||||
scaler = GradScaler(enabled=True) if trainer.use_fp16 else None
|
||||
|
||||
if kwargs.get("train_conf", {}).get("save_init_model", True):
|
||||
|
||||
|
||||
@ -703,11 +703,11 @@ class Trainer:
|
||||
if self.use_deepspeed:
|
||||
scaled_loss = model.backward(loss)
|
||||
else:
|
||||
loss = loss / self.accum_grad
|
||||
if self.use_fp16 or self.use_bf16:
|
||||
scaler.scale(loss).backward()
|
||||
scaled_loss = loss / self.accum_grad
|
||||
if scaler is not None:
|
||||
scaler.scale(scaled_loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
scaled_loss.backward()
|
||||
|
||||
def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
|
||||
batch_idx = loss_dict["batch_idx"]
|
||||
@ -732,7 +732,7 @@ class Trainer:
|
||||
# Execute an optimization step (update model parameters)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
if self.use_fp16 or self.use_bf16:
|
||||
if scaler is not None:
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user