mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
47 lines
1006 B
Python
Executable File
47 lines
1006 B
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import os
|
|
|
|
from funasr.tasks.asr import ASRTask
|
|
|
|
|
|
# for ASR Training
|
|
def parse_args():
|
|
parser = ASRTask.get_parser()
|
|
parser.add_argument(
|
|
"--gpu_id",
|
|
type=int,
|
|
default=0,
|
|
help="local gpu id.",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main(args=None, cmd=None):
|
|
# for ASR Training
|
|
ASRTask.main(args=args, cmd=cmd)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
|
|
# setup local gpu_id
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
|
|
|
# DDP settings
|
|
if args.ngpu > 1:
|
|
args.distributed = True
|
|
else:
|
|
args.distributed = False
|
|
assert args.num_worker_count == 1
|
|
|
|
# re-compute batch size: when dataset type is small
|
|
if args.dataset_type == "small":
|
|
if args.batch_size is not None:
|
|
args.batch_size = args.batch_size * args.ngpu
|
|
if args.batch_bins is not None:
|
|
args.batch_bins = args.batch_bins * args.ngpu
|
|
|
|
main(args=args)
|