mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
increase vad realtime punc
This commit is contained in:
parent
33693c4182
commit
ce7914034d
@ -9,7 +9,7 @@ logger.setLevel(logging.CRITICAL)
|
|||||||
inference_pipeline = pipeline(
|
inference_pipeline = pipeline(
|
||||||
task=Tasks.punctuation,
|
task=Tasks.punctuation,
|
||||||
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
|
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
|
||||||
output_dir="./tmp/"
|
model_revision = 'v1.0.2'
|
||||||
)
|
)
|
||||||
|
|
||||||
##################text二进制数据#####################
|
##################text二进制数据#####################
|
||||||
|
|||||||
@ -12,7 +12,7 @@ if __name__ == '__main__':
|
|||||||
return {'inputs': np.ones((1, text_length), dtype=np.int64),
|
return {'inputs': np.ones((1, text_length), dtype=np.int64),
|
||||||
'text_lengths': np.array([text_length,], dtype=np.int32),
|
'text_lengths': np.array([text_length,], dtype=np.int32),
|
||||||
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
|
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
|
||||||
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
|
'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
|
||||||
}
|
}
|
||||||
|
|
||||||
def _run(feed_dict):
|
def _run(feed_dict):
|
||||||
|
|||||||
@ -186,11 +186,12 @@ class CT_Transformer_VadRealtime(CT_Transformer):
|
|||||||
mini_sentence = cache_sent + mini_sentence
|
mini_sentence = cache_sent + mini_sentence
|
||||||
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
|
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
|
||||||
text_length = len(mini_sentence_id)
|
text_length = len(mini_sentence_id)
|
||||||
|
vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
|
||||||
data = {
|
data = {
|
||||||
"input": mini_sentence_id[None,:],
|
"input": mini_sentence_id[None,:],
|
||||||
"text_lengths": np.array([text_length], dtype='int32'),
|
"text_lengths": np.array([text_length], dtype='int32'),
|
||||||
"vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
|
"vad_mask": vad_mask
|
||||||
"sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
|
"sub_masks": vad_mask
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
|
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user