mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
664c400545
commit
c553a8db17
@ -29,8 +29,8 @@ def maybe_autocast(dtype=None, use_deepspeed=False):
|
|||||||
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
|
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
with autocast(enabled=True):
|
with autocast(enabled=True, dtype=dtype):
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
yield
|
yield
|
||||||
@ -60,6 +60,7 @@ class Trainer:
|
|||||||
use_ddp: bool = False,
|
use_ddp: bool = False,
|
||||||
use_fsdp: bool = False,
|
use_fsdp: bool = False,
|
||||||
use_fp16: bool = False,
|
use_fp16: bool = False,
|
||||||
|
use_bf16: bool = False,
|
||||||
use_deepspeed: bool = False,
|
use_deepspeed: bool = False,
|
||||||
output_dir: str = "./",
|
output_dir: str = "./",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -98,8 +99,11 @@ class Trainer:
|
|||||||
self.batch_total = 0
|
self.batch_total = 0
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
self.use_fp16 = use_fp16
|
self.use_fp16 = use_fp16
|
||||||
|
self.use_bf16 = use_bf16
|
||||||
if self.use_fp16:
|
if self.use_fp16:
|
||||||
self.dtype = torch.float16
|
self.dtype = torch.float16
|
||||||
|
if self.use_bf16:
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
|
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
|
||||||
self.validate_interval = kwargs.get("validate_interval", 5000)
|
self.validate_interval = kwargs.get("validate_interval", 5000)
|
||||||
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
|
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
|
||||||
@ -678,7 +682,7 @@ class Trainer:
|
|||||||
scaled_loss = model.backward(loss)
|
scaled_loss = model.backward(loss)
|
||||||
else:
|
else:
|
||||||
loss = loss / self.accum_grad
|
loss = loss / self.accum_grad
|
||||||
if self.use_fp16:
|
if self.use_fp16 or self.use_bf16:
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
else:
|
else:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -706,7 +710,7 @@ class Trainer:
|
|||||||
# Execute an optimization step (update model parameters)
|
# Execute an optimization step (update model parameters)
|
||||||
if self.use_ddp or self.use_fsdp:
|
if self.use_ddp or self.use_fsdp:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
if self.use_fp16:
|
if self.use_fp16 or self.use_bf16:
|
||||||
scaler.step(optim)
|
scaler.step(optim)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user