This commit is contained in:
游雁 2023-05-16 23:48:00 +08:00
parent dfab707845
commit 3d9f094e96
2 changed files with 35 additions and 25 deletions

View File

@ -34,6 +34,8 @@ def main(args=None, cmd=None):
from funasr.tasks.asr import ASRTask
if args.mode == "paraformer":
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
if args.mode == "uniasr":
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
ASRTask.main(args=args, cmd=cmd)
@ -42,8 +44,7 @@ if __name__ == '__main__':
args = parse_args()
# setup local gpu_id
if args.ngpu > 0:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
# DDP settings
if args.ngpu > 1:
@ -54,9 +55,10 @@ if __name__ == '__main__':
# re-compute batch size: when dataset type is small
if args.dataset_type == "small":
if args.batch_size is not None and args.ngpu > 0:
if args.batch_size is not None:
args.batch_size = args.batch_size * args.ngpu
if args.batch_bins is not None and args.ngpu > 0:
if args.batch_bins is not None:
args.batch_bins = args.batch_bins * args.ngpu
main(args=args)

View File

@ -282,6 +282,14 @@ class ASRTask(AbsTask):
decoder_choices,
# --predictor and --predictor_conf
predictor_choices,
# --encoder2 and --encoder2_conf
encoder_choices2,
# --decoder2 and --decoder2_conf
decoder_choices2,
# --predictor2 and --predictor2_conf
predictor_choices2,
# --stride_conv and --stride_conv_conf
stride_conv_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
@ -901,27 +909,27 @@ class ASRTaskParaformer(ASRTask):
# If you need more than one optimizers, change this value
num_optimizers: int = 1
# Add variable objects configurations
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --model and --model_conf
model_choices,
# --preencoder and --preencoder_conf
preencoder_choices,
# --encoder and --encoder_conf
encoder_choices,
# --postencoder and --postencoder_conf
postencoder_choices,
# --decoder and --decoder_conf
decoder_choices,
# --predictor and --predictor_conf
predictor_choices,
]
# # Add variable objects configurations
# class_choices_list = [
# # --frontend and --frontend_conf
# frontend_choices,
# # --specaug and --specaug_conf
# specaug_choices,
# # --normalize and --normalize_conf
# normalize_choices,
# # --model and --model_conf
# model_choices,
# # --preencoder and --preencoder_conf
# preencoder_choices,
# # --encoder and --encoder_conf
# encoder_choices,
# # --postencoder and --postencoder_conf
# postencoder_choices,
# # --decoder and --decoder_conf
# decoder_choices,
# # --predictor and --predictor_conf
# predictor_choices,
# ]
# If you need to modify train() or eval() procedures, change Trainer class here
trainer = Trainer