diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py index f8736e93a..f6c5504b8 100644 --- a/funasr/bin/asr_infer.py +++ b/funasr/bin/asr_infer.py @@ -1640,7 +1640,7 @@ class Speech2TextSAASR: assert check_argument_types() # 1. Build ASR model - from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch + from funasr.tasks.sa_asr import ASRTask scorers = {} asr_model, asr_train_args = ASRTask.build_model_from_file( asr_train_config, asr_model_file, cmvn_file, device @@ -1682,6 +1682,7 @@ class Speech2TextSAASR: # 4. Build BeamSearch object # transducer is not supported now beam_search_transducer = None + from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch weights = dict( decoder=1.0 - ctc_weight, diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py new file mode 100755 index 000000000..67106cf48 --- /dev/null +++ b/funasr/bin/sa_asr_train.py @@ -0,0 +1,50 @@ +# -*- encoding: utf-8 -*- +#!/usr/bin/env python3 +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import os + +from funasr.tasks.sa_asr import ASRTask + + +# for ASR Training +def parse_args(): + parser = ASRTask.get_parser() + parser.add_argument( + "--gpu_id", + type=int, + default=0, + help="local gpu id.", + ) + args = parser.parse_args() + return args + + +def main(args=None, cmd=None): + # for ASR Training + ASRTask.main(args=args, cmd=cmd) + + +if __name__ == '__main__': + args = parse_args() + + # setup local gpu_id + if args.ngpu > 0: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) + + # DDP settings + if args.ngpu > 1: + args.distributed = True + else: + args.distributed = False + assert args.num_worker_count == 1 + + # re-compute batch size: when dataset type is small + if args.dataset_type == "small": + if args.batch_size is not None and args.ngpu > 0: + args.batch_size = args.batch_size * args.ngpu + if args.batch_bins is not None and args.ngpu > 0: + args.batch_bins = args.batch_bins * args.ngpu + + main(args=args)