diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 9c8f672e7..27d5c4ad6 100755 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -444,7 +444,7 @@ if __name__ == '__main__': # ddp init os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) - args.distributed = args.dist_world_size > 1 + args.distributed = args.ngpu > 1 or args.dist_world_size > 1 distributed_option = build_distributed(args) # for logging