This commit is contained in:
游雁 2023-03-29 15:57:22 +08:00
parent 02652ef989
commit 1f8b46402c
3 changed files with 29 additions and 5 deletions

View File

@ -191,9 +191,10 @@ class ModelExport:
cmvn_file = os.path.join(model_dir, 'vad.mvn')
model, vad_infer_args = VADTask.build_model_from_file(
config, model_file, 'cpu'
config, model_file, cmvn_file=cmvn_file, device='cpu'
)
self.export_config["feats_dim"] = 400
self.frontend = model.frontend
self._export(model, tag_name)

View File

@ -192,7 +192,7 @@ class WindowDetector(object):
class E2EVadModel(nn.Module):
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@ -229,6 +229,7 @@ class E2EVadModel(nn.Module):
self.data_buf_all = None
self.waveform = None
self.ResetDetection()
self.frontend = frontend
def AllResetDetection(self):
self.is_final = False
@ -477,8 +478,9 @@ class E2EVadModel(nn.Module):
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final:
self.DetectCommonFrames()
else:

View File

@ -40,7 +40,7 @@ from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
@ -81,6 +81,7 @@ frontend_choices = ClassChoices(
s3prl=S3prlFrontend,
fused=FusedFrontends,
wav_frontend=WavFrontend,
wav_frontend_online=WavFrontendOnline,
),
type_check=AbsFrontend,
default="default",
@ -291,7 +292,24 @@ class VADTask(AbsTask):
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
# 1. frontend
if args.input_size is None:
# Extract features in the model
frontend_class = frontend_choices.get_class(args.frontend)
if args.frontend == 'wav_frontend':
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
else:
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
return model
@ -301,6 +319,7 @@ class VADTask(AbsTask):
cls,
config_file: Union[Path, str] = None,
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
):
"""Build model from the files.
@ -325,6 +344,8 @@ class VADTask(AbsTask):
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
model.to(device)