diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 66ae7ed62..9ef9dc9e4 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -593,6 +593,8 @@ class Trainer: time_beg = time.perf_counter() time5 = time_beg for batch_idx, batch in enumerate(dataloader_train): + if batch_idx == 0 and (self.use_ddp or self.use_fsdp or self.use_deepspeed): + dist.barrier() self.batch_total += 1 self.step_in_epoch += 1 loss_dict = {