diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 7695e51c5..643df71d3 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -223,6 +223,7 @@ def main(**kwargs): torch.cuda.empty_cache() + trainer.start_data_split_i = 0 trainer.validate_epoch( model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer )