diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index e0b6def54..01e2924d9 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -465,7 +465,8 @@ class Trainer: batch_num_epoch = len(dataloader_train) self.log( epoch, - batch_idx + kwargs.get("start_step", 0), + batch_idx, + log_step=batch_idx + kwargs.get("start_step", 0), step_in_epoch=self.step_in_epoch, batch_num_epoch=batch_num_epoch, lr=lr, @@ -634,11 +635,12 @@ class Trainer: tag="train", data_split_i=0, data_split_num=1, + log_step=None, **kwargs, ): if (batch_idx + 1) % self.log_interval == 0: - + batch_idx = log_step if log_step is not None else batch_idx gpu_info = ( "GPU, memory: usage: {:.3f} GB, " "peak: {:.3f} GB, "