mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
sdk utils revison
This commit is contained in:
parent
7da5b31e25
commit
f4eb3174b3
@ -24,6 +24,7 @@ class ModelExport:
|
||||
fallback_num: int = 0,
|
||||
audio_in: str = None,
|
||||
calib_num: int = 200,
|
||||
model_revision: str = None,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.set_all_random_seed(0)
|
||||
@ -41,6 +42,7 @@ class ModelExport:
|
||||
self.frontend = None
|
||||
self.audio_in = audio_in
|
||||
self.calib_num = calib_num
|
||||
self.model_revision = model_revision
|
||||
|
||||
|
||||
def _export(
|
||||
@ -171,7 +173,7 @@ class ModelExport:
|
||||
model_dir = tag_name
|
||||
if model_dir.startswith('damo'):
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
|
||||
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
|
||||
self.cache_dir = model_dir
|
||||
|
||||
if mode is None:
|
||||
@ -271,6 +273,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
|
||||
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
|
||||
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
|
||||
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
|
||||
args = parser.parse_args()
|
||||
|
||||
export_model = ModelExport(
|
||||
@ -281,5 +284,6 @@ if __name__ == '__main__':
|
||||
fallback_num=args.fallback_num,
|
||||
audio_in=args.audio_in,
|
||||
calib_num=args.calib_num,
|
||||
model_revision=args.model_revision,
|
||||
)
|
||||
export_model.export(args.model_name)
|
||||
|
||||
@ -11,6 +11,7 @@ parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
|
||||
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
|
||||
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
|
||||
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
|
||||
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
|
||||
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -18,7 +19,7 @@ model_dir = args.model_name
|
||||
if not Path(args.model_name).exists():
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
try:
|
||||
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir)
|
||||
model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
|
||||
(model_dir)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user