This commit is contained in:
游雁 2024-06-13 17:38:01 +08:00
parent 664c400545
commit c553a8db17

View File

@ -29,8 +29,8 @@ def maybe_autocast(dtype=None, use_deepspeed=False):
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False): with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
yield yield
else: else:
if dtype == torch.float16: if dtype == torch.float16 or dtype == torch.bfloat16:
with autocast(enabled=True): with autocast(enabled=True, dtype=dtype):
yield yield
else: else:
yield yield
@ -60,6 +60,7 @@ class Trainer:
use_ddp: bool = False, use_ddp: bool = False,
use_fsdp: bool = False, use_fsdp: bool = False,
use_fp16: bool = False, use_fp16: bool = False,
use_bf16: bool = False,
use_deepspeed: bool = False, use_deepspeed: bool = False,
output_dir: str = "./", output_dir: str = "./",
**kwargs, **kwargs,
@ -98,8 +99,11 @@ class Trainer:
self.batch_total = 0 self.batch_total = 0
self.dtype = torch.float32 self.dtype = torch.float32
self.use_fp16 = use_fp16 self.use_fp16 = use_fp16
self.use_bf16 = use_bf16
if self.use_fp16: if self.use_fp16:
self.dtype = torch.float16 self.dtype = torch.float16
if self.use_bf16:
self.dtype = torch.bfloat16
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000) self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
self.validate_interval = kwargs.get("validate_interval", 5000) self.validate_interval = kwargs.get("validate_interval", 5000)
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500) self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
@ -678,7 +682,7 @@ class Trainer:
scaled_loss = model.backward(loss) scaled_loss = model.backward(loss)
else: else:
loss = loss / self.accum_grad loss = loss / self.accum_grad
if self.use_fp16: if self.use_fp16 or self.use_bf16:
scaler.scale(loss).backward() scaler.scale(loss).backward()
else: else:
loss.backward() loss.backward()
@ -706,7 +710,7 @@ class Trainer:
# Execute an optimization step (update model parameters) # Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
dist.barrier() dist.barrier()
if self.use_fp16: if self.use_fp16 or self.use_bf16:
scaler.step(optim) scaler.step(optim)
scaler.update() scaler.update()
else: else: