mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add
This commit is contained in:
parent
c62d0537a6
commit
4580da347d
@ -95,7 +95,6 @@ class Trainer:
|
||||
self.start_epoch = 0
|
||||
self.max_epoch = kwargs.get("max_epoch", 100)
|
||||
|
||||
# self.kwargs = kwargs
|
||||
self.log_interval = kwargs.get("log_interval", 50)
|
||||
self.batch_total = 0
|
||||
self.dtype = torch.float32
|
||||
@ -151,6 +150,8 @@ class Trainer:
|
||||
self.writer = None
|
||||
|
||||
self.use_deepspeed = use_deepspeed
|
||||
if self.use_deepspeed:
|
||||
self.accum_grad = 1
|
||||
self.deepspeed_config = kwargs.get("deepspeed_config", "")
|
||||
excludes = kwargs.get("excludes", None)
|
||||
if excludes is not None:
|
||||
@ -163,6 +164,7 @@ class Trainer:
|
||||
effective_save_name_excludes = effective_save_name_excludes.split(",")
|
||||
self.effective_save_name_excludes = effective_save_name_excludes
|
||||
self.use_lora = kwargs.get("use_lora", False)
|
||||
self.kwargs = kwargs
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user