This commit is contained in:
游雁 2024-08-08 19:36:36 +08:00
parent c62d0537a6
commit 4580da347d

View File

@ -95,7 +95,6 @@ class Trainer:
self.start_epoch = 0
self.max_epoch = kwargs.get("max_epoch", 100)
# self.kwargs = kwargs
self.log_interval = kwargs.get("log_interval", 50)
self.batch_total = 0
self.dtype = torch.float32
@ -151,6 +150,8 @@ class Trainer:
self.writer = None
self.use_deepspeed = use_deepspeed
if self.use_deepspeed:
self.accum_grad = 1
self.deepspeed_config = kwargs.get("deepspeed_config", "")
excludes = kwargs.get("excludes", None)
if excludes is not None:
@ -163,6 +164,7 @@ class Trainer:
effective_save_name_excludes = effective_save_name_excludes.split(",")
self.effective_save_name_excludes = effective_save_name_excludes
self.use_lora = kwargs.get("use_lora", False)
self.kwargs = kwargs
def save_checkpoint(
self,