diff --git a/funasr/bin/punc_train_vadrealtime.py b/funasr/bin/punc_train_vadrealtime.py deleted file mode 100644 index c5afaad80..000000000 --- a/funasr/bin/punc_train_vadrealtime.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python3 -import os -from funasr.tasks.punctuation import PunctuationTask - - -def parse_args(): - parser = PunctuationTask.get_parser() - parser.add_argument( - "--gpu_id", - type=int, - default=0, - help="local gpu id.", - ) - parser.add_argument( - "--punc_list", - type=str, - default=None, - help="Punctuation list", - ) - args = parser.parse_args() - return args - - -def main(args=None, cmd=None): - """ - punc training. - """ - PunctuationTask.main(args=args, cmd=cmd) - - -if __name__ == "__main__": - args = parse_args() - - # setup local gpu_id - os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) - - # DDP settings - if args.ngpu > 1: - args.distributed = True - else: - args.distributed = False - assert args.num_worker_count == 1 - - main(args=args) diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py index 81f9d7ae8..5157eeb29 100644 --- a/funasr/bin/punctuation_infer_vadrealtime.py +++ b/funasr/bin/punctuation_infer_vadrealtime.py @@ -90,7 +90,7 @@ class Text2Punc: data = { "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0), "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')), - "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')), + "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')), } data = to_device(data, self.device) y, _ = self.wrapped_model(**data) diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py index d8ceff218..5a2f921f4 100644 --- a/funasr/datasets/large_datasets/utils/tokenize.py +++ b/funasr/datasets/large_datasets/utils/tokenize.py @@ -47,8 +47,8 @@ def tokenize(data, length = len(text) for i in range(length): x = text[i] - if i == length-1 and "punc" in data and text[i].startswith("vad:"): - vad = x[-1][4:] + if i == length-1 and "punc" in data and x.startswith("vad:"): + vad = x[4:] if len(vad) == 0: vad = -1 else: diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py index afeff4ee6..1adca0597 100644 --- a/funasr/datasets/preprocessor.py +++ b/funasr/datasets/preprocessor.py @@ -786,6 +786,7 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor): ) -> Dict[str, np.ndarray]: for i in range(self.num_tokenizer): text_name = self.text_name[i] + #import pdb; pdb.set_trace() if text_name in data and self.tokenizer[i] is not None: text = data[text_name] text = self.text_cleaner(text) @@ -800,7 +801,7 @@ class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor): data[self.vad_name] = np.array([vad], dtype=np.int64) text_ints = self.token_id_converter[i].tokens2ids(tokens) data[text_name] = np.array(text_ints, dtype=np.int64) - + return data def split_to_mini_sentence(words: list, word_limit: int = 20): assert word_limit > 1 @@ -813,4 +814,4 @@ def split_to_mini_sentence(words: list, word_limit: int = 20): sentences.append(words[i * word_limit:(i + 1) * word_limit]) if length % word_limit > 0: sentences.append(words[sentence_len * word_limit:]) - return sentences \ No newline at end of file + return sentences diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 0dc728ad3..0eb764f65 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -159,7 +159,7 @@ class CT_Transformer_VadRealtime(CT_Transformer): data = { "input": mini_sentence_id[None,:], "text_lengths": np.array([text_length], dtype='int32'), - "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32), + "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32), "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) } try: diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py index 8d63b27d9..777513e7e 100644 --- a/funasr/tasks/abs_task.py +++ b/funasr/tasks/abs_task.py @@ -1587,6 +1587,8 @@ class AbsTask(ABC): dest_sample_rate = args.frontend_conf["fs"] else: dest_sample_rate = 16000 + else: + dest_sample_rate = 16000 dataset = ESPnetDataset( iter_options.data_path_and_name_and_type,