mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
664c400545
commit
c553a8db17
@ -29,8 +29,8 @@ def maybe_autocast(dtype=None, use_deepspeed=False):
|
||||
with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
|
||||
yield
|
||||
else:
|
||||
if dtype == torch.float16:
|
||||
with autocast(enabled=True):
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
with autocast(enabled=True, dtype=dtype):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
@ -60,6 +60,7 @@ class Trainer:
|
||||
use_ddp: bool = False,
|
||||
use_fsdp: bool = False,
|
||||
use_fp16: bool = False,
|
||||
use_bf16: bool = False,
|
||||
use_deepspeed: bool = False,
|
||||
output_dir: str = "./",
|
||||
**kwargs,
|
||||
@ -98,8 +99,11 @@ class Trainer:
|
||||
self.batch_total = 0
|
||||
self.dtype = torch.float32
|
||||
self.use_fp16 = use_fp16
|
||||
self.use_bf16 = use_bf16
|
||||
if self.use_fp16:
|
||||
self.dtype = torch.float16
|
||||
if self.use_bf16:
|
||||
self.dtype = torch.bfloat16
|
||||
self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
|
||||
self.validate_interval = kwargs.get("validate_interval", 5000)
|
||||
self.keep_nbest_models = kwargs.get("keep_nbest_models", 500)
|
||||
@ -678,7 +682,7 @@ class Trainer:
|
||||
scaled_loss = model.backward(loss)
|
||||
else:
|
||||
loss = loss / self.accum_grad
|
||||
if self.use_fp16:
|
||||
if self.use_fp16 or self.use_bf16:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
@ -706,7 +710,7 @@ class Trainer:
|
||||
# Execute an optimization step (update model parameters)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
if self.use_fp16:
|
||||
if self.use_fp16 or self.use_bf16:
|
||||
scaler.step(optim)
|
||||
scaler.update()
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user