mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
77e4b6899b
commit
b8201e02ba
@ -1,7 +1,8 @@
|
|||||||
from funasr.build_utils.build_asr_model import build_asr_model
|
from funasr.build_utils.build_asr_model import build_asr_model
|
||||||
from funasr.build_utils.build_pretrain_model import build_pretrain_model
|
|
||||||
from funasr.build_utils.build_lm_model import build_lm_model
|
from funasr.build_utils.build_lm_model import build_lm_model
|
||||||
|
from funasr.build_utils.build_pretrain_model import build_pretrain_model
|
||||||
from funasr.build_utils.build_punc_model import build_punc_model
|
from funasr.build_utils.build_punc_model import build_punc_model
|
||||||
|
from funasr.build_utils.build_vad_model import build_vad_model
|
||||||
|
|
||||||
|
|
||||||
def build_model(args):
|
def build_model(args):
|
||||||
|
|||||||
77
funasr/build_utils/build_vad_model.py
Normal file
77
funasr/build_utils/build_vad_model.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from funasr.models.e2e_vad import E2EVadModel
|
||||||
|
from funasr.models.encoder.fsmn_encoder import FSMN
|
||||||
|
from funasr.models.frontend.default import DefaultFrontend
|
||||||
|
from funasr.models.frontend.fused import FusedFrontends
|
||||||
|
from funasr.models.frontend.s3prl import S3prlFrontend
|
||||||
|
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
|
||||||
|
from funasr.models.frontend.windowing import SlidingWindow
|
||||||
|
from funasr.torch_utils.initialize import initialize
|
||||||
|
from funasr.train.class_choices import ClassChoices
|
||||||
|
|
||||||
|
frontend_choices = ClassChoices(
|
||||||
|
name="frontend",
|
||||||
|
classes=dict(
|
||||||
|
default=DefaultFrontend,
|
||||||
|
sliding_window=SlidingWindow,
|
||||||
|
s3prl=S3prlFrontend,
|
||||||
|
fused=FusedFrontends,
|
||||||
|
wav_frontend=WavFrontend,
|
||||||
|
wav_frontend_online=WavFrontendOnline,
|
||||||
|
),
|
||||||
|
default="default",
|
||||||
|
)
|
||||||
|
encoder_choices = ClassChoices(
|
||||||
|
"encoder",
|
||||||
|
classes=dict(
|
||||||
|
fsmn=FSMN,
|
||||||
|
),
|
||||||
|
type_check=torch.nn.Module,
|
||||||
|
default="fsmn",
|
||||||
|
)
|
||||||
|
model_choices = ClassChoices(
|
||||||
|
"model",
|
||||||
|
classes=dict(
|
||||||
|
e2evad=E2EVadModel,
|
||||||
|
),
|
||||||
|
default="e2evad",
|
||||||
|
)
|
||||||
|
|
||||||
|
class_choices_list = [
|
||||||
|
# --frontend and --frontend_conf
|
||||||
|
frontend_choices,
|
||||||
|
# --encoder and --encoder_conf
|
||||||
|
encoder_choices,
|
||||||
|
# --model and --model_conf
|
||||||
|
model_choices,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_vad_model(args):
|
||||||
|
# frontend
|
||||||
|
if args.input_size is None:
|
||||||
|
frontend_class = frontend_choices.get_class(args.frontend)
|
||||||
|
if args.frontend == 'wav_frontend':
|
||||||
|
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||||
|
else:
|
||||||
|
frontend = frontend_class(**args.frontend_conf)
|
||||||
|
input_size = frontend.output_size()
|
||||||
|
else:
|
||||||
|
args.frontend = None
|
||||||
|
args.frontend_conf = {}
|
||||||
|
frontend = None
|
||||||
|
input_size = args.input_size
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
encoder_class = encoder_choices.get_class(args.encoder)
|
||||||
|
encoder = encoder_class(**args.encoder_conf)
|
||||||
|
|
||||||
|
model_class = model_choices.get_class(args.model)
|
||||||
|
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
|
||||||
|
|
||||||
|
# initialize
|
||||||
|
if args.init is not None:
|
||||||
|
initialize(model, args.init)
|
||||||
|
|
||||||
|
return model
|
||||||
Loading…
Reference in New Issue
Block a user