This commit is contained in:
游雁 2024-06-13 17:38:01 +08:00
parent 664c400545
commit c553a8db17

View File

@ -29,8 +29,8 @@ def maybe_autocast(dtype=None, use_deepspeed=False):
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
yield
else:
if dtype == torch.float16:
with autocast(enabled=True):
if dtype == torch.float16 or dtype == torch.bfloat16:
with autocast(enabled=True, dtype=dtype):
yield
else:
yield
@ -60,6 +60,7 @@ class Trainer:
use_ddp: bool = False,
use_fsdp: bool = False,
use_fp16: bool = False,
use_bf16: bool = False,
use_deepspeed: bool = False,
output_dir: str = "./",
**kwargs,
@ -98,8 +99,11 @@ class Trainer:
self.batch_total = 0
self.dtype = torch.float32
self.use_fp16 = use_fp16
self.use_bf16 = use_bf16
if self.use_fp16:
self.dtype = torch.float16
if self.use_bf16:
self.dtype = torch.bfloat16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", 5000)
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@ -678,7 +682,7 @@ class Trainer:
scaled_loss = model.backward(loss)
else:
loss = loss / self.accum_grad
if self.use_fp16:
if self.use_fp16 or self.use_bf16:
scaler.scale(loss).backward()
else:
loss.backward()
@ -706,7 +710,7 @@ class Trainer:
# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
if self.use_fp16:
if self.use_fp16 or self.use_bf16:
scaler.step(optim)
scaler.update()
else: