sdk utils revison

This commit is contained in:
游雁 2023-06-28 14:45:54 +08:00
parent 7da5b31e25
commit f4eb3174b3
2 changed files with 7 additions and 2 deletions

View File

@ -24,6 +24,7 @@ class ModelExport:
fallback_num: int = 0,
audio_in: str = None,
calib_num: int = 200,
model_revision: str = None,
):
assert check_argument_types()
self.set_all_random_seed(0)
@ -41,6 +42,7 @@ class ModelExport:
self.frontend = None
self.audio_in = audio_in
self.calib_num = calib_num
self.model_revision = model_revision
def _export(
@ -171,7 +173,7 @@ class ModelExport:
model_dir = tag_name
if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
self.cache_dir = model_dir
if mode is None:
@ -271,6 +273,7 @@ if __name__ == '__main__':
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
args = parser.parse_args()
export_model = ModelExport(
@ -281,5 +284,6 @@ if __name__ == '__main__':
fallback_num=args.fallback_num,
audio_in=args.audio_in,
calib_num=args.calib_num,
model_revision=args.model_revision,
)
export_model.export(args.model_name)

View File

@ -11,6 +11,7 @@ parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
args = parser.parse_args()
@ -18,7 +19,7 @@ model_dir = args.model_name
if not Path(args.model_name).exists():
from modelscope.hub.snapshot_download import snapshot_download
try:
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir)
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
(model_dir)