diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 9e13260c8..b14153233 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -24,6 +24,7 @@ class ModelExport: fallback_num: int = 0, audio_in: str = None, calib_num: int = 200, + model_revision: str = None, ): assert check_argument_types() self.set_all_random_seed(0) @@ -41,6 +42,7 @@ class ModelExport: self.frontend = None self.audio_in = audio_in self.calib_num = calib_num + self.model_revision = model_revision def _export( @@ -171,7 +173,7 @@ class ModelExport: model_dir = tag_name if model_dir.startswith('damo'): from modelscope.hub.snapshot_download import snapshot_download - model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) + model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision) self.cache_dir = model_dir if mode is None: @@ -271,6 +273,7 @@ if __name__ == '__main__': parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') parser.add_argument('--calib_num', type=int, default=200, help='calib max num') + parser.add_argument('--model_revision', type=str, default=None, help='model_revision') args = parser.parse_args() export_model = ModelExport( @@ -281,5 +284,6 @@ if __name__ == '__main__': fallback_num=args.fallback_num, audio_in=args.audio_in, calib_num=args.calib_num, + model_revision=args.model_revision, ) export_model.export(args.model_name) diff --git a/funasr/utils/runtime_sdk_download_tool.py b/funasr/utils/runtime_sdk_download_tool.py index dbddd553e..f8d4bc921 100644 --- a/funasr/utils/runtime_sdk_download_tool.py +++ b/funasr/utils/runtime_sdk_download_tool.py @@ -11,6 +11,7 @@ parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]') parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model') parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number') parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]') +parser.add_argument('--model_revision', type=str, default=None, help='model_revision') parser.add_argument('--calib_num', type=int, default=200, help='calib max num') args = parser.parse_args() @@ -18,7 +19,7 @@ model_dir = args.model_name if not Path(args.model_name).exists(): from modelscope.hub.snapshot_download import snapshot_download try: - model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir) + model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision) except: raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \ (model_dir)