inference

This commit is contained in:
游雁 2023-05-16 19:34:45 +08:00
parent 0271fbe4fd
commit f01980d169
2 changed files with 52 additions and 1 deletions

View File

@ -1640,7 +1640,7 @@ class Speech2TextSAASR:
assert check_argument_types()
# 1. Build ASR model
from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
from funasr.tasks.sa_asr import ASRTask
scorers = {}
asr_model, asr_train_args = ASRTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
@ -1682,6 +1682,7 @@ class Speech2TextSAASR:
# 4. Build BeamSearch object
# transducer is not supported now
beam_search_transducer = None
from funasr.modules.beam_search.beam_search_sa_asr import BeamSearch
weights = dict(
decoder=1.0 - ctc_weight,

50
funasr/bin/sa_asr_train.py Executable file
View File

@ -0,0 +1,50 @@
# -*- encoding: utf-8 -*-
#!/usr/bin/env python3
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os
from funasr.tasks.sa_asr import ASRTask
# for ASR Training
def parse_args():
parser = ASRTask.get_parser()
parser.add_argument(
"--gpu_id",
type=int,
default=0,
help="local gpu id.",
)
args = parser.parse_args()
return args
def main(args=None, cmd=None):
# for ASR Training
ASRTask.main(args=args, cmd=cmd)
if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
if args.ngpu > 0:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
args.distributed = True
else:
args.distributed = False
assert args.num_worker_count == 1
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None and args.ngpu > 0:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None and args.ngpu > 0:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)