diff --git a/examples/aishell/transformer/run.sh b/examples/aishell/transformer/run.sh index 80f81b5f8..da6d4d6ab 100755 --- a/examples/aishell/transformer/run.sh +++ b/examples/aishell/transformer/run.sh @@ -39,7 +39,7 @@ train_set=train valid_set=dev test_sets="dev test" -config=paraformer_conformer_12e_6d_2048_256.yaml +config=transformer_12e_6d_2048_256.yaml model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}" diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index cc7b215d2..f37538411 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -276,7 +276,7 @@ class Trainer: description = ( f"rank: {self.local_rank}, " f"epoch: {epoch}/{self.max_epoch}, " - f"step: {batch_idx}/{len(self.dataloader_train)}, total: {self.batch_total}, " + f"step: {batch_idx+1}/{len(self.dataloader_train)}, total: {self.batch_total}, " f"(loss: {loss.detach().cpu().item():.3f}), " f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, " f"{speed_stats}, " @@ -341,7 +341,7 @@ class Trainer: description = ( f"rank: {self.local_rank}, " f"validation epoch: {epoch}/{self.max_epoch}, " - f"step: {batch_idx}/{len(self.dataloader_val)}, " + f"step: {batch_idx+1}/{len(self.dataloader_val)}, " f"(loss: {loss.detach().cpu().item():.3f}), " f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, " f"{speed_stats}, "