mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
increase vad realtime punc
This commit is contained in:
parent
33693c4182
commit
ce7914034d
@ -9,7 +9,7 @@ logger.setLevel(logging.CRITICAL)
|
||||
inference_pipeline = pipeline(
|
||||
task=Tasks.punctuation,
|
||||
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
|
||||
output_dir="./tmp/"
|
||||
model_revision = 'v1.0.2'
|
||||
)
|
||||
|
||||
##################text二进制数据#####################
|
||||
|
||||
@ -12,7 +12,7 @@ if __name__ == '__main__':
|
||||
return {'inputs': np.ones((1, text_length), dtype=np.int64),
|
||||
'text_lengths': np.array([text_length,], dtype=np.int32),
|
||||
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
|
||||
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
|
||||
'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
|
||||
}
|
||||
|
||||
def _run(feed_dict):
|
||||
|
||||
@ -186,11 +186,12 @@ class CT_Transformer_VadRealtime(CT_Transformer):
|
||||
mini_sentence = cache_sent + mini_sentence
|
||||
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
|
||||
text_length = len(mini_sentence_id)
|
||||
vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
|
||||
data = {
|
||||
"input": mini_sentence_id[None,:],
|
||||
"text_lengths": np.array([text_length], dtype='int32'),
|
||||
"vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
|
||||
"sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
|
||||
"vad_mask": vad_mask
|
||||
"sub_masks": vad_mask
|
||||
}
|
||||
try:
|
||||
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user