From d0e4e2ad21d7d607501b11f1622079eb29ca3c11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sun, 18 Aug 2024 14:04:02 +0800 Subject: [PATCH] add --- funasr/bin/train_ds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py index e9835fb45..a4a73efe7 100644 --- a/funasr/bin/train_ds.py +++ b/funasr/bin/train_ds.py @@ -136,7 +136,7 @@ def main(**kwargs): **kwargs.get("train_conf"), ) - model = trainer.warp_model(model) + model = trainer.warp_model(model, **kwargs) kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0)) trainer.device = int(os.environ.get("LOCAL_RANK", 0))