mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
train
This commit is contained in:
parent
dfab707845
commit
3d9f094e96
@ -34,6 +34,8 @@ def main(args=None, cmd=None):
|
||||
from funasr.tasks.asr import ASRTask
|
||||
if args.mode == "paraformer":
|
||||
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
||||
if args.mode == "uniasr":
|
||||
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
||||
|
||||
ASRTask.main(args=args, cmd=cmd)
|
||||
|
||||
@ -42,8 +44,7 @@ if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
# setup local gpu_id
|
||||
if args.ngpu > 0:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||
|
||||
# DDP settings
|
||||
if args.ngpu > 1:
|
||||
@ -54,9 +55,10 @@ if __name__ == '__main__':
|
||||
|
||||
# re-compute batch size: when dataset type is small
|
||||
if args.dataset_type == "small":
|
||||
if args.batch_size is not None and args.ngpu > 0:
|
||||
if args.batch_size is not None:
|
||||
args.batch_size = args.batch_size * args.ngpu
|
||||
if args.batch_bins is not None and args.ngpu > 0:
|
||||
if args.batch_bins is not None:
|
||||
args.batch_bins = args.batch_bins * args.ngpu
|
||||
|
||||
main(args=args)
|
||||
|
||||
|
||||
@ -282,6 +282,14 @@ class ASRTask(AbsTask):
|
||||
decoder_choices,
|
||||
# --predictor and --predictor_conf
|
||||
predictor_choices,
|
||||
# --encoder2 and --encoder2_conf
|
||||
encoder_choices2,
|
||||
# --decoder2 and --decoder2_conf
|
||||
decoder_choices2,
|
||||
# --predictor2 and --predictor2_conf
|
||||
predictor_choices2,
|
||||
# --stride_conv and --stride_conv_conf
|
||||
stride_conv_choices,
|
||||
]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
@ -901,27 +909,27 @@ class ASRTaskParaformer(ASRTask):
|
||||
# If you need more than one optimizers, change this value
|
||||
num_optimizers: int = 1
|
||||
|
||||
# Add variable objects configurations
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
# --preencoder and --preencoder_conf
|
||||
preencoder_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --postencoder and --postencoder_conf
|
||||
postencoder_choices,
|
||||
# --decoder and --decoder_conf
|
||||
decoder_choices,
|
||||
# --predictor and --predictor_conf
|
||||
predictor_choices,
|
||||
]
|
||||
# # Add variable objects configurations
|
||||
# class_choices_list = [
|
||||
# # --frontend and --frontend_conf
|
||||
# frontend_choices,
|
||||
# # --specaug and --specaug_conf
|
||||
# specaug_choices,
|
||||
# # --normalize and --normalize_conf
|
||||
# normalize_choices,
|
||||
# # --model and --model_conf
|
||||
# model_choices,
|
||||
# # --preencoder and --preencoder_conf
|
||||
# preencoder_choices,
|
||||
# # --encoder and --encoder_conf
|
||||
# encoder_choices,
|
||||
# # --postencoder and --postencoder_conf
|
||||
# postencoder_choices,
|
||||
# # --decoder and --decoder_conf
|
||||
# decoder_choices,
|
||||
# # --predictor and --predictor_conf
|
||||
# predictor_choices,
|
||||
# ]
|
||||
|
||||
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||
trainer = Trainer
|
||||
|
||||
Loading…
Reference in New Issue
Block a user