increase vad realtime punc

This commit is contained in:
mengzhe.cmz 2023-05-17 17:12:14 +08:00
parent 33693c4182
commit ce7914034d
3 changed files with 5 additions and 4 deletions

View File

@ -9,7 +9,7 @@ logger.setLevel(logging.CRITICAL)
inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
output_dir="./tmp/"
model_revision = 'v1.0.2'
)
##################text二进制数据#####################

View File

@ -12,7 +12,7 @@ if __name__ == '__main__':
return {'inputs': np.ones((1, text_length), dtype=np.int64),
'text_lengths': np.array([text_length,], dtype=np.int32),
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
}
def _run(feed_dict):

View File

@ -186,11 +186,12 @@ class CT_Transformer_VadRealtime(CT_Transformer):
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
text_length = len(mini_sentence_id)
vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
data = {
"input": mini_sentence_id[None,:],
"text_lengths": np.array([text_length], dtype='int32'),
"vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
"sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
"vad_mask": vad_mask
"sub_masks": vad_mask
}
try:
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])