diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py index 5b1da0ca6..6029faea3 100644 --- a/funasr/build_utils/build_model.py +++ b/funasr/build_utils/build_model.py @@ -1,7 +1,8 @@ from funasr.build_utils.build_asr_model import build_asr_model -from funasr.build_utils.build_pretrain_model import build_pretrain_model from funasr.build_utils.build_lm_model import build_lm_model +from funasr.build_utils.build_pretrain_model import build_pretrain_model from funasr.build_utils.build_punc_model import build_punc_model +from funasr.build_utils.build_vad_model import build_vad_model def build_model(args): diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py new file mode 100644 index 000000000..76eb09b22 --- /dev/null +++ b/funasr/build_utils/build_vad_model.py @@ -0,0 +1,77 @@ +import torch + +from funasr.models.e2e_vad import E2EVadModel +from funasr.models.encoder.fsmn_encoder import FSMN +from funasr.models.frontend.default import DefaultFrontend +from funasr.models.frontend.fused import FusedFrontends +from funasr.models.frontend.s3prl import S3prlFrontend +from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline +from funasr.models.frontend.windowing import SlidingWindow +from funasr.torch_utils.initialize import initialize +from funasr.train.class_choices import ClassChoices + +frontend_choices = ClassChoices( + name="frontend", + classes=dict( + default=DefaultFrontend, + sliding_window=SlidingWindow, + s3prl=S3prlFrontend, + fused=FusedFrontends, + wav_frontend=WavFrontend, + wav_frontend_online=WavFrontendOnline, + ), + default="default", +) +encoder_choices = ClassChoices( + "encoder", + classes=dict( + fsmn=FSMN, + ), + type_check=torch.nn.Module, + default="fsmn", +) +model_choices = ClassChoices( + "model", + classes=dict( + e2evad=E2EVadModel, + ), + default="e2evad", +) + +class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --encoder and --encoder_conf + encoder_choices, + # --model and --model_conf + model_choices, +] + + +def build_vad_model(args): + # frontend + if args.input_size is None: + frontend_class = frontend_choices.get_class(args.frontend) + if args.frontend == 'wav_frontend': + frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) + else: + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(**args.encoder_conf) + + model_class = model_choices.get_class(args.model) + model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend) + + # initialize + if args.init is not None: + initialize(model, args.init) + + return model