mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
export
This commit is contained in:
parent
242431452b
commit
cf00b4a13f
@ -23,7 +23,7 @@ class Fsmn_vad():
|
|||||||
device_id: Union[str, int] = "-1",
|
device_id: Union[str, int] = "-1",
|
||||||
quantize: bool = False,
|
quantize: bool = False,
|
||||||
intra_op_num_threads: int = 4,
|
intra_op_num_threads: int = 4,
|
||||||
max_end_sil: int = 800,
|
max_end_sil: int = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if not Path(model_dir).exists():
|
if not Path(model_dir).exists():
|
||||||
@ -43,14 +43,17 @@ class Fsmn_vad():
|
|||||||
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
|
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.vad_scorer = E2EVadModel(config["vad_post_conf"])
|
self.vad_scorer = E2EVadModel(config["vad_post_conf"])
|
||||||
self.max_end_sil = max_end_sil
|
self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
|
||||||
|
self.encoder_conf = config["encoder_conf"]
|
||||||
|
|
||||||
def prepare_cache(self, in_cache: list = []):
|
def prepare_cache(self, in_cache: list = []):
|
||||||
if len(in_cache) > 0:
|
if len(in_cache) > 0:
|
||||||
return in_cache
|
return in_cache
|
||||||
|
fsmn_layers = self.encoder_conf["fsmn_layers"]
|
||||||
for i in range(4):
|
proj_dim = self.encoder_conf["proj_dim"]
|
||||||
cache = np.random.rand(1, 128, 19, 1).astype(np.float32)
|
lorder = self.encoder_conf["lorder"]
|
||||||
|
for i in range(fsmn_layers):
|
||||||
|
cache = np.random.rand(1, proj_dim, lorder-1, 1).astype(np.float32)
|
||||||
in_cache.append(cache)
|
in_cache.append(cache)
|
||||||
return in_cache
|
return in_cache
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user